Compare commits
4 Commits
6367cf013c
...
17a8104384
| Author | SHA1 | Date |
|---|---|---|
|
|
17a8104384 | |
|
|
14cb4b3c33 | |
|
|
369f3af384 | |
|
|
7ddc3a1742 |
|
|
@ -113,34 +113,15 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None,
|
||||||
|
|
||||||
estimated_input_tokens = _estimate_input_tokens(request.messages)
|
estimated_input_tokens = _estimate_input_tokens(request.messages)
|
||||||
estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key)
|
estimated_output_tokens = _resolve_estimated_output_tokens(request, model_key)
|
||||||
|
question = _extract_latest_question(request.messages)
|
||||||
|
|
||||||
call_id = run_id or str(uuid4())
|
call_id = run_id or str(uuid4())
|
||||||
if not run_id:
|
|
||||||
runtime = getattr(request, "runtime", None)
|
|
||||||
runtime_context = getattr(runtime, "context", None)
|
|
||||||
runtime_config = getattr(runtime, "config", None)
|
|
||||||
context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else []
|
|
||||||
config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else []
|
|
||||||
logger.warning(
|
|
||||||
"[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s",
|
|
||||||
call_id,
|
|
||||||
type(runtime_context).__name__ if runtime_context is not None else "None",
|
|
||||||
type(runtime_config).__name__ if runtime_config is not None else "None",
|
|
||||||
context_keys,
|
|
||||||
config_keys,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s",
|
|
||||||
session_id,
|
|
||||||
run_id,
|
|
||||||
call_id,
|
|
||||||
model_name,
|
|
||||||
)
|
|
||||||
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"sessionId": session_id,
|
"sessionId": session_id,
|
||||||
"callId": call_id,
|
"callId": call_id,
|
||||||
"modelName": model_name,
|
"modelName": model_name,
|
||||||
|
"question": question,
|
||||||
"frozenType": cfg.frozen_type,
|
"frozenType": cfg.frozen_type,
|
||||||
"estimatedInputTokens": estimated_input_tokens,
|
"estimatedInputTokens": estimated_input_tokens,
|
||||||
"estimatedOutputTokens": estimated_output_tokens,
|
"estimatedOutputTokens": estimated_output_tokens,
|
||||||
|
|
@ -150,19 +131,30 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None,
|
||||||
|
|
||||||
|
|
||||||
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
|
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
|
||||||
# Primary: LangGraph injects run_id into the top-level RunnableConfig
|
# Primary: use LangGraph's public runtime API to access the current RunnableConfig.
|
||||||
# (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config
|
# This matches the official guidance for code that needs config inside runtime-bound
|
||||||
# throughout graph node execution.
|
# execution, while middleware itself only receives ModelRequest(runtime=Runtime).
|
||||||
try:
|
try:
|
||||||
from langchain_core.runnables.config import var_child_runnable_config
|
from langgraph.config import get_config
|
||||||
|
|
||||||
lc_config = var_child_runnable_config.get()
|
config = get_config()
|
||||||
if isinstance(lc_config, dict):
|
if isinstance(config, dict):
|
||||||
run_id = lc_config.get("run_id")
|
# Depending on LangGraph API variant, run_id may live at different levels.
|
||||||
|
run_id = config.get("run_id")
|
||||||
|
if run_id is None:
|
||||||
|
metadata = config.get("metadata")
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
run_id = metadata.get("run_id")
|
||||||
|
if run_id is None:
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if isinstance(configurable, dict):
|
||||||
|
run_id = configurable.get("run_id")
|
||||||
if run_id is not None:
|
if run_id is not None:
|
||||||
return str(run_id)
|
return str(run_id)
|
||||||
except Exception:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("[BillingMiddleware] failed to read run_id from get_config(): %s", exc)
|
||||||
|
|
||||||
# Fallback: LangGraph API worker sets run_id via set_logging_context() before
|
# Fallback: LangGraph API worker sets run_id via set_logging_context() before
|
||||||
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).
|
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).
|
||||||
|
|
@ -504,6 +496,13 @@ def _extract_latest_user_text(messages: list[Any]) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_latest_question(messages: list[Any]) -> str:
|
||||||
|
question = _extract_latest_user_text(messages)
|
||||||
|
if isinstance(question, str) and len(question) > 27:
|
||||||
|
return question[:27] + "。。。"
|
||||||
|
return question
|
||||||
|
|
||||||
|
|
||||||
def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None:
|
def _extract_usage(request: ModelRequest, response: ModelCallResult | None) -> dict[str, int] | None:
|
||||||
if response is None:
|
if response is None:
|
||||||
usage = None
|
usage = None
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,15 @@ message that originally carried them.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PatchedChatOpenAI(ChatOpenAI):
|
class PatchedChatOpenAI(ChatOpenAI):
|
||||||
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
||||||
|
|
@ -75,6 +78,8 @@ class PatchedChatOpenAI(ChatOpenAI):
|
||||||
# Obtain the base payload from the parent implementation.
|
# Obtain the base payload from the parent implementation.
|
||||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
|
||||||
|
logger.debug("LLM request payload messages: %s", payload.get("messages"))
|
||||||
|
|
||||||
payload_messages = payload.get("messages", [])
|
payload_messages = payload.get("messages", [])
|
||||||
|
|
||||||
if len(payload_messages) == len(original_messages):
|
if len(payload_messages) == len(original_messages):
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch)
|
||||||
reserve_payload = seen_payloads[0][2]
|
reserve_payload = seen_payloads[0][2]
|
||||||
assert reserve_payload["callId"] == "run-1"
|
assert reserve_payload["callId"] == "run-1"
|
||||||
assert reserve_payload["frozenType"] == 1
|
assert reserve_payload["frozenType"] == 1
|
||||||
|
assert reserve_payload["question"] == "hello world"
|
||||||
assert reserve_payload["estimatedInputTokens"] == len("hello world")
|
assert reserve_payload["estimatedInputTokens"] == len("hello world")
|
||||||
assert reserve_payload["estimatedOutputTokens"] == 4096
|
assert reserve_payload["estimatedOutputTokens"] == 4096
|
||||||
assert "frozenAmount" not in reserve_payload
|
assert "frozenAmount" not in reserve_payload
|
||||||
|
|
@ -239,3 +240,75 @@ async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
|
||||||
assert isinstance(result, AIMessage)
|
assert isinstance(result, AIMessage)
|
||||||
reserve_payload = seen_payloads[0][2]
|
reserve_payload = seen_payloads[0][2]
|
||||||
assert reserve_payload["callId"] == "run-from-worker"
|
assert reserve_payload["callId"] == "run-from-worker"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_uses_nested_run_id_from_runnable_config(monkeypatch):
|
||||||
|
from langchain_core.runnables.config import var_child_runnable_config
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
request = _request_with_latest_user_text("hello world")
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
||||||
|
|
||||||
|
token = var_child_runnable_config.set(
|
||||||
|
{
|
||||||
|
"metadata": {"run_id": "run-from-metadata"},
|
||||||
|
"configurable": {"run_id": "run-from-configurable"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
finally:
|
||||||
|
var_child_runnable_config.reset(token)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
reserve_payload = seen_payloads[0][2]
|
||||||
|
assert reserve_payload["callId"] == "run-from-metadata"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch):
|
||||||
|
from langchain_core.runnables.config import var_child_runnable_config
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares import billing_middleware as bm
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
||||||
|
|
||||||
|
seen_payloads = []
|
||||||
|
|
||||||
|
async def fake_post(url, headers, payload, timeout_seconds):
|
||||||
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
||||||
|
if url.endswith("/frozen"):
|
||||||
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
||||||
|
return {"status": 1000, "message": "ok", "data": {}}
|
||||||
|
|
||||||
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
||||||
|
|
||||||
|
middleware = BillingMiddleware()
|
||||||
|
long_question = "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||||
|
request = _request_with_latest_user_text(long_question)
|
||||||
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
||||||
|
|
||||||
|
token = var_child_runnable_config.set({"run_id": "run-question-truncate"})
|
||||||
|
try:
|
||||||
|
result = await middleware.awrap_model_call(request, handler)
|
||||||
|
finally:
|
||||||
|
var_child_runnable_config.reset(token)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
reserve_payload = seen_payloads[0][2]
|
||||||
|
assert reserve_payload["question"] == "abcdefghijklmnopqrstuvwxyz1。。。"
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ export function InputBox({
|
||||||
mode: "flash" | "thinking" | "pro" | "ultra" | undefined;
|
mode: "flash" | "thinking" | "pro" | "ultra" | undefined;
|
||||||
};
|
};
|
||||||
extraHeader?: React.ReactNode;
|
extraHeader?: React.ReactNode;
|
||||||
showWelcomeStyle?: boolean;
|
showWelcomeStyle: boolean;
|
||||||
hasSubmitted?: boolean;
|
hasSubmitted?: boolean;
|
||||||
initialValue?: string;
|
initialValue?: string;
|
||||||
onContextChange?: (
|
onContextChange?: (
|
||||||
|
|
@ -376,11 +376,11 @@ export function InputBox({
|
||||||
/>
|
/>
|
||||||
</PromptInputActionMenuContent>
|
</PromptInputActionMenuContent>
|
||||||
</PromptInputActionMenu> */}
|
</PromptInputActionMenu> */}
|
||||||
<HistoryButton
|
{showWelcomeStyle && <HistoryButton
|
||||||
className="px-2!"
|
className="px-2!"
|
||||||
router={router}
|
router={router}
|
||||||
threadId={threadIdFromProps}
|
threadId={threadIdFromProps}
|
||||||
/>
|
/>}
|
||||||
<AddAttachmentsButton className="px-2!" />
|
<AddAttachmentsButton className="px-2!" />
|
||||||
<IframeSkillDialogButton
|
<IframeSkillDialogButton
|
||||||
className="px-2!"
|
className="px-2!"
|
||||||
|
|
@ -668,9 +668,26 @@ function HistoryButton({
|
||||||
router.replace(`/workspace/chats/${threadId}?is_chatting=true`)
|
router.replace(`/workspace/chats/${threadId}?is_chatting=true`)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
<svg width="18" height="18" viewBox="0 0 18 18" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg
|
||||||
<circle cx="9" cy="9" r="8.5" stroke="#150033" />
|
className="transition-[stroke] duration-200"
|
||||||
<path d="M9 6V10H12" stroke="#150033" strokeLinecap="round" strokeLinejoin="round" />
|
width="18"
|
||||||
|
height="18"
|
||||||
|
viewBox="0 0 18 18"
|
||||||
|
fill="none"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
>
|
||||||
|
<circle
|
||||||
|
className="stroke-[#150033] transition-[stroke] duration-200 group-hover:stroke-[#8E47F0]"
|
||||||
|
cx="9"
|
||||||
|
cy="9"
|
||||||
|
r="8.5"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
className="stroke-[#150033] transition-[stroke] duration-200 group-hover:stroke-[#8E47F0]"
|
||||||
|
d="M9 6V10H12"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
|
|
||||||
</PromptInputButton>
|
</PromptInputButton>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue