fix: the validation Error with qwen-max-latest Model (#706)

* fix: the validation Error with qwen-max-latest Model

    - Added comprehensive unit tests in tests/unit/graph/test_nodes.py for the new extract_plan_content function
    - Tests cover various input types: string, AIMessage, dictionary, other types
    - Includes a specific test case for issue #703 with the qwen-max-latest model
    - All tests pass successfully, confirming the function handles different input types correctly

* feat: address the code review concerns
This commit is contained in:
Willem Jiang 2025-11-24 21:13:15 +08:00 committed by GitHub
parent baf66cd3c7
commit af2547b089
2 changed files with 187 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import json
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Annotated, Literal from typing import Any, Annotated, Literal
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -361,6 +361,34 @@ def planner_node(
) )
def extract_plan_content(plan_data: str | dict | Any) -> str:
"""
Safely extract plan content from different types of plan data.
Args:
plan_data: The plan data which can be a string, AIMessage, or dict
Returns:
str: The plan content as a string (JSON string for dict inputs, or
extracted/original string for other types)
"""
if isinstance(plan_data, str):
# If it's already a string, return as is
return plan_data
elif hasattr(plan_data, 'content') and isinstance(plan_data.content, str):
# If it's an AIMessage or similar object with a content attribute
logger.debug(f"Extracting plan content from message object of type {type(plan_data).__name__}")
return plan_data.content
elif isinstance(plan_data, dict):
# If it's already a dictionary, convert to JSON string
logger.debug("Converting plan dictionary to JSON string")
return json.dumps(plan_data)
else:
# For any other type, try to convert to string
logger.warning(f"Unexpected plan data type {type(plan_data).__name__}, attempting to convert to string")
return str(plan_data)
def human_feedback_node( def human_feedback_node(
state: State, config: RunnableConfig state: State, config: RunnableConfig
) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]: ) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]:
@ -406,7 +434,13 @@ def human_feedback_node(
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
goto = "research_team" goto = "research_team"
try: try:
current_plan = repair_json_output(current_plan) # Safely extract plan content from different types (string, AIMessage, dict)
original_plan = current_plan
current_plan_content = extract_plan_content(current_plan)
logger.debug(f"Extracted plan content type: {type(current_plan_content).__name__}")
# Repair the JSON output
current_plan = repair_json_output(current_plan_content)
# increment the plan iterations # increment the plan iterations
plan_iterations += 1 plan_iterations += 1
# parse the plan # parse the plan
@ -414,8 +448,10 @@ def human_feedback_node(
# Validate and fix plan to ensure web search requirements are met # Validate and fix plan to ensure web search requirements are met
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search) new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search)
except json.JSONDecodeError: except (json.JSONDecodeError, AttributeError) as e:
logger.warning("Planner response is not a valid JSON") logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}")
if isinstance(current_plan, dict) and "content" in original_plan:
logger.warning(f"Plan appears to be an AIMessage object with content field")
if plan_iterations > 1: # the plan_iterations is increased before this check if plan_iterations > 1: # the plan_iterations is increased before this check
return Command( return Command(
update=preserve_state_meta_fields(state), update=preserve_state_meta_fields(state),

View File

@ -12,8 +12,155 @@ from src.graph.nodes import (
planner_node, planner_node,
reporter_node, reporter_node,
researcher_node, researcher_node,
extract_plan_content,
) )
class TestExtractPlanContent:
"""Test cases for the extract_plan_content function."""
def test_extract_plan_content_with_string(self):
"""Test that extract_plan_content returns the input string as-is."""
plan_json_str = '{"locale": "en-US", "has_enough_context": false, "title": "Test Plan"}'
result = extract_plan_content(plan_json_str)
assert result == plan_json_str
def test_extract_plan_content_with_ai_message(self):
"""Test that extract_plan_content extracts content from an AIMessage-like object."""
# Create a mock AIMessage object
class MockAIMessage:
def __init__(self, content):
self.content = content
plan_content = '{"locale": "zh-CN", "has_enough_context": false, "title": "测试计划"}'
plan_message = MockAIMessage(plan_content)
result = extract_plan_content(plan_message)
assert result == plan_content
def test_extract_plan_content_with_dict(self):
"""Test that extract_plan_content converts a dictionary to JSON string."""
plan_dict = {
"locale": "fr-FR",
"has_enough_context": True,
"title": "Plan de test",
"steps": []
}
expected_json = json.dumps(plan_dict)
result = extract_plan_content(plan_dict)
assert result == expected_json
def test_extract_plan_content_with_other_type(self):
"""Test that extract_plan_content converts other types to string."""
plan_value = 12345
expected_string = "12345"
result = extract_plan_content(plan_value)
assert result == expected_string
def test_extract_plan_content_with_complex_dict(self):
"""Test that extract_plan_content handles complex nested dictionaries."""
plan_dict = {
"locale": "zh-CN",
"has_enough_context": False,
"title": "埃菲尔铁塔与世界最高建筑高度比较研究计划",
"thought": "要回答埃菲尔铁塔比世界最高建筑高多少倍的问题,我们需要知道埃菲尔铁塔的高度以及当前世界最高建筑的高度。",
"steps": [
{
"need_search": True,
"title": "收集埃菲尔铁塔和世界最高建筑的高度数据",
"description": "从可靠来源检索埃菲尔铁塔的确切高度以及目前被公认为世界最高建筑的建筑物及其高度数据。",
"step_type": "research"
},
{
"need_search": True,
"title": "查找其他超高建筑作为对比基准",
"description": "获取其他具有代表性的超高建筑的高度数据,以提供更全面的比较背景。",
"step_type": "research"
}
]
}
result = extract_plan_content(plan_dict)
# Verify the result can be parsed back to a dictionary
parsed_result = json.loads(result)
assert parsed_result == plan_dict
def test_extract_plan_content_with_non_string_content(self):
"""Test that extract_plan_content handles AIMessage with non-string content."""
class MockAIMessageWithNonStringContent:
def __init__(self, content):
self.content = content
# Test with non-string content (should not be extracted)
plan_content = 12345
plan_message = MockAIMessageWithNonStringContent(plan_content)
result = extract_plan_content(plan_message)
# Should convert the entire object to string since content is not a string
assert isinstance(result, str)
assert "MockAIMessageWithNonStringContent" in result
def test_extract_plan_content_with_empty_string(self):
"""Test that extract_plan_content handles empty strings."""
empty_string = ""
result = extract_plan_content(empty_string)
assert result == ""
def test_extract_plan_content_with_empty_dict(self):
"""Test that extract_plan_content handles empty dictionaries."""
empty_dict = {}
expected_json = "{}"
result = extract_plan_content(empty_dict)
assert result == expected_json
def test_extract_plan_content_issue_703_case(self):
"""Test that extract_plan_content handles the specific case from issue #703."""
# This is the exact structure that was causing the error in issue #703
class MockAIMessageFromIssue703:
def __init__(self, content):
self.content = content
self.additional_kwargs = {}
self.response_metadata = {'finish_reason': 'stop', 'model_name': 'qwen-max-latest'}
self.type = 'ai'
self.id = 'run--ebc626af-3845-472b-aeee-acddebf5a4ea'
self.example = False
self.tool_calls = []
self.invalid_tool_calls = []
plan_content = '''{
"locale": "zh-CN",
"has_enough_context": false,
"thought": "要回答埃菲尔铁塔比世界最高建筑高多少倍的问题,我们需要知道埃菲尔铁塔的高度以及当前世界最高建筑的高度。",
"title": "埃菲尔铁塔与世界最高建筑高度比较研究计划",
"steps": [
{
"need_search": true,
"title": "收集埃菲尔铁塔和世界最高建筑的高度数据",
"description": "从可靠来源检索埃菲尔铁塔的确切高度以及目前被公认为世界最高建筑的建筑物及其高度数据。",
"step_type": "research"
}
]
}'''
plan_message = MockAIMessageFromIssue703(plan_content)
# Extract the content
result = extract_plan_content(plan_message)
# Verify the extracted content is the same as the original
assert result == plan_content
# Verify the extracted content can be parsed as JSON
parsed_result = json.loads(result)
assert parsed_result["locale"] == "zh-CN"
assert parsed_result["title"] == "埃菲尔铁塔与世界最高建筑高度比较研究计划"
assert len(parsed_result["steps"]) == 1
assert parsed_result["steps"][0]["title"] == "收集埃菲尔铁塔和世界最高建筑的高度数据"
# 在这里 mock 掉 get_llm_by_type避免 ValueError # 在这里 mock 掉 get_llm_by_type避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()): with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage