Fix(#1702): stream resume run (#1858)

* fix: repair stream resume run metadata

# Conflicts:
#	backend/packages/harness/deerflow/runtime/stream_bridge/memory.py
#	frontend/src/core/threads/hooks.ts

* fix(stream): repair resumable replay validation

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
2026-04-06 14:51:10 +08:00 committed by GitHub
parent 29575c32f9
commit 7c68dd4ad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 285 additions and 195 deletions

View File

@ -51,6 +51,7 @@ async def stateless_stream(body: RunCreateRequest, request: Request) -> Streamin
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)

View File

@ -118,8 +118,9 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
# LangGraph Platform includes run metadata in this header.
# The SDK's _get_run_metadata_from_response() parses it.
"Content-Location": (f"/api/threads/{thread_id}/runs/{record.run_id}/stream?thread_id={thread_id}&run_id={record.run_id}"),
# The SDK uses a greedy regex to extract the run id from this path,
# so it must point at the canonical run resource without extra suffixes.
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)

View File

@ -345,8 +345,9 @@ async def sse_consumer(
- ``cancel``: abort the background task on client disconnect.
- ``continue``: let the task run; events are discarded.
"""
last_event_id = request.headers.get("Last-Event-ID")
try:
async for entry in bridge.subscribe(record.run_id):
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
if await request.is_disconnected():
break

View File

@ -1,4 +1,4 @@
"""In-memory stream bridge backed by :class:`asyncio.Queue`."""
"""In-memory stream bridge backed by an in-process event log."""
from __future__ import annotations
@ -6,35 +6,41 @@ import asyncio
import logging
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from .base import END_SENTINEL, HEARTBEAT_SENTINEL, StreamBridge, StreamEvent
logger = logging.getLogger(__name__)
_PUBLISH_TIMEOUT = 30.0 # seconds to wait when queue is full
@dataclass
class _RunStream:
events: list[StreamEvent] = field(default_factory=list)
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
ended: bool = False
start_offset: int = 0
class MemoryStreamBridge(StreamBridge):
"""Per-run ``asyncio.Queue`` implementation.
"""Per-run in-memory event log implementation.
Each *run_id* gets its own queue on first :meth:`publish` call.
Events are retained for a bounded time window per run so late subscribers
and reconnecting clients can replay buffered events from ``Last-Event-ID``.
"""
def __init__(self, *, queue_maxsize: int = 256) -> None:
self._maxsize = queue_maxsize
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
self._streams: dict[str, _RunStream] = {}
self._counters: dict[str, int] = {}
self._dropped_counts: dict[str, int] = {}
# -- helpers ---------------------------------------------------------------
def _get_or_create_queue(self, run_id: str) -> asyncio.Queue[StreamEvent]:
if run_id not in self._queues:
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
def _get_or_create_stream(self, run_id: str) -> _RunStream:
if run_id not in self._streams:
self._streams[run_id] = _RunStream()
self._counters[run_id] = 0
self._dropped_counts[run_id] = 0
return self._queues[run_id]
return self._streams[run_id]
def _next_id(self, run_id: str) -> str:
self._counters[run_id] = self._counters.get(run_id, 0) + 1
@ -42,49 +48,39 @@ class MemoryStreamBridge(StreamBridge):
seq = self._counters[run_id] - 1
return f"{ts}-{seq}"
def _resolve_start_offset(self, stream: _RunStream, last_event_id: str | None) -> int:
if last_event_id is None:
return stream.start_offset
for index, entry in enumerate(stream.events):
if entry.id == last_event_id:
return stream.start_offset + index + 1
if stream.events:
logger.warning(
"last_event_id=%s not found in retained buffer; replaying from earliest retained event",
last_event_id,
)
return stream.start_offset
# -- StreamBridge API ------------------------------------------------------
async def publish(self, run_id: str, event: str, data: Any) -> None:
queue = self._get_or_create_queue(run_id)
stream = self._get_or_create_stream(run_id)
entry = StreamEvent(id=self._next_id(run_id), event=event, data=data)
try:
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
logger.warning(
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
run_id,
event,
self._dropped_counts[run_id],
)
async with stream.condition:
stream.events.append(entry)
if len(stream.events) > self._maxsize:
overflow = len(stream.events) - self._maxsize
del stream.events[:overflow]
stream.start_offset += overflow
stream.condition.notify_all()
async def publish_end(self, run_id: str) -> None:
queue = self._get_or_create_queue(run_id)
# END sentinel is critical — it is the only signal that allows
# subscribers to terminate. If the queue is full we evict the
# oldest *regular* events to make room rather than dropping END,
# which would cause the SSE connection to hang forever and leak
# the queue/counter resources for this run_id.
if queue.full():
evicted = 0
while queue.full():
try:
queue.get_nowait()
evicted += 1
except asyncio.QueueEmpty:
break # pragma: no cover defensive
if evicted:
logger.warning(
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
run_id,
evicted,
)
# After eviction the queue is guaranteed to have space, so a
# simple non-blocking put is safe. We still use put() (which
# blocks until space is available) as a defensive measure.
await queue.put(END_SENTINEL)
stream = self._get_or_create_stream(run_id)
async with stream.condition:
stream.ended = True
stream.condition.notify_all()
async def subscribe(
self,
@ -93,16 +89,34 @@ class MemoryStreamBridge(StreamBridge):
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
if last_event_id is not None:
logger.debug("last_event_id=%s accepted but ignored (memory bridge has no replay)", last_event_id)
stream = self._get_or_create_stream(run_id)
async with stream.condition:
next_offset = self._resolve_start_offset(stream, last_event_id)
queue = self._get_or_create_queue(run_id)
while True:
try:
entry = await asyncio.wait_for(queue.get(), timeout=heartbeat_interval)
except TimeoutError:
yield HEARTBEAT_SENTINEL
continue
async with stream.condition:
if next_offset < stream.start_offset:
logger.warning(
"subscriber for run %s fell behind retained buffer; resuming from offset %s",
run_id,
stream.start_offset,
)
next_offset = stream.start_offset
local_index = next_offset - stream.start_offset
if 0 <= local_index < len(stream.events):
entry = stream.events[local_index]
next_offset += 1
elif stream.ended:
entry = END_SENTINEL
else:
try:
await asyncio.wait_for(stream.condition.wait(), timeout=heartbeat_interval)
except TimeoutError:
entry = HEARTBEAT_SENTINEL
else:
continue
if entry is END_SENTINEL:
yield END_SENTINEL
return
@ -111,20 +125,9 @@ class MemoryStreamBridge(StreamBridge):
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
if delay > 0:
await asyncio.sleep(delay)
self._queues.pop(run_id, None)
self._streams.pop(run_id, None)
self._counters.pop(run_id, None)
self._dropped_counts.pop(run_id, None)
async def close(self) -> None:
self._queues.clear()
self._streams.clear()
self._counters.clear()
self._dropped_counts.clear()
def dropped_count(self, run_id: str) -> int:
"""Return the number of events dropped for *run_id*."""
return self._dropped_counts.get(run_id, 0)
@property
def dropped_total(self) -> int:
"""Return the total number of events dropped across all runs."""
return sum(self._dropped_counts.values())

View File

@ -3,6 +3,7 @@
import asyncio
import re
import anyio
import pytest
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
@ -44,7 +45,7 @@ async def test_publish_subscribe(bridge: MemoryStreamBridge):
async def test_heartbeat(bridge: MemoryStreamBridge):
"""When no events arrive within the heartbeat interval, yield a heartbeat."""
run_id = "run-heartbeat"
bridge._get_or_create_queue(run_id) # ensure queue exists
bridge._get_or_create_stream(run_id) # ensure stream exists
received = []
@ -61,37 +62,35 @@ async def test_heartbeat(bridge: MemoryStreamBridge):
@pytest.mark.anyio
async def test_cleanup(bridge: MemoryStreamBridge):
"""After cleanup, the run's queue is removed."""
"""After cleanup, the run's stream/event log is removed."""
run_id = "run-cleanup"
await bridge.publish(run_id, "test", {})
assert run_id in bridge._queues
assert run_id in bridge._streams
await bridge.cleanup(run_id)
assert run_id not in bridge._queues
assert run_id not in bridge._streams
assert run_id not in bridge._counters
@pytest.mark.anyio
async def test_backpressure():
"""With maxsize=1, publish should not block forever."""
async def test_history_is_bounded():
"""Retained history should be bounded by queue_maxsize."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-bp"
await bridge.publish(run_id, "first", {})
await bridge.publish(run_id, "second", {})
await bridge.publish_end(run_id)
# Second publish should either succeed after queue drains or warn+drop
# It should not hang indefinitely
async def publish_second():
await bridge.publish(run_id, "second", {})
received = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
received.append(entry)
if entry is END_SENTINEL:
break
# Give it a generous timeout — the publish timeout is 30s but we don't
# want to wait that long in tests. Instead, drain the queue first.
async def drain():
await asyncio.sleep(0.05)
bridge._queues[run_id].get_nowait()
await asyncio.gather(publish_second(), drain())
assert bridge._queues[run_id].qsize() == 1
assert len(received) == 2
assert received[0].event == "second"
assert received[1] is END_SENTINEL
@pytest.mark.anyio
@ -140,54 +139,116 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
@pytest.mark.anyio
async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge):
"""Reconnect should replay buffered events after the provided Last-Event-ID."""
run_id = "run-replay"
await bridge.publish(run_id, "metadata", {"run_id": run_id})
await bridge.publish(run_id, "values", {"step": 1})
await bridge.publish(run_id, "updates", {"step": 2})
await bridge.publish_end(run_id)
first_pass = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
first_pass.append(entry)
if entry is END_SENTINEL:
break
received = []
async for entry in bridge.subscribe(
run_id,
last_event_id=first_pass[0].id,
heartbeat_interval=1.0,
):
received.append(entry)
if entry is END_SENTINEL:
break
assert [entry.event for entry in received[:-1]] == ["values", "updates"]
assert received[-1] is END_SENTINEL
@pytest.mark.anyio
async def test_slow_subscriber_does_not_skip_after_buffer_trim():
"""A slow subscriber should continue from the correct absolute offset."""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-slow-subscriber"
await bridge.publish(run_id, "e1", {"step": 1})
await bridge.publish(run_id, "e2", {"step": 2})
stream = bridge._streams[run_id]
e1_id = stream.events[0].id
assert stream.start_offset == 0
await bridge.publish(run_id, "e3", {"step": 3}) # trims e1
assert stream.start_offset == 1
assert [entry.event for entry in stream.events] == ["e2", "e3"]
resumed_after_e1 = []
async for entry in bridge.subscribe(
run_id,
last_event_id=e1_id,
heartbeat_interval=1.0,
):
resumed_after_e1.append(entry)
if len(resumed_after_e1) == 2:
break
assert [entry.event for entry in resumed_after_e1] == ["e2", "e3"]
e2_id = resumed_after_e1[0].id
await bridge.publish_end(run_id)
received = []
async for entry in bridge.subscribe(
run_id,
last_event_id=e2_id,
heartbeat_interval=1.0,
):
received.append(entry)
if entry is END_SENTINEL:
break
assert [entry.event for entry in received[:-1]] == ["e3"]
assert received[-1] is END_SENTINEL
# ---------------------------------------------------------------------------
# END sentinel guarantee tests
# Stream termination tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_end_sentinel_delivered_when_queue_full():
"""END sentinel must always be delivered, even when the queue is completely full.
This is the critical regression test for the bug where publish_end()
would silently drop the END sentinel when the queue was full, causing
subscribe() to hang forever and leaking resources.
"""
async def test_publish_end_terminates_even_when_history_is_full():
"""publish_end() should terminate subscribers without mutating retained history."""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-end-full"
run_id = "run-end-history-full"
# Fill the queue to capacity
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
assert bridge._queues[run_id].full()
stream = bridge._streams[run_id]
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
# publish_end should succeed by evicting old events
await bridge.publish_end(run_id)
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
# Subscriber must receive END_SENTINEL
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
assert [entry.event for entry in events[:-1]] == ["event-1", "event-2"]
assert events[-1] is END_SENTINEL
@pytest.mark.anyio
async def test_end_sentinel_evicts_oldest_events():
"""When queue is full, publish_end evicts the oldest events to make room."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-evict"
# Fill queue with one event
await bridge.publish(run_id, "will-be-evicted", {})
assert bridge._queues[run_id].full()
# publish_end must succeed
async def test_publish_end_without_history_yields_end_immediately():
"""Subscribers should still receive END when a run completes without events."""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-end-empty"
await bridge.publish_end(run_id)
# The only event we should get is END_SENTINEL (the regular event was evicted)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
@ -199,8 +260,8 @@ async def test_end_sentinel_evicts_oldest_events():
@pytest.mark.anyio
async def test_end_sentinel_no_eviction_when_space_available():
"""When queue has space, publish_end should not evict anything."""
async def test_publish_end_preserves_history_when_space_available():
"""When history has spare capacity, publish_end should preserve prior events."""
bridge = MemoryStreamBridge(queue_maxsize=10)
run_id = "run-no-evict"
@ -244,87 +305,23 @@ async def test_concurrent_tasks_end_sentinel():
return events
return events # pragma: no cover
# Run producers and consumers concurrently
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
producers = [producer(rid) for rid in run_ids]
consumers = [consumer(rid) for rid in run_ids]
results: dict[str, list] = {}
# Start consumers first, then producers
consumer_tasks = [asyncio.create_task(c) for c in consumers]
await asyncio.gather(*producers)
async def consume_into(run_id: str) -> None:
results[run_id] = await consumer(run_id)
results = await asyncio.wait_for(
asyncio.gather(*consumer_tasks),
timeout=10.0,
)
with anyio.fail_after(10):
async with anyio.create_task_group() as task_group:
for run_id in run_ids:
task_group.start_soon(consume_into, run_id)
await anyio.sleep(0)
for run_id in run_ids:
task_group.start_soon(producer, run_id)
for i, events in enumerate(results):
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
# ---------------------------------------------------------------------------
# Drop counter tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_dropped_count_tracking():
"""Dropped events should be tracked per run_id."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-drop-count"
# Fill the queue
await bridge.publish(run_id, "first", {})
# This publish will time out and be dropped (we patch timeout to be instant)
# Instead, we verify the counter after publish_end eviction
await bridge.publish_end(run_id)
# dropped_count tracks publish() drops, not publish_end evictions
assert bridge.dropped_count(run_id) == 0
# cleanup should also clear the counter
await bridge.cleanup(run_id)
assert bridge.dropped_count(run_id) == 0
@pytest.mark.anyio
async def test_dropped_total():
"""dropped_total should sum across all runs."""
bridge = MemoryStreamBridge(queue_maxsize=256)
# No drops yet
assert bridge.dropped_total == 0
# Manually set some counts to verify the property
bridge._dropped_counts["run-a"] = 3
bridge._dropped_counts["run-b"] = 7
assert bridge.dropped_total == 10
@pytest.mark.anyio
async def test_cleanup_clears_dropped_counts():
"""cleanup() should clear the dropped counter for the run."""
bridge = MemoryStreamBridge(queue_maxsize=256)
run_id = "run-cleanup-drops"
bridge._get_or_create_queue(run_id)
bridge._dropped_counts[run_id] = 5
await bridge.cleanup(run_id)
assert run_id not in bridge._dropped_counts
@pytest.mark.anyio
async def test_close_clears_dropped_counts():
"""close() should clear all dropped counters."""
bridge = MemoryStreamBridge(queue_maxsize=256)
bridge._dropped_counts["run-x"] = 10
bridge._dropped_counts["run-y"] = 20
await bridge.close()
assert bridge.dropped_total == 0
assert len(bridge._dropped_counts) == 0
for run_id in run_ids:
events = results[run_id]
assert events[-1] is END_SENTINEL, f"Run {run_id} did not receive END sentinel"
# ---------------------------------------------------------------------------

View File

@ -36,6 +36,81 @@ type SendMessageOptions = {
additionalKwargs?: Record<string, unknown>;
};
function normalizeStoredRunId(runId: string | null): string | null {
if (!runId) {
return null;
}
const trimmed = runId.trim();
if (!trimmed) {
return null;
}
const queryIndex = trimmed.indexOf("?");
if (queryIndex >= 0) {
const params = new URLSearchParams(trimmed.slice(queryIndex + 1));
const queryRunId = params.get("run_id")?.trim();
if (queryRunId) {
return queryRunId;
}
}
const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? "";
if (!pathWithoutQueryOrHash) {
return null;
}
const runsMarker = "/runs/";
const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker);
if (runsIndex >= 0) {
const runIdAfterMarker = pathWithoutQueryOrHash
.slice(runsIndex + runsMarker.length)
.split("/", 1)[0]
?.trim();
if (runIdAfterMarker) {
return runIdAfterMarker;
}
return null;
}
const segments = pathWithoutQueryOrHash
.split("/")
.map((segment) => segment.trim())
.filter(Boolean);
return segments.at(-1) ?? null;
}
function getRunMetadataStorage(): {
getItem(key: `lg:stream:${string}`): string | null;
setItem(key: `lg:stream:${string}`, value: string): void;
removeItem(key: `lg:stream:${string}`): void;
} {
return {
getItem(key) {
const normalized = normalizeStoredRunId(
window.sessionStorage.getItem(key),
);
if (normalized) {
window.sessionStorage.setItem(key, normalized);
return normalized;
}
window.sessionStorage.removeItem(key);
return null;
},
setItem(key, value) {
const normalized = normalizeStoredRunId(value);
if (normalized) {
window.sessionStorage.setItem(key, normalized);
return;
}
window.sessionStorage.removeItem(key);
},
removeItem(key) {
window.sessionStorage.removeItem(key);
},
};
}
function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) {
return error;
@ -113,12 +188,24 @@ export function useThreadStream({
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const runMetadataStorageRef = useRef<
ReturnType<typeof getRunMetadataStorage> | undefined
>(undefined);
if (
typeof window !== "undefined" &&
runMetadataStorageRef.current === undefined
) {
runMetadataStorageRef.current = getRunMetadataStorage();
}
const thread = useStream<AgentThreadState>({
client: getAPIClient(isMock),
assistantId: "lead_agent",
threadId: onStreamThreadId,
reconnectOnMount: true,
reconnectOnMount: runMetadataStorageRef.current
? () => runMetadataStorageRef.current!
: false,
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id);