91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
"""In-memory stream bridge backed by :class:`asyncio.Queue`."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
from .base import END_SENTINEL, HEARTBEAT_SENTINEL, StreamBridge, StreamEvent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PUBLISH_TIMEOUT = 30.0 # seconds to wait when queue is full
|
|
|
|
|
|
class MemoryStreamBridge(StreamBridge):
|
|
"""Per-run ``asyncio.Queue`` implementation.
|
|
|
|
Each *run_id* gets its own queue on first :meth:`publish` call.
|
|
"""
|
|
|
|
def __init__(self, *, queue_maxsize: int = 256) -> None:
|
|
self._maxsize = queue_maxsize
|
|
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
|
|
self._counters: dict[str, int] = {}
|
|
|
|
# -- helpers ---------------------------------------------------------------
|
|
|
|
def _get_or_create_queue(self, run_id: str) -> asyncio.Queue[StreamEvent]:
|
|
if run_id not in self._queues:
|
|
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
|
|
self._counters[run_id] = 0
|
|
return self._queues[run_id]
|
|
|
|
def _next_id(self, run_id: str) -> str:
|
|
self._counters[run_id] = self._counters.get(run_id, 0) + 1
|
|
ts = int(time.time() * 1000)
|
|
seq = self._counters[run_id] - 1
|
|
return f"{ts}-{seq}"
|
|
|
|
# -- StreamBridge API ------------------------------------------------------
|
|
|
|
async def publish(self, run_id: str, event: str, data: Any) -> None:
|
|
queue = self._get_or_create_queue(run_id)
|
|
entry = StreamEvent(id=self._next_id(run_id), event=event, data=data)
|
|
try:
|
|
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
|
|
except TimeoutError:
|
|
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
|
|
|
|
async def publish_end(self, run_id: str) -> None:
|
|
queue = self._get_or_create_queue(run_id)
|
|
try:
|
|
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
|
|
except TimeoutError:
|
|
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
|
|
|
|
async def subscribe(
|
|
self,
|
|
run_id: str,
|
|
*,
|
|
last_event_id: str | None = None,
|
|
heartbeat_interval: float = 15.0,
|
|
) -> AsyncIterator[StreamEvent]:
|
|
if last_event_id is not None:
|
|
logger.debug("last_event_id=%s accepted but ignored (memory bridge has no replay)", last_event_id)
|
|
|
|
queue = self._get_or_create_queue(run_id)
|
|
while True:
|
|
try:
|
|
entry = await asyncio.wait_for(queue.get(), timeout=heartbeat_interval)
|
|
except TimeoutError:
|
|
yield HEARTBEAT_SENTINEL
|
|
continue
|
|
if entry is END_SENTINEL:
|
|
yield END_SENTINEL
|
|
return
|
|
yield entry
|
|
|
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
self._queues.pop(run_id, None)
|
|
self._counters.pop(run_id, None)
|
|
|
|
async def close(self) -> None:
|
|
self._queues.clear()
|
|
self._counters.clear()
|