Compare commits
No commits in common. "f3558d6bb2a2c5e1bfab90bca0912ed8efb3b3c7" and "ab9555255a40ff1ab59e782e3ebb64fd12636336" have entirely different histories.
f3558d6bb2
...
ab9555255a
@ -112,34 +112,6 @@ guardrails:
|
||||
3. Ask the agent: "Use bash to run echo hello"
|
||||
4. The agent sees: `Guardrail denied: tool 'bash' was blocked (oap.tool_not_allowed)`
|
||||
|
||||
### Option 1.5: Built-in SensitiveDataProvider (Strict Secret Blocking)
|
||||
|
||||
For secret-leak prevention, DeerFlow also ships `SensitiveDataProvider`, which
|
||||
blocks tool calls targeting sensitive file patterns (for example `.env`,
|
||||
`.env.*`, `*.pem`, `*.key`, `id_rsa*`, `secrets.*`, `credentials.*`).
|
||||
|
||||
This provider is strict: it blocks matching access even when the user
|
||||
explicitly asks the agent to reveal those files.
|
||||
|
||||
**config.yaml:**
|
||||
```yaml
|
||||
guardrails:
|
||||
enabled: true
|
||||
fail_closed: true
|
||||
provider:
|
||||
use: deerflow.guardrails.builtin:SensitiveDataProvider
|
||||
config:
|
||||
protected_tools: ["read_file", "write_file", "str_replace", "ls", "glob", "grep", "bash"]
|
||||
deny_basenames: [".env"]
|
||||
deny_globs: [".env.*", "*.pem", "*.key", "id_rsa*", "secrets.*", "credentials.*"]
|
||||
block_skills_env: true
|
||||
```
|
||||
|
||||
**Behavior summary:**
|
||||
- `read_file / ls / glob / grep / bash` attempting sensitive path access are denied
|
||||
- `write_file / str_replace` touching sensitive targets are denied
|
||||
- Denials are logged as structured audit events (tool/reason/thread/timestamp)
|
||||
|
||||
### Option 2: OAP Passport Provider (Policy-Based)
|
||||
|
||||
For policy enforcement based on the [Open Agent Passport (OAP)](https://github.com/aporthq/aport-spec) open standard. An OAP passport is a JSON document that declares an agent's identity, capabilities, and operational limits. Any provider that reads an OAP passport and returns OAP-compliant decisions works with DeerFlow.
|
||||
|
||||
@ -7,14 +7,3 @@ __all__ = [
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
|
||||
# Lazy-import shallow savers so the module is still importable without
|
||||
# langgraph-checkpoint-sqlite installed.
|
||||
def __getattr__(name: str):
|
||||
if name == "AsyncShallowSqliteSaver":
|
||||
from .shallow_sqlite import _make_async_shallow_saver
|
||||
return _make_async_shallow_saver()
|
||||
if name == "ShallowSqliteSaver":
|
||||
from .shallow_sqlite import _make_sync_shallow_saver
|
||||
return _make_sync_shallow_saver()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -55,18 +55,6 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
|
||||
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
|
||||
if getattr(config, "sqlite_mode", "full") == "shallow":
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||
|
||||
ShallowSaver = _make_async_shallow_saver()
|
||||
async with ShallowSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
logger.info("Checkpointer: using AsyncShallowSqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
|
||||
@ -67,18 +67,6 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
|
||||
# Shallow mode: use custom saver that keeps only the latest checkpoint per thread
|
||||
if getattr(config, "sqlite_mode", "full") == "shallow":
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||
|
||||
ShallowSaver = _make_sync_shallow_saver()
|
||||
with ShallowSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using ShallowSqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
"""Shallow persistence savers for LangGraph SQLite checkpointing.
|
||||
|
||||
Provides shallow (single-checkpoint-per-thread) variants of the LangGraph
|
||||
SQLite savers that automatically delete old checkpoints and writes for the
|
||||
same thread before each write, keeping only the latest state.
|
||||
|
||||
This prevents unbounded growth of ``checkpoints.db`` while preserving
|
||||
multi-turn conversation continuity.
|
||||
|
||||
Implements:
|
||||
- ``AsyncShallowSqliteSaver`` — async shallow variant
|
||||
- ``ShallowSqliteSaver`` — sync shallow variant
|
||||
|
||||
Usage is transparent through the existing checkpointer factory when
|
||||
``sqlite_mode: shallow`` is set in ``config.yaml``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.base import ChannelVersions, Checkpoint, CheckpointMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async shallow saver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AsyncShallowSqliteSaver:
|
||||
"""Async SQLite checkpointer that keeps only the latest checkpoint per thread.
|
||||
|
||||
Extends :class:`langgraph.checkpoint.sqlite.aio.AsyncSqliteSaver` and
|
||||
overrides :meth:`aput` to delete all existing checkpoints and writes for
|
||||
the same ``thread_id`` before inserting the new one.
|
||||
|
||||
Each conversation thread stores exactly one checkpoint at any time,
|
||||
preventing unbounded database growth.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
# Allow extension by subclasses without forcing late-bound import here.
|
||||
# The concrete class is built below via the _make_async_shallow factory.
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
def _make_async_shallow_saver() -> type:
|
||||
"""Build and return the ``AsyncShallowSqliteSaver`` class.
|
||||
|
||||
Import is deferred so that the module is importable even when
|
||||
``langgraph-checkpoint-sqlite`` is not installed.
|
||||
"""
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
class AsyncShallowSqliteSaver(AsyncSqliteSaver):
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
||||
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
await self.setup()
|
||||
|
||||
# Delete all existing checkpoints and writes for this thread
|
||||
# before inserting the new checkpoint — keeps only the latest.
|
||||
async with self.lock:
|
||||
await self.conn.execute(
|
||||
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||
(str(thread_id),),
|
||||
)
|
||||
await self.conn.execute(
|
||||
"DELETE FROM writes WHERE thread_id = ?",
|
||||
(str(thread_id),),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
return await super().aput(config, checkpoint, metadata, new_versions)
|
||||
|
||||
return AsyncShallowSqliteSaver
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync shallow saver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sync_shallow_saver() -> type:
|
||||
"""Build and return the ``ShallowSqliteSaver`` class.
|
||||
|
||||
Import is deferred so that the module is importable even when
|
||||
``langgraph-checkpoint-sqlite`` is not installed.
|
||||
"""
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
class ShallowSqliteSaver(SqliteSaver):
|
||||
def put(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Delete old checkpoints/writes for this thread, then save the new one."""
|
||||
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
|
||||
# Delete all existing checkpoints and writes for this thread
|
||||
# before inserting the new checkpoint — keeps only the latest.
|
||||
with self.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM checkpoints WHERE thread_id = ?",
|
||||
(str(thread_id),),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM writes WHERE thread_id = ?",
|
||||
(str(thread_id),),
|
||||
)
|
||||
|
||||
return super().put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
return ShallowSqliteSaver
|
||||
@ -174,6 +174,7 @@ def _extract_run_id(request: ModelRequest) -> str | None: # noqa: ARG001
|
||||
|
||||
def _reserve_failure_message(status_code: int | None) -> str:
|
||||
if status_code in _blocking_reserve_code_set():
|
||||
# TODO: 将账单错误文案迁移到国际化资源中,按语言返回提示。
|
||||
return "The account balance is insufficient for this model call."
|
||||
return "Billing reservation failed. Please try again later."
|
||||
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
"""Redact sensitive values from tool outputs before they re-enter the model context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
_SIMPLE_REDACTION = "[REDACTED]"
|
||||
|
||||
_PATTERNS = [
|
||||
re.compile(r"(?im)\b([A-Z][A-Z0-9_]{1,64})\s*=\s*([^\s\"'`]{6,})"),
|
||||
re.compile(r"(?i)\b(bearer\s+)[A-Za-z0-9._\-+/=]{8,}"),
|
||||
re.compile(r"(?i)\b(sk-[A-Za-z0-9]{12,})\b"),
|
||||
re.compile(r"(?i)\b(eyJ[A-Za-z0-9_\-]{10,}\.[A-Za-z0-9_\-]{10,}\.[A-Za-z0-9_\-]{10,})\b"),
|
||||
re.compile(r"(?is)-----BEGIN [A-Z ]*PRIVATE KEY-----.*?-----END [A-Z ]*PRIVATE KEY-----"),
|
||||
]
|
||||
|
||||
|
||||
def _redact_text(text: str) -> str:
|
||||
value = text
|
||||
# Preserve var name for KEY=VALUE style output.
|
||||
value = _PATTERNS[0].sub(lambda m: f"{m.group(1)}={_SIMPLE_REDACTION}", value)
|
||||
value = _PATTERNS[1].sub(lambda m: f"{m.group(1)}{_SIMPLE_REDACTION}", value)
|
||||
for pattern in _PATTERNS[2:]:
|
||||
value = pattern.sub(_SIMPLE_REDACTION, value)
|
||||
return value
|
||||
|
||||
|
||||
class SensitiveOutputRedactionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Redact secrets from tool outputs."""
|
||||
|
||||
def _redact_tool_result(self, result: ToolMessage | Command) -> ToolMessage | Command:
|
||||
if not isinstance(result, ToolMessage):
|
||||
return result
|
||||
if isinstance(result.content, str):
|
||||
redacted = _redact_text(result.content)
|
||||
if redacted != result.content:
|
||||
return ToolMessage(content=redacted, tool_call_id=result.tool_call_id, name=result.name, status=result.status)
|
||||
return result
|
||||
return result
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
return self._redact_tool_result(handler(request))
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
return self._redact_tool_result(await handler(request))
|
||||
|
||||
@ -127,10 +127,8 @@ def _build_runtime_middlewares(
|
||||
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
|
||||
|
||||
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
|
||||
from deerflow.agents.middlewares.sensitive_output_redaction_middleware import SensitiveOutputRedactionMiddleware
|
||||
|
||||
middlewares.append(SandboxAuditMiddleware())
|
||||
middlewares.append(SensitiveOutputRedactionMiddleware())
|
||||
middlewares.append(ToolErrorHandlingMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
@ -6,14 +6,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
||||
|
||||
SqliteMode = Literal["full", "shallow"]
|
||||
"""Persistence mode for the SQLite checkpointer.
|
||||
|
||||
- ``full`` — retain all checkpoint history (default, backward-compatible).
|
||||
- ``shallow`` — keep only the latest checkpoint per thread, deleting old
|
||||
records before each write to prevent unbounded database growth.
|
||||
"""
|
||||
|
||||
|
||||
class CheckpointerConfig(BaseModel):
|
||||
"""Configuration for LangGraph state persistence checkpointer."""
|
||||
@ -31,13 +23,6 @@ class CheckpointerConfig(BaseModel):
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
sqlite_mode: SqliteMode = Field(
|
||||
default="full",
|
||||
description="SQLite persistence mode. "
|
||||
"'full' retains all checkpoint history (default). "
|
||||
"'shallow' keeps only the latest checkpoint per thread, "
|
||||
"deleting old records before each write.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance — None means no checkpointer is configured.
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
"""Pre-tool-call authorization middleware."""
|
||||
|
||||
from deerflow.guardrails.builtin import AllowlistProvider, SensitiveDataProvider
|
||||
from deerflow.guardrails.builtin import AllowlistProvider
|
||||
from deerflow.guardrails.middleware import GuardrailMiddleware
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
|
||||
|
||||
__all__ = [
|
||||
"AllowlistProvider",
|
||||
"SensitiveDataProvider",
|
||||
"GuardrailDecision",
|
||||
"GuardrailMiddleware",
|
||||
"GuardrailProvider",
|
||||
|
||||
@ -1,20 +1,7 @@
|
||||
"""Built-in guardrail providers that ship with DeerFlow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shlex
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AllowlistProvider:
|
||||
"""Simple allowlist/denylist provider. No external dependencies."""
|
||||
@ -34,138 +21,3 @@ class AllowlistProvider:
|
||||
|
||||
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
return self.evaluate(request)
|
||||
|
||||
|
||||
class SensitiveDataProvider:
|
||||
"""Block tool calls that may access sensitive files such as .env and keys."""
|
||||
|
||||
name = "sensitive-data"
|
||||
|
||||
_DEFAULT_PROTECTED_TOOLS = {
|
||||
"read_file",
|
||||
"write_file",
|
||||
"str_replace",
|
||||
"ls",
|
||||
"glob",
|
||||
"grep",
|
||||
"bash",
|
||||
}
|
||||
_DEFAULT_DENY_BASENAMES = {".env"}
|
||||
_DEFAULT_DENY_GLOBS = {
|
||||
".env.*",
|
||||
"*.pem",
|
||||
"*.key",
|
||||
"id_rsa*",
|
||||
"secrets.*",
|
||||
"credentials.*",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
protected_tools: list[str] | None = None,
|
||||
deny_basenames: list[str] | None = None,
|
||||
deny_globs: list[str] | None = None,
|
||||
block_skills_env: bool = True,
|
||||
**_: Any,
|
||||
):
|
||||
self._protected_tools = {t.lower() for t in (protected_tools or list(self._DEFAULT_PROTECTED_TOOLS))}
|
||||
self._deny_basenames = {n.lower() for n in (deny_basenames or list(self._DEFAULT_DENY_BASENAMES))}
|
||||
self._deny_globs = {p.lower() for p in (deny_globs or list(self._DEFAULT_DENY_GLOBS))}
|
||||
self._block_skills_env = block_skills_env
|
||||
|
||||
def _normalize_candidate(self, raw: str | None) -> str:
|
||||
if not raw:
|
||||
return ""
|
||||
return str(raw).strip().strip("\"'")
|
||||
|
||||
def _looks_sensitive_path(self, raw_path: str) -> bool:
|
||||
value = self._normalize_candidate(raw_path)
|
||||
if not value:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if self._block_skills_env and "/mnt/skills/" in lowered:
|
||||
basename = PurePosixPath(lowered).name
|
||||
if basename == ".env" or basename.startswith(".env."):
|
||||
return True
|
||||
basename = PurePosixPath(lowered).name
|
||||
if basename in self._deny_basenames:
|
||||
return True
|
||||
return any(fnmatch.fnmatch(basename, pat) for pat in self._deny_globs)
|
||||
|
||||
def _extract_bash_candidates(self, command: str) -> list[str]:
|
||||
candidates: list[str] = []
|
||||
if not command:
|
||||
return candidates
|
||||
try:
|
||||
tokens = shlex.split(command)
|
||||
except ValueError:
|
||||
tokens = command.split()
|
||||
for token in tokens:
|
||||
t = token.strip()
|
||||
if not t:
|
||||
continue
|
||||
# Path-like tokens
|
||||
if "/" in t or t.startswith("."):
|
||||
candidates.append(t)
|
||||
# file.env style arguments may not contain slash
|
||||
if t.lower().startswith(".env"):
|
||||
candidates.append(t)
|
||||
return candidates
|
||||
|
||||
def _collect_candidates(self, request: GuardrailRequest) -> list[str]:
|
||||
args = request.tool_input if isinstance(request.tool_input, dict) else {}
|
||||
tool = request.tool_name
|
||||
candidates: list[str] = []
|
||||
if tool in {"read_file", "write_file", "str_replace", "ls"}:
|
||||
path = args.get("path")
|
||||
if isinstance(path, str):
|
||||
candidates.append(path)
|
||||
elif tool in {"glob", "grep"}:
|
||||
path = args.get("path")
|
||||
if isinstance(path, str):
|
||||
candidates.append(path)
|
||||
glob_pat = args.get("glob")
|
||||
if isinstance(glob_pat, str):
|
||||
candidates.append(glob_pat)
|
||||
elif tool == "bash":
|
||||
command = str(args.get("command") or "")
|
||||
candidates.extend(self._extract_bash_candidates(command))
|
||||
# Fast-path for common secret exposure commands
|
||||
if re.search(r"\b(printenv|env)\b", command, flags=re.IGNORECASE):
|
||||
candidates.append(".env")
|
||||
return candidates
|
||||
|
||||
def _audit(self, request: GuardrailRequest, decision: GuardrailDecision) -> None:
|
||||
if decision.allow:
|
||||
return
|
||||
code = decision.reasons[0].code if decision.reasons else "oap.blocked_pattern"
|
||||
rec = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"provider": self.name,
|
||||
"tool_name": request.tool_name,
|
||||
"reason_code": code,
|
||||
"thread_id": request.thread_id,
|
||||
"agent_id": request.agent_id,
|
||||
}
|
||||
logger.warning("[SensitiveDataGuardrail] %s", json.dumps(rec, ensure_ascii=False))
|
||||
|
||||
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
tool = (request.tool_name or "").lower()
|
||||
if tool not in self._protected_tools:
|
||||
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
|
||||
|
||||
candidates = self._collect_candidates(request)
|
||||
if any(self._looks_sensitive_path(c) for c in candidates):
|
||||
decision = GuardrailDecision(
|
||||
allow=False,
|
||||
reasons=[GuardrailReason(code="oap.blocked_pattern", message="sensitive path access is blocked by policy")],
|
||||
policy_id="sensitive-data.v1",
|
||||
)
|
||||
self._audit(request, decision)
|
||||
return decision
|
||||
|
||||
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
|
||||
|
||||
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
return self.evaluate(request)
|
||||
|
||||
@ -23,7 +23,7 @@ def _fake_app_config(*, enabled: bool = True, include_subagents: bool = True):
|
||||
default_estimated_output_tokens=None,
|
||||
)
|
||||
|
||||
model_cfg = SimpleNamespace(display_name="GPT-4", model="gpt-4", model_extra={"max_tokens": 4096})
|
||||
model_cfg = SimpleNamespace(display_name="GPT-4", model_extra={"max_tokens": 4096})
|
||||
return SimpleNamespace(
|
||||
billing=billing,
|
||||
get_model_config=lambda name: model_cfg if name == "gpt-4" else None,
|
||||
|
||||
@ -1,305 +0,0 @@
|
||||
"""Tests for shallow SQLite checkpoint savers (single-checkpoint-per-thread mode).
|
||||
|
||||
Uses in-memory SQLite (``:memory:``) — no filesystem dependency.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AsyncShallowSqliteSaver tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAsyncShallowSqliteSaver:
|
||||
"""Tests for ``AsyncShallowSqliteSaver`` — async shallow persistence."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aput_deletes_old_checkpoints_before_insert(self):
|
||||
"""After two aput calls for the same thread, only 1 checkpoint remains."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||
|
||||
ShallowSaver = _make_async_shallow_saver()
|
||||
|
||||
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
await saver.setup()
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test-thread-1", "checkpoint_ns": ""}}
|
||||
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-1", "channel_values": {"x": 1}}
|
||||
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-2", "channel_values": {"x": 2}}
|
||||
|
||||
# Write first checkpoint
|
||||
await saver.aput(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
|
||||
# Write second checkpoint — should delete the first
|
||||
await saver.aput(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||
|
||||
# Verify only 1 checkpoint remains
|
||||
results = []
|
||||
async for ckpt in saver.alist(thread_config):
|
||||
results.append(ckpt)
|
||||
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
||||
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_different_threads_do_not_interfere(self):
|
||||
"""Checkpoints for different thread_ids are independent."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||
|
||||
ShallowSaver = _make_async_shallow_saver()
|
||||
|
||||
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
await saver.setup()
|
||||
|
||||
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
||||
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
||||
|
||||
await saver.aput(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
await saver.aput(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
await saver.aput(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||
|
||||
# Thread A: only ckpt a2
|
||||
a_results = [c async for c in saver.alist(t1)]
|
||||
assert len(a_results) == 1
|
||||
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
||||
|
||||
# Thread B: still has b1
|
||||
b_results = [c async for c in saver.alist(t2)]
|
||||
assert len(b_results) == 1
|
||||
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_writes_table_also_cleaned(self):
|
||||
"""aput_writes entries from old checkpoints are also deleted."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_async_shallow_saver
|
||||
|
||||
ShallowSaver = _make_async_shallow_saver()
|
||||
|
||||
async with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
await saver.setup()
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test-writes", "checkpoint_ns": ""}}
|
||||
|
||||
# Write checkpoint 1 with associated writes
|
||||
ckpt1_config = await saver.aput(
|
||||
thread_config,
|
||||
{"ts": "Z", "id": "ckpt-w1", "channel_values": {}},
|
||||
{"source": "input", "step": 1, "writes": {}},
|
||||
{},
|
||||
)
|
||||
await saver.aput_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
||||
|
||||
# Write checkpoint 2 — should delete ckpt1 + its writes
|
||||
ckpt2_config = await saver.aput(
|
||||
thread_config,
|
||||
{"ts": "Z", "id": "ckpt-w2", "channel_values": {}},
|
||||
{"source": "loop", "step": 2, "writes": {}},
|
||||
{},
|
||||
)
|
||||
await saver.aput_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
||||
|
||||
# Verify only 1 checkpoint remains
|
||||
results = [c async for c in saver.alist(thread_config)]
|
||||
assert len(results) == 1
|
||||
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-w2"
|
||||
|
||||
# Verify only the latest writes exist (get_tuple returns writes)
|
||||
latest = await saver.aget_tuple(ckpt2_config)
|
||||
assert latest is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_full_mode_retains_all_checkpoints(self):
|
||||
"""In full mode (default), all checkpoints are preserved."""
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
|
||||
await saver.setup()
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test-full", "checkpoint_ns": ""}}
|
||||
|
||||
await saver.aput(thread_config, {"ts": "Z", "id": "f1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
await saver.aput(thread_config, {"ts": "Z", "id": "f2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||
|
||||
results = [c async for c in saver.alist(thread_config)]
|
||||
assert len(results) == 2, "Full mode should retain all checkpoints"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ShallowSqliteSaver (sync) tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShallowSqliteSaver:
|
||||
"""Tests for ``ShallowSqliteSaver`` — sync shallow persistence."""
|
||||
|
||||
def test_put_deletes_old_checkpoints_before_insert(self):
|
||||
"""After two put calls for the same thread, only 1 checkpoint remains."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||
|
||||
ShallowSaver = _make_sync_shallow_saver()
|
||||
|
||||
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
saver.setup()
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test-sync-1", "checkpoint_ns": ""}}
|
||||
checkpoint_1 = {"ts": "2024-01-01T00:00:00Z", "id": "ckpt-s1", "channel_values": {"x": 1}}
|
||||
checkpoint_2 = {"ts": "2024-01-01T00:01:00Z", "id": "ckpt-s2", "channel_values": {"x": 2}}
|
||||
|
||||
saver.put(thread_config, checkpoint_1, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
saver.put(thread_config, checkpoint_2, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||
|
||||
results = list(saver.list(thread_config))
|
||||
assert len(results) == 1, f"Expected 1 checkpoint, got {len(results)}"
|
||||
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-s2"
|
||||
|
||||
def test_different_threads_do_not_interfere_sync(self):
|
||||
"""Checkpoints for different thread_ids are independent (sync)."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||
|
||||
ShallowSaver = _make_sync_shallow_saver()
|
||||
|
||||
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
saver.setup()
|
||||
|
||||
t1 = {"configurable": {"thread_id": "thread-A", "checkpoint_ns": ""}}
|
||||
t2 = {"configurable": {"thread_id": "thread-B", "checkpoint_ns": ""}}
|
||||
|
||||
saver.put(t1, {"ts": "Z", "id": "a1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
saver.put(t2, {"ts": "Z", "id": "b1", "channel_values": {}}, {"source": "input", "step": 1, "writes": {}}, {})
|
||||
saver.put(t1, {"ts": "Z", "id": "a2", "channel_values": {}}, {"source": "loop", "step": 2, "writes": {}}, {})
|
||||
|
||||
a_results = list(saver.list(t1))
|
||||
assert len(a_results) == 1
|
||||
assert a_results[0].config["configurable"]["checkpoint_id"] == "a2"
|
||||
|
||||
b_results = list(saver.list(t2))
|
||||
assert len(b_results) == 1
|
||||
assert b_results[0].config["configurable"]["checkpoint_id"] == "b1"
|
||||
|
||||
def test_writes_table_also_cleaned_sync(self):
|
||||
"""put_writes entries from old checkpoints are also deleted (sync)."""
|
||||
from deerflow.agents.checkpointer.shallow_sqlite import _make_sync_shallow_saver
|
||||
|
||||
ShallowSaver = _make_sync_shallow_saver()
|
||||
|
||||
with ShallowSaver.from_conn_string(":memory:") as saver:
|
||||
saver.setup()
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test-sync-writes", "checkpoint_ns": ""}}
|
||||
|
||||
ckpt1_config = saver.put(
|
||||
thread_config,
|
||||
{"ts": "Z", "id": "ckpt-sw1", "channel_values": {}},
|
||||
{"source": "input", "step": 1, "writes": {}},
|
||||
{},
|
||||
)
|
||||
saver.put_writes(ckpt1_config, [("messages", "hello")], "task-1", "")
|
||||
|
||||
ckpt2_config = saver.put(
|
||||
thread_config,
|
||||
{"ts": "Z", "id": "ckpt-sw2", "channel_values": {}},
|
||||
{"source": "loop", "step": 2, "writes": {}},
|
||||
{},
|
||||
)
|
||||
saver.put_writes(ckpt2_config, [("messages", "world")], "task-2", "")
|
||||
|
||||
results = list(saver.list(thread_config))
|
||||
assert len(results) == 1
|
||||
assert results[0].config["configurable"]["checkpoint_id"] == "ckpt-sw2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShallowConfig:
|
||||
"""Tests for configuration and factory integration."""
|
||||
|
||||
def test_sqlite_mode_defaults_to_full(self):
|
||||
"""sqlite_mode defaults to 'full' when not specified."""
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
|
||||
config = CheckpointerConfig(type="sqlite", connection_string="test.db")
|
||||
assert config.sqlite_mode == "full"
|
||||
|
||||
def test_sqlite_mode_shallow_accepted(self):
|
||||
"""sqlite_mode can be set to 'shallow'."""
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
|
||||
config = CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="shallow")
|
||||
assert config.sqlite_mode == "shallow"
|
||||
|
||||
def test_load_sqlite_config_with_shallow_mode(self):
|
||||
"""load_checkpointer_config_from_dict accepts sqlite_mode."""
|
||||
from deerflow.config.checkpointer_config import (
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({
|
||||
"type": "sqlite",
|
||||
"connection_string": "/tmp/test.db",
|
||||
"sqlite_mode": "shallow",
|
||||
})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.sqlite_mode == "shallow"
|
||||
|
||||
def test_load_sqlite_config_defaults_sqlite_mode(self):
|
||||
"""load_checkpointer_config_from_dict defaults sqlite_mode to 'full' when omitted."""
|
||||
from deerflow.config.checkpointer_config import (
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({
|
||||
"type": "sqlite",
|
||||
"connection_string": "/tmp/test.db",
|
||||
})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.sqlite_mode == "full"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_factory_uses_shallow_saver(self):
|
||||
"""When sqlite_mode=shallow, async factory returns AsyncShallowSqliteSaver."""
|
||||
from deerflow.agents.checkpointer.async_provider import _async_checkpointer
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, set_checkpointer_config
|
||||
|
||||
set_checkpointer_config(CheckpointerConfig(
|
||||
type="sqlite",
|
||||
connection_string=":memory:",
|
||||
sqlite_mode="shallow",
|
||||
))
|
||||
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
||||
|
||||
async with _async_checkpointer(config) as saver:
|
||||
# Should be an instance of the shallow saver
|
||||
cls_name = type(saver).__name__
|
||||
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
||||
|
||||
set_checkpointer_config(None)
|
||||
|
||||
def test_sync_factory_uses_shallow_saver(self):
|
||||
"""When sqlite_mode=shallow, sync factory returns ShallowSqliteSaver."""
|
||||
from deerflow.agents.checkpointer.provider import _sync_checkpointer_cm
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
|
||||
config = CheckpointerConfig(type="sqlite", connection_string=":memory:", sqlite_mode="shallow")
|
||||
|
||||
with _sync_checkpointer_cm(config) as saver:
|
||||
cls_name = type(saver).__name__
|
||||
assert "Shallow" in cls_name, f"Expected shallow saver, got {cls_name}"
|
||||
|
||||
def test_invalid_sqlite_mode_raises(self):
|
||||
"""Invalid sqlite_mode value raises validation error."""
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
|
||||
with pytest.raises(Exception):
|
||||
CheckpointerConfig(type="sqlite", connection_string="test.db", sqlite_mode="unknown")
|
||||
@ -8,7 +8,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
from deerflow.guardrails.builtin import AllowlistProvider, SensitiveDataProvider
|
||||
from deerflow.guardrails.builtin import AllowlistProvider
|
||||
from deerflow.guardrails.middleware import GuardrailMiddleware
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
|
||||
|
||||
@ -105,46 +105,6 @@ class TestAllowlistProvider:
|
||||
assert decision.allow is False
|
||||
|
||||
|
||||
class TestSensitiveDataProvider:
|
||||
def test_denies_reading_env_file(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/tmp/.env"})
|
||||
decision = provider.evaluate(req)
|
||||
assert decision.allow is False
|
||||
assert decision.reasons[0].code == "oap.blocked_pattern"
|
||||
|
||||
def test_allows_normal_source_file(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/app/main.py"})
|
||||
decision = provider.evaluate(req)
|
||||
assert decision.allow is True
|
||||
|
||||
def test_denies_skills_env(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req = GuardrailRequest(tool_name="read_file", tool_input={"path": "/mnt/skills/public/foo/.env.local"})
|
||||
decision = provider.evaluate(req)
|
||||
assert decision.allow is False
|
||||
|
||||
def test_denies_glob_or_grep_targeting_env(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req = GuardrailRequest(tool_name="grep", tool_input={"path": "/workspace", "glob": "**/.env.*"})
|
||||
decision = provider.evaluate(req)
|
||||
assert decision.allow is False
|
||||
|
||||
def test_denies_bash_cat_env(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req = GuardrailRequest(tool_name="bash", tool_input={"command": "cat /workspace/.env"})
|
||||
decision = provider.evaluate(req)
|
||||
assert decision.allow is False
|
||||
|
||||
def test_denies_case_variant_and_key_material(self):
|
||||
provider = SensitiveDataProvider()
|
||||
req1 = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/.ENV"})
|
||||
req2 = GuardrailRequest(tool_name="read_file", tool_input={"path": "/workspace/id_rsa.pub"})
|
||||
assert provider.evaluate(req1).allow is False
|
||||
assert provider.evaluate(req2).allow is False
|
||||
|
||||
|
||||
# --- GuardrailMiddleware tests ---
|
||||
|
||||
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from deerflow.agents.middlewares.sensitive_output_redaction_middleware import SensitiveOutputRedactionMiddleware
|
||||
|
||||
|
||||
def _request(name: str = "bash", args: dict | None = None):
|
||||
req = MagicMock()
|
||||
req.tool_call = {"name": name, "args": args or {}, "id": "call_1"}
|
||||
return req
|
||||
|
||||
|
||||
def test_redacts_key_value_and_bearer():
|
||||
mw = SensitiveOutputRedactionMiddleware()
|
||||
req = _request()
|
||||
handler = MagicMock(
|
||||
return_value=ToolMessage(
|
||||
content="OPENAI_API_KEY=sk-abc123456789\nAuthorization: Bearer abcdefghijklmnop",
|
||||
tool_call_id="call_1",
|
||||
name="bash",
|
||||
status="success",
|
||||
)
|
||||
)
|
||||
result = mw.wrap_tool_call(req, handler)
|
||||
assert "OPENAI_API_KEY=[REDACTED]" in result.content
|
||||
assert "Bearer [REDACTED]" in result.content
|
||||
|
||||
|
||||
def test_redacts_private_key_block():
|
||||
mw = SensitiveOutputRedactionMiddleware()
|
||||
req = _request()
|
||||
handler = MagicMock(
|
||||
return_value=ToolMessage(
|
||||
content="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----",
|
||||
tool_call_id="call_1",
|
||||
name="bash",
|
||||
status="success",
|
||||
)
|
||||
)
|
||||
result = mw.wrap_tool_call(req, handler)
|
||||
assert result.content == "[REDACTED]"
|
||||
|
||||
|
||||
def test_async_path_redacts_jwt():
|
||||
mw = SensitiveOutputRedactionMiddleware()
|
||||
req = _request()
|
||||
|
||||
async def handler(_):
|
||||
return ToolMessage(
|
||||
content="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.abcde12345.zyxwv98765",
|
||||
tool_call_id="call_1",
|
||||
name="bash",
|
||||
status="success",
|
||||
)
|
||||
|
||||
result = asyncio.run(mw.awrap_tool_call(req, handler))
|
||||
assert result.content == "[REDACTED]"
|
||||
|
||||
@ -709,10 +709,6 @@ memory:
|
||||
# memory - In-process only. State is lost when the process exits. (default)
|
||||
# sqlite - File-based SQLite persistence. Survives restarts.
|
||||
# Requires: uv add langgraph-checkpoint-sqlite
|
||||
# sqlite_mode: full (default) retains all history.
|
||||
# sqlite_mode: shallow keeps only the latest checkpoint per
|
||||
# thread, deleting old records before each write to prevent
|
||||
# unbounded database growth.
|
||||
# postgres - PostgreSQL persistence. Suitable for multi-process deployments.
|
||||
# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool
|
||||
#
|
||||
@ -726,8 +722,6 @@ memory:
|
||||
checkpointer:
|
||||
type: sqlite
|
||||
connection_string: checkpoints.db
|
||||
# sqlite_mode: full # default — keep all checkpoint history
|
||||
# sqlite_mode: shallow # keep only latest checkpoint per thread (prevents DB bloat)
|
||||
#
|
||||
# PostgreSQL (multi-process, production):
|
||||
# checkpointer:
|
||||
@ -838,15 +832,3 @@ checkpointer:
|
||||
# use: my_package:MyGuardrailProvider
|
||||
# config:
|
||||
# key: value
|
||||
|
||||
# --- Option 4: SensitiveDataProvider (strict secret-file blocking) ---
|
||||
# guardrails:
|
||||
# enabled: true
|
||||
# fail_closed: true
|
||||
# provider:
|
||||
# use: deerflow.guardrails.builtin:SensitiveDataProvider
|
||||
# config:
|
||||
# protected_tools: ["read_file", "write_file", "str_replace", "ls", "glob", "grep", "bash"]
|
||||
# deny_basenames: [".env"]
|
||||
# deny_globs: [".env.*", "*.pem", "*.key", "id_rsa*", "secrets.*", "credentials.*"]
|
||||
# block_skills_env: true
|
||||
|
||||
@ -166,6 +166,7 @@ export default function AgentChatPage() {
|
||||
threadId={threadId}
|
||||
autoFocus={isNewThread}
|
||||
showWelcomeStyle={isNewThread}
|
||||
hasSubmitted={!isNewThread}
|
||||
status={
|
||||
thread.error
|
||||
? "error"
|
||||
|
||||
@ -233,6 +233,7 @@ export default function ChatPage() {
|
||||
[thread.values.artifacts],
|
||||
);
|
||||
|
||||
const [hasSubmitted, setHasSubmitted] = useState(false);
|
||||
const [historyCutoff, setHistoryCutoff] = useState<number | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
@ -240,6 +241,7 @@ export default function ChatPage() {
|
||||
setHistoryCutoff(null);
|
||||
return;
|
||||
}
|
||||
if (hasSubmitted) return;
|
||||
// Welcome 态下、未提交前,把当前已有消息都当作“历史”切掉。
|
||||
// 这样即使历史消息是后续异步补齐,也不会重新露出。
|
||||
setHistoryCutoff((prev) => {
|
||||
@ -248,6 +250,7 @@ export default function ChatPage() {
|
||||
return next > prev ? next : prev;
|
||||
});
|
||||
}, [
|
||||
hasSubmitted,
|
||||
historyCutoff,
|
||||
shouldRenderHistory,
|
||||
thread.isThreadLoading,
|
||||
@ -319,6 +322,7 @@ export default function ChatPage() {
|
||||
if (isNewThread && safeThreadId && !isThreadInitReady) {
|
||||
return;
|
||||
}
|
||||
setHasSubmitted(true);
|
||||
if (safeThreadId && (isNewThread || showWelcomeStyle)) {
|
||||
router.replace(`/workspace/chats/${safeThreadId}?is_chatting=true`);
|
||||
}
|
||||
@ -339,6 +343,22 @@ export default function ChatPage() {
|
||||
await thread.stop();
|
||||
}, [thread]);
|
||||
|
||||
const resetNewSessionState = useCallback(() => {
|
||||
setIsNewThread(true);
|
||||
setHasSubmitted(false);
|
||||
setHistoryCutoff(null);
|
||||
setArtifacts([]);
|
||||
deselectArtifact();
|
||||
setArtifactsOpen(false);
|
||||
setArtifactsFullscreen(false);
|
||||
}, [
|
||||
deselectArtifact,
|
||||
setArtifacts,
|
||||
setArtifactsFullscreen,
|
||||
setArtifactsOpen,
|
||||
setIsNewThread,
|
||||
]);
|
||||
|
||||
return (
|
||||
<ThreadContext.Provider value={{ threadId, thread }}>
|
||||
<div
|
||||
@ -359,7 +379,7 @@ export default function ChatPage() {
|
||||
<header
|
||||
className={cn(
|
||||
"bg-background absolute top-0 right-0 left-0 z-30 mx-4 grid h-[58px] shrink-0 grid-cols-3 items-center border-b transition-all duration-300 ease-in-out",
|
||||
showWelcomeStyle ? "hidden" : "",
|
||||
showWelcomeStyle && !hasSubmitted ? "hidden" : "",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-start overflow-hidden text-sm font-medium">
|
||||
@ -368,14 +388,7 @@ export default function ChatPage() {
|
||||
variant="ghost"
|
||||
className="px-[10px] py-[5px] text-sm font-medium text-ws-base-1 hover:text-ws-base-1/80"
|
||||
disabled={isStreaming}
|
||||
onClick={() => {
|
||||
sendToParent({
|
||||
type: POST_MESSAGE_TYPES.IS_CHATTING,
|
||||
isChatting: false,
|
||||
});
|
||||
router.replace(`/workspace/chats/${threadId}?is_chatting=false`)
|
||||
}
|
||||
}
|
||||
onClick={() => setShowExitDialog(true)}
|
||||
>
|
||||
<svg
|
||||
width="20"
|
||||
@ -431,20 +444,6 @@ export default function ChatPage() {
|
||||
</Button>
|
||||
}
|
||||
/> */}
|
||||
<Tooltip content={t.common.resetThread}>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
className="h-full px-[10px] py-[5px] text-sm font-medium text-ws-base-1 hover:text-ws-base-1"
|
||||
disabled={isStreaming}
|
||||
onClick={() => setShowExitDialog(true)}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
|
||||
<path d="M2 4H6M16 4H12M6 4H12M6 4C6 2.89543 6.89543 2 8 2H10C11.1046 2 12 2.89543 12 4M4 6V14C4 15.1046 4.89543 16 6 16H12C13.1046 16 14 15.1046 14 14V6M7 8V13M11 8V13" stroke="#150033" stroke-linecap="round" />
|
||||
</svg>
|
||||
{t.common.resetThread}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
|
||||
{artifacts?.length > 0 && !artifactsOpen && (
|
||||
<Tooltip content={t.chatPage.viewArtifactsTooltip}>
|
||||
@ -457,12 +456,7 @@ export default function ChatPage() {
|
||||
setSidebarOpen(false);
|
||||
}}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
|
||||
<path d="M16 7V4C16 2.89543 15.1046 2 14 2H4C2.89543 2 2 2.89543 2 4V14C2 15.1046 2.89543 16 4 16H9" stroke="#150033" stroke-linecap="round" />
|
||||
<path d="M5 5H9M5 8H7" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" />
|
||||
<circle cx="11.5" cy="10.5" r="3" stroke="#150033" />
|
||||
<path d="M15.5 14.5L14 13" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" />
|
||||
</svg>
|
||||
<FilesIcon />
|
||||
{t.common.artifacts}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
@ -472,14 +466,16 @@ export default function ChatPage() {
|
||||
<main
|
||||
className={cn(
|
||||
"flex min-h-0 max-w-full grow flex-col",
|
||||
showWelcomeStyle ? "bg-ws-surface-base" : "bg-background",
|
||||
showWelcomeStyle && !hasSubmitted
|
||||
? "bg-ws-surface-base"
|
||||
: "bg-background",
|
||||
)}
|
||||
>
|
||||
<div className="flex size-full justify-center">
|
||||
<MessageList
|
||||
className={cn(
|
||||
"size-full",
|
||||
!showWelcomeStyle && "pt-[58px]",
|
||||
(!showWelcomeStyle || hasSubmitted) && "pt-[58px]",
|
||||
)}
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
@ -512,7 +508,7 @@ export default function ChatPage() {
|
||||
<div
|
||||
className={cn(
|
||||
"h-full w-full transition-transform duration-300 ease-in-out",
|
||||
showWelcomeStyle ? "translate-x-0" : "",
|
||||
showWelcomeStyle && !hasSubmitted ? "translate-x-0" : "",
|
||||
artifactPanelOpen ? "translate-x-0" : "translate-x-full",
|
||||
)}
|
||||
>
|
||||
@ -574,7 +570,9 @@ export default function ChatPage() {
|
||||
<div
|
||||
className={cn(
|
||||
"pointer-events-auto relative w-full max-w-[720px]",
|
||||
showWelcomeStyle && "-translate-y-[calc(50vh-96px)]",
|
||||
showWelcomeStyle &&
|
||||
!hasSubmitted &&
|
||||
"-translate-y-[calc(50vh-96px)]",
|
||||
)}
|
||||
>
|
||||
{!(showWelcomeStyle && thread.isThreadLoading) ? (
|
||||
@ -583,6 +581,7 @@ export default function ChatPage() {
|
||||
className={cn("w-full rounded-[20px] bg-ws-surface-elevated")}
|
||||
threadId={threadId}
|
||||
showWelcomeStyle={showWelcomeStyle}
|
||||
hasSubmitted={hasSubmitted}
|
||||
autoFocus={showWelcomeStyle}
|
||||
status={
|
||||
thread.error
|
||||
@ -594,7 +593,9 @@ export default function ChatPage() {
|
||||
context={settings.context}
|
||||
extraHeader={
|
||||
<div className="flex flex-col gap-4">
|
||||
{showWelcomeStyle && <Welcome mode={settings.context.mode} />}
|
||||
{showWelcomeStyle && !hasSubmitted && (
|
||||
<Welcome mode={settings.context.mode} />
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
disabled={
|
||||
|
||||
@ -219,6 +219,7 @@ export function InputBox({
|
||||
context,
|
||||
extraHeader,
|
||||
showWelcomeStyle,
|
||||
hasSubmitted,
|
||||
initialValue,
|
||||
onContextChange,
|
||||
onSubmit,
|
||||
@ -237,6 +238,7 @@ export function InputBox({
|
||||
};
|
||||
extraHeader?: React.ReactNode;
|
||||
showWelcomeStyle: boolean;
|
||||
hasSubmitted?: boolean;
|
||||
initialValue?: string;
|
||||
onContextChange?: (
|
||||
context: Omit<
|
||||
@ -292,13 +294,14 @@ export function InputBox({
|
||||
const [isInputToolsTourReady, setIsInputToolsTourReady] = useState(false);
|
||||
const { data: referenceFilesData } = useReferenceFiles(threadIdFromProps);
|
||||
|
||||
// Welcome 态下禁用收缩,始终保持展开
|
||||
const effectiveIsFocused = (showWelcomeStyle ?? false) || isFocused;
|
||||
// isNewThread 时禁用收缩,始终保持展开(除非已提交消息)
|
||||
const effectiveIsFocused =
|
||||
((showWelcomeStyle ?? false) && !hasSubmitted) || isFocused;
|
||||
const shouldShowSuggestionList =
|
||||
showWelcomeStyle && searchParams.get("mode") !== "skill";
|
||||
showWelcomeStyle && !hasSubmitted && searchParams.get("mode") !== "skill";
|
||||
|
||||
useEffect(() => {
|
||||
if (!showWelcomeStyle) {
|
||||
if (!showWelcomeStyle || hasSubmitted) {
|
||||
setIsInputToolsTourReady(false);
|
||||
return;
|
||||
}
|
||||
@ -315,13 +318,14 @@ export function InputBox({
|
||||
return () => window.cancelAnimationFrame(frameId);
|
||||
}, [
|
||||
showWelcomeStyle,
|
||||
hasSubmitted,
|
||||
shouldShowSuggestionList,
|
||||
iframeSkill.isBootstrapping,
|
||||
iframeSkill.selectedSkills.length,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!showWelcomeStyle || !isInputToolsTourReady) {
|
||||
if (!showWelcomeStyle || hasSubmitted || !isInputToolsTourReady) {
|
||||
setIsInputToolsTourOpen(false);
|
||||
return;
|
||||
}
|
||||
@ -333,7 +337,7 @@ export function InputBox({
|
||||
if (!hasSeenTourForCurrentThread) {
|
||||
setIsInputToolsTourOpen(true);
|
||||
}
|
||||
}, [showWelcomeStyle, isInputToolsTourReady, threadId]);
|
||||
}, [showWelcomeStyle, hasSubmitted, isInputToolsTourReady, threadId]);
|
||||
|
||||
const finishInputToolsTour = useCallback(() => {
|
||||
const seenState = parseInputToolsTourSeenState(
|
||||
@ -813,6 +817,7 @@ export function InputBox({
|
||||
"border-0 rounded-[20px] backdrop-blur-sm",
|
||||
"transition-[height] duration-300 ease-out shadow-none ",
|
||||
!showWelcomeStyle && "h-[200px] shadow-[0_0_20px_0_rgba(0,0,0,0.10)]",
|
||||
hasSubmitted && "shadow-[0_0_20px_0_rgba(0,0,0,0.10)]!",
|
||||
effectiveIsFocused ? "h-[200px]" : "h-[80px]",
|
||||
)}
|
||||
disabled={isInputDisabled}
|
||||
@ -965,14 +970,14 @@ export function InputBox({
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{/* {!showWelcomeStyle && (
|
||||
{!showWelcomeStyle && (
|
||||
<div className="shrink-0 h-full">
|
||||
<ExitChattingButton
|
||||
router={router}
|
||||
threadId={threadIdFromProps}
|
||||
/>
|
||||
</div>
|
||||
)} */}
|
||||
)}
|
||||
<div ref={attachmentsButtonTourRef} className="shrink-0 h-full">
|
||||
<AddAttachmentsButton />
|
||||
</div>
|
||||
@ -1266,12 +1271,8 @@ function HistoryButton({
|
||||
<Tooltip content={t.inputBox.history}>
|
||||
<WorkspaceToolButton
|
||||
className={cn("text-ws-base-1 hover:text-ws-interactive-primary", className)}
|
||||
onClick={() =>{
|
||||
sendToParent({
|
||||
type: POST_MESSAGE_TYPES.IS_CHATTING,
|
||||
isChatting: true,
|
||||
});
|
||||
router.replace(`/workspace/chats/${threadId}?is_chatting=true`)}
|
||||
onClick={() =>
|
||||
router.replace(`/workspace/chats/${threadId}?is_chatting=true`)
|
||||
}
|
||||
>
|
||||
<svg
|
||||
@ -1317,13 +1318,9 @@ function ExitChattingButton({
|
||||
"text-ws-base-1 hover:text-ws-interactive-primary",
|
||||
className,
|
||||
)}
|
||||
onClick={() => {
|
||||
sendToParent({
|
||||
type: POST_MESSAGE_TYPES.IS_CHATTING,
|
||||
isChatting: false,
|
||||
});
|
||||
router.replace(`/workspace/chats/${threadId}?is_chatting=false`);
|
||||
}}
|
||||
onClick={() =>
|
||||
router.replace(`/workspace/chats/${threadId}?is_chatting=false`)
|
||||
}
|
||||
>
|
||||
<svg
|
||||
className="transition-[color] duration-200"
|
||||
|
||||
@ -46,13 +46,6 @@ import { CopyButton } from "../copy-button";
|
||||
|
||||
import { MarkdownContent } from "./markdown-content";
|
||||
|
||||
function localizeAssistantFixedCopy(content: string, localized: string): string {
|
||||
if (content.includes("The account balance is insufficient for this model call.")) {
|
||||
return localized;
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
export function MessageListItem({
|
||||
className,
|
||||
message,
|
||||
@ -183,14 +176,8 @@ function MessageContent_({
|
||||
const cleaned = stripPriorityHintSuffix(stripUploadedFilesTag(rawContent));
|
||||
return normalizeHumanMessageDisplayText(cleaned);
|
||||
}
|
||||
if (!rawContent) {
|
||||
return "";
|
||||
}
|
||||
return localizeAssistantFixedCopy(
|
||||
rawContent,
|
||||
t.threads.billingInsufficientBalance,
|
||||
);
|
||||
}, [rawContent, isHuman, t.threads.billingInsufficientBalance]);
|
||||
return rawContent ?? "";
|
||||
}, [rawContent, isHuman]);
|
||||
const isSummaryMessage = useMemo(
|
||||
() => isHuman && isSummaryTemplateMessage(message),
|
||||
[isHuman, message],
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import { MessageSquarePlus } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import {
|
||||
SidebarMenu,
|
||||
@ -12,34 +11,14 @@ import {
|
||||
SidebarTrigger,
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { useThreadChat } from "@/components/workspace/chats";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages";
|
||||
import { env } from "@/env";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { copyToClipboard } from "@/lib/utils";
|
||||
|
||||
export function WorkspaceHeader({ className }: { className?: string }) {
|
||||
const { t } = useI18n();
|
||||
const { state } = useSidebar();
|
||||
const pathname = usePathname();
|
||||
const { threadId } = useThreadChat();
|
||||
const threadUrl = threadId ? `/workspace/chats/${threadId}` : "";
|
||||
|
||||
const handleCopyThreadId = async () => {
|
||||
if (!threadId) return;
|
||||
sendToParent({
|
||||
type: POST_MESSAGE_TYPES.COPY_TO_CLIPBOARD,
|
||||
text: threadId,
|
||||
});
|
||||
try {
|
||||
await copyToClipboard(threadId);
|
||||
toast.success(t.clipboard.copiedToClipboard);
|
||||
} catch {
|
||||
toast.error(t.clipboard.failedToCopyToClipboard);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
@ -64,33 +43,7 @@ export function WorkspaceHeader({ className }: { className?: string }) {
|
||||
) : (
|
||||
<div className="text-primary ml-2 cursor-default font-serif">
|
||||
{/* TODO: 测试标识 */}
|
||||
XClaw{" "}
|
||||
<span className="text-sm text-ws-text-subtle-strong">v3.2.9 </span>{" "}
|
||||
<span
|
||||
className={cn(
|
||||
"text-xs font-mono",
|
||||
threadId
|
||||
? "cursor-pointer underline decoration-dotted underline-offset-4"
|
||||
: "text-ws-text-subtle-strong",
|
||||
)}
|
||||
onClick={() => {
|
||||
void handleCopyThreadId();
|
||||
}}
|
||||
title={threadId ? t.clipboard.copyToClipboard : undefined}
|
||||
>
|
||||
id:{threadId ? threadId.slice(0, 5) : "-"}
|
||||
</span>
|
||||
{" "}
|
||||
{threadId && (
|
||||
<a
|
||||
href={threadUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-xs underline decoration-dotted underline-offset-4"
|
||||
>
|
||||
打开
|
||||
</a>
|
||||
)}
|
||||
XClaw <span className="text-sm text-ws-text-subtle-strong">v3.2.8</span>
|
||||
</div>
|
||||
)}
|
||||
<SidebarTrigger />
|
||||
|
||||
@ -53,7 +53,6 @@ export const enUS: Translations = {
|
||||
exportSuccess: "Conversation exported",
|
||||
removeAttachment: "Remove attachment",
|
||||
reference: "Reference",
|
||||
resetThread: "Reset thread",
|
||||
},
|
||||
|
||||
// Welcome
|
||||
@ -320,8 +319,6 @@ export const enUS: Translations = {
|
||||
|
||||
threads: {
|
||||
streamError: "Something went wrong.",
|
||||
billingInsufficientBalance:
|
||||
"The account balance is insufficient for this model call.",
|
||||
invalidThreadId: "Invalid thread id 'new'. Please refresh and retry.",
|
||||
staleReferencesRemoved:
|
||||
"Some referenced files were invalid and were removed automatically.",
|
||||
|
||||
@ -48,7 +48,6 @@ export interface Translations {
|
||||
exportSuccess: string;
|
||||
removeAttachment: string;
|
||||
reference: string;
|
||||
resetThread: string;
|
||||
};
|
||||
|
||||
// Welcome
|
||||
@ -246,7 +245,6 @@ export interface Translations {
|
||||
|
||||
threads: {
|
||||
streamError: string;
|
||||
billingInsufficientBalance: string;
|
||||
invalidThreadId: string;
|
||||
staleReferencesRemoved: string;
|
||||
uploadFailed: string;
|
||||
|
||||
@ -55,7 +55,6 @@ export const zhCN: Translations = {
|
||||
exportSuccess: "对话已导出",
|
||||
removeAttachment: "移除附件",
|
||||
reference: "引用",
|
||||
resetThread: "重置会话",
|
||||
},
|
||||
|
||||
// Welcome
|
||||
@ -151,10 +150,10 @@ export const zhCN: Translations = {
|
||||
children: [{ id: "17", name: "Excel处理" }],
|
||||
},
|
||||
{
|
||||
suggestion: "微信文章撰写",
|
||||
suggestion: "营销策划",
|
||||
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
|
||||
icon: ShapesIcon,
|
||||
children: [{ id: "6134", name: "微信文章撰写" }],
|
||||
children: [{ id: "217", name: "产品营销背景" }],
|
||||
},
|
||||
],
|
||||
suggestionsCreate: [
|
||||
@ -307,7 +306,6 @@ export const zhCN: Translations = {
|
||||
|
||||
threads: {
|
||||
streamError: "出现了某些错误。",
|
||||
billingInsufficientBalance: "账户余额不足,无法完成本次模型调用。",
|
||||
invalidThreadId: "线程 ID 无效(new),请刷新后重试。",
|
||||
staleReferencesRemoved: "部分引用文件已失效,已自动移除并继续发送。",
|
||||
uploadFailed: "文件上传失败。",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user