deerflow2/src/agents/tool_interceptor.py

206 lines
7.0 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import logging
from typing import Any, Callable, List, Optional
from langchain_core.tools import BaseTool
from langgraph.types import interrupt
logger = logging.getLogger(__name__)
class ToolInterceptor:
"""Intercepts tool calls and triggers interrupts for specified tools."""
def __init__(self, interrupt_before_tools: Optional[List[str]] = None):
"""Initialize the interceptor with list of tools to interrupt before.
Args:
interrupt_before_tools: List of tool names to interrupt before execution.
If None or empty, no interrupts are triggered.
"""
self.interrupt_before_tools = interrupt_before_tools or []
logger.info(
f"ToolInterceptor initialized with interrupt_before_tools: {self.interrupt_before_tools}"
)
def should_interrupt(self, tool_name: str) -> bool:
"""Check if execution should be interrupted before this tool.
Args:
tool_name: Name of the tool being called
Returns:
bool: True if tool should trigger an interrupt, False otherwise
"""
should_interrupt = tool_name in self.interrupt_before_tools
if should_interrupt:
logger.info(f"Tool '{tool_name}' marked for interrupt")
return should_interrupt
@staticmethod
def _format_tool_input(tool_input: Any) -> str:
"""Format tool input for display in interrupt messages.
Attempts to format as JSON for better readability, with fallback to string representation.
Args:
tool_input: The tool input to format
Returns:
str: Formatted representation of the tool input
"""
if tool_input is None:
return "No input"
# Try to serialize as JSON first for better readability
try:
# Handle dictionaries and other JSON-serializable objects
if isinstance(tool_input, (dict, list, tuple)):
return json.dumps(tool_input, indent=2, default=str)
elif isinstance(tool_input, str):
return tool_input
else:
# For other types, try to convert to dict if it has __dict__
# Otherwise fall back to string representation
return str(tool_input)
except (TypeError, ValueError):
# JSON serialization failed, use string representation
return str(tool_input)
@staticmethod
def wrap_tool(
tool: BaseTool, interceptor: "ToolInterceptor"
) -> BaseTool:
"""Wrap a tool to add interrupt logic by creating a wrapper.
Args:
tool: The tool to wrap
interceptor: The ToolInterceptor instance
Returns:
BaseTool: The wrapped tool with interrupt capability
"""
original_func = tool.func
def intercepted_func(*args: Any, **kwargs: Any) -> Any:
"""Execute the tool with interrupt check."""
tool_name = tool.name
# Format tool input for display
tool_input = args[0] if args else kwargs
tool_input_repr = ToolInterceptor._format_tool_input(tool_input)
if interceptor.should_interrupt(tool_name):
logger.info(
f"Interrupting before tool '{tool_name}' with input: {tool_input_repr}"
)
# Trigger interrupt and wait for user feedback
feedback = interrupt(
f"About to execute tool: '{tool_name}'\n\nInput:\n{tool_input_repr}\n\nApprove execution?"
)
logger.info(f"Interrupt feedback for '{tool_name}': {feedback}")
# Check if user approved
if not ToolInterceptor._parse_approval(feedback):
logger.warning(f"User rejected execution of tool '{tool_name}'")
return {
"error": f"Tool execution rejected by user",
"tool": tool_name,
"status": "rejected",
}
logger.info(f"User approved execution of tool '{tool_name}'")
# Execute the original tool
try:
result = original_func(*args, **kwargs)
logger.debug(f"Tool '{tool_name}' execution completed")
return result
except Exception as e:
logger.error(f"Error executing tool '{tool_name}': {str(e)}")
raise
# Replace the function and update the tool
# Use object.__setattr__ to bypass Pydantic validation
object.__setattr__(tool, "func", intercepted_func)
return tool
@staticmethod
def _parse_approval(feedback: str) -> bool:
"""Parse user feedback to determine if tool execution was approved.
Args:
feedback: The feedback string from the user
Returns:
bool: True if feedback indicates approval, False otherwise
"""
if not feedback:
logger.warning("Empty feedback received, treating as rejection")
return False
feedback_lower = feedback.lower().strip()
# Check for approval keywords
approval_keywords = [
"approved",
"approve",
"yes",
"proceed",
"continue",
"ok",
"okay",
"accepted",
"accept",
"[approved]",
]
for keyword in approval_keywords:
if keyword in feedback_lower:
return True
# Default to rejection if no approval keywords found
logger.warning(
f"No approval keywords found in feedback: {feedback}. Treating as rejection."
)
return False
def wrap_tools_with_interceptor(
tools: List[BaseTool], interrupt_before_tools: Optional[List[str]] = None
) -> List[BaseTool]:
"""Wrap multiple tools with interrupt logic.
Args:
tools: List of tools to wrap
interrupt_before_tools: List of tool names to interrupt before
Returns:
List[BaseTool]: List of wrapped tools
"""
if not interrupt_before_tools:
logger.debug("No tool interrupts configured, returning tools as-is")
return tools
logger.info(
f"Wrapping {len(tools)} tools with interrupt logic for: {interrupt_before_tools}"
)
interceptor = ToolInterceptor(interrupt_before_tools)
wrapped_tools = []
for tool in tools:
try:
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
wrapped_tools.append(wrapped_tool)
logger.debug(f"Wrapped tool: {tool.name}")
except Exception as e:
logger.error(f"Failed to wrap tool {tool.name}: {str(e)}")
# Add original tool if wrapping fails
wrapped_tools.append(tool)
logger.info(f"Successfully wrapped {len(wrapped_tools)} tools")
return wrapped_tools