deerflow2/backend/app/gateway/third_party_proxy/ledger.py

290 lines
12 KiB
Python

"""In-memory call state ledger for the third-party proxy.
Tracks each proxied call from reserve → submit → query → finalize,
enforcing idempotency and ensuring billing finalize runs exactly once.
"""
from __future__ import annotations
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Literal
from uuid import uuid4
logger = logging.getLogger(__name__)
BillingState = Literal["UNRESERVED", "RESERVED", "FINALIZED", "FINALIZE_FAILED"]
TaskState = Literal["PENDING", "RUNNING", "SUCCESS", "FAILED", "UNKNOWN"]
@dataclass
class CallRecord:
proxy_call_id: str
provider: str
thread_id: str | None
# call_id is sent to the billing platform (callId in reserve payload)
call_id: str
frozen_id: str | None = None
provider_task_id: str | None = None
billing_state: BillingState = "UNRESERVED"
task_state: TaskState = "PENDING"
created_at: float = field(default_factory=time.time)
finalized_at: float | None = None
error: str | None = None
idempotency_key: str | None = None
# Cached last provider response — returned for repeat queries after finalization
last_response: dict[str, Any] | None = None
class CallLedger:
"""Thread-safe in-memory ledger for third-party proxy call records."""
def __init__(self) -> None:
self._records: dict[str, CallRecord] = {} # proxy_call_id → record
self._task_index: dict[str, str] = {} # "{provider}:{provider_task_id}" → proxy_call_id
self._idem_index: dict[str, str] = {} # "{provider}:{idem_key}" → proxy_call_id
self._lock = threading.Lock()
def create(
self,
provider: str,
thread_id: str | None,
idempotency_key: str | None = None,
) -> CallRecord:
"""Create a new call record, or return the existing one if idempotency key matches."""
with self._lock:
if idempotency_key:
existing = self._get_by_idem_key_locked(provider, idempotency_key)
if existing is not None:
logger.info(
"[ThirdPartyProxy][Ledger] idempotent hit: provider=%s proxy_call_id=%s idem_key=%s",
provider,
existing.proxy_call_id,
idempotency_key,
)
# logger.debug(
# "[ThirdPartyProxy][Ledger] existing record reused: call_id=%s task_id=%s billing_state=%s task_state=%s",
# existing.call_id,
# existing.provider_task_id,
# existing.billing_state,
# existing.task_state,
# )
return existing
record = CallRecord(
proxy_call_id=str(uuid4()),
provider=provider,
thread_id=thread_id,
call_id=str(uuid4()),
idempotency_key=idempotency_key,
)
self._records[record.proxy_call_id] = record
if idempotency_key:
self._idem_index[f"{provider}:{idempotency_key}"] = record.proxy_call_id
logger.info(
"[ThirdPartyProxy][Ledger] created record: provider=%s proxy_call_id=%s call_id=%s thread_id=%s",
provider,
record.proxy_call_id,
record.call_id,
thread_id,
)
# logger.debug(
# "[ThirdPartyProxy][Ledger] create details: idem_key=%s billing_state=%s task_state=%s",
# idempotency_key,
# record.billing_state,
# record.task_state,
# )
return record
def get(self, proxy_call_id: str) -> CallRecord | None:
return self._records.get(proxy_call_id)
def get_by_task_id(self, provider: str, provider_task_id: str) -> CallRecord | None:
key = f"{provider}:{provider_task_id}"
proxy_call_id = self._task_index.get(key)
return self._records.get(proxy_call_id) if proxy_call_id else None
def get_by_idempotency_key(self, provider: str, idempotency_key: str) -> CallRecord | None:
return self._get_by_idem_key_locked(provider, idempotency_key)
def set_reserved(self, proxy_call_id: str, frozen_id: str) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.frozen_id = frozen_id
record.billing_state = "RESERVED"
logger.info(
"[ThirdPartyProxy][Ledger] reserved: proxy_call_id=%s frozen_id=%s",
proxy_call_id,
frozen_id,
)
# logger.debug(
# "[ThirdPartyProxy][Ledger] reserve state: call_id=%s provider=%s task_state=%s",
# record.call_id,
# record.provider,
# record.task_state,
# )
else:
logger.debug(
"[ThirdPartyProxy][Ledger] set_reserved ignored for missing record: proxy_call_id=%s",
proxy_call_id,
)
def set_running(self, proxy_call_id: str, provider_task_id: str) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.provider_task_id = provider_task_id
record.task_state = "RUNNING"
self._task_index[f"{record.provider}:{provider_task_id}"] = proxy_call_id
logger.info(
"[ThirdPartyProxy][Ledger] running: proxy_call_id=%s provider_task_id=%s",
proxy_call_id,
provider_task_id,
)
# logger.debug(
# "[ThirdPartyProxy][Ledger] running state: provider=%s call_id=%s billing_state=%s",
# record.provider,
# record.call_id,
# record.billing_state,
# )
else:
logger.debug(
"[ThirdPartyProxy][Ledger] set_running ignored for missing record: proxy_call_id=%s provider_task_id=%s",
proxy_call_id,
provider_task_id,
)
def try_claim_finalize(self, proxy_call_id: str) -> bool:
"""Atomically claim finalization rights. Returns True only once per record."""
with self._lock:
record = self._records.get(proxy_call_id)
if record is None:
logger.debug(
"[ThirdPartyProxy][Ledger] finalize claim denied: missing record proxy_call_id=%s",
proxy_call_id,
)
return False
if record.billing_state in ("FINALIZED", "FINALIZE_FAILED"):
logger.debug(
"[ThirdPartyProxy][Ledger] finalize claim denied: proxy_call_id=%s billing_state=%s",
proxy_call_id,
record.billing_state,
)
return False
# Mark as finalized immediately to prevent concurrent finalize
record.billing_state = "FINALIZED"
logger.info(
"[ThirdPartyProxy][Ledger] finalize claimed: proxy_call_id=%s",
proxy_call_id,
)
logger.debug(
"[ThirdPartyProxy][Ledger] finalize claim state: call_id=%s provider=%s task_state=%s frozen_id=%s",
record.call_id,
record.provider,
record.task_state,
record.frozen_id,
)
return True
def set_finalized(self, proxy_call_id: str, task_state: TaskState) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.task_state = task_state
record.billing_state = "FINALIZED"
record.finalized_at = time.time()
logger.info(
"[ThirdPartyProxy][Ledger] finalized: proxy_call_id=%s task_state=%s",
proxy_call_id,
task_state,
)
logger.debug(
"[ThirdPartyProxy][Ledger] finalized state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
record.provider,
record.call_id,
record.frozen_id,
record.finalized_at,
)
else:
logger.debug(
"[ThirdPartyProxy][Ledger] set_finalized ignored for missing record: proxy_call_id=%s task_state=%s",
proxy_call_id,
task_state,
)
def set_finalize_failed(self, proxy_call_id: str, task_state: TaskState) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.task_state = task_state
record.billing_state = "FINALIZE_FAILED"
record.finalized_at = time.time()
logger.info(
"[ThirdPartyProxy][Ledger] finalize failed: proxy_call_id=%s task_state=%s",
proxy_call_id,
task_state,
)
logger.debug(
"[ThirdPartyProxy][Ledger] finalize failure state: provider=%s call_id=%s frozen_id=%s finalized_at=%s",
record.provider,
record.call_id,
record.frozen_id,
record.finalized_at,
)
else:
logger.debug(
"[ThirdPartyProxy][Ledger] set_finalize_failed ignored for missing record: proxy_call_id=%s task_state=%s",
proxy_call_id,
task_state,
)
def update_response(self, proxy_call_id: str, response: dict[str, Any]) -> None:
with self._lock:
record = self._records.get(proxy_call_id)
if record:
record.last_response = response
logger.debug(
"[ThirdPartyProxy][Ledger] cached response: proxy_call_id=%s keys=%s",
proxy_call_id,
sorted(response.keys()),
)
else:
logger.debug(
"[ThirdPartyProxy][Ledger] update_response ignored for missing record: proxy_call_id=%s",
proxy_call_id,
)
def is_finalized(self, proxy_call_id: str) -> bool:
record = self._records.get(proxy_call_id)
return record is not None and record.billing_state in ("FINALIZED", "FINALIZE_FAILED")
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _get_by_idem_key_locked(self, provider: str, idempotency_key: str) -> CallRecord | None:
key = f"{provider}:{idempotency_key}"
proxy_call_id = self._idem_index.get(key)
return self._records.get(proxy_call_id) if proxy_call_id else None
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_ledger: CallLedger | None = None
_ledger_lock = threading.Lock()
def get_ledger() -> CallLedger:
global _ledger
if _ledger is None:
with _ledger_lock:
if _ledger is None:
_ledger = CallLedger()
logger.info("[ThirdPartyProxy][Ledger] singleton initialized")
return _ledger