feat(client): add `available_skills` parameter to DeerFlowClient (#1779)
* feat(client): add `available_skills` parameter to DeerFlowClient for dynamic runtime skill filtering * Update backend/packages/harness/deerflow/client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(client): include `agent_name` and `available_skills` in agent config cache key --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
a6f0712732
commit
4e188b41d9
|
|
@ -117,6 +117,7 @@ class DeerFlowClient:
|
||||||
subagent_enabled: bool = False,
|
subagent_enabled: bool = False,
|
||||||
plan_mode: bool = False,
|
plan_mode: bool = False,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
available_skills: set[str] | None = None,
|
||||||
middlewares: Sequence[AgentMiddleware] | None = None,
|
middlewares: Sequence[AgentMiddleware] | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the client.
|
"""Initialize the client.
|
||||||
|
|
@ -133,6 +134,7 @@ class DeerFlowClient:
|
||||||
subagent_enabled: Enable subagent delegation.
|
subagent_enabled: Enable subagent delegation.
|
||||||
plan_mode: Enable TodoList middleware for plan mode.
|
plan_mode: Enable TodoList middleware for plan mode.
|
||||||
agent_name: Name of the agent to use.
|
agent_name: Name of the agent to use.
|
||||||
|
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
|
||||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||||
"""
|
"""
|
||||||
if config_path is not None:
|
if config_path is not None:
|
||||||
|
|
@ -148,6 +150,7 @@ class DeerFlowClient:
|
||||||
self._subagent_enabled = subagent_enabled
|
self._subagent_enabled = subagent_enabled
|
||||||
self._plan_mode = plan_mode
|
self._plan_mode = plan_mode
|
||||||
self._agent_name = agent_name
|
self._agent_name = agent_name
|
||||||
|
self._available_skills = set(available_skills) if available_skills is not None else None
|
||||||
self._middlewares = list(middlewares) if middlewares else []
|
self._middlewares = list(middlewares) if middlewares else []
|
||||||
|
|
||||||
# Lazy agent — created on first call, recreated when config changes.
|
# Lazy agent — created on first call, recreated when config changes.
|
||||||
|
|
@ -208,6 +211,8 @@ class DeerFlowClient:
|
||||||
cfg.get("thinking_enabled"),
|
cfg.get("thinking_enabled"),
|
||||||
cfg.get("is_plan_mode"),
|
cfg.get("is_plan_mode"),
|
||||||
cfg.get("subagent_enabled"),
|
cfg.get("subagent_enabled"),
|
||||||
|
self._agent_name,
|
||||||
|
frozenset(self._available_skills) if self._available_skills is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._agent is not None and self._agent_config_key == key:
|
if self._agent is not None and self._agent_config_key == key:
|
||||||
|
|
@ -226,6 +231,7 @@ class DeerFlowClient:
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
agent_name=self._agent_name,
|
agent_name=self._agent_name,
|
||||||
|
available_skills=self._available_skills,
|
||||||
),
|
),
|
||||||
"state_schema": ThreadState,
|
"state_schema": ThreadState,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,18 +59,20 @@ class TestClientInit:
|
||||||
assert client._subagent_enabled is False
|
assert client._subagent_enabled is False
|
||||||
assert client._plan_mode is False
|
assert client._plan_mode is False
|
||||||
assert client._agent_name is None
|
assert client._agent_name is None
|
||||||
|
assert client._available_skills is None
|
||||||
assert client._checkpointer is None
|
assert client._checkpointer is None
|
||||||
assert client._agent is None
|
assert client._agent is None
|
||||||
|
|
||||||
def test_custom_params(self, mock_app_config):
|
def test_custom_params(self, mock_app_config):
|
||||||
mock_middleware = MagicMock()
|
mock_middleware = MagicMock()
|
||||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
|
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||||
assert c._model_name == "gpt-4"
|
assert c._model_name == "gpt-4"
|
||||||
assert c._thinking_enabled is False
|
assert c._thinking_enabled is False
|
||||||
assert c._subagent_enabled is True
|
assert c._subagent_enabled is True
|
||||||
assert c._plan_mode is True
|
assert c._plan_mode is True
|
||||||
assert c._agent_name == "test-agent"
|
assert c._agent_name == "test-agent"
|
||||||
|
assert c._available_skills == {"skill1", "skill2"}
|
||||||
assert c._middlewares == [mock_middleware]
|
assert c._middlewares == [mock_middleware]
|
||||||
|
|
||||||
def test_invalid_agent_name(self, mock_app_config):
|
def test_invalid_agent_name(self, mock_app_config):
|
||||||
|
|
@ -394,8 +396,10 @@ class TestEnsureAgent:
|
||||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._agent_name = "custom-agent"
|
client._agent_name = "custom-agent"
|
||||||
|
client._available_skills = {"test_skill"}
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
assert client._agent is mock_agent
|
assert client._agent is mock_agent
|
||||||
|
|
@ -404,6 +408,7 @@ class TestEnsureAgent:
|
||||||
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
|
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||||
mock_apply_prompt.assert_called_once()
|
mock_apply_prompt.assert_called_once()
|
||||||
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||||
|
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
|
||||||
|
|
||||||
def test_uses_default_checkpointer_when_available(self, client):
|
def test_uses_default_checkpointer_when_available(self, client):
|
||||||
mock_agent = MagicMock()
|
mock_agent = MagicMock()
|
||||||
|
|
@ -441,6 +446,7 @@ class TestEnsureAgent:
|
||||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
|
|
@ -469,7 +475,7 @@ class TestEnsureAgent:
|
||||||
"""_ensure_agent does not recreate if config key unchanged."""
|
"""_ensure_agent does not recreate if config key unchanged."""
|
||||||
mock_agent = MagicMock()
|
mock_agent = MagicMock()
|
||||||
client._agent = mock_agent
|
client._agent = mock_agent
|
||||||
client._agent_config_key = (None, True, False, False)
|
client._agent_config_key = (None, True, False, False, None, None)
|
||||||
|
|
||||||
config = client._get_runnable_config("t1")
|
config = client._get_runnable_config("t1")
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
|
||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config_a)
|
client._ensure_agent(config_a)
|
||||||
first_agent = client._agent
|
first_agent = client._agent
|
||||||
|
|
@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
|
||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
|
||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client.reset_agent()
|
client.reset_agent()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue