fix: resolve missing serialized kwargs in PatchedChatDeepSeek (#2025)
* add tests * fix ci * fix ci
This commit is contained in:
parent
823f3af98c
commit
1b74d84590
|
|
@ -48,6 +48,10 @@ class CodexChatModel(BaseChatModel):
|
|||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "codex-responses"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,14 @@ class PatchedChatDeepSeek(ChatDeepSeek):
|
|||
request payload.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,246 @@
|
|||
"""Tests for deerflow.models.openai_codex_provider.CodexChatModel.
|
||||
|
||||
Covers:
|
||||
- LangChain serialization: is_lc_serializable, to_json kwargs, no token leakage
|
||||
- _parse_response: text content, tool calls, reasoning_content
|
||||
- _convert_messages: SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
- _parse_sse_data_line: valid data, [DONE], non-JSON, non-data lines
|
||||
- _parse_tool_call_arguments: valid JSON, invalid JSON, non-dict JSON
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from deerflow.models.credential_loader import CodexCliCredential
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
cred = CodexCliCredential(access_token="tok-test", account_id="acc-test")
|
||||
with patch("deerflow.models.openai_codex_provider.load_codex_cli_credential", return_value=cred):
|
||||
return CodexChatModel(model="gpt-5.4", reasoning_effort="medium", **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_lc_serializable_returns_true():
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
assert CodexChatModel.is_lc_serializable() is True
|
||||
|
||||
|
||||
def test_to_json_produces_constructor_type():
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
assert result["type"] == "constructor"
|
||||
assert "kwargs" in result
|
||||
|
||||
|
||||
def test_to_json_contains_model_and_reasoning_effort():
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
assert result["kwargs"]["model"] == "gpt-5.4"
|
||||
assert result["kwargs"]["reasoning_effort"] == "medium"
|
||||
|
||||
|
||||
def test_to_json_does_not_leak_access_token():
|
||||
"""_access_token is not a Pydantic field and must not appear in serialized kwargs."""
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
kwargs_str = json.dumps(result["kwargs"])
|
||||
assert "tok-test" not in kwargs_str
|
||||
assert "_access_token" not in kwargs_str
|
||||
assert "_account_id" not in kwargs_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_response_text_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Hello world"}],
|
||||
}
|
||||
],
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
result = model._parse_response(response)
|
||||
assert result.generations[0].message.content == "Hello world"
|
||||
|
||||
|
||||
def test_parse_response_reasoning_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"summary": [{"type": "summary_text", "text": "I reasoned about this."}],
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Answer"}],
|
||||
},
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
result = model._parse_response(response)
|
||||
msg = result.generations[0].message
|
||||
assert msg.content == "Answer"
|
||||
assert msg.additional_kwargs["reasoning_content"] == "I reasoned about this."
|
||||
|
||||
|
||||
def test_parse_response_tool_call():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": "web_search",
|
||||
"arguments": '{"query": "test"}',
|
||||
"call_id": "call_abc",
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
result = model._parse_response(response)
|
||||
tool_calls = result.generations[0].message.tool_calls
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["name"] == "web_search"
|
||||
assert tool_calls[0]["args"] == {"query": "test"}
|
||||
assert tool_calls[0]["id"] == "call_abc"
|
||||
|
||||
|
||||
def test_parse_response_invalid_tool_call_arguments():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": "bad_tool",
|
||||
"arguments": "not-json",
|
||||
"call_id": "call_bad",
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
result = model._parse_response(response)
|
||||
msg = result.generations[0].message
|
||||
assert len(msg.tool_calls) == 0
|
||||
assert len(msg.invalid_tool_calls) == 1
|
||||
assert msg.invalid_tool_calls[0]["name"] == "bad_tool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _convert_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_messages_human():
|
||||
model = _make_model()
|
||||
_, items = model._convert_messages([HumanMessage(content="Hello")])
|
||||
assert items == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
def test_convert_messages_system_becomes_instructions():
|
||||
model = _make_model()
|
||||
instructions, items = model._convert_messages([SystemMessage(content="You are helpful.")])
|
||||
assert "You are helpful." in instructions
|
||||
assert items == []
|
||||
|
||||
|
||||
def test_convert_messages_ai_with_tool_calls():
|
||||
model = _make_model()
|
||||
ai = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "search", "args": {"q": "foo"}, "id": "tc1", "type": "tool_call"}],
|
||||
)
|
||||
_, items = model._convert_messages([ai])
|
||||
assert any(item.get("type") == "function_call" and item["name"] == "search" for item in items)
|
||||
|
||||
|
||||
def test_convert_messages_tool_message():
|
||||
model = _make_model()
|
||||
tool_msg = ToolMessage(content="result data", tool_call_id="tc1")
|
||||
_, items = model._convert_messages([tool_msg])
|
||||
assert items[0]["type"] == "function_call_output"
|
||||
assert items[0]["call_id"] == "tc1"
|
||||
assert items[0]["output"] == "result data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_sse_data_line
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_sse_data_line_valid():
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
data = {"type": "response.completed", "response": {}}
|
||||
line = "data: " + json.dumps(data)
|
||||
assert CodexChatModel._parse_sse_data_line(line) == data
|
||||
|
||||
|
||||
def test_parse_sse_data_line_done_returns_none():
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
assert CodexChatModel._parse_sse_data_line("data: [DONE]") is None
|
||||
|
||||
|
||||
def test_parse_sse_data_line_non_data_returns_none():
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
assert CodexChatModel._parse_sse_data_line("event: ping") is None
|
||||
|
||||
|
||||
def test_parse_sse_data_line_invalid_json_returns_none():
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
assert CodexChatModel._parse_sse_data_line("data: {bad json}") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_tool_call_arguments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_tool_call_arguments_valid_string():
|
||||
model = _make_model()
|
||||
parsed, err = model._parse_tool_call_arguments({"arguments": '{"key": "val"}', "name": "t", "call_id": "c"})
|
||||
assert parsed == {"key": "val"}
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_parse_tool_call_arguments_already_dict():
|
||||
model = _make_model()
|
||||
parsed, err = model._parse_tool_call_arguments({"arguments": {"key": "val"}, "name": "t", "call_id": "c"})
|
||||
assert parsed == {"key": "val"}
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_parse_tool_call_arguments_invalid_json():
|
||||
model = _make_model()
|
||||
parsed, err = model._parse_tool_call_arguments({"arguments": "not-json", "name": "t", "call_id": "c"})
|
||||
assert parsed is None
|
||||
assert err is not None
|
||||
assert "Failed to parse" in err["error"]
|
||||
|
||||
|
||||
def test_parse_tool_call_arguments_non_dict_json():
|
||||
model = _make_model()
|
||||
parsed, err = model._parse_tool_call_arguments({"arguments": '["list", "not", "dict"]', "name": "t", "call_id": "c"})
|
||||
assert parsed is None
|
||||
assert err is not None
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
"""Tests for deerflow.models.patched_deepseek.PatchedChatDeepSeek.
|
||||
|
||||
Covers:
|
||||
- LangChain serialization protocol: is_lc_serializable, lc_secrets, to_json
|
||||
- reasoning_content restoration in _get_request_payload (single and multi-turn)
|
||||
- Positional fallback when message counts differ
|
||||
- No-op when no reasoning_content present
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
from deerflow.models.patched_deepseek import PatchedChatDeepSeek
|
||||
|
||||
return PatchedChatDeepSeek(
|
||||
model="deepseek-reasoner",
|
||||
api_key="test-key",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_lc_serializable_returns_true():
|
||||
from deerflow.models.patched_deepseek import PatchedChatDeepSeek
|
||||
|
||||
assert PatchedChatDeepSeek.is_lc_serializable() is True
|
||||
|
||||
|
||||
def test_lc_secrets_contains_api_key_mapping():
|
||||
model = _make_model()
|
||||
secrets = model.lc_secrets
|
||||
assert "api_key" in secrets
|
||||
assert secrets["api_key"] == "DEEPSEEK_API_KEY"
|
||||
assert "openai_api_key" in secrets
|
||||
|
||||
|
||||
def test_to_json_produces_constructor_type():
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
assert result["type"] == "constructor"
|
||||
assert "kwargs" in result
|
||||
|
||||
|
||||
def test_to_json_kwargs_contains_model():
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
assert result["kwargs"]["model_name"] == "deepseek-reasoner"
|
||||
assert result["kwargs"]["api_base"] == "https://api.deepseek.com/v1"
|
||||
|
||||
|
||||
def test_to_json_kwargs_contains_custom_api_base():
|
||||
model = _make_model(api_base="https://ark.cn-beijing.volces.com/api/v3")
|
||||
result = model.to_json()
|
||||
assert result["kwargs"]["api_base"] == "https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
|
||||
def test_to_json_api_key_is_masked():
|
||||
"""api_key must not appear as plain text in the serialized output."""
|
||||
model = _make_model()
|
||||
result = model.to_json()
|
||||
api_key_value = result["kwargs"].get("api_key") or result["kwargs"].get("openai_api_key")
|
||||
assert api_key_value is None or isinstance(api_key_value, dict), f"API key must not be plain text, got: {api_key_value!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reasoning_content preservation in _get_request_payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_payload_message(role: str, content: str | None = None, tool_calls: list | None = None) -> dict:
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if tool_calls is not None:
|
||||
msg["tool_calls"] = tool_calls
|
||||
return msg
|
||||
|
||||
|
||||
def test_reasoning_content_injected_into_assistant_message():
|
||||
"""reasoning_content from additional_kwargs is restored in the payload."""
|
||||
model = _make_model()
|
||||
|
||||
human = HumanMessage(content="What is 2+2?")
|
||||
ai = AIMessage(
|
||||
content="4",
|
||||
additional_kwargs={"reasoning_content": "Let me think: 2+2=4"},
|
||||
)
|
||||
|
||||
base_payload = {
|
||||
"messages": [
|
||||
_make_payload_message("user", "What is 2+2?"),
|
||||
_make_payload_message("assistant", "4"),
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
|
||||
payload = model._get_request_payload([human, ai])
|
||||
|
||||
assistant_msg = next(m for m in payload["messages"] if m["role"] == "assistant")
|
||||
assert assistant_msg["reasoning_content"] == "Let me think: 2+2=4"
|
||||
|
||||
|
||||
def test_no_reasoning_content_is_noop():
|
||||
"""Messages without reasoning_content are left unchanged."""
|
||||
model = _make_model()
|
||||
|
||||
human = HumanMessage(content="hello")
|
||||
ai = AIMessage(content="hi", additional_kwargs={})
|
||||
|
||||
base_payload = {
|
||||
"messages": [
|
||||
_make_payload_message("user", "hello"),
|
||||
_make_payload_message("assistant", "hi"),
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
|
||||
payload = model._get_request_payload([human, ai])
|
||||
|
||||
assistant_msg = next(m for m in payload["messages"] if m["role"] == "assistant")
|
||||
assert "reasoning_content" not in assistant_msg
|
||||
|
||||
|
||||
def test_reasoning_content_multi_turn():
|
||||
"""All assistant turns each get their own reasoning_content."""
|
||||
model = _make_model()
|
||||
|
||||
human1 = HumanMessage(content="Step 1?")
|
||||
ai1 = AIMessage(content="A1", additional_kwargs={"reasoning_content": "Thought1"})
|
||||
human2 = HumanMessage(content="Step 2?")
|
||||
ai2 = AIMessage(content="A2", additional_kwargs={"reasoning_content": "Thought2"})
|
||||
|
||||
base_payload = {
|
||||
"messages": [
|
||||
_make_payload_message("user", "Step 1?"),
|
||||
_make_payload_message("assistant", "A1"),
|
||||
_make_payload_message("user", "Step 2?"),
|
||||
_make_payload_message("assistant", "A2"),
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human1, ai1, human2, ai2])
|
||||
payload = model._get_request_payload([human1, ai1, human2, ai2])
|
||||
|
||||
assistant_msgs = [m for m in payload["messages"] if m["role"] == "assistant"]
|
||||
assert assistant_msgs[0]["reasoning_content"] == "Thought1"
|
||||
assert assistant_msgs[1]["reasoning_content"] == "Thought2"
|
||||
|
||||
|
||||
def test_positional_fallback_when_count_differs():
|
||||
"""Falls back to positional matching when payload/original message counts differ."""
|
||||
model = _make_model()
|
||||
|
||||
human = HumanMessage(content="hi")
|
||||
ai = AIMessage(content="hello", additional_kwargs={"reasoning_content": "My reasoning"})
|
||||
|
||||
# Simulate count mismatch: payload has 3 messages, original has 2
|
||||
extra_system = _make_payload_message("system", "You are helpful.")
|
||||
base_payload = {
|
||||
"messages": [
|
||||
extra_system,
|
||||
_make_payload_message("user", "hi"),
|
||||
_make_payload_message("assistant", "hello"),
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
|
||||
payload = model._get_request_payload([human, ai])
|
||||
|
||||
assistant_msg = next(m for m in payload["messages"] if m["role"] == "assistant")
|
||||
assert assistant_msg["reasoning_content"] == "My reasoning"
|
||||
Loading…
Reference in New Issue