242 lines
9.2 KiB
Python
242 lines
9.2 KiB
Python
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
from deerflow.agents.middlewares.billing_middleware import BillingMiddleware
|
|
|
|
|
|
def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
|
|
billing = SimpleNamespace(
|
|
enabled=enabled,
|
|
include_subagents=include_subagents,
|
|
fail_closed=True,
|
|
block_only_specific_reserve_codes=True,
|
|
blocking_reserve_codes=[-1104, -1106],
|
|
frozen_type=1,
|
|
reserve_url="http://billing.local/accountFrozen/frozen",
|
|
finalize_url="http://billing.local/accountFrozen/release",
|
|
headers={"Authorization": "Bearer x"},
|
|
timeout_seconds=3.0,
|
|
default_expire_seconds=1800,
|
|
default_estimated_output_tokens=None,
|
|
)
|
|
|
|
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
|
|
return SimpleNamespace(
|
|
billing=billing,
|
|
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
|
|
)
|
|
|
|
|
|
def _request_with_latest_user_text(text: str):
|
|
request = MagicMock()
|
|
request.messages = [HumanMessage(content="old"), HumanMessage(content=text)]
|
|
request.model_settings = {}
|
|
request.runtime = SimpleNamespace(
|
|
config={"configurable": {"thread_id": "thread-1", "model_name": "gpt-4"}},
|
|
context={"thread_id": "thread-1"},
|
|
)
|
|
return request
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_uses_estimated_tokens_and_finalizes(monkeypatch):
|
|
from langchain_core.runnables.config import var_child_runnable_config
|
|
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
seen_payloads = []
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
|
if url.endswith("/frozen"):
|
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
|
return {"status": 1000, "message": "ok", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("hello world")
|
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 11, "output_tokens": 22, "total_tokens": 33}))
|
|
|
|
token = var_child_runnable_config.set({"run_id": "run-1"})
|
|
try:
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
finally:
|
|
var_child_runnable_config.reset(token)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
assert len(seen_payloads) == 2
|
|
|
|
reserve_payload = seen_payloads[0][2]
|
|
assert reserve_payload["callId"] == "run-1"
|
|
assert reserve_payload["frozenType"] == 1
|
|
assert reserve_payload["estimatedInputTokens"] == len("hello world")
|
|
assert reserve_payload["estimatedOutputTokens"] == 4096
|
|
assert "frozenAmount" not in reserve_payload
|
|
|
|
finalize_payload = seen_payloads[1][2]
|
|
assert finalize_payload["frozenId"] == "frozen-123"
|
|
assert finalize_payload["finalAmount"] == 0
|
|
assert finalize_payload["usageInputTokens"] == 11
|
|
assert finalize_payload["usageOutputTokens"] == 22
|
|
assert finalize_payload["usageTotalTokens"] == 33
|
|
assert finalize_payload["finalizeReason"] == "success"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_fail_closed_on_insufficient_balance(monkeypatch):
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
return {"status": -1106, "message": "insufficient balance", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("question")
|
|
handler = AsyncMock(return_value=AIMessage(content="should not run"))
|
|
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
assert "insufficient" in str(result.content).lower()
|
|
handler.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_finalize_uses_state_messages_usage_when_response_missing_usage(monkeypatch):
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
seen_payloads = []
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
|
if url.endswith("/frozen"):
|
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
|
return {"status": 1000, "message": "ok", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("hello world")
|
|
request.state = {
|
|
"messages": [
|
|
HumanMessage(content="hello world"),
|
|
AIMessage(content="ok", usage_metadata={"input_tokens": 101, "output_tokens": 202, "total_tokens": 303}),
|
|
]
|
|
}
|
|
handler = AsyncMock(return_value=AIMessage(content="ok"))
|
|
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
assert len(seen_payloads) == 2
|
|
|
|
finalize_payload = seen_payloads[1][2]
|
|
assert finalize_payload["frozenId"] == "frozen-123"
|
|
assert finalize_payload["usageInputTokens"] == 101
|
|
assert finalize_payload["usageOutputTokens"] == 202
|
|
assert finalize_payload["usageTotalTokens"] == 303
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_does_not_block_on_non_blocking_reserve_code(monkeypatch):
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
if url.endswith("/frozen"):
|
|
return {"status": 5001, "message": "platform busy", "data": {}}
|
|
return {"status": 1000, "message": "ok", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("question")
|
|
handler = AsyncMock(return_value=AIMessage(content="model-ran"))
|
|
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
assert result.content == "model-ran"
|
|
handler.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_uses_runnable_config_run_id(monkeypatch):
|
|
"""run_id is sourced from var_child_runnable_config, which LangGraph populates
|
|
via langgraph_api/stream.py during graph node execution."""
|
|
from langchain_core.runnables.config import var_child_runnable_config
|
|
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
seen_payloads = []
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
|
if url.endswith("/frozen"):
|
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
|
return {"status": 1000, "message": "ok", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("hello world")
|
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
|
|
|
token = var_child_runnable_config.set({"run_id": "run-from-ctx"})
|
|
try:
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
finally:
|
|
var_child_runnable_config.reset(token)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
reserve_payload = seen_payloads[0][2]
|
|
assert reserve_payload["callId"] == "run-from-ctx"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_model_call_uses_worker_config_fallback_run_id(monkeypatch):
|
|
"""Fallback: run_id from langgraph_api.logging.worker_config when var_child_runnable_config is unset."""
|
|
from deerflow.agents.middlewares import billing_middleware as bm
|
|
|
|
monkeypatch.setattr(bm, "get_app_config", lambda: _fake_app_config())
|
|
|
|
seen_payloads = []
|
|
|
|
async def fake_post(url, headers, payload, timeout_seconds):
|
|
seen_payloads.append((url, headers, payload, timeout_seconds))
|
|
if url.endswith("/frozen"):
|
|
return {"status": 1000, "message": "ok", "data": {"frozenId": "frozen-123"}}
|
|
return {"status": 1000, "message": "ok", "data": {}}
|
|
|
|
monkeypatch.setattr(bm, "_post_async", fake_post)
|
|
|
|
import langgraph_api.logging as lg_logging
|
|
|
|
middleware = BillingMiddleware()
|
|
request = _request_with_latest_user_text("hello world")
|
|
handler = AsyncMock(return_value=AIMessage(content="ok", usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}))
|
|
|
|
token = lg_logging.worker_config.set({"run_id": "run-from-worker"})
|
|
try:
|
|
result = await middleware.awrap_model_call(request, handler)
|
|
finally:
|
|
lg_logging.worker_config.reset(token)
|
|
|
|
assert isinstance(result, AIMessage)
|
|
reserve_payload = seen_payloads[0][2]
|
|
assert reserve_payload["callId"] == "run-from-worker"
|