feat(billing): refactor run_id extraction and enhance logging in middleware

This commit is contained in:
Titan 2026-04-14 18:04:53 +08:00
parent 14cb4b3c33
commit 17a8104384
3 changed files with 62 additions and 29 deletions

View File

@ -116,27 +116,6 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None,
question = _extract_latest_question(request.messages) question = _extract_latest_question(request.messages)
call_id = run_id or str(uuid4()) call_id = run_id or str(uuid4())
if not run_id:
runtime = getattr(request, "runtime", None)
runtime_context = getattr(runtime, "context", None)
runtime_config = getattr(runtime, "config", None)
context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else []
config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else []
logger.warning(
"[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s",
call_id,
type(runtime_context).__name__ if runtime_context is not None else "None",
type(runtime_config).__name__ if runtime_config is not None else "None",
context_keys,
config_keys,
)
logger.info(
"[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s",
session_id,
run_id,
call_id,
model_name,
)
expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds)
payload: dict[str, Any] = { payload: dict[str, Any] = {
"sessionId": session_id, "sessionId": session_id,
@ -152,19 +131,30 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None,
def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001 def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
# Primary: LangGraph injects run_id into the top-level RunnableConfig # Primary: use LangGraph's public runtime API to access the current RunnableConfig.
# (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config # This matches the official guidance for code that needs config inside runtime-bound
# throughout graph node execution. # execution, while middleware itself only receives ModelRequest(runtime=Runtime).
try: try:
from langchain_core.runnables.config import var_child_runnable_config from langgraph.config import get_config
lc_config = var_child_runnable_config.get() config = get_config()
if isinstance(lc_config, dict): if isinstance(config, dict):
run_id = lc_config.get("run_id") # Depending on LangGraph API variant, run_id may live at different levels.
run_id = config.get("run_id")
if run_id is None:
metadata = config.get("metadata")
if isinstance(metadata, dict):
run_id = metadata.get("run_id")
if run_id is None:
configurable = config.get("configurable")
if isinstance(configurable, dict):
run_id = configurable.get("run_id")
if run_id is not None: if run_id is not None:
return str(run_id) return str(run_id)
except Exception: except RuntimeError:
pass pass
except Exception as exc:
logger.warning("[BillingMiddleware] failed to read run_id from get_config(): %s", exc)
# Fallback: LangGraph API worker sets run_id via set_logging_context() before # Fallback: LangGraph API worker sets run_id via set_logging_context() before
# astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139). # astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139).

View File

@ -21,12 +21,15 @@ message that originally carried them.
from __future__ import annotations from __future__ import annotations
import logging
from typing import Any from typing import Any
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
logger = logging.getLogger(__name__)
class PatchedChatOpenAI(ChatOpenAI): class PatchedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway. """ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
@ -75,6 +78,8 @@ class PatchedChatOpenAI(ChatOpenAI):
# Obtain the base payload from the parent implementation. # Obtain the base payload from the parent implementation.
payload = super()._get_request_payload(input_, stop=stop, **kwargs) payload = super()._get_request_payload(input_, stop=stop, **kwargs)
logger.debug("LLM request payload messages: %s", payload.get("messages"))
payload_messages = payload.get("messages", []) payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages): if len(payload_messages) == len(original_messages):

View File

@ -242,6 +242,44 @@ async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
assert reserve_payload["callId"] == "run-from-worker" assert reserve_payload["callId"] == "run-from-worker"
@pytest.mark.anyio
async def test_awrap_model_call_uses_nested_run_id_from_runnable_config(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = var_child_runnable_config.set(
{
"metadata": {"run_id": "run-from-metadata"},
"configurable": {"run_id": "run-from-configurable"},
}
)
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-metadata"
@pytest.mark.anyio @pytest.mark.anyio
async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch): async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config from langchain_core.runnables.config import var_child_runnable_config