deerflow2/backend/tests/test_third_party_proxy.py

293 lines
10 KiB
Python

"""Unit tests for the third-party proxy module."""
from __future__ import annotations
from app.gateway.third_party_proxy.ledger import CallLedger
from app.gateway.routers.third_party import (
_extract_usage_tokens,
_extract_usage_tokens_from_submit_stream,
_resolve_final_amount,
)
from app.gateway.third_party_proxy.proxy import (
API_KEY_MARKER,
_path_matches,
_replace_api_key_marker_in_body,
_replace_api_key_marker_in_headers,
jsonpath_get,
match_query_route,
match_submit_route,
)
from deerflow.config.third_party_proxy_config import (
QueryRouteConfig,
SubmitRouteConfig,
ThirdPartyProviderConfig,
)
# ---------------------------------------------------------------------------
# _path_matches
# ---------------------------------------------------------------------------
class TestPathMatches:
def test_exact_match(self):
assert _path_matches("/openapi/v2/query", "/openapi/v2/query")
def test_exact_no_match(self):
assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit")
def test_glob_matches_prefix(self):
assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**")
def test_glob_matches_prefix_itself(self):
assert _path_matches("/openapi/v2", "/openapi/v2/**")
def test_glob_no_match_different_prefix(self):
assert not _path_matches("/other/v2/submit", "/openapi/v2/**")
def test_trailing_slashes_normalised(self):
assert _path_matches("/openapi/v2/query/", "/openapi/v2/query")
def test_glob_excludes_sibling_prefix(self):
# /openapi/v2/** should not match /openapi/v2extra/foo
assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**")
# ---------------------------------------------------------------------------
# jsonpath_get
# ---------------------------------------------------------------------------
class TestJsonpathGet:
def test_single_key(self):
assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc"
def test_nested_key(self):
data = {"usage": {"thirdPartyConsumeMoney": 1.23}}
assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23
def test_missing_key_returns_none(self):
assert jsonpath_get({"foo": "bar"}, "taskId") is None
def test_rejects_dollar_prefixed_path(self):
assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None
def test_short_path_supported(self):
assert jsonpath_get({"x": 1}, "x") == 1
def test_non_dict_intermediate(self):
data = {"usage": "not-a-dict"}
assert jsonpath_get(data, "usage.something") is None
def test_none_input(self):
assert jsonpath_get(None, "x") is None
# ---------------------------------------------------------------------------
# match_submit_route / match_query_route
# ---------------------------------------------------------------------------
_PROVIDER_CFG = ThirdPartyProviderConfig(
base_url="https://example.com",
api_key_env="TEST_API_KEY",
submit_routes=[
SubmitRouteConfig(
method="POST",
path_pattern="/openapi/v2/**",
exclude_path_pattern="/openapi/v2/query",
task_id_jsonpath="taskId",
)
],
query_routes=[
QueryRouteConfig(
method="POST",
path_pattern="/openapi/v2/query",
request_task_id_jsonpath="taskId",
status_jsonpath="status",
success_values=["SUCCESS"],
failure_values=["FAILED", "CANCELLED"],
usage_jsonpath="usage.thirdPartyConsumeMoney",
usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"],
)
],
)
class TestMatchRoutes:
def test_submit_matches_non_query_path(self):
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit")
assert result is not None
assert result.task_id_jsonpath == "taskId"
def test_submit_excluded_by_exclude_pattern(self):
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
assert result is None
def test_submit_wrong_method(self):
result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit")
assert result is None
def test_query_matches(self):
result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
assert result is not None
assert result.status_jsonpath == "status"
def test_query_wrong_method(self):
result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query")
assert result is None
# ---------------------------------------------------------------------------
# CallLedger
# ---------------------------------------------------------------------------
class TestCallLedger:
def _make_ledger(self) -> CallLedger:
return CallLedger()
def test_create_and_get(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
assert rec.provider == "prov"
found = ledger.get(rec.proxy_call_id)
assert found is not None
assert found.proxy_call_id == rec.proxy_call_id
def test_set_reserved(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
ledger.set_reserved(rec.proxy_call_id, "frozen-123")
found = ledger.get(rec.proxy_call_id)
assert found.frozen_id == "frozen-123"
assert found.billing_state == "RESERVED"
def test_set_running(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
ledger.set_running(rec.proxy_call_id, "task-abc")
found = ledger.get_by_task_id("prov", "task-abc")
assert found is not None
assert found.proxy_call_id == rec.proxy_call_id
def test_try_claim_finalize_once(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
# First claim should succeed
assert ledger.try_claim_finalize(rec.proxy_call_id) is True
# Second claim should fail — already in progress/done
assert ledger.try_claim_finalize(rec.proxy_call_id) is False
def test_is_finalized(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
assert ledger.is_finalized(rec.proxy_call_id) is False
ledger.try_claim_finalize(rec.proxy_call_id)
ledger.set_finalized(rec.proxy_call_id, "SUCCESS")
assert ledger.is_finalized(rec.proxy_call_id) is True
def test_idempotency_key_dedup(self):
ledger = self._make_ledger()
rec1 = ledger.create("prov", "tid", "idem-key-1")
rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1")
assert rec2 is not None
assert rec2.proxy_call_id == rec1.proxy_call_id
def test_update_response(self):
ledger = self._make_ledger()
rec = ledger.create("prov", "tid", None)
ledger.update_response(rec.proxy_call_id, {"result": "ok"})
found = ledger.get(rec.proxy_call_id)
assert found.last_response == {"result": "ok"}
class TestResolveFinalAmount:
def test_sum_multiple_usage_paths(self):
route = QueryRouteConfig(
path_pattern="/openapi/v2/query",
request_task_id_jsonpath="taskId",
status_jsonpath="status",
success_values=["SUCCESS"],
failure_values=["FAILED"],
usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"],
)
resp_json = {
"usage": {
"thirdPartyConsumeMoney": None,
"consumeMoney": "0.099",
}
}
amount = _resolve_final_amount(resp_json, route)
assert amount == 0.099
def test_fallback_to_legacy_single_usage_path(self):
route = QueryRouteConfig(
path_pattern="/openapi/v2/query",
request_task_id_jsonpath="taskId",
status_jsonpath="status",
success_values=["SUCCESS"],
failure_values=["FAILED"],
usage_jsonpath="usage.thirdPartyConsumeMoney",
)
resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}}
amount = _resolve_final_amount(resp_json, route)
assert amount == 1.5
class TestExtractUsageTokens:
def test_prefers_openai_usage_keys(self):
resp_json = {
"usage": {
"prompt_tokens": 123,
"completion_tokens": 45,
}
}
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
assert input_tokens == 123
assert output_tokens == 45
def test_supports_generic_usage_keys(self):
resp_json = {
"usage": {
"input_tokens": "88",
"output_tokens": "12",
}
}
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
assert input_tokens == 88
assert output_tokens == 12
class TestExtractUsageTokensFromSubmitStream:
def test_extracts_usage_from_final_sse_chunk(self):
body = (
b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n'
b'data: [DONE]\n\n'
)
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
assert input_tokens == 22
assert output_tokens == 17
def test_returns_zero_when_no_usage_found(self):
body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
assert input_tokens == 0
assert output_tokens == 0
class TestApiKeyMarkerReplacement:
def test_replace_marker_in_headers(self):
headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"}
replaced = _replace_api_key_marker_in_headers(headers, "real-key")
assert replaced["Authorization"] == "Bearer real-key"
def test_replace_marker_in_json_body(self):
headers = {"Content-Type": "application/json"}
body = (
b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}'
)
replaced = _replace_api_key_marker_in_body(headers, body, "real-key")
assert b'"apiKey":"real-key"' in replaced
assert b'"token":"Bearer real-key"' in replaced