refactor: refine teh background check logic (#306)

This commit is contained in:
DanielWalnut 2025-06-11 11:10:02 +08:00 committed by GitHub
parent 446901ec0b
commit 310f08076a
4 changed files with 19 additions and 10 deletions

View File

@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
@tool @tool
def handoff_to_planner( def handoff_to_planner(
task_title: Annotated[str, "The title of the task to be handed off."], research_topic: Annotated[str, "The topic of the research task to be handed off."],
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."], locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
): ):
"""Handoff to planner agent to do plan.""" """Handoff to planner agent to do plan."""
@ -48,7 +48,7 @@ def handoff_to_planner(
def background_investigation_node(state: State, config: RunnableConfig): def background_investigation_node(state: State, config: RunnableConfig):
logger.info("background investigation node is running.") logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content query = state.get("research_topic")
background_investigation_results = None background_investigation_results = None
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch( searched_content = LoggedTavilySearch(
@ -87,10 +87,8 @@ def planner_node(
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
messages = apply_prompt_template("planner", state, configurable) messages = apply_prompt_template("planner", state, configurable)
if ( if state.get("enable_background_investigation") and state.get(
plan_iterations == 0 "background_investigation_results"
and state.get("enable_background_investigation")
and state.get("background_investigation_results")
): ):
messages += [ messages += [
{ {
@ -221,6 +219,7 @@ def coordinator_node(
goto = "__end__" goto = "__end__"
locale = state.get("locale", "en-US") # Default locale if not specified locale = state.get("locale", "en-US") # Default locale if not specified
research_topic = state.get("research_topic", "")
if len(response.tool_calls) > 0: if len(response.tool_calls) > 0:
goto = "planner" goto = "planner"
@ -231,8 +230,11 @@ def coordinator_node(
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
if tool_call.get("name", "") != "handoff_to_planner": if tool_call.get("name", "") != "handoff_to_planner":
continue continue
if tool_locale := tool_call.get("args", {}).get("locale"): if tool_call.get("args", {}).get("locale") and tool_call.get(
locale = tool_locale "args", {}
).get("research_topic"):
locale = tool_call.get("args", {}).get("locale")
research_topic = tool_call.get("args", {}).get("research_topic")
break break
except Exception as e: except Exception as e:
logger.error(f"Error processing tool calls: {e}") logger.error(f"Error processing tool calls: {e}")
@ -243,7 +245,11 @@ def coordinator_node(
logger.debug(f"Coordinator response: {response}") logger.debug(f"Coordinator response: {response}")
return Command( return Command(
update={"locale": locale, "resources": configurable.resources}, update={
"locale": locale,
"research_topic": research_topic,
"resources": configurable.resources,
},
goto=goto, goto=goto,
) )

View File

@ -12,6 +12,7 @@ class State(MessagesState):
# Runtime Variables # Runtime Variables
locale: str = "en-US" locale: str = "en-US"
research_topic: str = ""
observations: list[str] = [] observations: list[str] = []
resources: list[Resource] = [] resources: list[Resource] = []
plan_iterations: int = 0 plan_iterations: int = 0

View File

@ -87,7 +87,7 @@ async def chat_stream(request: ChatRequest):
async def _astream_workflow_generator( async def _astream_workflow_generator(
messages: List[ChatMessage], messages: List[dict],
thread_id: str, thread_id: str,
resources: List[Resource], resources: List[Resource],
max_plan_iterations: int, max_plan_iterations: int,
@ -107,6 +107,7 @@ async def _astream_workflow_generator(
"observations": [], "observations": [],
"auto_accepted_plan": auto_accepted_plan, "auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation, "enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "",
} }
if not auto_accepted_plan and interrupt_feedback: if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]" resume_msg = f"[{interrupt_feedback}]"

View File

@ -20,6 +20,7 @@ MOCK_SEARCH_RESULTS = [
def mock_state(): def mock_state():
return { return {
"messages": [HumanMessage(content="test query")], "messages": [HumanMessage(content="test query")],
"research_topic": "test query",
"background_investigation_results": None, "background_investigation_results": None,
} }