116 lines
4.1 KiB
Python
116 lines
4.1 KiB
Python
"""Tool error handling middleware and shared runtime middleware builders."""
|
|
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.errors import GraphBubbleUp
|
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
|
from langgraph.types import Command
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
|
|
|
|
|
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|
"""Convert tool exceptions into error ToolMessages so the run can continue."""
|
|
|
|
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
|
|
tool_name = str(request.tool_call.get("name") or "unknown_tool")
|
|
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
|
|
detail = str(exc).strip() or exc.__class__.__name__
|
|
if len(detail) > 500:
|
|
detail = detail[:497] + "..."
|
|
|
|
content = (
|
|
f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. "
|
|
"Continue with available context, or choose an alternative tool."
|
|
)
|
|
return ToolMessage(
|
|
content=content,
|
|
tool_call_id=tool_call_id,
|
|
name=tool_name,
|
|
status="error",
|
|
)
|
|
|
|
@override
|
|
def wrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
) -> ToolMessage | Command:
|
|
try:
|
|
return handler(request)
|
|
except GraphBubbleUp:
|
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
|
raise
|
|
except Exception as exc:
|
|
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
|
return self._build_error_message(request, exc)
|
|
|
|
@override
|
|
async def awrap_tool_call(
|
|
self,
|
|
request: ToolCallRequest,
|
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
) -> ToolMessage | Command:
|
|
try:
|
|
return await handler(request)
|
|
except GraphBubbleUp:
|
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
|
raise
|
|
except Exception as exc:
|
|
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
|
return self._build_error_message(request, exc)
|
|
|
|
|
|
def _build_runtime_middlewares(
|
|
*,
|
|
include_uploads: bool,
|
|
include_dangling_tool_call_patch: bool,
|
|
lazy_init: bool = True,
|
|
) -> list[AgentMiddleware]:
|
|
"""Build shared base middlewares for agent execution."""
|
|
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
|
from src.sandbox.middleware import SandboxMiddleware
|
|
|
|
middlewares: list[AgentMiddleware] = [
|
|
ThreadDataMiddleware(lazy_init=lazy_init),
|
|
SandboxMiddleware(lazy_init=lazy_init),
|
|
]
|
|
|
|
if include_uploads:
|
|
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
|
|
|
middlewares.insert(1, UploadsMiddleware())
|
|
|
|
if include_dangling_tool_call_patch:
|
|
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
|
|
|
middlewares.append(DanglingToolCallMiddleware())
|
|
|
|
middlewares.append(ToolErrorHandlingMiddleware())
|
|
return middlewares
|
|
|
|
|
|
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
|
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
|
return _build_runtime_middlewares(
|
|
include_uploads=True,
|
|
include_dangling_tool_call_patch=True,
|
|
lazy_init=lazy_init,
|
|
)
|
|
|
|
|
|
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
|
return _build_runtime_middlewares(
|
|
include_uploads=False,
|
|
include_dangling_tool_call_patch=False,
|
|
lazy_init=lazy_init,
|
|
)
|