deerflow2/backend/tests/test_billing_middleware.py

315 lines
12 KiB
Python

from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
billing = SimpleNamespace(
enabled=enabled,
include_subagents=include_subagents,
fail_closed=True,
block_only_specific_reserve_codes=True,
blocking_reserve_codes=[-1104, -1106],
frozen_type=1,
reserve_url="http://billing.local/accountFrozen/frozen",
finalize_url="http://billing.local/accountFrozen/release",
headers={"Authorization": "Bearer x"},
timeout_seconds=3.0,
default_expire_seconds=1800,
default_estimated_output_tokens=None,
)
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
return SimpleNamespace(
billing=billing,
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
)
def _request_with_latest_user_text(text: str):
request = MagicMock()
request.messages = [HumanMessage(content="old"), HumanMessage(content=text)]
request.model_settings = {}
request.runtime = SimpleNamespace(
config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}},
context={"thread_id": "thread-1"},
)
return request
@pytest.mark.anyio
async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33}))
token = var_child_runnable_config.set({"run_id": "run-1"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
assert len(seen_payloads) == 2
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-1"
assert reserve_payload["frozenType"] == 1
assert reserve_payload["question"] == "hello world"
assert reserve_payload["estimatedInputTokens"] == len("hello world")
assert reserve_payload["estimatedOutputTokens"] == 4096
assert "frozenAmount" not in reserve_payload
finalize_payload = seen_payloads[1][2]
assert finalize_payload["frozenId"] == "frozen-123"
assert finalize_payload["finalAmount"] == 0
assert finalize_payload["usageInputTokens"] == 11
assert finalize_payload["usageOutputTokens"] == 22
assert finalize_payload["usageTotalTokens"] == 33
assert finalize_payload["finalizeReason"] == "success"
@pytest.mark.anyio
async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
async def fake_post(url, headers, payload, timeout_seconds):
return {"status": -1106, "message": "insufficient balance", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("question")
handler = AsyncMock(return_value=AIMessage(content="should not run"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert "insufficient" in str(result.content).lower()
handler.assert_not_awaited()
@pytest.mark.anyio
async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
request.state = {
"messages": [
HumanMessage(content="hello world"),
AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}),
]
}
handler = AsyncMock(return_value=AIMessage(content="ok"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert len(seen_payloads) == 2
finalize_payload = seen_payloads[1][2]
assert finalize_payload["frozenId"] == "frozen-123"
assert finalize_payload["usageInputTokens"] == 101
assert finalize_payload["usageOutputTokens"] == 202
assert finalize_payload["usageTotalTokens"] == 303
@pytest.mark.anyio
async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch):
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
async def fake_post(url, headers, payload, timeout_seconds):
if url.endswith("/frozen"):
return {"status": 5001, "message": "platform busy", "data": {}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("question")
handler = AsyncMock(return_value=AIMessage(content="model-ran"))
result = await middleware.awrap_model_call(request, handler)
assert isinstance(result, AIMessage)
assert result.content == "model-ran"
handler.assert_awaited_once()
@pytest.mark.anyio
async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch):
"""run_id is sourced from var_child_runnable_config, which LangGraph populates
via langgraph_api/stream.py during graph node execution."""
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = var_child_runnable_config.set({"run_id": "run-from-ctx"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-ctx"
@pytest.mark.anyio
async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
"""Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset."""
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
import langgraph_api.logging as lg_logging
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = lg_logging.worker_config.set({"run_id": "run-from-worker"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
lg_logging.worker_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-worker"
@pytest.mark.anyio
async def test_awrap_model_call_uses_nested_run_id_from_runnable_config(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
request = _request_with_latest_user_text("hello world")
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = var_child_runnable_config.set(
{
"metadata": {"run_id": "run-from-metadata"},
"configurable": {"run_id": "run-from-configurable"},
}
)
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["callId"] == "run-from-metadata"
@pytest.mark.anyio
async def test_awrap_model_call_truncates_question_like_token_usage_middleware(monkeypatch):
from langchain_core.runnables.config import var_child_runnable_config
from deerflow.agents.middlewares import billing_middleware as bm
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
seen_payloads = []
async def fake_post(url, headers, payload, timeout_seconds):
seen_payloads.append((url, headers, payload, timeout_seconds))
if url.endswith("/frozen"):
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
return {"status": 1000, "message": "ok", "data": {}}
monkeypatch.setattr(bm, "_post_async", fake_post)
middleware = BillingMiddleware()
long_question = "abcdefghijklmnopqrstuvwxyz1234567890"
request = _request_with_latest_user_text(long_question)
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
token = var_child_runnable_config.set({"run_id": "run-question-truncate"})
try:
result = await middleware.awrap_model_call(request, handler)
finally:
var_child_runnable_config.reset(token)
assert isinstance(result, AIMessage)
reserve_payload = seen_payloads[0][2]
assert reserve_payload["question"] == "abcdefghijklmnopqrstuvwxyz1。。。"