153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
"""Tests for the in-memory StreamBridge implementation."""
|
|
|
|
import asyncio
|
|
import re
|
|
|
|
import pytest
|
|
|
|
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests for MemoryStreamBridge
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def bridge() -> MemoryStreamBridge:
|
|
return MemoryStreamBridge(queue_maxsize=256)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_publish_subscribe(bridge: MemoryStreamBridge):
|
|
"""Three events followed by end should be received in order."""
|
|
run_id = "run-1"
|
|
|
|
await bridge.publish(run_id, "metadata", {"run_id": run_id})
|
|
await bridge.publish(run_id, "values", {"messages": []})
|
|
await bridge.publish(run_id, "updates", {"step": 1})
|
|
await bridge.publish_end(run_id)
|
|
|
|
received = []
|
|
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
|
received.append(entry)
|
|
if entry is END_SENTINEL:
|
|
break
|
|
|
|
assert len(received) == 4
|
|
assert received[0].event == "metadata"
|
|
assert received[1].event == "values"
|
|
assert received[2].event == "updates"
|
|
assert received[3] is END_SENTINEL
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_heartbeat(bridge: MemoryStreamBridge):
|
|
"""When no events arrive within the heartbeat interval, yield a heartbeat."""
|
|
run_id = "run-heartbeat"
|
|
bridge._get_or_create_queue(run_id) # ensure queue exists
|
|
|
|
received = []
|
|
|
|
async def consumer():
|
|
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
|
received.append(entry)
|
|
if entry is HEARTBEAT_SENTINEL:
|
|
break
|
|
|
|
await asyncio.wait_for(consumer(), timeout=2.0)
|
|
assert len(received) == 1
|
|
assert received[0] is HEARTBEAT_SENTINEL
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_cleanup(bridge: MemoryStreamBridge):
|
|
"""After cleanup, the run's queue is removed."""
|
|
run_id = "run-cleanup"
|
|
await bridge.publish(run_id, "test", {})
|
|
assert run_id in bridge._queues
|
|
|
|
await bridge.cleanup(run_id)
|
|
assert run_id not in bridge._queues
|
|
assert run_id not in bridge._counters
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_backpressure():
|
|
"""With maxsize=1, publish should not block forever."""
|
|
bridge = MemoryStreamBridge(queue_maxsize=1)
|
|
run_id = "run-bp"
|
|
|
|
await bridge.publish(run_id, "first", {})
|
|
|
|
# Second publish should either succeed after queue drains or warn+drop
|
|
# It should not hang indefinitely
|
|
async def publish_second():
|
|
await bridge.publish(run_id, "second", {})
|
|
|
|
# Give it a generous timeout — the publish timeout is 30s but we don't
|
|
# want to wait that long in tests. Instead, drain the queue first.
|
|
async def drain():
|
|
await asyncio.sleep(0.05)
|
|
bridge._queues[run_id].get_nowait()
|
|
|
|
await asyncio.gather(publish_second(), drain())
|
|
assert bridge._queues[run_id].qsize() == 1
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_multiple_runs(bridge: MemoryStreamBridge):
|
|
"""Two different run_ids should not interfere with each other."""
|
|
await bridge.publish("run-a", "event-a", {"a": 1})
|
|
await bridge.publish("run-b", "event-b", {"b": 2})
|
|
await bridge.publish_end("run-a")
|
|
await bridge.publish_end("run-b")
|
|
|
|
events_a = []
|
|
async for entry in bridge.subscribe("run-a", heartbeat_interval=1.0):
|
|
events_a.append(entry)
|
|
if entry is END_SENTINEL:
|
|
break
|
|
|
|
events_b = []
|
|
async for entry in bridge.subscribe("run-b", heartbeat_interval=1.0):
|
|
events_b.append(entry)
|
|
if entry is END_SENTINEL:
|
|
break
|
|
|
|
assert len(events_a) == 2
|
|
assert events_a[0].event == "event-a"
|
|
assert events_a[0].data == {"a": 1}
|
|
|
|
assert len(events_b) == 2
|
|
assert events_b[0].event == "event-b"
|
|
assert events_b[0].data == {"b": 2}
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_event_id_format(bridge: MemoryStreamBridge):
|
|
"""Event IDs should use timestamp-sequence format."""
|
|
run_id = "run-id-format"
|
|
await bridge.publish(run_id, "test", {"key": "value"})
|
|
await bridge.publish_end(run_id)
|
|
|
|
received = []
|
|
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
|
received.append(entry)
|
|
if entry is END_SENTINEL:
|
|
break
|
|
|
|
event = received[0]
|
|
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_make_stream_bridge_defaults():
|
|
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
|
|
async with make_stream_bridge() as bridge:
|
|
assert isinstance(bridge, MemoryStreamBridge)
|