fix(middleware): fix present_files thread id fallback (#2181)
* fix present files thread id fallback * fix: resolve present_files thread id from runtime config
This commit is contained in:
parent
1df389b9d0
commit
f4c17c66ce
|
|
@ -3,6 +3,7 @@ from typing import Annotated
|
||||||
|
|
||||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.config import get_config
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
|
|
@ -12,6 +13,23 @@ from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||||
|
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
||||||
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||||
|
if thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
runtime_config = getattr(runtime, "config", None) or {}
|
||||||
|
thread_id = runtime_config.get("configurable", {}).get("thread_id")
|
||||||
|
if thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return get_config().get("configurable", {}).get("thread_id")
|
||||||
|
except RuntimeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _normalize_presented_filepath(
|
def _normalize_presented_filepath(
|
||||||
runtime: ToolRuntime[ContextT, ThreadState],
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
filepath: str,
|
filepath: str,
|
||||||
|
|
@ -33,9 +51,9 @@ def _normalize_presented_filepath(
|
||||||
if runtime.state is None:
|
if runtime.state is None:
|
||||||
raise ValueError("Thread runtime state is not available")
|
raise ValueError("Thread runtime state is not available")
|
||||||
|
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
thread_id = _get_thread_id(runtime)
|
||||||
if not thread_id:
|
if not thread_id:
|
||||||
raise ValueError("Thread ID is not available in runtime context")
|
raise ValueError("Thread ID is not available in runtime context or runtime config")
|
||||||
|
|
||||||
thread_data = runtime.state.get("thread_data") or {}
|
thread_data = runtime.state.get("thread_data") or {}
|
||||||
outputs_path = thread_data.get("outputs_path")
|
outputs_path = thread_data.get("outputs_path")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ def _make_runtime(outputs_path: str) -> SimpleNamespace:
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
state={"thread_data": {"outputs_path": outputs_path}},
|
state={"thread_data": {"outputs_path": outputs_path}},
|
||||||
context={"thread_id": "thread-1"},
|
context={"thread_id": "thread-1"},
|
||||||
|
config={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -50,6 +51,34 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||||
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_present_files_uses_config_thread_id_when_context_missing(tmp_path, monkeypatch):
|
||||||
|
outputs_dir = tmp_path / "threads" / "thread-from-config" / "user-data" / "outputs"
|
||||||
|
outputs_dir.mkdir(parents=True)
|
||||||
|
artifact_path = outputs_dir / "summary.json"
|
||||||
|
artifact_path.write_text("{}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
present_file_tool_module,
|
||||||
|
"get_paths",
|
||||||
|
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
state={"thread_data": {"outputs_path": str(outputs_dir)}},
|
||||||
|
context={},
|
||||||
|
config={"configurable": {"thread_id": "thread-from-config"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = present_file_tool_module.present_file_tool.func(
|
||||||
|
runtime=runtime,
|
||||||
|
filepaths=["/mnt/user-data/outputs/summary.json"],
|
||||||
|
tool_call_id="tc-config",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.update["artifacts"] == ["/mnt/user-data/outputs/summary.json"]
|
||||||
|
assert result.update["messages"][0].content == "Successfully presented files"
|
||||||
|
|
||||||
|
|
||||||
def test_present_files_rejects_paths_outside_outputs(tmp_path):
|
def test_present_files_rejects_paths_outside_outputs(tmp_path):
|
||||||
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
|
outputs_dir = tmp_path / "threads" / "thread-1" / "user-data" / "outputs"
|
||||||
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
|
workspace_dir = tmp_path / "threads" / "thread-1" / "user-data" / "workspace"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue