293 lines
10 KiB
Python
293 lines
10 KiB
Python
"""Unit tests for the third-party proxy module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from app.gateway.third_party_proxy.ledger import CallLedger
|
|
from app.gateway.routers.third_party import (
|
|
_extract_usage_tokens,
|
|
_extract_usage_tokens_from_submit_stream,
|
|
_resolve_final_amount,
|
|
)
|
|
from app.gateway.third_party_proxy.proxy import (
|
|
API_KEY_MARKER,
|
|
_path_matches,
|
|
_replace_api_key_marker_in_body,
|
|
_replace_api_key_marker_in_headers,
|
|
jsonpath_get,
|
|
match_query_route,
|
|
match_submit_route,
|
|
)
|
|
from deerflow.config.third_party_proxy_config import (
|
|
QueryRouteConfig,
|
|
SubmitRouteConfig,
|
|
ThirdPartyProviderConfig,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _path_matches
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPathMatches:
|
|
def test_exact_match(self):
|
|
assert _path_matches("/openapi/v2/query", "/openapi/v2/query")
|
|
|
|
def test_exact_no_match(self):
|
|
assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit")
|
|
|
|
def test_glob_matches_prefix(self):
|
|
assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**")
|
|
|
|
def test_glob_matches_prefix_itself(self):
|
|
assert _path_matches("/openapi/v2", "/openapi/v2/**")
|
|
|
|
def test_glob_no_match_different_prefix(self):
|
|
assert not _path_matches("/other/v2/submit", "/openapi/v2/**")
|
|
|
|
def test_trailing_slashes_normalised(self):
|
|
assert _path_matches("/openapi/v2/query/", "/openapi/v2/query")
|
|
|
|
def test_glob_excludes_sibling_prefix(self):
|
|
# /openapi/v2/** should not match /openapi/v2extra/foo
|
|
assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# jsonpath_get
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestJsonpathGet:
|
|
def test_single_key(self):
|
|
assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc"
|
|
|
|
def test_nested_key(self):
|
|
data = {"usage": {"thirdPartyConsumeMoney": 1.23}}
|
|
assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23
|
|
|
|
def test_missing_key_returns_none(self):
|
|
assert jsonpath_get({"foo": "bar"}, "taskId") is None
|
|
|
|
def test_rejects_dollar_prefixed_path(self):
|
|
assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None
|
|
|
|
def test_short_path_supported(self):
|
|
assert jsonpath_get({"x": 1}, "x") == 1
|
|
|
|
def test_non_dict_intermediate(self):
|
|
data = {"usage": "not-a-dict"}
|
|
assert jsonpath_get(data, "usage.something") is None
|
|
|
|
def test_none_input(self):
|
|
assert jsonpath_get(None, "x") is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# match_submit_route / match_query_route
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PROVIDER_CFG = ThirdPartyProviderConfig(
|
|
base_url="https://example.com",
|
|
api_key_env="TEST_API_KEY",
|
|
submit_routes=[
|
|
SubmitRouteConfig(
|
|
method="POST",
|
|
path_pattern="/openapi/v2/**",
|
|
exclude_path_pattern="/openapi/v2/query",
|
|
task_id_jsonpath="taskId",
|
|
)
|
|
],
|
|
query_routes=[
|
|
QueryRouteConfig(
|
|
method="POST",
|
|
path_pattern="/openapi/v2/query",
|
|
request_task_id_jsonpath="taskId",
|
|
status_jsonpath="status",
|
|
success_values=["SUCCESS"],
|
|
failure_values=["FAILED", "CANCELLED"],
|
|
usage_jsonpath="usage.thirdPartyConsumeMoney",
|
|
usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"],
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
class TestMatchRoutes:
|
|
def test_submit_matches_non_query_path(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit")
|
|
assert result is not None
|
|
assert result.task_id_jsonpath == "taskId"
|
|
|
|
def test_submit_excluded_by_exclude_pattern(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
|
assert result is None
|
|
|
|
def test_submit_wrong_method(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit")
|
|
assert result is None
|
|
|
|
def test_query_matches(self):
|
|
result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
|
assert result is not None
|
|
assert result.status_jsonpath == "status"
|
|
|
|
def test_query_wrong_method(self):
|
|
result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query")
|
|
assert result is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CallLedger
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCallLedger:
|
|
def _make_ledger(self) -> CallLedger:
|
|
return CallLedger()
|
|
|
|
def test_create_and_get(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
assert rec.provider == "prov"
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found is not None
|
|
assert found.proxy_call_id == rec.proxy_call_id
|
|
|
|
def test_set_reserved(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.set_reserved(rec.proxy_call_id, "frozen-123")
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found.frozen_id == "frozen-123"
|
|
assert found.billing_state == "RESERVED"
|
|
|
|
def test_set_running(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.set_running(rec.proxy_call_id, "task-abc")
|
|
found = ledger.get_by_task_id("prov", "task-abc")
|
|
assert found is not None
|
|
assert found.proxy_call_id == rec.proxy_call_id
|
|
|
|
def test_try_claim_finalize_once(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
# First claim should succeed
|
|
assert ledger.try_claim_finalize(rec.proxy_call_id) is True
|
|
# Second claim should fail — already in progress/done
|
|
assert ledger.try_claim_finalize(rec.proxy_call_id) is False
|
|
|
|
def test_is_finalized(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
assert ledger.is_finalized(rec.proxy_call_id) is False
|
|
ledger.try_claim_finalize(rec.proxy_call_id)
|
|
ledger.set_finalized(rec.proxy_call_id, "SUCCESS")
|
|
assert ledger.is_finalized(rec.proxy_call_id) is True
|
|
|
|
def test_idempotency_key_dedup(self):
|
|
ledger = self._make_ledger()
|
|
rec1 = ledger.create("prov", "tid", "idem-key-1")
|
|
rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1")
|
|
assert rec2 is not None
|
|
assert rec2.proxy_call_id == rec1.proxy_call_id
|
|
|
|
def test_update_response(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.update_response(rec.proxy_call_id, {"result": "ok"})
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found.last_response == {"result": "ok"}
|
|
|
|
|
|
class TestResolveFinalAmount:
|
|
def test_sum_multiple_usage_paths(self):
|
|
route = QueryRouteConfig(
|
|
path_pattern="/openapi/v2/query",
|
|
request_task_id_jsonpath="taskId",
|
|
status_jsonpath="status",
|
|
success_values=["SUCCESS"],
|
|
failure_values=["FAILED"],
|
|
usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"],
|
|
)
|
|
resp_json = {
|
|
"usage": {
|
|
"thirdPartyConsumeMoney": None,
|
|
"consumeMoney": "0.099",
|
|
}
|
|
}
|
|
amount = _resolve_final_amount(resp_json, route)
|
|
assert amount == 0.099
|
|
|
|
def test_fallback_to_legacy_single_usage_path(self):
|
|
route = QueryRouteConfig(
|
|
path_pattern="/openapi/v2/query",
|
|
request_task_id_jsonpath="taskId",
|
|
status_jsonpath="status",
|
|
success_values=["SUCCESS"],
|
|
failure_values=["FAILED"],
|
|
usage_jsonpath="usage.thirdPartyConsumeMoney",
|
|
)
|
|
resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}}
|
|
amount = _resolve_final_amount(resp_json, route)
|
|
assert amount == 1.5
|
|
|
|
|
|
class TestExtractUsageTokens:
|
|
def test_prefers_openai_usage_keys(self):
|
|
resp_json = {
|
|
"usage": {
|
|
"prompt_tokens": 123,
|
|
"completion_tokens": 45,
|
|
}
|
|
}
|
|
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
|
|
assert input_tokens == 123
|
|
assert output_tokens == 45
|
|
|
|
def test_supports_generic_usage_keys(self):
|
|
resp_json = {
|
|
"usage": {
|
|
"input_tokens": "88",
|
|
"output_tokens": "12",
|
|
}
|
|
}
|
|
input_tokens, output_tokens = _extract_usage_tokens(resp_json)
|
|
assert input_tokens == 88
|
|
assert output_tokens == 12
|
|
|
|
|
|
class TestExtractUsageTokensFromSubmitStream:
|
|
def test_extracts_usage_from_final_sse_chunk(self):
|
|
body = (
|
|
b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
|
|
b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n'
|
|
b'data: [DONE]\n\n'
|
|
)
|
|
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
|
|
assert input_tokens == 22
|
|
assert output_tokens == 17
|
|
|
|
def test_returns_zero_when_no_usage_found(self):
|
|
body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n'
|
|
input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body)
|
|
assert input_tokens == 0
|
|
assert output_tokens == 0
|
|
|
|
|
|
class TestApiKeyMarkerReplacement:
|
|
def test_replace_marker_in_headers(self):
|
|
headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"}
|
|
replaced = _replace_api_key_marker_in_headers(headers, "real-key")
|
|
assert replaced["Authorization"] == "Bearer real-key"
|
|
|
|
def test_replace_marker_in_json_body(self):
|
|
headers = {"Content-Type": "application/json"}
|
|
body = (
|
|
b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}'
|
|
)
|
|
replaced = _replace_api_key_marker_in_body(headers, body, "real-key")
|
|
assert b'"apiKey":"real-key"' in replaced
|
|
assert b'"token":"Bearer real-key"' in replaced
|