diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py index 46628f3a..7d17488f 100644 --- a/backend/app/gateway/routers/runs.py +++ b/backend/app/gateway/routers/runs.py @@ -51,6 +51,7 @@ async def stateless_stream(body: RunCreateRequest, request: Request) -> Streamin "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", + "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", }, ) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index d29786ed..105fc9ca 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -118,8 +118,9 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) - "Connection": "keep-alive", "X-Accel-Buffering": "no", # LangGraph Platform includes run metadata in this header. - # The SDK's _get_run_metadata_from_response() parses it. - "Content-Location": (f"/api/threads/{thread_id}/runs/{record.run_id}/stream?thread_id={thread_id}&run_id={record.run_id}"), + # The SDK uses a greedy regex to extract the run id from this path, + # so it must point at the canonical run resource without extra suffixes. + "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", }, ) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 272801b6..172e2781 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -345,8 +345,9 @@ async def sse_consumer( - ``cancel``: abort the background task on client disconnect. - ``continue``: let the task run; events are discarded. """ + last_event_id = request.headers.get("Last-Event-ID") try: - async for entry in bridge.subscribe(record.run_id): + async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id): if await request.is_disconnected(): break diff --git a/backend/packages/harness/deerflow/runtime/stream_bridge/memory.py b/backend/packages/harness/deerflow/runtime/stream_bridge/memory.py index 45aff134..cb5b8d1f 100644 --- a/backend/packages/harness/deerflow/runtime/stream_bridge/memory.py +++ b/backend/packages/harness/deerflow/runtime/stream_bridge/memory.py @@ -1,4 +1,4 @@ -"""In-memory stream bridge backed by :class:`asyncio.Queue`.""" +"""In-memory stream bridge backed by an in-process event log.""" from __future__ import annotations @@ -6,35 +6,41 @@ import asyncio import logging import time from collections.abc import AsyncIterator +from dataclasses import dataclass, field from typing import Any from .base import END_SENTINEL, HEARTBEAT_SENTINEL, StreamBridge, StreamEvent logger = logging.getLogger(__name__) -_PUBLISH_TIMEOUT = 30.0 # seconds to wait when queue is full + +@dataclass +class _RunStream: + events: list[StreamEvent] = field(default_factory=list) + condition: asyncio.Condition = field(default_factory=asyncio.Condition) + ended: bool = False + start_offset: int = 0 class MemoryStreamBridge(StreamBridge): - """Per-run ``asyncio.Queue`` implementation. + """Per-run in-memory event log implementation. - Each *run_id* gets its own queue on first :meth:`publish` call. + Events are retained for a bounded time window per run so late subscribers + and reconnecting clients can replay buffered events from ``Last-Event-ID``. """ def __init__(self, *, queue_maxsize: int = 256) -> None: self._maxsize = queue_maxsize - self._queues: dict[str, asyncio.Queue[StreamEvent]] = {} + self._streams: dict[str, _RunStream] = {} self._counters: dict[str, int] = {} - self._dropped_counts: dict[str, int] = {} # -- helpers --------------------------------------------------------------- - def _get_or_create_queue(self, run_id: str) -> asyncio.Queue[StreamEvent]: - if run_id not in self._queues: - self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize) + def _get_or_create_stream(self, run_id: str) -> _RunStream: + if run_id not in self._streams: + self._streams[run_id] = _RunStream() self._counters[run_id] = 0 - self._dropped_counts[run_id] = 0 - return self._queues[run_id] + return self._streams[run_id] def _next_id(self, run_id: str) -> str: self._counters[run_id] = self._counters.get(run_id, 0) + 1 @@ -42,49 +48,39 @@ class MemoryStreamBridge(StreamBridge): seq = self._counters[run_id] - 1 return f"{ts}-{seq}" + def _resolve_start_offset(self, stream: _RunStream, last_event_id: str | None) -> int: + if last_event_id is None: + return stream.start_offset + + for index, entry in enumerate(stream.events): + if entry.id == last_event_id: + return stream.start_offset + index + 1 + + if stream.events: + logger.warning( + "last_event_id=%s not found in retained buffer; replaying from earliest retained event", + last_event_id, + ) + return stream.start_offset + # -- StreamBridge API ------------------------------------------------------ async def publish(self, run_id: str, event: str, data: Any) -> None: - queue = self._get_or_create_queue(run_id) + stream = self._get_or_create_stream(run_id) entry = StreamEvent(id=self._next_id(run_id), event=event, data=data) - try: - await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT) - except TimeoutError: - self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1 - logger.warning( - "Stream bridge queue full for run %s — dropping event %s (total dropped: %d)", - run_id, - event, - self._dropped_counts[run_id], - ) + async with stream.condition: + stream.events.append(entry) + if len(stream.events) > self._maxsize: + overflow = len(stream.events) - self._maxsize + del stream.events[:overflow] + stream.start_offset += overflow + stream.condition.notify_all() async def publish_end(self, run_id: str) -> None: - queue = self._get_or_create_queue(run_id) - - # END sentinel is critical — it is the only signal that allows - # subscribers to terminate. If the queue is full we evict the - # oldest *regular* events to make room rather than dropping END, - # which would cause the SSE connection to hang forever and leak - # the queue/counter resources for this run_id. - if queue.full(): - evicted = 0 - while queue.full(): - try: - queue.get_nowait() - evicted += 1 - except asyncio.QueueEmpty: - break # pragma: no cover – defensive - if evicted: - logger.warning( - "Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery", - run_id, - evicted, - ) - - # After eviction the queue is guaranteed to have space, so a - # simple non-blocking put is safe. We still use put() (which - # blocks until space is available) as a defensive measure. - await queue.put(END_SENTINEL) + stream = self._get_or_create_stream(run_id) + async with stream.condition: + stream.ended = True + stream.condition.notify_all() async def subscribe( self, @@ -93,16 +89,34 @@ class MemoryStreamBridge(StreamBridge): last_event_id: str | None = None, heartbeat_interval: float = 15.0, ) -> AsyncIterator[StreamEvent]: - if last_event_id is not None: - logger.debug("last_event_id=%s accepted but ignored (memory bridge has no replay)", last_event_id) + stream = self._get_or_create_stream(run_id) + async with stream.condition: + next_offset = self._resolve_start_offset(stream, last_event_id) - queue = self._get_or_create_queue(run_id) while True: - try: - entry = await asyncio.wait_for(queue.get(), timeout=heartbeat_interval) - except TimeoutError: - yield HEARTBEAT_SENTINEL - continue + async with stream.condition: + if next_offset < stream.start_offset: + logger.warning( + "subscriber for run %s fell behind retained buffer; resuming from offset %s", + run_id, + stream.start_offset, + ) + next_offset = stream.start_offset + + local_index = next_offset - stream.start_offset + if 0 <= local_index < len(stream.events): + entry = stream.events[local_index] + next_offset += 1 + elif stream.ended: + entry = END_SENTINEL + else: + try: + await asyncio.wait_for(stream.condition.wait(), timeout=heartbeat_interval) + except TimeoutError: + entry = HEARTBEAT_SENTINEL + else: + continue + if entry is END_SENTINEL: yield END_SENTINEL return @@ -111,20 +125,9 @@ class MemoryStreamBridge(StreamBridge): async def cleanup(self, run_id: str, *, delay: float = 0) -> None: if delay > 0: await asyncio.sleep(delay) - self._queues.pop(run_id, None) + self._streams.pop(run_id, None) self._counters.pop(run_id, None) - self._dropped_counts.pop(run_id, None) async def close(self) -> None: - self._queues.clear() + self._streams.clear() self._counters.clear() - self._dropped_counts.clear() - - def dropped_count(self, run_id: str) -> int: - """Return the number of events dropped for *run_id*.""" - return self._dropped_counts.get(run_id, 0) - - @property - def dropped_total(self) -> int: - """Return the total number of events dropped across all runs.""" - return sum(self._dropped_counts.values()) diff --git a/backend/tests/test_stream_bridge.py b/backend/tests/test_stream_bridge.py index f9aee486..efd5e792 100644 --- a/backend/tests/test_stream_bridge.py +++ b/backend/tests/test_stream_bridge.py @@ -3,6 +3,7 @@ import asyncio import re +import anyio import pytest from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge @@ -44,7 +45,7 @@ async def test_publish_subscribe(bridge: MemoryStreamBridge): async def test_heartbeat(bridge: MemoryStreamBridge): """When no events arrive within the heartbeat interval, yield a heartbeat.""" run_id = "run-heartbeat" - bridge._get_or_create_queue(run_id) # ensure queue exists + bridge._get_or_create_stream(run_id) # ensure stream exists received = [] @@ -61,37 +62,35 @@ async def test_heartbeat(bridge: MemoryStreamBridge): @pytest.mark.anyio async def test_cleanup(bridge: MemoryStreamBridge): - """After cleanup, the run's queue is removed.""" + """After cleanup, the run's stream/event log is removed.""" run_id = "run-cleanup" await bridge.publish(run_id, "test", {}) - assert run_id in bridge._queues + assert run_id in bridge._streams await bridge.cleanup(run_id) - assert run_id not in bridge._queues + assert run_id not in bridge._streams assert run_id not in bridge._counters @pytest.mark.anyio -async def test_backpressure(): - """With maxsize=1, publish should not block forever.""" +async def test_history_is_bounded(): + """Retained history should be bounded by queue_maxsize.""" bridge = MemoryStreamBridge(queue_maxsize=1) run_id = "run-bp" await bridge.publish(run_id, "first", {}) + await bridge.publish(run_id, "second", {}) + await bridge.publish_end(run_id) - # Second publish should either succeed after queue drains or warn+drop - # It should not hang indefinitely - async def publish_second(): - await bridge.publish(run_id, "second", {}) + received = [] + async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): + received.append(entry) + if entry is END_SENTINEL: + break - # Give it a generous timeout — the publish timeout is 30s but we don't - # want to wait that long in tests. Instead, drain the queue first. - async def drain(): - await asyncio.sleep(0.05) - bridge._queues[run_id].get_nowait() - - await asyncio.gather(publish_second(), drain()) - assert bridge._queues[run_id].qsize() == 1 + assert len(received) == 2 + assert received[0].event == "second" + assert received[1] is END_SENTINEL @pytest.mark.anyio @@ -140,54 +139,116 @@ async def test_event_id_format(bridge: MemoryStreamBridge): assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}" +@pytest.mark.anyio +async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge): + """Reconnect should replay buffered events after the provided Last-Event-ID.""" + run_id = "run-replay" + await bridge.publish(run_id, "metadata", {"run_id": run_id}) + await bridge.publish(run_id, "values", {"step": 1}) + await bridge.publish(run_id, "updates", {"step": 2}) + await bridge.publish_end(run_id) + + first_pass = [] + async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): + first_pass.append(entry) + if entry is END_SENTINEL: + break + + received = [] + async for entry in bridge.subscribe( + run_id, + last_event_id=first_pass[0].id, + heartbeat_interval=1.0, + ): + received.append(entry) + if entry is END_SENTINEL: + break + + assert [entry.event for entry in received[:-1]] == ["values", "updates"] + assert received[-1] is END_SENTINEL + + +@pytest.mark.anyio +async def test_slow_subscriber_does_not_skip_after_buffer_trim(): + """A slow subscriber should continue from the correct absolute offset.""" + bridge = MemoryStreamBridge(queue_maxsize=2) + run_id = "run-slow-subscriber" + await bridge.publish(run_id, "e1", {"step": 1}) + await bridge.publish(run_id, "e2", {"step": 2}) + + stream = bridge._streams[run_id] + e1_id = stream.events[0].id + assert stream.start_offset == 0 + + await bridge.publish(run_id, "e3", {"step": 3}) # trims e1 + assert stream.start_offset == 1 + assert [entry.event for entry in stream.events] == ["e2", "e3"] + + resumed_after_e1 = [] + async for entry in bridge.subscribe( + run_id, + last_event_id=e1_id, + heartbeat_interval=1.0, + ): + resumed_after_e1.append(entry) + if len(resumed_after_e1) == 2: + break + + assert [entry.event for entry in resumed_after_e1] == ["e2", "e3"] + e2_id = resumed_after_e1[0].id + + await bridge.publish_end(run_id) + + received = [] + async for entry in bridge.subscribe( + run_id, + last_event_id=e2_id, + heartbeat_interval=1.0, + ): + received.append(entry) + if entry is END_SENTINEL: + break + + assert [entry.event for entry in received[:-1]] == ["e3"] + assert received[-1] is END_SENTINEL + + # --------------------------------------------------------------------------- -# END sentinel guarantee tests +# Stream termination tests # --------------------------------------------------------------------------- @pytest.mark.anyio -async def test_end_sentinel_delivered_when_queue_full(): - """END sentinel must always be delivered, even when the queue is completely full. - - This is the critical regression test for the bug where publish_end() - would silently drop the END sentinel when the queue was full, causing - subscribe() to hang forever and leaking resources. - """ +async def test_publish_end_terminates_even_when_history_is_full(): + """publish_end() should terminate subscribers without mutating retained history.""" bridge = MemoryStreamBridge(queue_maxsize=2) - run_id = "run-end-full" + run_id = "run-end-history-full" - # Fill the queue to capacity await bridge.publish(run_id, "event-1", {"n": 1}) await bridge.publish(run_id, "event-2", {"n": 2}) - assert bridge._queues[run_id].full() + stream = bridge._streams[run_id] + assert [entry.event for entry in stream.events] == ["event-1", "event-2"] - # publish_end should succeed by evicting old events await bridge.publish_end(run_id) + assert [entry.event for entry in stream.events] == ["event-1", "event-2"] - # Subscriber must receive END_SENTINEL events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) if entry is END_SENTINEL: break - assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered" + assert [entry.event for entry in events[:-1]] == ["event-1", "event-2"] + assert events[-1] is END_SENTINEL @pytest.mark.anyio -async def test_end_sentinel_evicts_oldest_events(): - """When queue is full, publish_end evicts the oldest events to make room.""" - bridge = MemoryStreamBridge(queue_maxsize=1) - run_id = "run-evict" - - # Fill queue with one event - await bridge.publish(run_id, "will-be-evicted", {}) - assert bridge._queues[run_id].full() - - # publish_end must succeed +async def test_publish_end_without_history_yields_end_immediately(): + """Subscribers should still receive END when a run completes without events.""" + bridge = MemoryStreamBridge(queue_maxsize=2) + run_id = "run-end-empty" await bridge.publish_end(run_id) - # The only event we should get is END_SENTINEL (the regular event was evicted) events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) @@ -199,8 +260,8 @@ async def test_end_sentinel_evicts_oldest_events(): @pytest.mark.anyio -async def test_end_sentinel_no_eviction_when_space_available(): - """When queue has space, publish_end should not evict anything.""" +async def test_publish_end_preserves_history_when_space_available(): + """When history has spare capacity, publish_end should preserve prior events.""" bridge = MemoryStreamBridge(queue_maxsize=10) run_id = "run-no-evict" @@ -244,87 +305,23 @@ async def test_concurrent_tasks_end_sentinel(): return events return events # pragma: no cover - # Run producers and consumers concurrently run_ids = [f"concurrent-{i}" for i in range(num_runs)] - producers = [producer(rid) for rid in run_ids] - consumers = [consumer(rid) for rid in run_ids] + results: dict[str, list] = {} - # Start consumers first, then producers - consumer_tasks = [asyncio.create_task(c) for c in consumers] - await asyncio.gather(*producers) + async def consume_into(run_id: str) -> None: + results[run_id] = await consumer(run_id) - results = await asyncio.wait_for( - asyncio.gather(*consumer_tasks), - timeout=10.0, - ) + with anyio.fail_after(10): + async with anyio.create_task_group() as task_group: + for run_id in run_ids: + task_group.start_soon(consume_into, run_id) + await anyio.sleep(0) + for run_id in run_ids: + task_group.start_soon(producer, run_id) - for i, events in enumerate(results): - assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel" - - -# --------------------------------------------------------------------------- -# Drop counter tests -# --------------------------------------------------------------------------- - - -@pytest.mark.anyio -async def test_dropped_count_tracking(): - """Dropped events should be tracked per run_id.""" - bridge = MemoryStreamBridge(queue_maxsize=1) - run_id = "run-drop-count" - - # Fill the queue - await bridge.publish(run_id, "first", {}) - - # This publish will time out and be dropped (we patch timeout to be instant) - # Instead, we verify the counter after publish_end eviction - await bridge.publish_end(run_id) - - # dropped_count tracks publish() drops, not publish_end evictions - assert bridge.dropped_count(run_id) == 0 - - # cleanup should also clear the counter - await bridge.cleanup(run_id) - assert bridge.dropped_count(run_id) == 0 - - -@pytest.mark.anyio -async def test_dropped_total(): - """dropped_total should sum across all runs.""" - bridge = MemoryStreamBridge(queue_maxsize=256) - - # No drops yet - assert bridge.dropped_total == 0 - - # Manually set some counts to verify the property - bridge._dropped_counts["run-a"] = 3 - bridge._dropped_counts["run-b"] = 7 - assert bridge.dropped_total == 10 - - -@pytest.mark.anyio -async def test_cleanup_clears_dropped_counts(): - """cleanup() should clear the dropped counter for the run.""" - bridge = MemoryStreamBridge(queue_maxsize=256) - run_id = "run-cleanup-drops" - - bridge._get_or_create_queue(run_id) - bridge._dropped_counts[run_id] = 5 - - await bridge.cleanup(run_id) - assert run_id not in bridge._dropped_counts - - -@pytest.mark.anyio -async def test_close_clears_dropped_counts(): - """close() should clear all dropped counters.""" - bridge = MemoryStreamBridge(queue_maxsize=256) - bridge._dropped_counts["run-x"] = 10 - bridge._dropped_counts["run-y"] = 20 - - await bridge.close() - assert bridge.dropped_total == 0 - assert len(bridge._dropped_counts) == 0 + for run_id in run_ids: + events = results[run_id] + assert events[-1] is END_SENTINEL, f"Run {run_id} did not receive END sentinel" # --------------------------------------------------------------------------- diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index fbcce030..395f1560 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -36,6 +36,81 @@ type SendMessageOptions = { additionalKwargs?: Record; }; +function normalizeStoredRunId(runId: string | null): string | null { + if (!runId) { + return null; + } + + const trimmed = runId.trim(); + if (!trimmed) { + return null; + } + + const queryIndex = trimmed.indexOf("?"); + if (queryIndex >= 0) { + const params = new URLSearchParams(trimmed.slice(queryIndex + 1)); + const queryRunId = params.get("run_id")?.trim(); + if (queryRunId) { + return queryRunId; + } + } + + const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? ""; + if (!pathWithoutQueryOrHash) { + return null; + } + + const runsMarker = "/runs/"; + const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker); + if (runsIndex >= 0) { + const runIdAfterMarker = pathWithoutQueryOrHash + .slice(runsIndex + runsMarker.length) + .split("/", 1)[0] + ?.trim(); + if (runIdAfterMarker) { + return runIdAfterMarker; + } + return null; + } + + const segments = pathWithoutQueryOrHash + .split("/") + .map((segment) => segment.trim()) + .filter(Boolean); + return segments.at(-1) ?? null; +} + +function getRunMetadataStorage(): { + getItem(key: `lg:stream:${string}`): string | null; + setItem(key: `lg:stream:${string}`, value: string): void; + removeItem(key: `lg:stream:${string}`): void; +} { + return { + getItem(key) { + const normalized = normalizeStoredRunId( + window.sessionStorage.getItem(key), + ); + if (normalized) { + window.sessionStorage.setItem(key, normalized); + return normalized; + } + window.sessionStorage.removeItem(key); + return null; + }, + setItem(key, value) { + const normalized = normalizeStoredRunId(value); + if (normalized) { + window.sessionStorage.setItem(key, normalized); + return; + } + window.sessionStorage.removeItem(key); + }, + removeItem(key) { + window.sessionStorage.removeItem(key); + }, + }; +} + function getStreamErrorMessage(error: unknown): string { if (typeof error === "string" && error.trim()) { return error; @@ -113,12 +188,24 @@ export function useThreadStream({ const queryClient = useQueryClient(); const updateSubtask = useUpdateSubtask(); + const runMetadataStorageRef = useRef< + ReturnType | undefined + >(undefined); + + if ( + typeof window !== "undefined" && + runMetadataStorageRef.current === undefined + ) { + runMetadataStorageRef.current = getRunMetadataStorage(); + } const thread = useStream({ client: getAPIClient(isMock), assistantId: "lead_agent", threadId: onStreamThreadId, - reconnectOnMount: true, + reconnectOnMount: runMetadataStorageRef.current + ? () => runMetadataStorageRef.current! + : false, fetchStateHistory: { limit: 1 }, onCreated(meta) { handleStreamStart(meta.thread_id);