290 lines
12 KiB
Python
290 lines
12 KiB
Python
"""In-memory call state ledger for the third-party proxy.
|
|
|
|
Tracks each proxied call from reserve → submit → query → finalize,
|
|
enforcing idempotency and ensuring billing finalize runs exactly once.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Literal
|
|
from uuid import uuid4
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BillingState = Literal["UNRESERVED", "RESERVED", "FINALIZED", "FINALIZE_FAILED"]
|
|
TaskState = Literal["PENDING", "RUNNING", "SUCCESS", "FAILED", "UNKNOWN"]
|
|
|
|
|
|
@dataclass
|
|
class CallRecord:
|
|
proxy_call_id: str
|
|
provider: str
|
|
thread_id: str | None
|
|
# call_id is sent to the billing platform (callId in reserve payload)
|
|
call_id: str
|
|
frozen_id: str | None = None
|
|
provider_task_id: str | None = None
|
|
billing_state: BillingState = "UNRESERVED"
|
|
task_state: TaskState = "PENDING"
|
|
created_at: float = field(default_factory=time.time)
|
|
finalized_at: float | None = None
|
|
error: str | None = None
|
|
idempotency_key: str | None = None
|
|
# Cached last provider response — returned for repeat queries after finalization
|
|
last_response: dict[str, Any] | None = None
|
|
|
|
|
|
class CallLedger:
|
|
"""Thread-safe in-memory ledger for third-party proxy call records."""
|
|
|
|
def __init__(self) -> None:
|
|
self._records: dict[str, CallRecord] = {} # proxy_call_id → record
|
|
self._task_index: dict[str, str] = {} # "{provider}:{provider_task_id}" → proxy_call_id
|
|
self._idem_index: dict[str, str] = {} # "{provider}:{idem_key}" → proxy_call_id
|
|
self._lock = threading.Lock()
|
|
|
|
def create(
|
|
self,
|
|
provider: str,
|
|
thread_id: str | None,
|
|
idempotency_key: str | None = None,
|
|
) -> CallRecord:
|
|
"""Create a new call record, or return the existing one if idempotency key matches."""
|
|
with self._lock:
|
|
if idempotency_key:
|
|
existing = self._get_by_idem_key_locked(provider, idempotency_key)
|
|
if existing is not None:
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] idempotent hit: provider=%s proxy_call_id=%s idem_key=%s",
|
|
provider,
|
|
existing.proxy_call_id,
|
|
idempotency_key,
|
|
)
|
|
# logger.debug(
|
|
# "[ThirdPartyProxy][Ledger] existing record reused: call_id=%s task_id=%s billing_state=%s task_state=%s",
|
|
# existing.call_id,
|
|
# existing.provider_task_id,
|
|
# existing.billing_state,
|
|
# existing.task_state,
|
|
# )
|
|
return existing
|
|
|
|
record = CallRecord(
|
|
proxy_call_id=str(uuid4()),
|
|
provider=provider,
|
|
thread_id=thread_id,
|
|
call_id=str(uuid4()),
|
|
idempotency_key=idempotency_key,
|
|
)
|
|
self._records[record.proxy_call_id] = record
|
|
if idempotency_key:
|
|
self._idem_index[f"{provider}:{idempotency_key}"] = record.proxy_call_id
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] created record: provider=%s proxy_call_id=%s call_id=%s thread_id=%s",
|
|
provider,
|
|
record.proxy_call_id,
|
|
record.call_id,
|
|
thread_id,
|
|
)
|
|
# logger.debug(
|
|
# "[ThirdPartyProxy][Ledger] create details: idem_key=%s billing_state=%s task_state=%s",
|
|
# idempotency_key,
|
|
# record.billing_state,
|
|
# record.task_state,
|
|
# )
|
|
return record
|
|
|
|
def get(self, proxy_call_id: str) -> CallRecord | None:
|
|
return self._records.get(proxy_call_id)
|
|
|
|
def get_by_task_id(self, provider: str, provider_task_id: str) -> CallRecord | None:
|
|
key = f"{provider}:{provider_task_id}"
|
|
proxy_call_id = self._task_index.get(key)
|
|
return self._records.get(proxy_call_id) if proxy_call_id else None
|
|
|
|
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
|
return self._get_by_idem_key_locked(provider, idempotency_key)
|
|
|
|
def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None:
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record:
|
|
record.frozen_id = frozen_id
|
|
record.billing_state = "RESERVED"
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s",
|
|
proxy_call_id,
|
|
frozen_id,
|
|
)
|
|
# logger.debug(
|
|
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",
|
|
# record.call_id,
|
|
# record.provider,
|
|
# record.task_state,
|
|
# )
|
|
else:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] set_reserved ignored for missing record: proxy_call_id=%s",
|
|
proxy_call_id,
|
|
)
|
|
|
|
def set_running(self, proxy_call_id: str, provider_task_id: str) -> None:
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record:
|
|
record.provider_task_id = provider_task_id
|
|
record.task_state = "RUNNING"
|
|
self._task_index[f"{record.provider}:{provider_task_id}"] = proxy_call_id
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] running: proxy_call_id=%s provider_task_id=%s",
|
|
proxy_call_id,
|
|
provider_task_id,
|
|
)
|
|
# logger.debug(
|
|
# "[ThirdPartyProxy][Ledger] running state: provider=%s call_id=%s billing_state=%s",
|
|
# record.provider,
|
|
# record.call_id,
|
|
# record.billing_state,
|
|
# )
|
|
else:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] set_running ignored for missing record: proxy_call_id=%s provider_task_id=%s",
|
|
proxy_call_id,
|
|
provider_task_id,
|
|
)
|
|
|
|
def try_claim_finalize(self, proxy_call_id: str) -> bool:
|
|
"""Atomically claim finalization rights. Returns True only once per record."""
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record is None:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] finalize claim denied: missing record proxy_call_id=%s",
|
|
proxy_call_id,
|
|
)
|
|
return False
|
|
if record.billing_state in ("FINALIZED", "FINALIZE_FAILED"):
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] finalize claim denied: proxy_call_id=%s billing_state=%s",
|
|
proxy_call_id,
|
|
record.billing_state,
|
|
)
|
|
return False
|
|
# Mark as finalized immediately to prevent concurrent finalize
|
|
record.billing_state = "FINALIZED"
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] finalize claimed: proxy_call_id=%s",
|
|
proxy_call_id,
|
|
)
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] finalize claim state: call_id=%s provider=%s task_state=%s frozen_id=%s",
|
|
record.call_id,
|
|
record.provider,
|
|
record.task_state,
|
|
record.frozen_id,
|
|
)
|
|
return True
|
|
|
|
def set_finalized(self, proxy_call_id: str, task_state: TaskState) -> None:
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record:
|
|
record.task_state = task_state
|
|
record.billing_state = "FINALIZED"
|
|
record.finalized_at = time.time()
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] finalized: proxy_call_id=%s task_state=%s",
|
|
proxy_call_id,
|
|
task_state,
|
|
)
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] finalized state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
|
|
record.provider,
|
|
record.call_id,
|
|
record.frozen_id,
|
|
record.finalized_at,
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] set_finalized ignored for missing record: proxy_call_id=%s task_state=%s",
|
|
proxy_call_id,
|
|
task_state,
|
|
)
|
|
|
|
def set_finalize_failed(self, proxy_call_id: str, task_state: TaskState) -> None:
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record:
|
|
record.task_state = task_state
|
|
record.billing_state = "FINALIZE_FAILED"
|
|
record.finalized_at = time.time()
|
|
logger.info(
|
|
"[ThirdPartyProxy][Ledger] finalize failed: proxy_call_id=%s task_state=%s",
|
|
proxy_call_id,
|
|
task_state,
|
|
)
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] finalize failure state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
|
|
record.provider,
|
|
record.call_id,
|
|
record.frozen_id,
|
|
record.finalized_at,
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] set_finalize_failed ignored for missing record: proxy_call_id=%s task_state=%s",
|
|
proxy_call_id,
|
|
task_state,
|
|
)
|
|
|
|
def update_response(self, proxy_call_id: str, response: dict[str, Any]) -> None:
|
|
with self._lock:
|
|
record = self._records.get(proxy_call_id)
|
|
if record:
|
|
record.last_response = response
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] cached response: proxy_call_id=%s keys=%s",
|
|
proxy_call_id,
|
|
sorted(response.keys()),
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"[ThirdPartyProxy][Ledger] update_response ignored for missing record: proxy_call_id=%s",
|
|
proxy_call_id,
|
|
)
|
|
|
|
def is_finalized(self, proxy_call_id: str) -> bool:
|
|
record = self._records.get(proxy_call_id)
|
|
return record is not None and record.billing_state in ("FINALIZED", "FINALIZE_FAILED")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Private helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _get_by_idem_key_locked(self, provider: str, idempotency_key: str) -> CallRecord | None:
|
|
key = f"{provider}:{idempotency_key}"
|
|
proxy_call_id = self._idem_index.get(key)
|
|
return self._records.get(proxy_call_id) if proxy_call_id else None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ledger: CallLedger | None = None
|
|
_ledger_lock = threading.Lock()
|
|
|
|
|
|
def get_ledger() -> CallLedger:
|
|
global _ledger
|
|
if _ledger is None:
|
|
with _ledger_lock:
|
|
if _ledger is None:
|
|
_ledger = CallLedger()
|
|
logger.info("[ThirdPartyProxy][Ledger] singleton initialized")
|
|
return _ledger
|