from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.messages import AIMessage, HumanMessage from deerflow.agents.middlewares.billing_middleware import BillingMiddleware def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True): billing = SimpleNamespace( enabled=enabled, include_subagents=include_subagents, fail_closed=True, block_only_specific_reserve_codes=True, blocking_reserve_codes=[-1104, -1106], frozen_type=1, reserve_url="http://billing.local/accountFrozen/frozen", finalize_url="http://billing.local/accountFrozen/release", headers={"Authorization": "Bearer x"}, timeout_seconds=3.0, default_expire_seconds=1800, default_estimated_output_tokens=None, ) model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096}) return SimpleNamespace( billing=billing, get_model_config=lambda name: model_cfg if name == "gpt-4" else None, ) def _request_with_latest_user_text(text: str): request = MagicMock() request.messages = [HumanMessage(content="old"), HumanMessage(content=text)] request.model_settings = {} request.runtime = SimpleNamespace( config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}}, context={"thread_id": "thread-1"}, ) return request @pytest.mark.anyio async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch): from langchain_core.runnables.config import var_child_runnable_config from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("hello world") handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33})) token = var_child_runnable_config.set({"run_id": "run-1"}) try: result = await middleware.awrap_model_call(request, handler) finally: var_child_runnable_config.reset(token) assert isinstance(result, AIMessage) assert len(seen_payloads) == 2 reserve_payload = seen_payloads[0][2] assert reserve_payload["callId"] == "run-1" assert reserve_payload["frozenType"] == 1 assert reserve_payload["question"] == "hello world" assert reserve_payload["estimatedInputTokens"] == len("hello world") assert reserve_payload["estimatedOutputTokens"] == 4096 assert "frozenAmount" not in reserve_payload finalize_payload = seen_payloads[1][2] assert finalize_payload["frozenId"] == "frozen-123" assert finalize_payload["finalAmount"] == 0 assert finalize_payload["usageInputTokens"] == 11 assert finalize_payload["usageOutputTokens"] == 22 assert finalize_payload["usageTotalTokens"] == 33 assert finalize_payload["finalizeReason"] == "success" @pytest.mark.anyio async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch): from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) async def fake_post(url, headers, payload, timeout_seconds): return {"status": -1106, "message": "insufficient balance", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("question") handler = AsyncMock(return_value=AIMessage(content="should not run")) result = await middleware.awrap_model_call(request, handler) assert isinstance(result, AIMessage) assert "insufficient" in str(result.content).lower() handler.assert_not_awaited() @pytest.mark.anyio async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch): from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("hello world") request.state = { "messages": [ HumanMessage(content="hello world"), AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}), ] } handler = AsyncMock(return_value=AIMessage(content="ok")) result = await middleware.awrap_model_call(request, handler) assert isinstance(result, AIMessage) assert len(seen_payloads) == 2 finalize_payload = seen_payloads[1][2] assert finalize_payload["frozenId"] == "frozen-123" assert finalize_payload["usageInputTokens"] == 101 assert finalize_payload["usageOutputTokens"] == 202 assert finalize_payload["usageTotalTokens"] == 303 @pytest.mark.anyio async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch): from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) async def fake_post(url, headers, payload, timeout_seconds): if url.endswith("/frozen"): return {"status": 5001, "message": "platform busy", "data": {}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("question") handler = AsyncMock(return_value=AIMessage(content="model-ran")) result = await middleware.awrap_model_call(request, handler) assert isinstance(result, AIMessage) assert result.content == "model-ran" handler.assert_awaited_once() @pytest.mark.anyio async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch): """run_id is sourced from var_child_runnable_config, which LangGraph populates via langgraph_api/stream.py during graph node execution.""" from langchain_core.runnables.config import var_child_runnable_config from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("hello world") handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) token = var_child_runnable_config.set({"run_id": "run-from-ctx"}) try: result = await middleware.awrap_model_call(request, handler) finally: var_child_runnable_config.reset(token) assert isinstance(result, AIMessage) reserve_payload = seen_payloads[0][2] assert reserve_payload["callId"] == "run-from-ctx" @pytest.mark.anyio async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch): """Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset.""" from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) import langgraph_api.logging as lg_logging middleware = BillingMiddleware() request = _request_with_latest_user_text("hello world") handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) token = lg_logging.worker_config.set({"run_id": "run-from-worker"}) try: result = await middleware.awrap_model_call(request, handler) finally: lg_logging.worker_config.reset(token) assert isinstance(result, AIMessage) reserve_payload = seen_payloads[0][2] assert reserve_payload["callId"] == "run-from-worker" @pytest.mark.anyio async def test_awrap_model_call_uses_nested_run_id_from_runnable_config(monkeypatch): from langchain_core.runnables.config import var_child_runnable_config from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() request = _request_with_latest_user_text("hello world") handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) token = var_child_runnable_config.set( { "metadata": {"run_id": "run-from-metadata"}, "configurable": {"run_id": "run-from-configurable"}, } ) try: result = await middleware.awrap_model_call(request, handler) finally: var_child_runnable_config.reset(token) assert isinstance(result, AIMessage) reserve_payload = seen_payloads[0][2] assert reserve_payload["callId"] == "run-from-metadata" @pytest.mark.anyio async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch): from langchain_core.runnables.config import var_child_runnable_config from deerflow.agents.middlewares import billing_middleware as bm monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config()) seen_payloads = [] async def fake_post(url, headers, payload, timeout_seconds): seen_payloads.append((url, headers, payload, timeout_seconds)) if url.endswith("/frozen"): return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}} return {"status": 1000, "message": "ok", "data": {}} monkeypatch.setattr(bm, "_post_async", fake_post) middleware = BillingMiddleware() long_question = "abcdefghijklmnopqrstuvwxyz1234567890" request = _request_with_latest_user_text(long_question) handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) token = var_child_runnable_config.set({"run_id": "run-question-truncate"}) try: result = await middleware.awrap_model_call(request, handler) finally: var_child_runnable_config.reset(token) assert isinstance(result, AIMessage) reserve_payload = seen_payloads[0][2] assert reserve_payload["question"] == "abcdefghijklmnopqrstuvwxyz1。。。"