Compare commits

..

9 Commits

24 changed files with 976 additions and 63 deletions

View File

@ -112,6 +112,34 @@ guardrails:
3. Ask the agent: "Use bash to run echo hello" 3. Ask the agent: "Use bash to run echo hello"
4. The agent sees: `Guardrail denied: tool 'bash' was blocked (oap.tool_not_allowed)` 4. The agent sees: `Guardrail denied: tool 'bash' was blocked (oap.tool_not_allowed)`
### Option 1.5: Built-in SensitiveDataProvider (Strict Secret Blocking)
For secret-leak prevention, DeerFlow also ships `SensitiveDataProvider`, which
blocks tool calls targeting sensitive file patterns (for example `.env`,
`.env.*`, `*.pem`, `*.key`, `id_rsa*`, `secrets.*`, `credentials.*`).
This provider is strict: it blocks matching access even when the user
explicitly asks the agent to reveal those files.
**config.yaml:**
```yaml
guardrails:
enabled: true
fail_closed: true
provider:
use: deerflow.guardrails.builtin:SensitiveDataProvider
config:
protected_tools: ["read_file", "write_file", "str_replace", "ls", "glob", "grep", "bash"]
deny_basenames: [".env"]
deny_globs: [".env.*", "*.pem", "*.key", "id_rsa*", "secrets.*", "credentials.*"]
block_skills_env: true
```
**Behavior summary:**
- `read_file / ls / glob / grep / bash` attempting sensitive path access are denied
- `write_file / str_replace` touching sensitive targets are denied
- Denials are logged as structured audit events (tool/reason/thread/timestamp)
### Option 2: OAP Passport Provider (Policy-Based) ### Option 2: OAP Passport Provider (Policy-Based)
For policy enforcement based on the [Open Agent Passport (OAP)](https://github.com/aporthq/aport-spec) open standard. An OAP passport is a JSON document that declares an agent's identity, capabilities, and operational limits. Any provider that reads an OAP passport and returns OAP-compliant decisions works with DeerFlow. For policy enforcement based on the [Open Agent Passport (OAP)](https://github.com/aporthq/aport-spec) open standard. An OAP passport is a JSON document that declares an agent's identity, capabilities, and operational limits. Any provider that reads an OAP passport and returns OAP-compliant decisions works with DeerFlow.

View File

@ -7,3 +7,14 @@ __all__ = [
"checkpointer_context", "checkpointer_context",
"make_checkpointer", "make_checkpointer",
] ]
# Lazy-import shallow savers so the module is still importable without
# langgraph-checkpoint-sqlite installed.
def __getattr__(name: str):
if name == "AsyncShallowSqliteSaver":
from .shallow_sqlite import _make_async_shallow_saver
return _make_async_shallow_saver()
if name == "ShallowSqliteSaver":
from .shallow_sqlite import _make_sync_shallow_saver
return _make_sync_shallow_saver()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -55,6 +55,18 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str) ensure_sqlite_parent_dir(conn_str)
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
if getattr(config, "sqlite_mode", "full") == "shallow":
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(conn_str) as saver:
await saver.setup()
logger.info("Checkpointer: using AsyncShallowSqliteSaver (%s)", conn_str)
yield saver
return
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver: async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup() await saver.setup()
yield saver yield saver

View File

@ -67,6 +67,18 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
if getattr(config, "sqlite_mode", "full") == "shallow":
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using ShallowSqliteSaver (%s)", conn_str)
yield saver
return
with SqliteSaver.from_conn_string(conn_str) as saver: with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup() saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)

View File

@ -0,0 +1,128 @@
"""Shallow persistence savers for LangGraph SQLite checkpointing.
Provides shallow (single-checkpoint-per-thread) variants of the LangGraph
SQLite savers that automatically delete old checkpoints and writes for the
same thread before each write, keeping only the latest state.
This prevents unbounded growth of ``checkpoints.db`` while preserving
multi-turn conversation continuity.
Implements:
- ``AsyncShallowSqliteSaver`` async shallow variant
- ``ShallowSqliteSaver`` sync shallow variant
Usage is transparent through the existing checkpointer factory when
``sqlite_mode: shallow`` is set in ``config.yaml``.
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Async shallow saver
# ---------------------------------------------------------------------------
class AsyncShallowSqliteSaver:
"""Async SQLite checkpointer that keeps only the latest checkpoint per thread.
Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and
overrides :meth:`aput` to delete all existing checkpoints and writes for
the same ``thread_id`` before inserting the new one.
Each conversation thread stores exactly one checkpoint at any time,
preventing unbounded database growth.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
# Allow extension by subclasses without forcing late-bound import here.
# The concrete class is built below via the _make_async_shallow factory.
super().__init_subclass__(**kwargs)
def _make_async_shallow_saver() -> type:
"""Build and return the ``AsyncShallowSqliteSaver`` class.
Import is deferred so that the module is importable even when
``langgraph-checkpoint-sqlite`` is not installed.
"""
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
class AsyncShallowSqliteSaver(AsyncSqliteSaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Delete old checkpoints/writes for this thread, then save the new one."""
thread_id = config["configurable"]["thread_id"]
await self.setup()
# Delete all existing checkpoints and writes for this thread
# before inserting the new checkpoint — keeps only the latest.
async with self.lock:
await self.conn.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(str(thread_id),),
)
await self.conn.execute(
"DELETE FROM writes WHERE thread_id = ?",
(str(thread_id),),
)
await self.conn.commit()
return await super().aput(config, checkpoint, metadata, new_versions)
return AsyncShallowSqliteSaver
# ---------------------------------------------------------------------------
# Sync shallow saver
# ---------------------------------------------------------------------------
def _make_sync_shallow_saver() -> type:
"""Build and return the ``ShallowSqliteSaver`` class.
Import is deferred so that the module is importable even when
``langgraph-checkpoint-sqlite`` is not installed.
"""
from langgraph.checkpoint.sqlite import SqliteSaver
class ShallowSqliteSaver(SqliteSaver):
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Delete old checkpoints/writes for this thread, then save the new one."""
thread_id = config["configurable"]["thread_id"]
# Delete all existing checkpoints and writes for this thread
# before inserting the new checkpoint — keeps only the latest.
with self.cursor() as cur:
cur.execute(
"DELETE FROM checkpoints WHERE thread_id = ?",
(str(thread_id),),
)
cur.execute(
"DELETE FROM writes WHERE thread_id = ?",
(str(thread_id),),
)
return super().put(config, checkpoint, metadata, new_versions)
return ShallowSqliteSaver

View File

@ -174,7 +174,6 @@ def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
def _reserve_failure_message(status_code: int | None) -> str: def _reserve_failure_message(status_code: int | None) -> str:
if status_code in _blocking_reserve_code_set(): if status_code in _blocking_reserve_code_set():
# TODO: 将账单错误文案迁移到国际化资源中,按语言返回提示。
return "The account balance is insufficient for this model call." return "The account balance is insufficient for this model call."
return "Billing reservation failed. Please try again later." return "Billing reservation failed. Please try again later."

View File

@ -0,0 +1,64 @@
"""Redact sensitive values from tool outputs before they re-enter the model context."""
from __future__ import annotations
import re
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
_SIMPLE_REDACTION = "[REDACTED]"
_PATTERNS = [
re.compile(r"(?im)\b([A-Z][A-Z0-9_]{1,64})\s*=\s*([^\s\"'`]{6,})"),
re.compile(r"(?i)\b(bearer\s+)[A-Za-z0-9._\-+/=]{8,}"),
re.compile(r"(?i)\b(sk-[A-Za-z0-9]{12,})\b"),
re.compile(r"(?i)\b(eyJ[A-Za-z0-9_\-]{10,}\.[A-Za-z0-9_\-]{10,}\.[A-Za-z0-9_\-]{10,})\b"),
re.compile(r"(?is)-----BEGIN [A-Z ]*PRIVATE KEY-----.*?-----END [A-Z ]*PRIVATE KEY-----"),
]
def _redact_text(text: str) -> str:
value = text
# Preserve var name for KEY=VALUE style output.
value = _PATTERNS[0].sub(lambda m: f"{m.group(1)}={_SIMPLE_REDACTION}", value)
value = _PATTERNS[1].sub(lambda m: f"{m.group(1)}{_SIMPLE_REDACTION}", value)
for pattern in _PATTERNS[2:]:
value = pattern.sub(_SIMPLE_REDACTION, value)
return value
class SensitiveOutputRedactionMiddleware(AgentMiddleware[AgentState]):
"""Redact secrets from tool outputs."""
def _redact_tool_result(self, result: ToolMessage | Command) -> ToolMessage | Command:
if not isinstance(result, ToolMessage):
return result
if isinstance(result.content, str):
redacted = _redact_text(result.content)
if redacted != result.content:
return ToolMessage(content=redacted, tool_call_id=result.tool_call_id, name=result.name, status=result.status)
return result
return result
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
return self._redact_tool_result(handler(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
return self._redact_tool_result(await handler(request))

View File

@ -127,8 +127,10 @@ def _build_runtime_middlewares(
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport)) middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
from deerflow.agents.middlewares.sensitive_output_redaction_middleware import SensitiveOutputRedactionMiddleware
middlewares.append(SandboxAuditMiddleware()) middlewares.append(SandboxAuditMiddleware())
middlewares.append(SensitiveOutputRedactionMiddleware())
middlewares.append(ToolErrorHandlingMiddleware()) middlewares.append(ToolErrorHandlingMiddleware())
return middlewares return middlewares

View File

@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"] CheckpointerType = Literal["memory", "sqlite", "postgres"]
SqliteMode = Literal["full", "shallow"]
"""Persistence mode for the SQLite checkpointer.
- ``full`` retain all checkpoint history (default, backward-compatible).
- ``shallow`` keep only the latest checkpoint per thread, deleting old
records before each write to prevent unbounded database growth.
"""
class CheckpointerConfig(BaseModel): class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer.""" """Configuration for LangGraph state persistence checkpointer."""
@ -23,6 +31,13 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
) )
sqlite_mode: SqliteMode = Field(
default="full",
description="SQLite persistence mode. "
"'full' retains all checkpoint history (default). "
"'shallow' keeps only the latest checkpoint per thread, "
"deleting old records before each write.",
)
# Global configuration instance — None means no checkpointer is configured. # Global configuration instance — None means no checkpointer is configured.

View File

@ -1,11 +1,12 @@
"""Pre-tool-call authorization middleware.""" """Pre-tool-call authorization middleware."""
from deerflow.guardrails.builtin import AllowlistProvider from deerflow.guardrails.builtin import AllowlistProvider, SensitiveDataProvider
from deerflow.guardrails.middleware import GuardrailMiddleware from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
__all__ = [ __all__ = [
"AllowlistProvider", "AllowlistProvider",
"SensitiveDataProvider",
"GuardrailDecision", "GuardrailDecision",
"GuardrailMiddleware", "GuardrailMiddleware",
"GuardrailProvider", "GuardrailProvider",

View File

@ -1,7 +1,20 @@
"""Built-in guardrail providers that ship with DeerFlow.""" """Built-in guardrail providers that ship with DeerFlow."""
from __future__ import annotations
import fnmatch
import json
import logging
import re
import shlex
from datetime import UTC, datetime
from pathlib import PurePosixPath
from typing import Any
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
logger = logging.getLogger(__name__)
class AllowlistProvider: class AllowlistProvider:
"""Simple allowlist/denylist provider. No external dependencies.""" """Simple allowlist/denylist provider. No external dependencies."""
@ -21,3 +34,138 @@ class AllowlistProvider:
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision: async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
return self.evaluate(request) return self.evaluate(request)
class SensitiveDataProvider:
"""Block tool calls that may access sensitive files such as .env and keys."""
name = "sensitive-data"
_DEFAULT_PROTECTED_TOOLS = {
"read_file",
"write_file",
"str_replace",
"ls",
"glob",
"grep",
"bash",
}
_DEFAULT_DENY_BASENAMES = {".env"}
_DEFAULT_DENY_GLOBS = {
".env.*",
"*.pem",
"*.key",
"id_rsa*",
"secrets.*",
"credentials.*",
}
def __init__(
self,
*,
protected_tools: list[str] | None = None,
deny_basenames: list[str] | None = None,
deny_globs: list[str] | None = None,
block_skills_env: bool = True,
**_: Any,
):
self._protected_tools = {t.lower() for t in (protected_tools or list(self._DEFAULT_PROTECTED_TOOLS))}
self._deny_basenames = {n.lower() for n in (deny_basenames or list(self._DEFAULT_DENY_BASENAMES))}
self._deny_globs = {p.lower() for p in (deny_globs or list(self._DEFAULT_DENY_GLOBS))}
self._block_skills_env = block_skills_env
def _normalize_candidate(self, raw: str | None) -> str:
if not raw:
return ""
return str(raw).strip().strip("\"'")
def _looks_sensitive_path(self, raw_path: str) -> bool:
value = self._normalize_candidate(raw_path)
if not value:
return False
lowered = value.lower()
if self._block_skills_env and "/mnt/skills/" in lowered:
basename = PurePosixPath(lowered).name
if basename == ".env" or basename.startswith(".env."):
return True
basename = PurePosixPath(lowered).name
if basename in self._deny_basenames:
return True
return any(fnmatch.fnmatch(basename, pat) for pat in self._deny_globs)
def _extract_bash_candidates(self, command: str) -> list[str]:
candidates: list[str] = []
if not command:
return candidates
try:
tokens = shlex.split(command)
except ValueError:
tokens = command.split()
for token in tokens:
t = token.strip()
if not t:
continue
# Path-like tokens
if "/" in t or t.startswith("."):
candidates.append(t)
# file.env style arguments may not contain slash
if t.lower().startswith(".env"):
candidates.append(t)
return candidates
def _collect_candidates(self, request: GuardrailRequest) -> list[str]:
args = request.tool_input if isinstance(request.tool_input, dict) else {}
tool = request.tool_name
candidates: list[str] = []
if tool in {"read_file", "write_file", "str_replace", "ls"}:
path = args.get("path")
if isinstance(path, str):
candidates.append(path)
elif tool in {"glob", "grep"}:
path = args.get("path")
if isinstance(path, str):
candidates.append(path)
glob_pat = args.get("glob")
if isinstance(glob_pat, str):
candidates.append(glob_pat)
elif tool == "bash":
command = str(args.get("command") or "")
candidates.extend(self._extract_bash_candidates(command))
# Fast-path for common secret exposure commands
if re.search(r"\b(printenv|env)\b", command, flags=re.IGNORECASE):
candidates.append(".env")
return candidates
def _audit(self, request: GuardrailRequest, decision: GuardrailDecision) -> None:
if decision.allow:
return
code = decision.reasons[0].code if decision.reasons else "oap.blocked_pattern"
rec = {
"timestamp": datetime.now(UTC).isoformat(),
"provider": self.name,
"tool_name": request.tool_name,
"reason_code": code,
"thread_id": request.thread_id,
"agent_id": request.agent_id,
}
logger.warning("[SensitiveDataGuardrail] %s", json.dumps(rec, ensure_ascii=False))
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
tool = (request.tool_name or "").lower()
if tool not in self._protected_tools:
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
candidates = self._collect_candidates(request)
if any(self._looks_sensitive_path(c) for c in candidates):
decision = GuardrailDecision(
allow=False,
reasons=[GuardrailReason(code="oap.blocked_pattern", message="sensitive path access is blocked by policy")],
policy_id="sensitive-data.v1",
)
self._audit(request, decision)
return decision
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
return self.evaluate(request)

View File

@ -23,7 +23,7 @@ def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
default_estimated_output_tokens=None, default_estimated_output_tokens=None,
) )
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096}) model_cfg = SimpleNamespace(display_name="GPT-4", model="gpt-4", model_extra={"max_tokens": 4096})
return SimpleNamespace( return SimpleNamespace(
billing=billing, billing=billing,
get_model_config=lambda name: model_cfg if name == "gpt-4" else None, get_model_config=lambda name: model_cfg if name == "gpt-4" else None,

View File

@ -0,0 +1,305 @@
"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
Uses in-memory SQLite (``:memory:``) no filesystem dependency.
"""
from __future__ import annotations
import pytest
# ---------------------------------------------------------------------------
# AsyncShallowSqliteSaver tests
# ---------------------------------------------------------------------------
class TestAsyncShallowSqliteSaver:
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
@pytest.mark.anyio
async def test_aput_deletes_old_checkpoints_before_insert(self):
"""After two aput calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
# Write first checkpoint
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
# Write second checkpoint — should delete the first
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
# Verify only 1 checkpoint remains
results = []
async for ckpt in saver.alist(thread_config):
results.append(ckpt)
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
@pytest.mark.anyio
async def test_different_threads_do_not_interfere(self):
"""Checkpoints for different thread_ids are independent."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
# Thread A: only ckpt a2
a_results = [c async for c in saver.alist(t1)]
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
# Thread B: still has b1
b_results = [c async for c in saver.alist(t2)]
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
@pytest.mark.anyio
async def test_writes_table_also_cleaned(self):
"""aput_writes entries from old checkpoints are also deleted."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
ShallowSaver = _make_async_shallow_saver()
async with ShallowSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
# Write checkpoint 1 with associated writes
ckpt1_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
# Write checkpoint 2 — should delete ckpt1 + its writes
ckpt2_config = await saver.aput(
thread_config,
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
# Verify only 1 checkpoint remains
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
# Verify only the latest writes exist (get_tuple returns writes)
latest = await saver.aget_tuple(ckpt2_config)
assert latest is not None
@pytest.mark.anyio
async def test_full_mode_retains_all_checkpoints(self):
"""In full mode (default), all checkpoints are preserved."""
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
await saver.setup()
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
results = [c async for c in saver.alist(thread_config)]
assert len(results) == 2, "Full mode should retain all checkpoints"
# ---------------------------------------------------------------------------
# ShallowSqliteSaver (sync) tests
# ---------------------------------------------------------------------------
class TestShallowSqliteSaver:
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
def test_put_deletes_old_checkpoints_before_insert(self):
"""After two put calls for the same thread, only 1 checkpoint remains."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
results = list(saver.list(thread_config))
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
def test_different_threads_do_not_interfere_sync(self):
"""Checkpoints for different thread_ids are independent (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
a_results = list(saver.list(t1))
assert len(a_results) == 1
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
b_results = list(saver.list(t2))
assert len(b_results) == 1
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
def test_writes_table_also_cleaned_sync(self):
"""put_writes entries from old checkpoints are also deleted (sync)."""
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
ShallowSaver = _make_sync_shallow_saver()
with ShallowSaver.from_conn_string(":memory:") as saver:
saver.setup()
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
ckpt1_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
{"source": "input", "step": 1, "writes": {}},
{},
)
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
ckpt2_config = saver.put(
thread_config,
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
{"source": "loop", "step": 2, "writes": {}},
{},
)
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
results = list(saver.list(thread_config))
assert len(results) == 1
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
# ---------------------------------------------------------------------------
# Config integration tests
# ---------------------------------------------------------------------------
class TestShallowConfig:
"""Tests for configuration and factory integration."""
def test_sqlite_mode_defaults_to_full(self):
"""sqlite_mode defaults to 'full' when not specified."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
assert config.sqlite_mode == "full"
def test_sqlite_mode_shallow_accepted(self):
"""sqlite_mode can be set to 'shallow'."""
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_with_shallow_mode(self):
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
"sqlite_mode": "shallow",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "shallow"
def test_load_sqlite_config_defaults_sqlite_mode(self):
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
from deerflow.config.checkpointer_config import (
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
set_checkpointer_config(None)
load_checkpointer_config_from_dict({
"type": "sqlite",
"connection_string": "/tmp/test.db",
})
config = get_checkpointer_config()
assert config is not None
assert config.sqlite_mode == "full"
@pytest.mark.anyio
async def test_async_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
set_checkpointer_config(CheckpointerConfig(
type="sqlite",
connection_string=":memory:",
sqlite_mode="shallow",
))
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
async with _async_checkpointer(config) as saver:
# Should be an instance of the shallow saver
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
set_checkpointer_config(None)
def test_sync_factory_uses_shallow_saver(self):
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
from deerflow.config.checkpointer_config import CheckpointerConfig
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
with _sync_checkpointer_cm(config) as saver:
cls_name = type(saver).__name__
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
def test_invalid_sqlite_mode_raises(self):
"""Invalid sqlite_mode value raises validation error."""
from deerflow.config.checkpointer_config import CheckpointerConfig
with pytest.raises(Exception):
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")

View File

@ -8,7 +8,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from langgraph.errors import GraphBubbleUp from langgraph.errors import GraphBubbleUp
from deerflow.guardrails.builtin import AllowlistProvider from deerflow.guardrails.builtin import AllowlistProvider, SensitiveDataProvider
from deerflow.guardrails.middleware import GuardrailMiddleware from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
@ -105,6 +105,46 @@ class TestAllowlistProvider:
assert decision.allow is False assert decision.allow is False
class TestSensitiveDataProvider:
def test_denies_reading_env_file(self):
provider = SensitiveDataProvider()
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/tmp/.env"})
decision = provider.evaluate(req)
assert decision.allow is False
assert decision.reasons[0].code == "oap.blocked_pattern"
def test_allows_normal_source_file(self):
provider = SensitiveDataProvider()
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/app/main.py"})
decision = provider.evaluate(req)
assert decision.allow is True
def test_denies_skills_env(self):
provider = SensitiveDataProvider()
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/mnt/skills/public/foo/.env.local"})
decision = provider.evaluate(req)
assert decision.allow is False
def test_denies_glob_or_grep_targeting_env(self):
provider = SensitiveDataProvider()
req = GuardrailRequest(tool_name="grep", tool_input={"path": "/workspace", "glob": "**/.env.*"})
decision = provider.evaluate(req)
assert decision.allow is False
def test_denies_bash_cat_env(self):
provider = SensitiveDataProvider()
req = GuardrailRequest(tool_name="bash", tool_input={"command": "cat /workspace/.env"})
decision = provider.evaluate(req)
assert decision.allow is False
def test_denies_case_variant_and_key_material(self):
provider = SensitiveDataProvider()
req1 = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/.ENV"})
req2 = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/id_rsa.pub"})
assert provider.evaluate(req1).allow is False
assert provider.evaluate(req2).allow is False
# --- GuardrailMiddleware tests --- # --- GuardrailMiddleware tests ---

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock
from langchain_core.messages import ToolMessage
from deerflow.agents.middlewares.sensitive_output_redaction_middleware import SensitiveOutputRedactionMiddleware
def _request(name: str = "bash", args: dict | None = None):
req = MagicMock()
req.tool_call = {"name": name, "args": args or {}, "id": "call_1"}
return req
def test_redacts_key_value_and_bearer():
mw = SensitiveOutputRedactionMiddleware()
req = _request()
handler = MagicMock(
return_value=ToolMessage(
content="OPENAI_API_KEY=sk-abc123456789\nAuthorization: Bearer abcdefghijklmnop",
tool_call_id="call_1",
name="bash",
status="success",
)
)
result = mw.wrap_tool_call(req, handler)
assert "OPENAI_API_KEY=[REDACTED]" in result.content
assert "Bearer [REDACTED]" in result.content
def test_redacts_private_key_block():
mw = SensitiveOutputRedactionMiddleware()
req = _request()
handler = MagicMock(
return_value=ToolMessage(
content="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----",
tool_call_id="call_1",
name="bash",
status="success",
)
)
result = mw.wrap_tool_call(req, handler)
assert result.content == "[REDACTED]"
def test_async_path_redacts_jwt():
mw = SensitiveOutputRedactionMiddleware()
req = _request()
async def handler(_):
return ToolMessage(
content="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.abcde12345.zyxwv98765",
tool_call_id="call_1",
name="bash",
status="success",
)
result = asyncio.run(mw.awrap_tool_call(req, handler))
assert result.content == "[REDACTED]"

View File

@ -709,6 +709,10 @@ memory:
# memory - In-process only. State is lost when the process exits. (default) # memory - In-process only. State is lost when the process exits. (default)
# sqlite - File-based SQLite persistence. Survives restarts. # sqlite - File-based SQLite persistence. Survives restarts.
# Requires: uv add langgraph-checkpoint-sqlite # Requires: uv add langgraph-checkpoint-sqlite
# sqlite_mode: full (default) retains all history.
# sqlite_mode: shallow keeps only the latest checkpoint per
# thread, deleting old records before each write to prevent
# unbounded database growth.
# postgres - PostgreSQL persistence. Suitable for multi-process deployments. # postgres - PostgreSQL persistence. Suitable for multi-process deployments.
# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool # Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool
# #
@ -722,6 +726,8 @@ memory:
checkpointer: checkpointer:
type: sqlite type: sqlite
connection_string: checkpoints.db connection_string: checkpoints.db
# sqlite_mode: full # default — keep all checkpoint history
# sqlite_mode: shallow # keep only latest checkpoint per thread (prevents DB bloat)
# #
# PostgreSQL (multi-process, production): # PostgreSQL (multi-process, production):
# checkpointer: # checkpointer:
@ -832,3 +838,15 @@ checkpointer:
# use: my_package:MyGuardrailProvider # use: my_package:MyGuardrailProvider
# config: # config:
# key: value # key: value
# --- Option 4: SensitiveDataProvider (strict secret-file blocking) ---
# guardrails:
# enabled: true
# fail_closed: true
# provider:
# use: deerflow.guardrails.builtin:SensitiveDataProvider
# config:
# protected_tools: ["read_file", "write_file", "str_replace", "ls", "glob", "grep", "bash"]
# deny_basenames: [".env"]
# deny_globs: [".env.*", "*.pem", "*.key", "id_rsa*", "secrets.*", "credentials.*"]
# block_skills_env: true

View File

@ -166,7 +166,6 @@ export default function AgentChatPage() {
threadId={threadId} threadId={threadId}
autoFocus={isNewThread} autoFocus={isNewThread}
showWelcomeStyle={isNewThread} showWelcomeStyle={isNewThread}
hasSubmitted={!isNewThread}
status={ status={
thread.error thread.error
? "error" ? "error"

View File

@ -233,7 +233,6 @@ export default function ChatPage() {
[thread.values.artifacts], [thread.values.artifacts],
); );
const [hasSubmitted, setHasSubmitted] = useState(false);
const [historyCutoff, setHistoryCutoff] = useState<number | null>(null); const [historyCutoff, setHistoryCutoff] = useState<number | null>(null);
useEffect(() => { useEffect(() => {
@ -241,7 +240,6 @@ export default function ChatPage() {
setHistoryCutoff(null); setHistoryCutoff(null);
return; return;
} }
if (hasSubmitted) return;
// Welcome 态下、未提交前,把当前已有消息都当作“历史”切掉。 // Welcome 态下、未提交前,把当前已有消息都当作“历史”切掉。
// 这样即使历史消息是后续异步补齐,也不会重新露出。 // 这样即使历史消息是后续异步补齐,也不会重新露出。
setHistoryCutoff((prev) => { setHistoryCutoff((prev) => {
@ -250,7 +248,6 @@ export default function ChatPage() {
return next > prev ? next : prev; return next > prev ? next : prev;
}); });
}, [ }, [
hasSubmitted,
historyCutoff, historyCutoff,
shouldRenderHistory, shouldRenderHistory,
thread.isThreadLoading, thread.isThreadLoading,
@ -322,7 +319,6 @@ export default function ChatPage() {
if (isNewThread && safeThreadId && !isThreadInitReady) { if (isNewThread && safeThreadId && !isThreadInitReady) {
return; return;
} }
setHasSubmitted(true);
if (safeThreadId && (isNewThread || showWelcomeStyle)) { if (safeThreadId && (isNewThread || showWelcomeStyle)) {
router.replace(`/workspace/chats/${safeThreadId}?is_chatting=true`); router.replace(`/workspace/chats/${safeThreadId}?is_chatting=true`);
} }
@ -343,22 +339,6 @@ export default function ChatPage() {
await thread.stop(); await thread.stop();
}, [thread]); }, [thread]);
const resetNewSessionState = useCallback(() => {
setIsNewThread(true);
setHasSubmitted(false);
setHistoryCutoff(null);
setArtifacts([]);
deselectArtifact();
setArtifactsOpen(false);
setArtifactsFullscreen(false);
}, [
deselectArtifact,
setArtifacts,
setArtifactsFullscreen,
setArtifactsOpen,
setIsNewThread,
]);
return ( return (
<ThreadContext.Provider value={{ threadId, thread }}> <ThreadContext.Provider value={{ threadId, thread }}>
<div <div
@ -379,7 +359,7 @@ export default function ChatPage() {
<header <header
className={cn( className={cn(
"bg-background absolute top-0 right-0 left-0 z-30 mx-4 grid h-[58px] shrink-0 grid-cols-3 items-center border-b transition-all duration-300 ease-in-out", "bg-background absolute top-0 right-0 left-0 z-30 mx-4 grid h-[58px] shrink-0 grid-cols-3 items-center border-b transition-all duration-300 ease-in-out",
showWelcomeStyle && !hasSubmitted ? "hidden" : "", showWelcomeStyle ? "hidden" : "",
)} )}
> >
<div className="flex items-center justify-start overflow-hidden text-sm font-medium"> <div className="flex items-center justify-start overflow-hidden text-sm font-medium">
@ -388,7 +368,14 @@ export default function ChatPage() {
variant="ghost" variant="ghost"
className="px-[10px] py-[5px] text-sm font-medium text-ws-base-1 hover:text-ws-base-1/80" className="px-[10px] py-[5px] text-sm font-medium text-ws-base-1 hover:text-ws-base-1/80"
disabled={isStreaming} disabled={isStreaming}
onClick={() => setShowExitDialog(true)} onClick={() => {
sendToParent({
type: POST_MESSAGE_TYPES.IS_CHATTING,
isChatting: false,
});
router.replace(`/workspace/chats/${threadId}?is_chatting=false`)
}
}
> >
<svg <svg
width="20" width="20"
@ -444,6 +431,20 @@ export default function ChatPage() {
</Button> </Button>
} }
/> */} /> */}
<Tooltip content={t.common.resetThread}>
<Button
size="sm"
variant="ghost"
className="h-full px-[10px] py-[5px] text-sm font-medium text-ws-base-1 hover:text-ws-base-1"
disabled={isStreaming}
onClick={() => setShowExitDialog(true)}
>
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
<path d="M2 4H6M16 4H12M6 4H12M6 4C6 2.89543 6.89543 2 8 2H10C11.1046 2 12 2.89543 12 4M4 6V14C4 15.1046 4.89543 16 6 16H12C13.1046 16 14 15.1046 14 14V6M7 8V13M11 8V13" stroke="#150033" stroke-linecap="round" />
</svg>
{t.common.resetThread}
</Button>
</Tooltip>
{artifacts?.length > 0 && !artifactsOpen && ( {artifacts?.length > 0 && !artifactsOpen && (
<Tooltip content={t.chatPage.viewArtifactsTooltip}> <Tooltip content={t.chatPage.viewArtifactsTooltip}>
@ -456,7 +457,12 @@ export default function ChatPage() {
setSidebarOpen(false); setSidebarOpen(false);
}} }}
> >
<FilesIcon /> <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
<path d="M16 7V4C16 2.89543 15.1046 2 14 2H4C2.89543 2 2 2.89543 2 4V14C2 15.1046 2.89543 16 4 16H9" stroke="#150033" stroke-linecap="round" />
<path d="M5 5H9M5 8H7" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" />
<circle cx="11.5" cy="10.5" r="3" stroke="#150033" />
<path d="M15.5 14.5L14 13" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" />
</svg>
{t.common.artifacts} {t.common.artifacts}
</Button> </Button>
</Tooltip> </Tooltip>
@ -466,16 +472,14 @@ export default function ChatPage() {
<main <main
className={cn( className={cn(
"flex min-h-0 max-w-full grow flex-col", "flex min-h-0 max-w-full grow flex-col",
showWelcomeStyle && !hasSubmitted showWelcomeStyle ? "bg-ws-surface-base" : "bg-background",
? "bg-ws-surface-base"
: "bg-background",
)} )}
> >
<div className="flex size-full justify-center"> <div className="flex size-full justify-center">
<MessageList <MessageList
className={cn( className={cn(
"size-full", "size-full",
(!showWelcomeStyle || hasSubmitted) && "pt-[58px]", !showWelcomeStyle && "pt-[58px]",
)} )}
threadId={threadId} threadId={threadId}
thread={thread} thread={thread}
@ -508,7 +512,7 @@ export default function ChatPage() {
<div <div
className={cn( className={cn(
"h-full w-full transition-transform duration-300 ease-in-out", "h-full w-full transition-transform duration-300 ease-in-out",
showWelcomeStyle && !hasSubmitted ? "translate-x-0" : "", showWelcomeStyle ? "translate-x-0" : "",
artifactPanelOpen ? "translate-x-0" : "translate-x-full", artifactPanelOpen ? "translate-x-0" : "translate-x-full",
)} )}
> >
@ -570,9 +574,7 @@ export default function ChatPage() {
<div <div
className={cn( className={cn(
"pointer-events-auto relative w-full max-w-[720px]", "pointer-events-auto relative w-full max-w-[720px]",
showWelcomeStyle && showWelcomeStyle && "-translate-y-[calc(50vh-96px)]",
!hasSubmitted &&
"-translate-y-[calc(50vh-96px)]",
)} )}
> >
{!(showWelcomeStyle && thread.isThreadLoading) ? ( {!(showWelcomeStyle && thread.isThreadLoading) ? (
@ -581,7 +583,6 @@ export default function ChatPage() {
className={cn("w-full rounded-[20px] bg-ws-surface-elevated")} className={cn("w-full rounded-[20px] bg-ws-surface-elevated")}
threadId={threadId} threadId={threadId}
showWelcomeStyle={showWelcomeStyle} showWelcomeStyle={showWelcomeStyle}
hasSubmitted={hasSubmitted}
autoFocus={showWelcomeStyle} autoFocus={showWelcomeStyle}
status={ status={
thread.error thread.error
@ -593,9 +594,7 @@ export default function ChatPage() {
context={settings.context} context={settings.context}
extraHeader={ extraHeader={
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
{showWelcomeStyle && !hasSubmitted && ( {showWelcomeStyle && <Welcome mode={settings.context.mode} />}
<Welcome mode={settings.context.mode} />
)}
</div> </div>
} }
disabled={ disabled={

View File

@ -219,7 +219,6 @@ export function InputBox({
context, context,
extraHeader, extraHeader,
showWelcomeStyle, showWelcomeStyle,
hasSubmitted,
initialValue, initialValue,
onContextChange, onContextChange,
onSubmit, onSubmit,
@ -238,7 +237,6 @@ export function InputBox({
}; };
extraHeader?: React.ReactNode; extraHeader?: React.ReactNode;
showWelcomeStyle: boolean; showWelcomeStyle: boolean;
hasSubmitted?: boolean;
initialValue?: string; initialValue?: string;
onContextChange?: ( onContextChange?: (
context: Omit< context: Omit<
@ -294,14 +292,13 @@ export function InputBox({
const [isInputToolsTourReady, setIsInputToolsTourReady] = useState(false); const [isInputToolsTourReady, setIsInputToolsTourReady] = useState(false);
const { data: referenceFilesData } = useReferenceFiles(threadIdFromProps); const { data: referenceFilesData } = useReferenceFiles(threadIdFromProps);
// isNewThread 时禁用收缩,始终保持展开(除非已提交消息) // Welcome 态下禁用收缩,始终保持展开
const effectiveIsFocused = const effectiveIsFocused = (showWelcomeStyle ?? false) || isFocused;
((showWelcomeStyle ?? false) && !hasSubmitted) || isFocused;
const shouldShowSuggestionList = const shouldShowSuggestionList =
showWelcomeStyle && !hasSubmitted && searchParams.get("mode") !== "skill"; showWelcomeStyle && searchParams.get("mode") !== "skill";
useEffect(() => { useEffect(() => {
if (!showWelcomeStyle || hasSubmitted) { if (!showWelcomeStyle) {
setIsInputToolsTourReady(false); setIsInputToolsTourReady(false);
return; return;
} }
@ -318,14 +315,13 @@ export function InputBox({
return () => window.cancelAnimationFrame(frameId); return () => window.cancelAnimationFrame(frameId);
}, [ }, [
showWelcomeStyle, showWelcomeStyle,
hasSubmitted,
shouldShowSuggestionList, shouldShowSuggestionList,
iframeSkill.isBootstrapping, iframeSkill.isBootstrapping,
iframeSkill.selectedSkills.length, iframeSkill.selectedSkills.length,
]); ]);
useEffect(() => { useEffect(() => {
if (!showWelcomeStyle || hasSubmitted || !isInputToolsTourReady) { if (!showWelcomeStyle || !isInputToolsTourReady) {
setIsInputToolsTourOpen(false); setIsInputToolsTourOpen(false);
return; return;
} }
@ -337,7 +333,7 @@ export function InputBox({
if (!hasSeenTourForCurrentThread) { if (!hasSeenTourForCurrentThread) {
setIsInputToolsTourOpen(true); setIsInputToolsTourOpen(true);
} }
}, [showWelcomeStyle, hasSubmitted, isInputToolsTourReady, threadId]); }, [showWelcomeStyle, isInputToolsTourReady, threadId]);
const finishInputToolsTour = useCallback(() => { const finishInputToolsTour = useCallback(() => {
const seenState = parseInputToolsTourSeenState( const seenState = parseInputToolsTourSeenState(
@ -817,7 +813,6 @@ export function InputBox({
"border-0 rounded-[20px] backdrop-blur-sm", "border-0 rounded-[20px] backdrop-blur-sm",
"transition-[height] duration-300 ease-out shadow-none ", "transition-[height] duration-300 ease-out shadow-none ",
!showWelcomeStyle && "h-[200px] shadow-[0_0_20px_0_rgba(0,0,0,0.10)]", !showWelcomeStyle && "h-[200px] shadow-[0_0_20px_0_rgba(0,0,0,0.10)]",
hasSubmitted && "shadow-[0_0_20px_0_rgba(0,0,0,0.10)]!",
effectiveIsFocused ? "h-[200px]" : "h-[80px]", effectiveIsFocused ? "h-[200px]" : "h-[80px]",
)} )}
disabled={isInputDisabled} disabled={isInputDisabled}
@ -970,14 +965,14 @@ export function InputBox({
/> />
</div> </div>
)} )}
{!showWelcomeStyle && ( {/* {!showWelcomeStyle && (
<div className="shrink-0 h-full"> <div className="shrink-0 h-full">
<ExitChattingButton <ExitChattingButton
router={router} router={router}
threadId={threadIdFromProps} threadId={threadIdFromProps}
/> />
</div> </div>
)} )} */}
<div ref={attachmentsButtonTourRef} className="shrink-0 h-full"> <div ref={attachmentsButtonTourRef} className="shrink-0 h-full">
<AddAttachmentsButton /> <AddAttachmentsButton />
</div> </div>
@ -1271,8 +1266,12 @@ function HistoryButton({
<Tooltip content={t.inputBox.history}> <Tooltip content={t.inputBox.history}>
<WorkspaceToolButton <WorkspaceToolButton
className={cn("text-ws-base-1 hover:text-ws-interactive-primary", className)} className={cn("text-ws-base-1 hover:text-ws-interactive-primary", className)}
onClick={() => onClick={() =>{
router.replace(`/workspace/chats/${threadId}?is_chatting=true`) sendToParent({
type: POST_MESSAGE_TYPES.IS_CHATTING,
isChatting: true,
});
router.replace(`/workspace/chats/${threadId}?is_chatting=true`)}
} }
> >
<svg <svg
@ -1318,9 +1317,13 @@ function ExitChattingButton({
"text-ws-base-1 hover:text-ws-interactive-primary", "text-ws-base-1 hover:text-ws-interactive-primary",
className, className,
)} )}
onClick={() => onClick={() => {
router.replace(`/workspace/chats/${threadId}?is_chatting=false`) sendToParent({
} type: POST_MESSAGE_TYPES.IS_CHATTING,
isChatting: false,
});
router.replace(`/workspace/chats/${threadId}?is_chatting=false`);
}}
> >
<svg <svg
className="transition-[color] duration-200" className="transition-[color] duration-200"

View File

@ -46,6 +46,13 @@ import { CopyButton } from "../copy-button";
import { MarkdownContent } from "./markdown-content"; import { MarkdownContent } from "./markdown-content";
function localizeAssistantFixedCopy(content: string, localized: string): string {
if (content.includes("The account balance is insufficient for this model call.")) {
return localized;
}
return content;
}
export function MessageListItem({ export function MessageListItem({
className, className,
message, message,
@ -176,8 +183,14 @@ function MessageContent_({
const cleaned = stripPriorityHintSuffix(stripUploadedFilesTag(rawContent)); const cleaned = stripPriorityHintSuffix(stripUploadedFilesTag(rawContent));
return normalizeHumanMessageDisplayText(cleaned); return normalizeHumanMessageDisplayText(cleaned);
} }
return rawContent ?? ""; if (!rawContent) {
}, [rawContent, isHuman]); return "";
}
return localizeAssistantFixedCopy(
rawContent,
t.threads.billingInsufficientBalance,
);
}, [rawContent, isHuman, t.threads.billingInsufficientBalance]);
const isSummaryMessage = useMemo( const isSummaryMessage = useMemo(
() => isHuman && isSummaryTemplateMessage(message), () => isHuman && isSummaryTemplateMessage(message),
[isHuman, message], [isHuman, message],

View File

@ -3,6 +3,7 @@
import { MessageSquarePlus } from "lucide-react"; import { MessageSquarePlus } from "lucide-react";
import Link from "next/link"; import Link from "next/link";
import { usePathname } from "next/navigation"; import { usePathname } from "next/navigation";
import { toast } from "sonner";
import { import {
SidebarMenu, SidebarMenu,
@ -11,14 +12,34 @@ import {
SidebarTrigger, SidebarTrigger,
useSidebar, useSidebar,
} from "@/components/ui/sidebar"; } from "@/components/ui/sidebar";
import { useThreadChat } from "@/components/workspace/chats";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages";
import { env } from "@/env"; import { env } from "@/env";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { copyToClipboard } from "@/lib/utils";
export function WorkspaceHeader({ className }: { className?: string }) { export function WorkspaceHeader({ className }: { className?: string }) {
const { t } = useI18n(); const { t } = useI18n();
const { state } = useSidebar(); const { state } = useSidebar();
const pathname = usePathname(); const pathname = usePathname();
const { threadId } = useThreadChat();
const threadUrl = threadId ? `/workspace/chats/${threadId}` : "";
const handleCopyThreadId = async () => {
if (!threadId) return;
sendToParent({
type: POST_MESSAGE_TYPES.COPY_TO_CLIPBOARD,
text: threadId,
});
try {
await copyToClipboard(threadId);
toast.success(t.clipboard.copiedToClipboard);
} catch {
toast.error(t.clipboard.failedToCopyToClipboard);
}
};
return ( return (
<> <>
<div <div
@ -43,7 +64,33 @@ export function WorkspaceHeader({ className }: { className?: string }) {
) : ( ) : (
<div className="text-primary ml-2 cursor-default font-serif"> <div className="text-primary ml-2 cursor-default font-serif">
{/* TODO: 测试标识 */} {/* TODO: 测试标识 */}
XClaw <span className="text-sm text-ws-text-subtle-strong">v3.2.8</span> XClaw{" "}
<span className="text-sm text-ws-text-subtle-strong">v3.2.9 </span>{" "}
<span
className={cn(
"text-xs font-mono",
threadId
? "cursor-pointer underline decoration-dotted underline-offset-4"
: "text-ws-text-subtle-strong",
)}
onClick={() => {
void handleCopyThreadId();
}}
title={threadId ? t.clipboard.copyToClipboard : undefined}
>
id:{threadId ? threadId.slice(0, 5) : "-"}
</span>
{" "}
{threadId && (
<a
href={threadUrl}
target="_blank"
rel="noopener noreferrer"
className="text-xs underline decoration-dotted underline-offset-4"
>
</a>
)}
</div> </div>
)} )}
<SidebarTrigger /> <SidebarTrigger />

View File

@ -53,6 +53,7 @@ export const enUS: Translations = {
exportSuccess: "Conversation exported", exportSuccess: "Conversation exported",
removeAttachment: "Remove attachment", removeAttachment: "Remove attachment",
reference: "Reference", reference: "Reference",
resetThread: "Reset thread",
}, },
// Welcome // Welcome
@ -319,6 +320,8 @@ export const enUS: Translations = {
threads: { threads: {
streamError: "Something went wrong.", streamError: "Something went wrong.",
billingInsufficientBalance:
"The account balance is insufficient for this model call.",
invalidThreadId: "Invalid thread id 'new'. Please refresh and retry.", invalidThreadId: "Invalid thread id 'new'. Please refresh and retry.",
staleReferencesRemoved: staleReferencesRemoved:
"Some referenced files were invalid and were removed automatically.", "Some referenced files were invalid and were removed automatically.",

View File

@ -48,6 +48,7 @@ export interface Translations {
exportSuccess: string; exportSuccess: string;
removeAttachment: string; removeAttachment: string;
reference: string; reference: string;
resetThread: string;
}; };
// Welcome // Welcome
@ -245,6 +246,7 @@ export interface Translations {
threads: { threads: {
streamError: string; streamError: string;
billingInsufficientBalance: string;
invalidThreadId: string; invalidThreadId: string;
staleReferencesRemoved: string; staleReferencesRemoved: string;
uploadFailed: string; uploadFailed: string;

View File

@ -55,6 +55,7 @@ export const zhCN: Translations = {
exportSuccess: "对话已导出", exportSuccess: "对话已导出",
removeAttachment: "移除附件", removeAttachment: "移除附件",
reference: "引用", reference: "引用",
resetThread: "重置会话",
}, },
// Welcome // Welcome
@ -150,10 +151,10 @@ export const zhCN: Translations = {
children: [{ id: "17", name: "Excel处理" }], children: [{ id: "17", name: "Excel处理" }],
}, },
{ {
suggestion: "营销策划", suggestion: "微信文章撰写",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。", prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon, icon: ShapesIcon,
children: [{ id: "217", name: "产品营销背景" }], children: [{ id: "6134", name: "微信文章撰写" }],
}, },
], ],
suggestionsCreate: [ suggestionsCreate: [
@ -306,6 +307,7 @@ export const zhCN: Translations = {
threads: { threads: {
streamError: "出现了某些错误。", streamError: "出现了某些错误。",
billingInsufficientBalance: "账户余额不足,无法完成本次模型调用。",
invalidThreadId: "线程 ID 无效new请刷新后重试。", invalidThreadId: "线程 ID 无效new请刷新后重试。",
staleReferencesRemoved: "部分引用文件已失效,已自动移除并继续发送。", staleReferencesRemoved: "部分引用文件已失效,已自动移除并继续发送。",
uploadFailed: "文件上传失败。", uploadFailed: "文件上传失败。",