193 lines
6.7 KiB
Python
193 lines
6.7 KiB
Python
"""Unit tests for the third-party proxy module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from app.gateway.third_party_proxy.ledger import CallLedger
|
|
from app.gateway.third_party_proxy.proxy import (
|
|
_path_matches,
|
|
jsonpath_get,
|
|
match_query_route,
|
|
match_submit_route,
|
|
)
|
|
from deerflow.config.third_party_proxy_config import (
|
|
QueryRouteConfig,
|
|
SubmitRouteConfig,
|
|
ThirdPartyProviderConfig,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _path_matches
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPathMatches:
|
|
def test_exact_match(self):
|
|
assert _path_matches("/openapi/v2/query", "/openapi/v2/query")
|
|
|
|
def test_exact_no_match(self):
|
|
assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit")
|
|
|
|
def test_glob_matches_prefix(self):
|
|
assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**")
|
|
|
|
def test_glob_matches_prefix_itself(self):
|
|
assert _path_matches("/openapi/v2", "/openapi/v2/**")
|
|
|
|
def test_glob_no_match_different_prefix(self):
|
|
assert not _path_matches("/other/v2/submit", "/openapi/v2/**")
|
|
|
|
def test_trailing_slashes_normalised(self):
|
|
assert _path_matches("/openapi/v2/query/", "/openapi/v2/query")
|
|
|
|
def test_glob_excludes_sibling_prefix(self):
|
|
# /openapi/v2/** should not match /openapi/v2extra/foo
|
|
assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# jsonpath_get
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestJsonpathGet:
|
|
def test_single_key(self):
|
|
assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc"
|
|
|
|
def test_nested_key(self):
|
|
data = {"usage": {"thirdPartyConsumeMoney": 1.23}}
|
|
assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23
|
|
|
|
def test_missing_key_returns_none(self):
|
|
assert jsonpath_get({"foo": "bar"}, "taskId") is None
|
|
|
|
def test_rejects_dollar_prefixed_path(self):
|
|
assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None
|
|
|
|
def test_short_path_supported(self):
|
|
assert jsonpath_get({"x": 1}, "x") == 1
|
|
|
|
def test_non_dict_intermediate(self):
|
|
data = {"usage": "not-a-dict"}
|
|
assert jsonpath_get(data, "usage.something") is None
|
|
|
|
def test_none_input(self):
|
|
assert jsonpath_get(None, "x") is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# match_submit_route / match_query_route
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PROVIDER_CFG = ThirdPartyProviderConfig(
|
|
base_url="https://example.com",
|
|
api_key_env="TEST_API_KEY",
|
|
submit_routes=[
|
|
SubmitRouteConfig(
|
|
method="POST",
|
|
path_pattern="/openapi/v2/**",
|
|
exclude_path_pattern="/openapi/v2/query",
|
|
task_id_jsonpath="taskId",
|
|
)
|
|
],
|
|
query_routes=[
|
|
QueryRouteConfig(
|
|
method="POST",
|
|
path_pattern="/openapi/v2/query",
|
|
request_task_id_jsonpath="taskId",
|
|
status_jsonpath="status",
|
|
success_values=["SUCCESS"],
|
|
failure_values=["FAILED", "CANCELLED"],
|
|
usage_jsonpath="usage.thirdPartyConsumeMoney",
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
class TestMatchRoutes:
|
|
def test_submit_matches_non_query_path(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit")
|
|
assert result is not None
|
|
assert result.task_id_jsonpath == "taskId"
|
|
|
|
def test_submit_excluded_by_exclude_pattern(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
|
assert result is None
|
|
|
|
def test_submit_wrong_method(self):
|
|
result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit")
|
|
assert result is None
|
|
|
|
def test_query_matches(self):
|
|
result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query")
|
|
assert result is not None
|
|
assert result.status_jsonpath == "status"
|
|
|
|
def test_query_wrong_method(self):
|
|
result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query")
|
|
assert result is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CallLedger
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCallLedger:
|
|
def _make_ledger(self) -> CallLedger:
|
|
return CallLedger()
|
|
|
|
def test_create_and_get(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
assert rec.provider == "prov"
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found is not None
|
|
assert found.proxy_call_id == rec.proxy_call_id
|
|
|
|
def test_set_reserved(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.set_reserved(rec.proxy_call_id, "frozen-123")
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found.frozen_id == "frozen-123"
|
|
assert found.billing_state == "RESERVED"
|
|
|
|
def test_set_running(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.set_running(rec.proxy_call_id, "task-abc")
|
|
found = ledger.get_by_task_id("prov", "task-abc")
|
|
assert found is not None
|
|
assert found.proxy_call_id == rec.proxy_call_id
|
|
|
|
def test_try_claim_finalize_once(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
# First claim should succeed
|
|
assert ledger.try_claim_finalize(rec.proxy_call_id) is True
|
|
# Second claim should fail — already in progress/done
|
|
assert ledger.try_claim_finalize(rec.proxy_call_id) is False
|
|
|
|
def test_is_finalized(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
assert ledger.is_finalized(rec.proxy_call_id) is False
|
|
ledger.try_claim_finalize(rec.proxy_call_id)
|
|
ledger.set_finalized(rec.proxy_call_id, "SUCCESS")
|
|
assert ledger.is_finalized(rec.proxy_call_id) is True
|
|
|
|
def test_idempotency_key_dedup(self):
|
|
ledger = self._make_ledger()
|
|
rec1 = ledger.create("prov", "tid", "idem-key-1")
|
|
rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1")
|
|
assert rec2 is not None
|
|
assert rec2.proxy_call_id == rec1.proxy_call_id
|
|
|
|
def test_update_response(self):
|
|
ledger = self._make_ledger()
|
|
rec = ledger.create("prov", "tid", None)
|
|
ledger.update_response(rec.proxy_call_id, {"result": "ok"})
|
|
found = ledger.get(rec.proxy_call_id)
|
|
assert found.last_response == {"result": "ok"}
|