"""In-memory call state ledger for the third-party proxy. Tracks each proxied call from reserve → submit → query → finalize, enforcing idempotency and ensuring billing finalize runs exactly once. """ from __future__ import annotations import logging import threading import time from dataclasses import dataclass, field from typing import Any, Literal from uuid import uuid4 logger = logging.getLogger(__name__) BillingState = Literal["UNRESERVED", "RESERVED", "FINALIZED", "FINALIZE_FAILED"] TaskState = Literal["PENDING", "RUNNING", "SUCCESS", "FAILED", "UNKNOWN"] @dataclass class CallRecord: proxy_call_id: str provider: str thread_id: str | None # call_id is sent to the billing platform (callId in reserve payload) call_id: str frozen_id: str | None = None provider_task_id: str | None = None billing_state: BillingState = "UNRESERVED" task_state: TaskState = "PENDING" created_at: float = field(default_factory=time.time) finalized_at: float | None = None error: str | None = None idempotency_key: str | None = None # Cached last provider response — returned for repeat queries after finalization last_response: dict[str, Any] | None = None class CallLedger: """Thread-safe in-memory ledger for third-party proxy call records.""" def __init__(self) -> None: self._records: dict[str, CallRecord] = {} # proxy_call_id → record self._task_index: dict[str, str] = {} # "{provider}:{provider_task_id}" → proxy_call_id self._idem_index: dict[str, str] = {} # "{provider}:{idem_key}" → proxy_call_id self._lock = threading.Lock() def create( self, provider: str, thread_id: str | None, idempotency_key: str | None = None, ) -> CallRecord: """Create a new call record, or return the existing one if idempotency key matches.""" with self._lock: if idempotency_key: existing = self._get_by_idem_key_locked(provider, idempotency_key) if existing is not None: logger.info( "[ThirdPartyProxy][Ledger] idempotent hit: provider=%s proxy_call_id=%s idem_key=%s", provider, existing.proxy_call_id, idempotency_key, ) # logger.debug( # "[ThirdPartyProxy][Ledger] existing record reused: call_id=%s task_id=%s billing_state=%s task_state=%s", # existing.call_id, # existing.provider_task_id, # existing.billing_state, # existing.task_state, # ) return existing record = CallRecord( proxy_call_id=str(uuid4()), provider=provider, thread_id=thread_id, call_id=str(uuid4()), idempotency_key=idempotency_key, ) self._records[record.proxy_call_id] = record if idempotency_key: self._idem_index[f"{provider}:{idempotency_key}"] = record.proxy_call_id logger.info( "[ThirdPartyProxy][Ledger] created record: provider=%s proxy_call_id=%s call_id=%s thread_id=%s", provider, record.proxy_call_id, record.call_id, thread_id, ) # logger.debug( # "[ThirdPartyProxy][Ledger] create details: idem_key=%s billing_state=%s task_state=%s", # idempotency_key, # record.billing_state, # record.task_state, # ) return record def get(self, proxy_call_id: str) -> CallRecord | None: return self._records.get(proxy_call_id) def get_by_task_id(self, provider: str, provider_task_id: str) -> CallRecord | None: key = f"{provider}:{provider_task_id}" proxy_call_id = self._task_index.get(key) return self._records.get(proxy_call_id) if proxy_call_id else None def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None: return self._get_by_idem_key_locked(provider, idempotency_key) def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.frozen_id = frozen_id record.billing_state = "RESERVED" logger.info( "[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s", proxy_call_id, frozen_id, ) # logger.debug( # "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s", # record.call_id, # record.provider, # record.task_state, # ) else: logger.debug( "[ThirdPartyProxy][Ledger] set_reserved ignored for missing record: proxy_call_id=%s", proxy_call_id, ) def set_running(self, proxy_call_id: str, provider_task_id: str) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.provider_task_id = provider_task_id record.task_state = "RUNNING" self._task_index[f"{record.provider}:{provider_task_id}"] = proxy_call_id logger.info( "[ThirdPartyProxy][Ledger] running: proxy_call_id=%s provider_task_id=%s", proxy_call_id, provider_task_id, ) # logger.debug( # "[ThirdPartyProxy][Ledger] running state: provider=%s call_id=%s billing_state=%s", # record.provider, # record.call_id, # record.billing_state, # ) else: logger.debug( "[ThirdPartyProxy][Ledger] set_running ignored for missing record: proxy_call_id=%s provider_task_id=%s", proxy_call_id, provider_task_id, ) def try_claim_finalize(self, proxy_call_id: str) -> bool: """Atomically claim finalization rights. Returns True only once per record.""" with self._lock: record = self._records.get(proxy_call_id) if record is None: logger.debug( "[ThirdPartyProxy][Ledger] finalize claim denied: missing record proxy_call_id=%s", proxy_call_id, ) return False if record.billing_state in ("FINALIZED", "FINALIZE_FAILED"): logger.debug( "[ThirdPartyProxy][Ledger] finalize claim denied: proxy_call_id=%s billing_state=%s", proxy_call_id, record.billing_state, ) return False # Mark as finalized immediately to prevent concurrent finalize record.billing_state = "FINALIZED" logger.info( "[ThirdPartyProxy][Ledger] finalize claimed: proxy_call_id=%s", proxy_call_id, ) logger.debug( "[ThirdPartyProxy][Ledger] finalize claim state: call_id=%s provider=%s task_state=%s frozen_id=%s", record.call_id, record.provider, record.task_state, record.frozen_id, ) return True def set_finalized(self, proxy_call_id: str, task_state: TaskState) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.task_state = task_state record.billing_state = "FINALIZED" record.finalized_at = time.time() logger.info( "[ThirdPartyProxy][Ledger] finalized: proxy_call_id=%s task_state=%s", proxy_call_id, task_state, ) logger.debug( "[ThirdPartyProxy][Ledger] finalized state: provider=%s call_id=%s frozen_id=%s finalized_at=%s", record.provider, record.call_id, record.frozen_id, record.finalized_at, ) else: logger.debug( "[ThirdPartyProxy][Ledger] set_finalized ignored for missing record: proxy_call_id=%s task_state=%s", proxy_call_id, task_state, ) def set_finalize_failed(self, proxy_call_id: str, task_state: TaskState) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.task_state = task_state record.billing_state = "FINALIZE_FAILED" record.finalized_at = time.time() logger.info( "[ThirdPartyProxy][Ledger] finalize failed: proxy_call_id=%s task_state=%s", proxy_call_id, task_state, ) logger.debug( "[ThirdPartyProxy][Ledger] finalize failure state: provider=%s call_id=%s frozen_id=%s finalized_at=%s", record.provider, record.call_id, record.frozen_id, record.finalized_at, ) else: logger.debug( "[ThirdPartyProxy][Ledger] set_finalize_failed ignored for missing record: proxy_call_id=%s task_state=%s", proxy_call_id, task_state, ) def update_response(self, proxy_call_id: str, response: dict[str, Any]) -> None: with self._lock: record = self._records.get(proxy_call_id) if record: record.last_response = response logger.debug( "[ThirdPartyProxy][Ledger] cached response: proxy_call_id=%s keys=%s", proxy_call_id, sorted(response.keys()), ) else: logger.debug( "[ThirdPartyProxy][Ledger] update_response ignored for missing record: proxy_call_id=%s", proxy_call_id, ) def is_finalized(self, proxy_call_id: str) -> bool: record = self._records.get(proxy_call_id) return record is not None and record.billing_state in ("FINALIZED", "FINALIZE_FAILED") # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _get_by_idem_key_locked(self, provider: str, idempotency_key: str) -> CallRecord | None: key = f"{provider}:{idempotency_key}" proxy_call_id = self._idem_index.get(key) return self._records.get(proxy_call_id) if proxy_call_id else None # --------------------------------------------------------------------------- # Module-level singleton # --------------------------------------------------------------------------- _ledger: CallLedger | None = None _ledger_lock = threading.Lock() def get_ledger() -> CallLedger: global _ledger if _ledger is None: with _ledger_lock: if _ledger is None: _ledger = CallLedger() logger.info("[ThirdPartyProxy][Ledger] singleton initialized") return _ledger