From 17a810438456daf6b30e961efc5a60e78abc9da5 Mon Sep 17 00:00:00 2001 From: Titan Date: Tue, 14 Apr 2026 18:04:53 +0800 Subject: [PATCH] feat(billing): refactor run_id extraction and enhance logging in middleware --- .../agents/middlewares/billing_middleware.py | 48 ++++++++----------- .../harness/deerflow/models/patched_openai.py | 5 ++ backend/tests/test_billing_middleware.py | 38 +++++++++++++++ 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py index 779fd58e..65cef099 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/billing_middleware.py @@ -116,27 +116,6 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, question = _extract_latest_question(request.messages) call_id = run_id or str(uuid4()) - if not run_id: - runtime = getattr(request, "runtime", None) - runtime_context = getattr(runtime, "context", None) - runtime_config = getattr(runtime, "config", None) - context_keys = sorted(runtime_context.keys()) if isinstance(runtime_context, dict) else [] - config_keys = sorted(runtime_config.keys()) if isinstance(runtime_config, dict) else [] - logger.warning( - "[BillingMiddleware] run_id missing in runtime; fallback callId=%s context_type=%s config_type=%s context_keys=%s config_keys=%s", - call_id, - type(runtime_context).__name__ if runtime_context is not None else "None", - type(runtime_config).__name__ if runtime_config is not None else "None", - context_keys, - config_keys, - ) - logger.info( - "[BillingMiddleware] id mapping: thread_id=%s run_id=%s call_id=%s model_name=%s", - session_id, - run_id, - call_id, - model_name, - ) expire_at = datetime.now() + timedelta(seconds=cfg.default_expire_seconds) payload: dict[str, Any] = { "sessionId": session_id, @@ -152,19 +131,30 @@ def _reserve_payload(request: ModelRequest) -> tuple[dict[str, Any], str | None, def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001 - # Primary: LangGraph injects run_id into the top-level RunnableConfig - # (langgraph_api/stream.py:218) and propagates it via var_child_runnable_config - # throughout graph node execution. + # Primary: use LangGraph's public runtime API to access the current RunnableConfig. + # This matches the official guidance for code that needs config inside runtime-bound + # execution, while middleware itself only receives ModelRequest(runtime=Runtime). try: - from langchain_core.runnables.config import var_child_runnable_config + from langgraph.config import get_config - lc_config = var_child_runnable_config.get() - if isinstance(lc_config, dict): - run_id = lc_config.get("run_id") + config = get_config() + if isinstance(config, dict): + # Depending on LangGraph API variant, run_id may live at different levels. + run_id = config.get("run_id") + if run_id is None: + metadata = config.get("metadata") + if isinstance(metadata, dict): + run_id = metadata.get("run_id") + if run_id is None: + configurable = config.get("configurable") + if isinstance(configurable, dict): + run_id = configurable.get("run_id") if run_id is not None: return str(run_id) - except Exception: + except RuntimeError: pass + except Exception as exc: + logger.warning("[BillingMiddleware] failed to read run_id from get_config(): %s", exc) # Fallback: LangGraph API worker sets run_id via set_logging_context() before # astream_state, storing it in worker_config ContextVar (langgraph_api/worker.py:139). diff --git a/backend/packages/harness/deerflow/models/patched_openai.py b/backend/packages/harness/deerflow/models/patched_openai.py index 9a7801f4..7225fba0 100644 --- a/backend/packages/harness/deerflow/models/patched_openai.py +++ b/backend/packages/harness/deerflow/models/patched_openai.py @@ -21,12 +21,15 @@ message that originally carried them. from __future__ import annotations +import logging from typing import Any from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI +logger = logging.getLogger(__name__) + class PatchedChatOpenAI(ChatOpenAI): """ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway. @@ -75,6 +78,8 @@ class PatchedChatOpenAI(ChatOpenAI): # Obtain the base payload from the parent implementation. payload = super()._get_request_payload(input_, stop=stop, **kwargs) + logger.debug("LLM request payload messages: %s", payload.get("messages")) + payload_messages = payload.get("messages", []) if len(payload_messages) == len(original_messages): diff --git a/backend/tests/test_billing_middleware.py b/backend/tests/test_billing_middleware.py index 3d577eef..553ea2b6 100644 --- a/backend/tests/test_billing_middleware.py +++ b/backend/tests/test_billing_middleware.py @@ -242,6 +242,44 @@ async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch): assert reserve_payload["callId"] == "run-from-worker" +@pytest.mark.anyio +async def test_awrap_model_call_uses_nested_run_id_from_runnable_config(monkeypatch): + from langchain_core.runnables.config import var_child_runnable_config + + from deerflow.agents.middlewares import billing_middleware as bm + + monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) + + seen_payloads = [] + + async def fake_post(url, headers, payload, timeout_seconds): + seen_payloads.append((url, headers, payload, timeout_seconds)) + if url.endswith("/frozen"): + return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} + return {"status": 1000, "message": "ok", "data": {}} + + monkeypatch.setattr(bm, "_post_async", fake_post) + + middleware = BillingMiddleware() + request = _request_with_latest_user_text("hello world") + handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) + + token = var_child_runnable_config.set( + { + "metadata": {"run_id": "run-from-metadata"}, + "configurable": {"run_id": "run-from-configurable"}, + } + ) + try: + result = await middleware.awrap_model_call(request, handler) + finally: + var_child_runnable_config.reset(token) + + assert isinstance(result, AIMessage) + reserve_payload = seen_payloads[0][2] + assert reserve_payload["callId"] == "run-from-metadata" + + @pytest.mark.anyio async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch): from langchain_core.runnables.config import var_child_runnable_config