refactor: refine teh background check logic (#306)
This commit is contained in:
parent
446901ec0b
commit
310f08076a
|
|
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def handoff_to_planner(
|
def handoff_to_planner(
|
||||||
task_title: Annotated[str, "The title of the task to be handed off."],
|
research_topic: Annotated[str, "The topic of the research task to be handed off."],
|
||||||
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
||||||
):
|
):
|
||||||
"""Handoff to planner agent to do plan."""
|
"""Handoff to planner agent to do plan."""
|
||||||
|
|
@ -48,7 +48,7 @@ def handoff_to_planner(
|
||||||
def background_investigation_node(state: State, config: RunnableConfig):
|
def background_investigation_node(state: State, config: RunnableConfig):
|
||||||
logger.info("background investigation node is running.")
|
logger.info("background investigation node is running.")
|
||||||
configurable = Configuration.from_runnable_config(config)
|
configurable = Configuration.from_runnable_config(config)
|
||||||
query = state["messages"][-1].content
|
query = state.get("research_topic")
|
||||||
background_investigation_results = None
|
background_investigation_results = None
|
||||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||||
searched_content = LoggedTavilySearch(
|
searched_content = LoggedTavilySearch(
|
||||||
|
|
@ -87,10 +87,8 @@ def planner_node(
|
||||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||||
messages = apply_prompt_template("planner", state, configurable)
|
messages = apply_prompt_template("planner", state, configurable)
|
||||||
|
|
||||||
if (
|
if state.get("enable_background_investigation") and state.get(
|
||||||
plan_iterations == 0
|
"background_investigation_results"
|
||||||
and state.get("enable_background_investigation")
|
|
||||||
and state.get("background_investigation_results")
|
|
||||||
):
|
):
|
||||||
messages += [
|
messages += [
|
||||||
{
|
{
|
||||||
|
|
@ -221,6 +219,7 @@ def coordinator_node(
|
||||||
|
|
||||||
goto = "__end__"
|
goto = "__end__"
|
||||||
locale = state.get("locale", "en-US") # Default locale if not specified
|
locale = state.get("locale", "en-US") # Default locale if not specified
|
||||||
|
research_topic = state.get("research_topic", "")
|
||||||
|
|
||||||
if len(response.tool_calls) > 0:
|
if len(response.tool_calls) > 0:
|
||||||
goto = "planner"
|
goto = "planner"
|
||||||
|
|
@ -231,8 +230,11 @@ def coordinator_node(
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
if tool_call.get("name", "") != "handoff_to_planner":
|
if tool_call.get("name", "") != "handoff_to_planner":
|
||||||
continue
|
continue
|
||||||
if tool_locale := tool_call.get("args", {}).get("locale"):
|
if tool_call.get("args", {}).get("locale") and tool_call.get(
|
||||||
locale = tool_locale
|
"args", {}
|
||||||
|
).get("research_topic"):
|
||||||
|
locale = tool_call.get("args", {}).get("locale")
|
||||||
|
research_topic = tool_call.get("args", {}).get("research_topic")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing tool calls: {e}")
|
logger.error(f"Error processing tool calls: {e}")
|
||||||
|
|
@ -243,7 +245,11 @@ def coordinator_node(
|
||||||
logger.debug(f"Coordinator response: {response}")
|
logger.debug(f"Coordinator response: {response}")
|
||||||
|
|
||||||
return Command(
|
return Command(
|
||||||
update={"locale": locale, "resources": configurable.resources},
|
update={
|
||||||
|
"locale": locale,
|
||||||
|
"research_topic": research_topic,
|
||||||
|
"resources": configurable.resources,
|
||||||
|
},
|
||||||
goto=goto,
|
goto=goto,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ class State(MessagesState):
|
||||||
|
|
||||||
# Runtime Variables
|
# Runtime Variables
|
||||||
locale: str = "en-US"
|
locale: str = "en-US"
|
||||||
|
research_topic: str = ""
|
||||||
observations: list[str] = []
|
observations: list[str] = []
|
||||||
resources: list[Resource] = []
|
resources: list[Resource] = []
|
||||||
plan_iterations: int = 0
|
plan_iterations: int = 0
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ async def chat_stream(request: ChatRequest):
|
||||||
|
|
||||||
|
|
||||||
async def _astream_workflow_generator(
|
async def _astream_workflow_generator(
|
||||||
messages: List[ChatMessage],
|
messages: List[dict],
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
resources: List[Resource],
|
resources: List[Resource],
|
||||||
max_plan_iterations: int,
|
max_plan_iterations: int,
|
||||||
|
|
@ -107,6 +107,7 @@ async def _astream_workflow_generator(
|
||||||
"observations": [],
|
"observations": [],
|
||||||
"auto_accepted_plan": auto_accepted_plan,
|
"auto_accepted_plan": auto_accepted_plan,
|
||||||
"enable_background_investigation": enable_background_investigation,
|
"enable_background_investigation": enable_background_investigation,
|
||||||
|
"research_topic": messages[-1]["content"] if messages else "",
|
||||||
}
|
}
|
||||||
if not auto_accepted_plan and interrupt_feedback:
|
if not auto_accepted_plan and interrupt_feedback:
|
||||||
resume_msg = f"[{interrupt_feedback}]"
|
resume_msg = f"[{interrupt_feedback}]"
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ MOCK_SEARCH_RESULTS = [
|
||||||
def mock_state():
|
def mock_state():
|
||||||
return {
|
return {
|
||||||
"messages": [HumanMessage(content="test query")],
|
"messages": [HumanMessage(content="test query")],
|
||||||
|
"research_topic": "test query",
|
||||||
"background_investigation_results": None,
|
"background_investigation_results": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue