"""Unit tests for the third-party proxy module.""" from __future__ import annotations from app.gateway.third_party_proxy.ledger import CallLedger from app.gateway.routers.third_party import ( _extract_usage_tokens, _extract_usage_tokens_from_submit_stream, _resolve_final_amount, ) from app.gateway.third_party_proxy.proxy import ( API_KEY_MARKER, _path_matches, _replace_api_key_marker_in_body, _replace_api_key_marker_in_headers, jsonpath_get, match_query_route, match_submit_route, ) from deerflow.config.third_party_proxy_config import ( QueryRouteConfig, SubmitRouteConfig, ThirdPartyProviderConfig, ) # --------------------------------------------------------------------------- # _path_matches # --------------------------------------------------------------------------- class TestPathMatches: def test_exact_match(self): assert _path_matches("/openapi/v2/query", "/openapi/v2/query") def test_exact_no_match(self): assert not _path_matches("/openapi/v2/query", "/openapi/v2/submit") def test_glob_matches_prefix(self): assert _path_matches("/openapi/v2/vidu/submit", "/openapi/v2/**") def test_glob_matches_prefix_itself(self): assert _path_matches("/openapi/v2", "/openapi/v2/**") def test_glob_no_match_different_prefix(self): assert not _path_matches("/other/v2/submit", "/openapi/v2/**") def test_trailing_slashes_normalised(self): assert _path_matches("/openapi/v2/query/", "/openapi/v2/query") def test_glob_excludes_sibling_prefix(self): # /openapi/v2/** should not match /openapi/v2extra/foo assert not _path_matches("/openapi/v2extra/foo", "/openapi/v2/**") # --------------------------------------------------------------------------- # jsonpath_get # --------------------------------------------------------------------------- class TestJsonpathGet: def test_single_key(self): assert jsonpath_get({"taskId": "abc"}, "taskId") == "abc" def test_nested_key(self): data = {"usage": {"thirdPartyConsumeMoney": 1.23}} assert jsonpath_get(data, "usage.thirdPartyConsumeMoney") == 1.23 def test_missing_key_returns_none(self): assert jsonpath_get({"foo": "bar"}, "taskId") is None def test_rejects_dollar_prefixed_path(self): assert jsonpath_get({"taskId": "abc"}, "$.taskId") is None def test_short_path_supported(self): assert jsonpath_get({"x": 1}, "x") == 1 def test_non_dict_intermediate(self): data = {"usage": "not-a-dict"} assert jsonpath_get(data, "usage.something") is None def test_none_input(self): assert jsonpath_get(None, "x") is None # --------------------------------------------------------------------------- # match_submit_route / match_query_route # --------------------------------------------------------------------------- _PROVIDER_CFG = ThirdPartyProviderConfig( base_url="https://example.com", api_key_env="TEST_API_KEY", submit_routes=[ SubmitRouteConfig( method="POST", path_pattern="/openapi/v2/**", exclude_path_pattern="/openapi/v2/query", task_id_jsonpath="taskId", ) ], query_routes=[ QueryRouteConfig( method="POST", path_pattern="/openapi/v2/query", request_task_id_jsonpath="taskId", status_jsonpath="status", success_values=["SUCCESS"], failure_values=["FAILED", "CANCELLED"], usage_jsonpath="usage.thirdPartyConsumeMoney", usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"], ) ], ) class TestMatchRoutes: def test_submit_matches_non_query_path(self): result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/vidu/submit") assert result is not None assert result.task_id_jsonpath == "taskId" def test_submit_excluded_by_exclude_pattern(self): result = match_submit_route(_PROVIDER_CFG, "POST", "/openapi/v2/query") assert result is None def test_submit_wrong_method(self): result = match_submit_route(_PROVIDER_CFG, "GET", "/openapi/v2/vidu/submit") assert result is None def test_query_matches(self): result = match_query_route(_PROVIDER_CFG, "POST", "/openapi/v2/query") assert result is not None assert result.status_jsonpath == "status" def test_query_wrong_method(self): result = match_query_route(_PROVIDER_CFG, "GET", "/openapi/v2/query") assert result is None # --------------------------------------------------------------------------- # CallLedger # --------------------------------------------------------------------------- class TestCallLedger: def _make_ledger(self) -> CallLedger: return CallLedger() def test_create_and_get(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) assert rec.provider == "prov" found = ledger.get(rec.proxy_call_id) assert found is not None assert found.proxy_call_id == rec.proxy_call_id def test_set_reserved(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) ledger.set_reserved(rec.proxy_call_id, "frozen-123") found = ledger.get(rec.proxy_call_id) assert found.frozen_id == "frozen-123" assert found.billing_state == "RESERVED" def test_set_running(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) ledger.set_running(rec.proxy_call_id, "task-abc") found = ledger.get_by_task_id("prov", "task-abc") assert found is not None assert found.proxy_call_id == rec.proxy_call_id def test_try_claim_finalize_once(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) # First claim should succeed assert ledger.try_claim_finalize(rec.proxy_call_id) is True # Second claim should fail — already in progress/done assert ledger.try_claim_finalize(rec.proxy_call_id) is False def test_is_finalized(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) assert ledger.is_finalized(rec.proxy_call_id) is False ledger.try_claim_finalize(rec.proxy_call_id) ledger.set_finalized(rec.proxy_call_id, "SUCCESS") assert ledger.is_finalized(rec.proxy_call_id) is True def test_idempotency_key_dedup(self): ledger = self._make_ledger() rec1 = ledger.create("prov", "tid", "idem-key-1") rec2 = ledger.get_by_idempotency_key("prov", "idem-key-1") assert rec2 is not None assert rec2.proxy_call_id == rec1.proxy_call_id def test_update_response(self): ledger = self._make_ledger() rec = ledger.create("prov", "tid", None) ledger.update_response(rec.proxy_call_id, {"result": "ok"}) found = ledger.get(rec.proxy_call_id) assert found.last_response == {"result": "ok"} class TestResolveFinalAmount: def test_sum_multiple_usage_paths(self): route = QueryRouteConfig( path_pattern="/openapi/v2/query", request_task_id_jsonpath="taskId", status_jsonpath="status", success_values=["SUCCESS"], failure_values=["FAILED"], usage_jsonpaths=["usage.thirdPartyConsumeMoney", "usage.consumeMoney"], ) resp_json = { "usage": { "thirdPartyConsumeMoney": None, "consumeMoney": "0.099", } } amount = _resolve_final_amount(resp_json, route) assert amount == 0.099 def test_fallback_to_legacy_single_usage_path(self): route = QueryRouteConfig( path_pattern="/openapi/v2/query", request_task_id_jsonpath="taskId", status_jsonpath="status", success_values=["SUCCESS"], failure_values=["FAILED"], usage_jsonpath="usage.thirdPartyConsumeMoney", ) resp_json = {"usage": {"thirdPartyConsumeMoney": "1.5"}} amount = _resolve_final_amount(resp_json, route) assert amount == 1.5 class TestExtractUsageTokens: def test_prefers_openai_usage_keys(self): resp_json = { "usage": { "prompt_tokens": 123, "completion_tokens": 45, } } input_tokens, output_tokens = _extract_usage_tokens(resp_json) assert input_tokens == 123 assert output_tokens == 45 def test_supports_generic_usage_keys(self): resp_json = { "usage": { "input_tokens": "88", "output_tokens": "12", } } input_tokens, output_tokens = _extract_usage_tokens(resp_json) assert input_tokens == 88 assert output_tokens == 12 class TestExtractUsageTokensFromSubmitStream: def test_extracts_usage_from_final_sse_chunk(self): body = ( b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n' b'data: {"id":"x","choices":[],"usage":{"prompt_tokens":22,"completion_tokens":17}}\n\n' b'data: [DONE]\n\n' ) input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body) assert input_tokens == 22 assert output_tokens == 17 def test_returns_zero_when_no_usage_found(self): body = b'data: {"id":"x","choices":[{"delta":{"content":"hello"}}]}\n\n' input_tokens, output_tokens = _extract_usage_tokens_from_submit_stream(body) assert input_tokens == 0 assert output_tokens == 0 class TestApiKeyMarkerReplacement: def test_replace_marker_in_headers(self): headers = {"Authorization": f"Bearer {API_KEY_MARKER}", "Content-Type": "application/json"} replaced = _replace_api_key_marker_in_headers(headers, "real-key") assert replaced["Authorization"] == "Bearer real-key" def test_replace_marker_in_json_body(self): headers = {"Content-Type": "application/json"} body = ( b'{"apiKey":"__API_KEY_MARKER__","nested":{"token":"Bearer __API_KEY_MARKER__"}}' ) replaced = _replace_api_key_marker_in_body(headers, body, "real-key") assert b'"apiKey":"real-key"' in replaced assert b'"token":"Bearer real-key"' in replaced