fix: expose custom events from DeerFlowClient.stream() (#1827)
* fix: expose custom client stream events Signed-off-by: suyua9 <1521777066@qq.com> * fix(client): normalize streamed custom mode values * test(client): satisfy backend ruff import ordering --------- Signed-off-by: suyua9 <1521777066@qq.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
ed90a2ee9d
commit
29575c32f9
|
|
@ -345,6 +345,7 @@ class DeerFlowClient:
|
||||||
Yields:
|
Yields:
|
||||||
StreamEvent with one of:
|
StreamEvent with one of:
|
||||||
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
||||||
|
- type="custom" data={...}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}}
|
- type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||||
|
|
@ -365,7 +366,22 @@ class DeerFlowClient:
|
||||||
seen_ids: set[str] = set()
|
seen_ids: set[str] = set()
|
||||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"):
|
for item in self._agent.stream(
|
||||||
|
state,
|
||||||
|
config=config,
|
||||||
|
context=context,
|
||||||
|
stream_mode=["values", "custom"],
|
||||||
|
):
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
mode, chunk = item
|
||||||
|
mode = str(mode)
|
||||||
|
else:
|
||||||
|
mode, chunk = "values", item
|
||||||
|
|
||||||
|
if mode == "custom":
|
||||||
|
yield StreamEvent(type="custom", data=chunk)
|
||||||
|
continue
|
||||||
|
|
||||||
messages = chunk.get("messages", [])
|
messages = chunk.get("messages", [])
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
@ -205,6 +206,33 @@ class TestStream:
|
||||||
msg_events = _ai_events(events)
|
msg_events = _ai_events(events)
|
||||||
assert msg_events[0].data["content"] == "Hello!"
|
assert msg_events[0].data["content"] == "Hello!"
|
||||||
|
|
||||||
|
def test_custom_events_are_forwarded(self, client):
|
||||||
|
"""stream() forwards custom stream events alongside normal values output."""
|
||||||
|
ai = AIMessage(content="Hello!", id="ai-1")
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.stream.return_value = iter(
|
||||||
|
[
|
||||||
|
("custom", {"type": "task_started", "task_id": "task-1"}),
|
||||||
|
("values", {"messages": [HumanMessage(content="hi", id="h-1"), ai]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_ensure_agent"),
|
||||||
|
patch.object(client, "_agent", agent),
|
||||||
|
):
|
||||||
|
events = list(client.stream("hi", thread_id="t-custom"))
|
||||||
|
|
||||||
|
agent.stream.assert_called_once()
|
||||||
|
call_kwargs = agent.stream.call_args.kwargs
|
||||||
|
assert call_kwargs["stream_mode"] == ["values", "custom"]
|
||||||
|
|
||||||
|
assert events[0].type == "custom"
|
||||||
|
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
||||||
|
assert any(event.type == "messages-tuple" and event.data["content"] == "Hello!" for event in events)
|
||||||
|
assert any(event.type == "values" for event in events)
|
||||||
|
assert events[-1].type == "end"
|
||||||
|
|
||||||
def test_context_propagation(self, client):
|
def test_context_propagation(self, client):
|
||||||
"""stream() passes agent_name to the context."""
|
"""stream() passes agent_name to the context."""
|
||||||
agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}])
|
agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}])
|
||||||
|
|
@ -222,6 +250,33 @@ class TestStream:
|
||||||
assert call_kwargs["context"]["thread_id"] == "t1"
|
assert call_kwargs["context"]["thread_id"] == "t1"
|
||||||
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
||||||
|
|
||||||
|
def test_custom_mode_is_normalized_to_string(self, client):
|
||||||
|
"""stream() forwards custom events even when the mode is not a plain string."""
|
||||||
|
|
||||||
|
class StreamMode(Enum):
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
agent = _make_agent_mock(
|
||||||
|
[
|
||||||
|
(StreamMode.CUSTOM, {"type": "task_started", "task_id": "task-1"}),
|
||||||
|
{"messages": [AIMessage(content="Hello!", id="ai-1")]},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_ensure_agent"),
|
||||||
|
patch.object(client, "_agent", agent),
|
||||||
|
):
|
||||||
|
events = list(client.stream("hi", thread_id="t-custom-enum"))
|
||||||
|
|
||||||
|
assert events[0].type == "custom"
|
||||||
|
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
||||||
|
assert any(event.type == "messages-tuple" and event.data["content"] == "Hello!" for event in events)
|
||||||
|
assert events[-1].type == "end"
|
||||||
|
|
||||||
def test_tool_call_and_result(self, client):
|
def test_tool_call_and_result(self, client):
|
||||||
"""stream() emits messages-tuple events for tool calls and results."""
|
"""stream() emits messages-tuple events for tool calls and results."""
|
||||||
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
|
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue