Compare commits

...

31 Commits

Author SHA1 Message Date
mt
c17ba298fb fix(ui): 发送按钮 SVG 属性改为 JSX 驼峰格式
- stroke-width → strokeWidth, stroke-linecap → strokeLinecap, stroke-linejoin → strokeLinejoin
2026-06-11 09:50:32 +08:00
mt
7d5e25e325 feat(input): 附件引用弹窗新增搜索过滤框
- DropdownMenu 内新增 Input 搜索框,autoFocus
- filterMentionCandidates 同时受 mentionQuery 和 mentionSearchText 双重过滤
- 搜索时重置高亮索引避免越界
- 上/下箭头将焦点交还给候选列表复用键盘导航
- 所有关闭路径统一重置搜索文字
- 弹窗打开时自动 refetch 最新文件列表
2026-06-11 09:50:29 +08:00
mt
f3c160f103 feat(artifact): artifact markdown 表格复用 CopyButton
- ArtifactFilePreview 中 Streamdown 的 table 组件覆盖为 MarkdownTable
- artifact 区表格复制/下载行为与聊天区一致
2026-06-11 09:50:22 +08:00
mt
407618baf0 refactor(table): 表格复制按钮复用 CopyButton,下载改为 markdown+BOM
- MarkdownTable 导出为公共组件
- 复制按钮直接复用 CopyButton,行为与 iframe 复制一致
- 表格数据通过 tableRef 在 render 阶段同步计算
- useLayoutEffect 确保首次渲染后即可获取正确数据
- 下载按钮改为 markdown 格式 (.md),UTF-8 with BOM
- 移除废弃的 escapeCsvCell / toCsvTable
2026-06-11 09:50:19 +08:00
mt
1637a0e71c fix(copy): copyToClipboard 始终发送 postMessage
- 移除 copyToClipboard 内独立的 isInIframe 判断
- 改为始终调用 sendToParent,由 sendToParent 内部统一判断 iframe 环境
- 与 openSkillDialog 等其他 iframe 通信保持一致
2026-06-11 09:50:15 +08:00
03ff3ece7f fix(brand): brand更新不及时的问题 2026-06-10 18:01:34 +08:00
mt
c45bc4d521 style(input): 调整发送按钮为圆形图标样式并更新主题色
- prompt-input.tsx: 发送按钮改为 36x36 圆形,使用 SVG 箭头/方块图标替代文字
- input-box.tsx: 占位区域尺寸同步调整为 36x36
- globals.css: 新增 brand-default/brand-sxwz 品牌 CSS class,新增 ws-interactive-hover token,主题色 #8e47f0 → #150033
2026-06-10 17:52:02 +08:00
mt
9eb494b1b4 feat(brand): 聊天页 sxwz 模式下输入框左移 172px
- ChatPage 接入 useBrand,brand === 'sxwz' 时主容器和输入框 translate-x-[-172px]
- 退出对话回欢迎页时同步关闭 artifacts 面板
2026-06-10 17:51:53 +08:00
mt
0bd9b9bdcb feat(brand): workspace 组件接入品牌文案和 Logo 切换
- layout.tsx: 包裹 BrandProvider + BrandSessionInitializer,SidebarProvider 注入 rootClassName
- welcome.tsx: copy.productLabel 替代硬编码,appLogoSrc 条件渲染 Image/文字
- workspace-header.tsx: 侧边栏折叠时显示品牌缩写,展开时显示 Logo 或 appName
2026-06-10 17:51:46 +08:00
mt
62fd2e6f06 feat(brand): 新增品牌切换系统核心模块
- 定义 Brand 类型、BrandCopy 文案映射、BRAND_COPY 配置
- BrandProvider + useBrand hook 提供 brand/copy/rootClassName
- BrandSessionInitializer 从 URL ?isSxwz= 初始化品牌会话
- sessionStorage 持久化 + URL 参数优先级解析
- parseBrandFromSearchParams 区分为 true/false/无参数三种情况
- 新增 default 品牌 Logo (coxwork.png)
2026-06-10 17:51:34 +08:00
63563ce6a3 feat: 重置会话时新增checkbox,清除当前会话的memory 2026-06-02 10:21:33 +08:00
dd98337a92 fix: 修复错误的id 2026-06-01 17:57:29 +08:00
0fdeb27e06 fix(workspace): 优化欢迎建议布局并完善输入框提交判断 2026-06-01 17:55:03 +08:00
f0f7b8df4d script: 添加git hook,每次push之前都rebase git-main 2026-05-18 16:31:05 +08:00
ae2cfa2386 chore:快捷Skill按钮维护 2026-05-18 16:03:53 +08:00
15b295f45e dev:版本推进 2026-05-18 16:03:53 +08:00
453ef0d4da chore: 快捷Skill按钮维护 2026-05-18 16:03:53 +08:00
d2d7d0fc99 chore:隐藏管理记忆的入口 2026-05-18 16:03:53 +08:00
41ac04f9f9 feat: 本地清理旧用户遗留的localstorage 2026-05-18 16:03:53 +08:00
92b6bcc5fb feat(ThreadMemoryPanel): 新增会话记忆下拉面板并完成 i18n 接入 2026-05-18 16:03:53 +08:00
fc9a30c784 fix: 修复图标属性名为小驼峰 2026-05-18 16:03:53 +08:00
e338fa90d6 refactor(memory): 切换线程记忆为纯 memory_json 存储
移除 thread_memory 对 memory_md/Markdown 解析的运行时依赖,仅保留 memory_json 读写路径。\n同步更新 SQLite/MySQL 存储实现与测试基线,并补充迁移文档的最终状态说明。
2026-05-18 16:03:53 +08:00
86a1460d5e fix(workspace): 修复复制消息时误带隐藏上下文内容 2026-05-18 16:03:53 +08:00
88732e58c4 feat: 使用大模型美观输出,等待用户输入之后,大模型输出规范json,再反序列化存入数据库。 2026-05-18 16:03:53 +08:00
1c14be0c33 fix(thread-memory): 修复语言识别与队列健壮性 2026-05-18 16:03:53 +08:00
cba81112fd feat: 对齐df的注入模式 2026-05-18 16:03:53 +08:00
03aa9dd8f8 feat:写入跟用户相同的语言的记忆 2026-05-18 16:03:53 +08:00
31daed1887 feat: 数据结构向df的memory.json对齐 2026-05-18 16:03:53 +08:00
7db468aa6f feat: 增加MD列 2026-05-18 16:03:53 +08:00
b49e838980 feat:json会话记忆 2026-05-18 16:03:53 +08:00
6197a1c14d feat: 工具调用的description使用用户的语言 2026-05-18 16:03:53 +08:00
57 changed files with 3629 additions and 232 deletions

7
.githooks/pre-push Executable file
View File

@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
REPO_ROOT="$(git rev-parse --show-toplevel)"
cd "$REPO_ROOT"
"$REPO_ROOT/scripts/git/pre-push-rebase.sh"

3
.gitignore vendored
View File

@ -42,9 +42,6 @@ skills/
logs/ logs/
log/ log/
# Local git hooks (keep only on this machine, do not push)
.githooks/
# pnpm # pnpm
.pnpm-store .pnpm-store
sandbox_image_cache.tar sandbox_image_cache.tar

View File

@ -1,6 +1,6 @@
# DeerFlow - Unified Development Environment # DeerFlow - Unified Development Environment
.PHONY: help config config-upgrade check install dev dev-daemon start stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway .PHONY: help config config-upgrade check install hooks-install dev dev-daemon start stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
BASH ?= bash BASH ?= bash
@ -18,6 +18,7 @@ help:
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml" @echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
@echo " make check - Check if all required tools are installed" @echo " make check - Check if all required tools are installed"
@echo " make install - Install all dependencies (frontend + backend)" @echo " make install - Install all dependencies (frontend + backend)"
@echo " make hooks-install - Install local Git hooks for this repo"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)" @echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)" @echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-daemon - Start all services in background (daemon mode)" @echo " make dev-daemon - Start all services in background (daemon mode)"
@ -63,6 +64,10 @@ install:
@echo " make setup-sandbox" @echo " make setup-sandbox"
@echo "" @echo ""
# Install repository-local Git hooks
hooks-install:
@./scripts/git/install-hooks.sh
# Pre-pull sandbox Docker image (optional but recommended) # Pre-pull sandbox Docker image (optional but recommended)
setup-sandbox: setup-sandbox:
@echo "==========================================" @echo "=========================================="

View File

@ -192,6 +192,14 @@ make down # 停止并移除容器
make install # 安装 backend + frontend 依赖 make install # 安装 backend + frontend 依赖
``` ```
如果你希望每次 `git push` 之前自动把当前分支 rebase 到 `origin/git-main`,可以再执行:
```bash
make hooks-install
```
这会把仓库的 `core.hooksPath` 指向 `.githooks/`,启用 `pre-push` hook。
3. **(可选)预拉取 sandbox 镜像** 3. **(可选)预拉取 sandbox 镜像**
```bash ```bash
# 如果使用 Docker / Container sandbox建议先执行 # 如果使用 Docker / Container sandbox建议先执行

View File

@ -21,7 +21,13 @@ from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_store from app.gateway.deps import get_checkpointer, get_store
from deerflow.agents.memory.thread_summary import (
ThreadMemoryConflictError,
apply_thread_memory_summary,
render_thread_memory_summary,
)
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.agents.memory.thread_storage import delete_thread_memory_data
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -121,6 +127,27 @@ class ThreadHistoryRequest(BaseModel):
before: str | None = Field(default=None, description="Cursor for pagination") before: str | None = Field(default=None, description="Cursor for pagination")
class ThreadMemorySummaryResponse(BaseModel):
threadId: str
memoryVersion: int
summary: str
class ThreadMemorySummaryUpdateRequest(BaseModel):
summary: str = Field(..., min_length=1, description="User-edited natural language memory summary")
memoryVersion: int = Field(..., ge=0, description="Expected memory version for CAS update")
class ThreadMemoryRecordResponse(BaseModel):
threadId: str
ownerId: str | None = None
user: dict[str, Any] = Field(default_factory=dict)
history: dict[str, Any] = Field(default_factory=dict)
facts: list[dict[str, Any]] = Field(default_factory=list)
memoryVersion: int = 0
lastUpdated: str = ""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -244,6 +271,17 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
return response return response
@router.delete("/{thread_id}/memory", response_model=ThreadDeleteResponse)
async def delete_thread_memory(thread_id: str) -> ThreadDeleteResponse:
"""Delete per-thread memory only (explicit trigger)."""
try:
delete_thread_memory_data(thread_id)
except Exception as exc:
logger.exception("Failed to delete thread memory for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to delete thread memory.") from exc
return ThreadDeleteResponse(success=True, message=f"Deleted thread memory for {thread_id}")
@router.post("", response_model=ThreadResponse) @router.post("", response_model=ThreadResponse)
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread. """Create a new thread.
@ -680,3 +718,27 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
raise HTTPException(status_code=500, detail="Failed to get thread history") raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries return entries
@router.get("/{thread_id}/memory-summary", response_model=ThreadMemorySummaryResponse)
async def get_thread_memory_summary(thread_id: str) -> ThreadMemorySummaryResponse:
"""Render per-thread memory as human-readable text for user inspection/editing."""
try:
payload = render_thread_memory_summary(thread_id)
except Exception as exc:
logger.exception("Failed to render thread memory summary for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to render thread memory summary.") from exc
return ThreadMemorySummaryResponse(**payload)
@router.post("/{thread_id}/memory-summary", response_model=ThreadMemoryRecordResponse)
async def update_thread_memory_summary(thread_id: str, body: ThreadMemorySummaryUpdateRequest) -> ThreadMemoryRecordResponse:
"""Apply edited natural-language summary back into structured thread memory."""
try:
payload = apply_thread_memory_summary(thread_id, body.summary, body.memoryVersion)
except ThreadMemoryConflictError as exc:
raise HTTPException(status_code=409, detail="Thread memory changed; refresh and retry.") from exc
except Exception as exc:
logger.exception("Failed to apply thread memory summary for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to apply thread memory summary.") from exc
return ThreadMemoryRecordResponse(**payload)

View File

@ -0,0 +1,65 @@
# Thread Memory Storage Migration: `memory_md` -> `memory_json`
## Summary
Per-thread memory now uses `thread_memory.memory_json` as the primary storage format.
- New writes persist structured JSON into `memory_json`.
- Reads prefer `memory_json`.
- Runtime no longer depends on `memory_md`.
## Why
`memory_md` stores structured state inside Markdown fenced blocks. This is readable for humans, but costly for:
- querying and analytics
- schema evolution
- migration reliability
`memory_json` keeps the same logical payload while making storage machine-friendly.
## Runtime behavior
- Read path uses `memory_json` only.
- Write path uses `memory_json` only.
## Auto migration behavior
- SQLite: on startup, adds `memory_json` column when missing.
- MySQL: on startup, adds `memory_json` column when missing.
No destructive migration is required for existing data.
## One-shot operational backfill (legacy command)
For faster cleanup in production, run:
```bash
cd backend
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/backfill_thread_memory_json.py --dry-run
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/backfill_thread_memory_json.py
```
Current codebase keeps this command for compatibility. In fully migrated environments it returns zero legacy rows.
## Final cleanup: drop `memory_md` column
After confirming all environments are migrated, run:
```bash
cd backend
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/drop_thread_memory_md_column.py --dry-run
UV_CACHE_DIR=/tmp/uv-cache uv run python scripts/drop_thread_memory_md_column.py
```
Notes:
- SQLite migration rebuilds `thread_memory` table and preserves data.
- MySQL migration runs `ALTER TABLE ... DROP COLUMN memory_md`.
## Follow-up (optional)
After all active environments have fully migrated and no legacy rows remain:
1. backfill any remaining rows that still rely on `memory_md`
2. remove `memory_md` column from schema
3. remove Markdown parsing fallback code

View File

@ -391,9 +391,34 @@ def _get_memory_context(agent_name: str | None = None) -> str:
""" """
try: try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.agents.memory.thread_prompt import format_thread_memory_for_injection
from deerflow.agents.memory.thread_storage import get_thread_memory_data
from deerflow.config.memory_config import get_memory_config from deerflow.config.memory_config import get_memory_config
from deerflow.config.thread_memory_config import get_thread_memory_config
from langgraph.config import get_config
config = get_memory_config() config = get_memory_config()
thread_config = get_thread_memory_config()
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if thread_config.enabled and thread_config.injection_enabled and thread_id:
thread_memory = get_thread_memory_data(thread_id)
if thread_memory is not None:
thread_content = format_thread_memory_for_injection(
{
"user": thread_memory.get("user", {}),
"history": thread_memory.get("history", {}),
"facts": thread_memory.get("facts", []),
},
max_tokens=thread_config.max_injection_tokens,
)
if thread_content.strip():
return f"""<memory>
{thread_content}
</memory>
"""
if not config.enabled or not config.injection_enabled: if not config.enabled or not config.injection_enabled:
return "" return ""

View File

@ -0,0 +1,143 @@
"""Prompt and formatting helpers for per-thread memory."""
from __future__ import annotations
import json
import re
from typing import Any
from langchain_core.messages import HumanMessage
from deerflow.agents.memory.prompt import format_conversation_for_update, format_memory_for_injection
THREAD_MEMORY_UPDATE_PROMPT = """You are a user profile memory system.
Current per-thread memory:
<existing_memory>
{existing_memory}
</existing_memory>
Conversation:
<conversation>
{conversation}
</conversation>
Preferred memory language: {preferred_language}
Return JSON only with this schema:
{{
"user": {{
"workContext": {{"summary": string, "updatedAt": string}},
"personalContext": {{"summary": string, "updatedAt": string}},
"topOfMind": {{"summary": string, "updatedAt": string}}
}},
"history": {{
"recentMonths": {{"summary": string, "updatedAt": string}},
"earlierContext": {{"summary": string, "updatedAt": string}},
"longTermBackground": {{"summary": string, "updatedAt": string}}
}},
"facts": [
{{
"content": string,
"category": "tech_stack"|"preference"|"personal"|"context"|"goal",
"confidence": number
}}
]
}}
Rules:
- Keep only stable and useful user profile facts.
- Do not store sensitive personal data (phone/email/address/password/token/id/bank).
- Deduplicate and keep high-confidence facts.
- Write all human-readable text fields (`summary`, `content`, and similar prose) in the preferred memory language.
- Return valid JSON only.
"""
def create_empty_thread_memory() -> dict[str, Any]:
return {
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
def _extract_human_text(content: Any) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
chunks: list[str] = []
for item in content:
if isinstance(item, str):
stripped = item.strip()
if stripped:
chunks.append(stripped)
elif isinstance(item, dict):
text_val = item.get("text")
if isinstance(text_val, str):
stripped = text_val.strip()
if stripped:
chunks.append(stripped)
return "\n".join(chunks).strip()
return ""
def _infer_preferred_memory_language(messages: list[Any]) -> str:
user_texts: list[str] = []
for msg in messages:
if isinstance(msg, HumanMessage):
extracted = _extract_human_text(getattr(msg, "content", None))
if extracted:
user_texts.append(extracted)
if not user_texts:
return "same as the user's latest message"
# Prioritize the latest user message; fallback to a short recent window.
recent_window = user_texts[-3:]
language_sample = "\n".join(recent_window)
# If user explicitly provides locale hints, prefer them.
locale_match = re.search(r"\b([a-z]{2}-[A-Z]{2})\b", language_sample)
if locale_match:
return locale_match.group(1)
# Script-based heuristic (dynamic, not hard-coded to two languages).
script_patterns = {
"zh-Hans": r"[\u4e00-\u9fff]",
"ja-JP": r"[\u3040-\u30ff]",
"ko-KR": r"[\uac00-\ud7af]",
"ru-RU": r"[\u0400-\u04FF]",
"ar": r"[\u0600-\u06FF]",
"hi-IN": r"[\u0900-\u097F]",
"th-TH": r"[\u0E00-\u0E7F]",
"he-IL": r"[\u0590-\u05FF]",
"el-GR": r"[\u0370-\u03FF]",
}
counts = {lang: len(re.findall(pattern, language_sample)) for lang, pattern in script_patterns.items()}
best_lang, best_count = max(counts.items(), key=lambda item: item[1])
if best_count > 0:
return best_lang
# Latin-script fallback: ask model to keep same language as the user's latest message.
return "same as the user's latest message"
def format_thread_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
return format_memory_for_injection(memory_data, max_tokens=max_tokens)
def build_thread_memory_prompt(existing_memory: dict[str, Any], messages: list[Any]) -> str:
return THREAD_MEMORY_UPDATE_PROMPT.format(
existing_memory=json.dumps(existing_memory, ensure_ascii=False, indent=2),
conversation=format_conversation_for_update(messages),
preferred_language=_infer_preferred_memory_language(messages),
)

View File

@ -0,0 +1,76 @@
"""Debounced queue for per-thread memory updates."""
from __future__ import annotations
import threading
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from deerflow.config.thread_memory_config import get_thread_memory_config
@dataclass
class ThreadConversationContext:
thread_id: str
messages: list[Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
class ThreadMemoryUpdateQueue:
def __init__(self):
self._queue_by_thread: dict[str, ThreadConversationContext] = {}
self._lock = threading.Lock()
self._timers: dict[str, threading.Timer] = {}
self._processing_threads: set[str] = set()
def add(self, thread_id: str, messages: list[Any]) -> None:
config = get_thread_memory_config()
if not config.enabled:
return
with self._lock:
self._queue_by_thread[thread_id] = ThreadConversationContext(thread_id=thread_id, messages=messages)
self._reset_timer(thread_id)
def _reset_timer(self, thread_id: str) -> None:
config = get_thread_memory_config()
timer = self._timers.get(thread_id)
if timer is not None:
timer.cancel()
timer = threading.Timer(config.debounce_seconds, self._process_thread, args=(thread_id,))
timer.daemon = True
self._timers[thread_id] = timer
timer.start()
def _process_thread(self, thread_id: str) -> None:
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
with self._lock:
if thread_id in self._processing_threads:
self._reset_timer(thread_id)
return
context = self._queue_by_thread.pop(thread_id, None)
if context is None:
self._timers.pop(thread_id, None)
return
self._processing_threads.add(thread_id)
self._timers.pop(thread_id, None)
try:
updater = ThreadMemoryUpdater()
updater.update_memory(context.messages, context.thread_id)
finally:
with self._lock:
self._processing_threads.discard(thread_id)
_thread_queue: ThreadMemoryUpdateQueue | None = None
_lock = threading.Lock()
def get_thread_memory_queue() -> ThreadMemoryUpdateQueue:
global _thread_queue
with _lock:
if _thread_queue is None:
_thread_queue = ThreadMemoryUpdateQueue()
return _thread_queue

View File

@ -0,0 +1,279 @@
"""Storage providers for per-thread memory."""
from __future__ import annotations
import abc
import json
import logging
import sqlite3
import threading
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
from deerflow.config.paths import get_paths
from deerflow.config.thread_memory_config import get_thread_memory_config
logger = logging.getLogger(__name__)
class ThreadMemoryStorage(abc.ABC):
@abc.abstractmethod
def load(self, thread_id: str) -> dict[str, Any] | None:
pass
@abc.abstractmethod
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
pass
@abc.abstractmethod
def delete(self, thread_id: str) -> bool:
pass
def _row_to_memory(row: tuple[Any, ...]) -> dict[str, Any]:
thread_id, owner_id_col, memory_json_raw, memory_version, last_updated = row
decoded: dict[str, Any] = {}
if isinstance(memory_json_raw, str) and memory_json_raw.strip():
try:
parsed_json = json.loads(memory_json_raw)
if isinstance(parsed_json, dict):
decoded = parsed_json
except Exception:
decoded = {}
owner_id = decoded.get("ownerId")
if owner_id is None:
owner_id = owner_id_col
user = decoded.get("user", create_empty_thread_memory()["user"])
history = decoded.get("history", create_empty_thread_memory()["history"])
facts = decoded.get("facts", [])
return {
"threadId": thread_id,
"ownerId": owner_id,
"user": user,
"history": history,
"facts": facts,
"memoryVersion": int(memory_version),
"lastUpdated": str(last_updated),
}
class SqliteThreadMemoryStorage(ThreadMemoryStorage):
def __init__(self, db_path: str):
path = Path(db_path)
if not path.is_absolute():
path = get_paths().base_dir / path
path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(path), check_same_thread=False)
self._lock = threading.Lock()
with self._lock:
self._conn.execute(
"""
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY,
owner_id TEXT NULL,
memory_json TEXT NOT NULL DEFAULT '',
memory_version INTEGER NOT NULL DEFAULT 0,
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
)
"""
)
self._ensure_memory_json_column()
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id)")
self._conn.commit()
def _ensure_memory_json_column(self) -> None:
columns = self._conn.execute("PRAGMA table_info(thread_memory)").fetchall()
has_memory_json = any(col[1] == "memory_json" for col in columns)
if not has_memory_json:
self._conn.execute("ALTER TABLE thread_memory ADD COLUMN memory_json TEXT NOT NULL DEFAULT ''")
def load(self, thread_id: str) -> dict[str, Any] | None:
with self._lock:
row = self._conn.execute(
"SELECT thread_id, owner_id, memory_json, memory_version, last_updated "
"FROM thread_memory WHERE thread_id = ?",
(thread_id,),
).fetchone()
if row is None:
return None
return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
now = datetime.now(UTC).isoformat().replace("+00:00", "Z")
owner_id = data.get("ownerId")
if expected_version is None:
expected_version = 0
with self._lock:
cur = self._conn.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version, last_updated)
VALUES (?, ?, ?, 0, ?)
ON CONFLICT(thread_id) DO NOTHING
""",
(
thread_id,
owner_id,
json.dumps(data, ensure_ascii=False),
now,
),
)
if cur.rowcount == 1:
self._conn.commit()
return True
cur = self._conn.execute(
"""
UPDATE thread_memory
SET owner_id = ?, memory_json = ?, memory_version = memory_version + 1, last_updated = ?
WHERE thread_id = ? AND memory_version = ?
""",
(
owner_id,
json.dumps(data, ensure_ascii=False),
now,
thread_id,
expected_version,
),
)
self._conn.commit()
return cur.rowcount == 1
def delete(self, thread_id: str) -> bool:
with self._lock:
self._conn.execute("DELETE FROM thread_memory WHERE thread_id = ?", (thread_id,))
self._conn.commit()
return True
def count_legacy_rows(self) -> int:
return 0
def backfill_legacy_rows(self, *, limit: int | None = None) -> dict[str, int]:
_ = limit
return {"scanned": 0, "updated": 0, "skipped": 0, "failed": 0}
class MysqlThreadMemoryStorage(ThreadMemoryStorage):
def __init__(self, host: str, port: int, user: str, password: str, database: str):
import pymysql
self._conn = pymysql.connect(host=host, port=port, user=user, password=password, database=database, charset="utf8mb4")
with self._conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY,
owner_id VARCHAR(64) NULL,
memory_json LONGTEXT NOT NULL,
memory_version INT NOT NULL DEFAULT 0,
last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_owner_id (owner_id)
)
"""
)
cur.execute("SHOW COLUMNS FROM thread_memory LIKE 'memory_json'")
if cur.fetchone() is None:
cur.execute("ALTER TABLE thread_memory ADD COLUMN memory_json LONGTEXT NOT NULL DEFAULT ''")
self._conn.commit()
def load(self, thread_id: str) -> dict[str, Any] | None:
with self._conn.cursor() as cur:
cur.execute(
"SELECT thread_id, owner_id, memory_json, memory_version, last_updated FROM thread_memory WHERE thread_id = %s",
(thread_id,),
)
row = cur.fetchone()
if row is None:
return None
return _row_to_memory(row)
def save(self, thread_id: str, data: dict[str, Any], expected_version: int | None = None) -> bool:
if expected_version is None:
expected_version = 0
owner_id = data.get("ownerId")
with self._conn.cursor() as cur:
cur.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version)
VALUES (%s, %s, %s, 0)
ON DUPLICATE KEY UPDATE thread_id = thread_id
""",
(
thread_id,
owner_id,
json.dumps(data, ensure_ascii=False),
),
)
if cur.rowcount == 1:
self._conn.commit()
return True
cur.execute(
"""
UPDATE thread_memory
SET owner_id = %s, memory_json = %s, memory_version = memory_version + 1
WHERE thread_id = %s AND memory_version = %s
""",
(
owner_id,
json.dumps(data, ensure_ascii=False),
thread_id,
expected_version,
),
)
self._conn.commit()
return cur.rowcount == 1
def delete(self, thread_id: str) -> bool:
with self._conn.cursor() as cur:
cur.execute("DELETE FROM thread_memory WHERE thread_id = %s", (thread_id,))
self._conn.commit()
return True
def count_legacy_rows(self) -> int:
return 0
def backfill_legacy_rows(self, *, limit: int | None = None) -> dict[str, int]:
_ = limit
return {"scanned": 0, "updated": 0, "skipped": 0, "failed": 0}
_thread_storage: ThreadMemoryStorage | None = None
_thread_storage_lock = threading.Lock()
def get_thread_memory_storage() -> ThreadMemoryStorage:
global _thread_storage
if _thread_storage is not None:
return _thread_storage
with _thread_storage_lock:
if _thread_storage is not None:
return _thread_storage
config = get_thread_memory_config()
if config.database.type == "mysql":
mysql = config.database.mysql
_thread_storage = MysqlThreadMemoryStorage(
host=mysql.host,
port=mysql.port,
user=mysql.user,
password=mysql.password,
database=mysql.database,
)
else:
_thread_storage = SqliteThreadMemoryStorage(config.database.sqlite.path)
return _thread_storage
def get_thread_memory_data(thread_id: str) -> dict[str, Any] | None:
return get_thread_memory_storage().load(thread_id)
def delete_thread_memory_data(thread_id: str) -> bool:
return get_thread_memory_storage().delete(thread_id)
def initial_thread_memory_record() -> dict[str, Any]:
return {"ownerId": None, **create_empty_thread_memory()}

View File

@ -0,0 +1,300 @@
"""Thread memory summary generation and application helpers."""
from __future__ import annotations
import json
import logging
import re
import hashlib
from typing import Any
from deerflow.agents.memory.thread_prompt import create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
from deerflow.agents.memory.updater import _extract_text
from deerflow.config.thread_memory_config import get_thread_memory_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
SUMMARY_RENDER_PROMPT = """You are an assistant that renders thread memory into natural language.
Thread memory JSON:
<memory_json>
{memory_json}
</memory_json>
Task:
- Output a concise, human-friendly editable profile summary.
- Keep the original language of the memory content where possible.
- Cover user profile, history, and key facts.
- Return plain text only (no markdown code fences).
"""
SUMMARY_PARSE_PROMPT = """You convert user-edited natural-language memory into a structured patch JSON.
Current thread memory JSON:
<current_memory_json>
{current_memory_json}
</current_memory_json>
Edited summary text:
<edited_summary>
{edited_summary}
</edited_summary>
Return JSON only with this schema (all fields optional):
{{
"user": {{
"workContext": {{"summary": string}},
"personalContext": {{"summary": string}},
"topOfMind": {{"summary": string}}
}},
"history": {{
"recentMonths": {{"summary": string}},
"earlierContext": {{"summary": string}},
"longTermBackground": {{"summary": string}}
}},
"facts": [
{{
"content": string,
"category": "preference"|"knowledge"|"context"|"behavior"|"goal"|"correction",
"confidence": number
}}
]
}}
"""
class ThreadMemoryConflictError(RuntimeError):
"""Raised when compare-and-swap save fails due to version mismatch."""
def _get_summary_model():
config = get_thread_memory_config()
return create_chat_model(name=config.model_name, thinking_enabled=False, stream_usage=False)
def _strip_code_fence(text: str) -> str:
cleaned = text.strip()
if not cleaned.startswith("```"):
return cleaned
lines = cleaned.split("\n")
return "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]).strip()
def _extract_json_object(text: str) -> dict[str, Any] | None:
cleaned = _strip_code_fence(text)
try:
parsed = json.loads(cleaned)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
repaired = _escape_inner_quotes_in_json_strings(cleaned)
if repaired != cleaned:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=full_text")
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
candidate = match.group(0)
repaired = _escape_inner_quotes_in_json_strings(candidate)
if repaired != candidate:
try:
parsed = json.loads(repaired)
if isinstance(parsed, dict):
logger.warning("THREAD_SUMMARY_DEBUG parse_repaired mode=regex_object")
return parsed
except json.JSONDecodeError:
return None
return None
def _escape_inner_quotes_in_json_strings(text: str) -> str:
"""Heuristically repair unescaped inner double quotes inside JSON strings.
If a quote appears while inside a string but the next non-space character is
not a valid string terminator (comma, object/array close, or key colon), it is
treated as content and escaped.
"""
out: list[str] = []
in_string = False
escape = False
n = len(text)
i = 0
while i < n:
ch = text[i]
if not in_string:
out.append(ch)
if ch == '"':
in_string = True
i += 1
continue
if escape:
out.append(ch)
escape = False
i += 1
continue
if ch == "\\":
out.append(ch)
escape = True
i += 1
continue
if ch == '"':
j = i + 1
while j < n and text[j].isspace():
j += 1
next_char = text[j] if j < n else ""
# Valid JSON string terminators in context:
# - key string: :
# - value string: , } ]
if next_char in {":", ",", "}", "]", ""}:
out.append(ch)
in_string = False
else:
out.append('\\"')
i += 1
continue
out.append(ch)
i += 1
return "".join(out)
def _merge_summary_patch(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
merged = {"ownerId": base.get("ownerId"), **create_empty_thread_memory()}
merged["user"] = dict(base.get("user", {})) if isinstance(base.get("user"), dict) else merged["user"]
merged["history"] = dict(base.get("history", {})) if isinstance(base.get("history"), dict) else merged["history"]
merged["facts"] = list(base.get("facts", [])) if isinstance(base.get("facts"), list) else []
for section_name in ("user", "history"):
section_patch = patch.get(section_name, {})
if not isinstance(section_patch, dict):
continue
for key, value in section_patch.items():
if key not in merged[section_name] or not isinstance(value, dict):
continue
summary = value.get("summary")
if isinstance(summary, str):
merged[section_name][key]["summary"] = summary
facts_patch = patch.get("facts")
if isinstance(facts_patch, list):
merged["facts"] = facts_patch
return merged
def render_thread_memory_summary(thread_id: str) -> dict[str, Any]:
storage = get_thread_memory_storage()
current = storage.load(thread_id)
memory = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
memory_payload = {
"user": memory.get("user", {}),
"history": memory.get("history", {}),
"facts": memory.get("facts", []),
}
prompt = SUMMARY_RENDER_PROMPT.format(memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2))
response = _get_summary_model().invoke(prompt)
text = _strip_code_fence(_extract_text(response.content))
return {
"threadId": thread_id,
"memoryVersion": int(memory.get("memoryVersion", 0)),
"summary": text,
}
def apply_thread_memory_summary(thread_id: str, edited_summary: str, expected_version: int) -> dict[str, Any]:
storage = get_thread_memory_storage()
current = storage.load(thread_id)
base = {"ownerId": None, **create_empty_thread_memory()} if current is None else current
memory_payload = {
"user": base.get("user", {}),
"history": base.get("history", {}),
"facts": base.get("facts", []),
}
prompt = SUMMARY_PARSE_PROMPT.format(
current_memory_json=json.dumps(memory_payload, ensure_ascii=False, indent=2),
edited_summary=edited_summary,
)
response = _get_summary_model().invoke(prompt)
raw = _extract_text(response.content)
raw_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
logger.warning(
"THREAD_SUMMARY_DEBUG parse_raw_meta thread=%s raw_length=%d raw_sha256=%s",
thread_id,
len(raw),
raw_hash,
)
patch = _extract_json_object(raw)
if patch is None:
cleaned = _strip_code_fence(raw)
decode_error = None
try:
json.loads(cleaned)
except json.JSONDecodeError as exc:
decode_error = exc
if decode_error is not None:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=%s line=%d col=%d pos=%d snippet=%r",
thread_id,
decode_error.msg,
decode_error.lineno,
decode_error.colno,
decode_error.pos,
cleaned[max(0, decode_error.pos - 80): decode_error.pos + 80],
)
else:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_error thread=%s msg=no_json_object_extracted raw_head=%r",
thread_id,
cleaned[:200],
)
logger.warning("THREAD_SUMMARY_DEBUG parse_fallback thread=%s", thread_id)
patch = {
"user": {
"topOfMind": {
"summary": edited_summary.strip(),
}
}
}
else:
logger.warning(
"THREAD_SUMMARY_DEBUG parse_success thread=%s patch=%s",
thread_id,
json.dumps(patch, ensure_ascii=False)[:2000],
)
merged = _merge_summary_patch(base, patch if isinstance(patch, dict) else {})
cleaned = ThreadMemoryUpdater()._scrub_sensitive(merged, thread_id)
cleaned["ownerId"] = base.get("ownerId")
logger.warning(
"THREAD_SUMMARY_DEBUG apply_cleaned thread=%s cleaned=%s",
thread_id,
json.dumps(
{
"user": cleaned.get("user", {}),
"history": cleaned.get("history", {}),
"facts_count": len(cleaned.get("facts", []) if isinstance(cleaned.get("facts"), list) else []),
},
ensure_ascii=False,
)[:2000],
)
if not storage.save(thread_id, cleaned, expected_version=expected_version):
raise ThreadMemoryConflictError(f"Thread memory version conflict for {thread_id}")
latest = storage.load(thread_id)
return latest if latest is not None else {"threadId": thread_id, "memoryVersion": expected_version, **cleaned}

View File

@ -0,0 +1,148 @@
"""Per-thread memory updater."""
from __future__ import annotations
import json
import logging
import re
import uuid
from datetime import UTC, datetime
from typing import Any
from deerflow.agents.memory.updater import _extract_text
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, create_empty_thread_memory
from deerflow.agents.memory.thread_storage import get_thread_memory_storage
from deerflow.config.thread_memory_config import get_thread_memory_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
_SENSITIVE_PATTERNS = (
re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"),
re.compile(r"\b(?:\+?\d[\d -]{7,}\d)\b"),
re.compile(r"\b(?:api[_-]?key|token|password|passwd|secret)\b", re.IGNORECASE),
re.compile(r"\b\d{15,19}\b"), # bank-card like
)
class ThreadMemoryUpdater:
def __init__(self, model_name: str | None = None):
self._model_name = model_name
def _get_model(self):
config = get_thread_memory_config()
# Non-stream invoke path: some OpenAI-compatible gateways reject
# stream_options when stream=false, so force stream_usage off here.
return create_chat_model(
name=self._model_name or config.model_name,
thinking_enabled=False,
stream_usage=False,
)
def _scrub_sensitive(self, data: dict[str, Any], thread_id: str) -> dict[str, Any]:
def safe_confidence(val: Any, default: float = 0.5) -> float:
try:
parsed = float(val)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, parsed))
def safe_text(val: Any) -> str | None:
if not isinstance(val, str):
return None
text = val.strip()
if not text:
return None
if any(p.search(text) for p in _SENSITIVE_PATTERNS):
logger.info("thread_memory sensitive value dropped for thread=%s", thread_id)
return None
return text
user = data.get("user", {})
history = data.get("history", {})
facts = data.get("facts", [])
cleaned = create_empty_thread_memory()
def copy_summary_section(target_parent: dict[str, Any], target_key: str, source_parent: Any):
if not isinstance(source_parent, dict):
return
source_section = source_parent.get(target_key)
if not isinstance(source_section, dict):
return
summary = safe_text(source_section.get("summary"))
updated_at = safe_text(source_section.get("updatedAt"))
if summary:
target_parent[target_key]["summary"] = summary
if updated_at:
target_parent[target_key]["updatedAt"] = updated_at
elif summary:
target_parent[target_key]["updatedAt"] = datetime.now(UTC).isoformat().replace("+00:00", "Z")
copy_summary_section(cleaned["user"], "workContext", user)
copy_summary_section(cleaned["user"], "personalContext", user)
copy_summary_section(cleaned["user"], "topOfMind", user)
copy_summary_section(cleaned["history"], "recentMonths", history)
copy_summary_section(cleaned["history"], "earlierContext", history)
copy_summary_section(cleaned["history"], "longTermBackground", history)
seen: set[str] = set()
for fact in facts if isinstance(facts, list) else []:
if not isinstance(fact, dict):
continue
content = safe_text(fact.get("content"))
if not content:
continue
key = content.casefold()
if key in seen:
continue
seen.add(key)
confidence = safe_confidence(fact.get("confidence", 0.5))
cleaned["facts"].append(
{
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": content,
"category": str(fact.get("category", "context")).strip() or "context",
"confidence": confidence,
"createdAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"source": thread_id,
}
)
return cleaned
def update_memory(self, messages: list[Any], thread_id: str) -> bool:
config = get_thread_memory_config()
if not config.enabled or not messages or not thread_id:
return False
storage = get_thread_memory_storage()
current = storage.load(thread_id)
base_memory = create_empty_thread_memory() if current is None else {
"user": current.get("user", {}),
"history": current.get("history", {}),
"facts": current.get("facts", []),
}
prompt = build_thread_memory_prompt(base_memory, messages)
if not prompt.strip():
return False
try:
response = self._get_model().invoke(prompt)
response_text = _extract_text(response.content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
parsed = json.loads(response_text)
cleaned = self._scrub_sensitive(parsed, thread_id)
expected_version = 0 if current is None else int(current.get("memoryVersion", 0))
if storage.save(thread_id, cleaned, expected_version=expected_version):
return True
# conflict retry once
latest = storage.load(thread_id)
latest_version = 0 if latest is None else int(latest.get("memoryVersion", 0))
logger.info("thread_memory conflict detected, retrying once: thread=%s version=%s", thread_id, latest_version)
return storage.save(thread_id, cleaned, expected_version=latest_version)
except Exception:
logger.exception("Thread memory update failed for thread=%s", thread_id)
return False

View File

@ -10,7 +10,9 @@ from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.memory.thread_queue import get_thread_memory_queue
from deerflow.config.memory_config import get_memory_config from deerflow.config.memory_config import get_memory_config
from deerflow.config.thread_memory_config import get_thread_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -206,8 +208,9 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns: Returns:
None (no state changes needed from this middleware). None (no state changes needed from this middleware).
""" """
config = get_memory_config() global_config = get_memory_config()
if not config.enabled: thread_config = get_thread_memory_config()
if not global_config.enabled and not thread_config.enabled:
return None return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata # Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
@ -239,13 +242,19 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update # Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages) correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue() if global_config.enabled:
queue.add( queue = get_memory_queue()
thread_id=thread_id, queue.add(
messages=filtered_messages, thread_id=thread_id,
agent_name=self._agent_name, messages=filtered_messages,
correction_detected=correction_detected, agent_name=self._agent_name,
reinforcement_detected=reinforcement_detected, correction_detected=correction_detected,
) reinforcement_detected=reinforcement_detected,
)
if thread_config.enabled:
get_thread_memory_queue().add(
thread_id=thread_id,
messages=filtered_messages,
)
return None return None

View File

@ -2,6 +2,7 @@ from .app_config import get_app_config
from .billing_config import BillingConfig from .billing_config import BillingConfig
from .extensions_config import ExtensionsConfig, get_extensions_config from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config from .memory_config import MemoryConfig, get_memory_config
from .thread_memory_config import ThreadMemoryConfig, get_thread_memory_config
from .paths import Paths, get_paths from .paths import Paths, get_paths
from .skills_config import SkillsConfig from .skills_config import SkillsConfig
from .tracing_config import ( from .tracing_config import (
@ -22,6 +23,8 @@ __all__ = [
"get_extensions_config", "get_extensions_config",
"MemoryConfig", "MemoryConfig",
"get_memory_config", "get_memory_config",
"ThreadMemoryConfig",
"get_thread_memory_config",
"get_tracing_config", "get_tracing_config",
"get_explicitly_enabled_tracing_providers", "get_explicitly_enabled_tracing_providers",
"get_enabled_tracing_providers", "get_enabled_tracing_providers",

View File

@ -25,6 +25,7 @@ from deerflow.config.title_config import TitleConfig, load_title_config_from_dic
from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.thread_memory_config import ThreadMemoryConfig, load_thread_memory_config_from_dict
load_dotenv() load_dotenv()
@ -55,6 +56,7 @@ class AppConfig(BaseModel):
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration") title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
thread_memory: ThreadMemoryConfig = Field(default_factory=ThreadMemoryConfig, description="Per-thread memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False) model_config = ConfigDict(extra="allow", frozen=False)
@ -118,6 +120,8 @@ class AppConfig(BaseModel):
# Load memory config if present # Load memory config if present
if "memory" in config_data: if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"]) load_memory_config_from_dict(config_data["memory"])
if "thread_memory" in config_data:
load_thread_memory_config_from_dict(config_data["thread_memory"])
# Load subagents config if present # Load subagents config if present
if "subagents" in config_data: if "subagents" in config_data:

View File

@ -0,0 +1,50 @@
"""Configuration for per-thread memory mechanism."""
from pydantic import BaseModel, Field
class ThreadMemorySqliteConfig(BaseModel):
path: str = Field(default="thread_memory.db", description="SQLite database file path")
class ThreadMemoryMysqlConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=3306)
user: str = Field(default="root")
password: str = Field(default="")
database: str = Field(default="deerflow")
class ThreadMemoryDatabaseConfig(BaseModel):
type: str = Field(default="sqlite", description="Database type: sqlite or mysql")
sqlite: ThreadMemorySqliteConfig = Field(default_factory=ThreadMemorySqliteConfig)
mysql: ThreadMemoryMysqlConfig = Field(default_factory=ThreadMemoryMysqlConfig)
class ThreadMemoryConfig(BaseModel):
enabled: bool = Field(default=True)
debounce_seconds: int = Field(default=30, ge=1, le=300)
model_name: str | None = Field(default=None)
max_facts: int = Field(default=100, ge=10, le=500)
fact_confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
injection_enabled: bool = Field(default=True)
max_injection_tokens: int = Field(default=2000, ge=100, le=8000)
bootstrap_from_global: bool = Field(default=False)
database: ThreadMemoryDatabaseConfig = Field(default_factory=ThreadMemoryDatabaseConfig)
_thread_memory_config: ThreadMemoryConfig = ThreadMemoryConfig()
def get_thread_memory_config() -> ThreadMemoryConfig:
return _thread_memory_config
def set_thread_memory_config(config: ThreadMemoryConfig) -> None:
global _thread_memory_config
_thread_memory_config = config
def load_thread_memory_config_from_dict(config_dict: dict) -> None:
global _thread_memory_config
_thread_memory_config = ThreadMemoryConfig(**config_dict)

View File

@ -88,18 +88,24 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
if not has_stream_usage: if not has_stream_usage:
model_settings_from_config["stream_usage"] = True model_settings_from_config["stream_usage"] = True
effective_stream_usage = kwargs.get("stream_usage", model_settings_from_config.get("stream_usage"))
# Some OpenAI-compatible providers only return usage in streaming mode # Some OpenAI-compatible providers only return usage in streaming mode
# when stream_options.include_usage is explicitly enabled. # when stream_options.include_usage is explicitly enabled.
stream_options_source = "kwargs" if "stream_options" in kwargs else "config" if effective_stream_usage:
stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options") stream_options_source = "kwargs" if "stream_options" in kwargs else "config"
if stream_options is None: stream_options = kwargs.get("stream_options") if stream_options_source == "kwargs" else model_settings_from_config.get("stream_options")
model_settings_from_config["stream_options"] = {"include_usage": True} if stream_options is None:
elif isinstance(stream_options, dict) and "include_usage" not in stream_options: model_settings_from_config["stream_options"] = {"include_usage": True}
patched_stream_options = {**stream_options, "include_usage": True} elif isinstance(stream_options, dict) and "include_usage" not in stream_options:
if stream_options_source == "kwargs": patched_stream_options = {**stream_options, "include_usage": True}
kwargs["stream_options"] = patched_stream_options if stream_options_source == "kwargs":
else: kwargs["stream_options"] = patched_stream_options
model_settings_from_config["stream_options"] = patched_stream_options else:
model_settings_from_config["stream_options"] = patched_stream_options
else:
# Some OpenAI-compatible endpoints reject stream_options when stream is false.
model_settings_from_config.pop("stream_options", None)
kwargs.pop("stream_options", None)
except Exception: except Exception:
# Keep model creation robust when langchain_openai isn't available. # Keep model creation robust when langchain_openai isn't available.
pass pass

View File

@ -973,7 +973,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
- Use `python -m pip` (inside the virtual environment) to install Python packages. - Use `python -m pip` (inside the virtual environment) to install Python packages.
Args: Args:
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are running this command in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
command: The bash command to execute. Always use absolute paths for files and directories. command: The bash command to execute. Always use absolute paths for files and directories.
""" """
try: try:
@ -1017,7 +1017,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
"""List the contents of a directory up to 2 levels deep in tree format. """List the contents of a directory up to 2 levels deep in tree format.
Args: Args:
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are listing this directory in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the directory to list. path: The **absolute** path to the directory to list.
""" """
try: try:
@ -1060,7 +1060,7 @@ def glob_tool(
"""Find files or directories that match a glob pattern under a root directory. """Find files or directories that match a glob pattern under a root directory.
Args: Args:
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are searching for these paths in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The glob pattern to match relative to the root path, for example `**/*.py`. pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
path: The **absolute** root directory to search under. path: The **absolute** root directory to search under.
include_dirs: Whether matching directories should also be returned. Default is False. include_dirs: Whether matching directories should also be returned. Default is False.
@ -1112,7 +1112,7 @@ def grep_tool(
"""Search for matching lines inside text files under a root directory. """Search for matching lines inside text files under a root directory.
Args: Args:
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are searching file contents in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The string or regex pattern to search for. pattern: The string or regex pattern to search for.
path: The **absolute** root directory to search under. path: The **absolute** root directory to search under.
glob: Optional glob filter for candidate files, for example `**/*.py`. glob: Optional glob filter for candidate files, for example `**/*.py`.
@ -1179,7 +1179,7 @@ def read_file_tool(
"""Read the contents of a text file. Use this to examine source code, configuration files, logs, or any text-based file. """Read the contents of a text file. Use this to examine source code, configuration files, logs, or any text-based file.
Args: Args:
description: Explain why you are reading this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are reading this file in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the file to read. path: The **absolute** path to the file to read.
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range. start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range. end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
@ -1234,7 +1234,7 @@ def write_file_tool(
"""Write text content to a file. """Write text content to a file.
Args: Args:
description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are writing to this file in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND. path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD. content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
""" """
@ -1276,7 +1276,7 @@ def str_replace_tool(
If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file. If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file.
Args: Args:
description: Explain why you are replacing the substring in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are replacing the substring in short words in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the file to replace the substring in. ALWAYS PROVIDE THIS PARAMETER SECOND. path: The **absolute** path to the file to replace the substring in. ALWAYS PROVIDE THIS PARAMETER SECOND.
old_str: The substring to replace. ALWAYS PROVIDE THIS PARAMETER THIRD. old_str: The substring to replace. ALWAYS PROVIDE THIS PARAMETER THIRD.
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH. new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.

View File

@ -54,7 +54,7 @@ async def task_tool(
- Tasks requiring user interaction or clarification - Tasks requiring user interaction or clarification
Args: Args:
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST. description: A short (3-5 word) description of the task for logging/display, in the user's language. ALWAYS PROVIDE THIS PARAMETER FIRST.
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND. prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max. max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.

View File

@ -0,0 +1,31 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from app.gateway.routers import threads
@pytest.mark.anyio
async def test_delete_thread_does_not_delete_thread_memory():
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(checkpointer=None, store=None)))
with (
patch("app.gateway.routers.threads._delete_thread_data", return_value=threads.ThreadDeleteResponse(success=True, message="ok")),
patch("app.gateway.routers.threads.get_store", return_value=None),
patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory,
):
response = await threads.delete_thread_data("thread-1", request)
assert response.success is True
delete_memory.assert_not_called()
@pytest.mark.anyio
async def test_delete_thread_memory_endpoint_calls_cleanup():
with patch("app.gateway.routers.threads.delete_thread_memory_data") as delete_memory:
response = await threads.delete_thread_memory("thread-1")
assert response.success is True
assert response.message == "Deleted thread memory for thread-1"
delete_memory.assert_called_once_with("thread-1")

View File

@ -0,0 +1,32 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.thread_memory_config import ThreadMemoryConfig
def test_thread_memory_queue_runs_even_if_global_memory_disabled():
middleware = MemoryMiddleware()
state = {"messages": [HumanMessage(content="My name is Alice"), AIMessage(content="Nice to meet you")]}
runtime = SimpleNamespace(context={"thread_id": "thread-test"})
mock_global_queue = MagicMock()
mock_thread_queue = MagicMock()
with (
patch("deerflow.agents.middlewares.memory_middleware.get_memory_config", return_value=MemoryConfig(enabled=False)),
patch(
"deerflow.agents.middlewares.memory_middleware.get_thread_memory_config",
return_value=ThreadMemoryConfig(enabled=True),
),
patch("deerflow.agents.middlewares.memory_middleware.get_memory_queue", return_value=mock_global_queue),
patch("deerflow.agents.middlewares.memory_middleware.get_thread_memory_queue", return_value=mock_thread_queue),
):
middleware.after_agent(state, runtime)
mock_global_queue.add.assert_not_called()
mock_thread_queue.add.assert_called_once()

View File

@ -0,0 +1,81 @@
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.memory.thread_prompt import build_thread_memory_prompt, format_thread_memory_for_injection
def test_thread_memory_injection_keeps_profile_and_preferences_under_small_budget(monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
memory = {
"user": {
"workContext": {"summary": "Building APIs", "updatedAt": "2026-05-08T00:00:00Z"},
"personalContext": {"summary": "Engineer using Python and React", "updatedAt": "2026-05-08T00:00:00Z"},
"topOfMind": {"summary": "Improving thread memory", "updatedAt": "2026-05-08T00:00:00Z"},
},
"history": {
"recentMonths": {"summary": "Shipped memory features", "updatedAt": "2026-05-08T00:00:00Z"},
"earlierContext": {"summary": "Started from TS projects", "updatedAt": "2026-05-08T00:00:00Z"},
"longTermBackground": {"summary": "Frontend developer", "updatedAt": "2026-05-08T00:00:00Z"},
},
"facts": [
{"content": "Fact one that might be trimmed", "category": "context", "confidence": 0.9},
{"content": "Fact two that might be trimmed", "category": "context", "confidence": 0.8},
],
}
result = format_thread_memory_for_injection(memory, max_tokens=140)
assert "User Context:" in result
assert "History:" in result
def test_build_thread_memory_prompt_does_not_raise_format_key_error():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[HumanMessage(content="My name is Alice.")],
)
assert "Current per-thread memory" in prompt
assert '"user"' in prompt
assert "Preferred memory language: same as the user's latest message" in prompt
def test_build_thread_memory_prompt_prefers_chinese_for_chinese_conversation():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[HumanMessage(content="我叫小明,我更喜欢中文交流。")],
)
assert "Preferred memory language: zh-Hans" in prompt
def test_build_thread_memory_prompt_prefers_japanese_for_japanese_conversation():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[HumanMessage(content="私は日本語で会話したいです。")],
)
assert "Preferred memory language: ja-JP" in prompt
def test_build_thread_memory_prompt_uses_user_messages_only_for_language_inference():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(content="请用中文记录记忆"),
AIMessage(content="Sure, I will answer in English with many many words."),
AIMessage(content="More English content that should not change language inference."),
],
)
assert "Preferred memory language: zh-Hans" in prompt
def test_build_thread_memory_prompt_handles_structured_human_content():
prompt = build_thread_memory_prompt(
{"user": {}, "history": {}, "facts": []},
[
HumanMessage(
content=[
{"type": "text", "text": "我希望记忆使用中文。"},
{"type": "text", "text": "请继续。"},
]
),
AIMessage(content="I can also reply in English."),
],
)
assert "Preferred memory language: zh-Hans" in prompt

View File

@ -0,0 +1,33 @@
from unittest.mock import patch
from deerflow.agents.memory.thread_queue import ThreadMemoryUpdateQueue
def test_thread_queue_keeps_latest_message_per_thread():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["msg-1"])
queue.add("thread-b", ["msg-2"])
queue.add("thread-a", ["msg-3"])
assert set(queue._queue_by_thread.keys()) == {"thread-a", "thread-b"}
assert queue._queue_by_thread["thread-a"].messages == ["msg-3"]
def test_thread_queue_processes_single_thread_without_affecting_others():
queue = ThreadMemoryUpdateQueue()
with patch.object(queue, "_reset_timer"):
queue.add("thread-a", ["a-msg"])
queue.add("thread-b", ["b-msg"])
updater_calls: list[tuple[list[str], str]] = []
class _FakeUpdater:
def update_memory(self, messages, thread_id):
updater_calls.append((messages, thread_id))
with patch("deerflow.agents.memory.thread_updater.ThreadMemoryUpdater", _FakeUpdater):
queue._process_thread("thread-a")
assert updater_calls == [(["a-msg"], "thread-a")]
assert "thread-b" in queue._queue_by_thread

View File

@ -0,0 +1,99 @@
import json
from deerflow.agents.memory.thread_storage import SqliteThreadMemoryStorage
def _payload():
return {
"ownerId": None,
"user": {
"workContext": {"summary": "Frontend engineer", "updatedAt": "2026-05-08T00:00:00Z"},
"personalContext": {"summary": "Prefers Chinese", "updatedAt": "2026-05-08T00:00:00Z"},
"topOfMind": {"summary": "Thread memory migration", "updatedAt": "2026-05-08T00:00:00Z"},
},
"history": {
"recentMonths": {"summary": "Worked on memory features", "updatedAt": "2026-05-08T00:00:00Z"},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "Builds web products", "updatedAt": "2026-05-08T00:00:00Z"},
},
"facts": [],
}
def test_sqlite_thread_memory_compare_and_swap(tmp_path):
storage = SqliteThreadMemoryStorage(str(tmp_path / "thread-memory.db"))
thread_id = "thread-1"
assert storage.save(thread_id, _payload(), expected_version=0) is True
loaded = storage.load(thread_id)
assert loaded is not None
assert loaded["memoryVersion"] == 0
# wrong expected version should fail
assert storage.save(thread_id, _payload(), expected_version=9) is False
# correct version should pass and increment
assert storage.save(thread_id, _payload(), expected_version=0) is True
loaded2 = storage.load(thread_id)
assert loaded2 is not None
assert loaded2["memoryVersion"] == 1
def test_sqlite_thread_memory_saves_json_payload(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-md"
assert storage.save(thread_id, _payload(), expected_version=0) is True
with storage._lock:
row = storage._conn.execute("SELECT memory_json FROM thread_memory WHERE thread_id = ?", (thread_id,)).fetchone()
assert row is not None
assert isinstance(row[0], str)
parsed = json.loads(row[0])
assert parsed["user"]["workContext"]["summary"] == "Frontend engineer"
def test_sqlite_thread_memory_uses_owner_id_column_when_json_missing_owner(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
thread_id = "thread-load"
payload = _payload()
with storage._lock:
storage._conn.execute(
"""
INSERT INTO thread_memory (thread_id, owner_id, memory_json, memory_version, last_updated)
VALUES (?, ?, ?, 0, datetime('now'))
""",
(
thread_id,
"owner-1",
json.dumps(
{
"user": payload["user"],
"history": payload["history"],
"facts": [],
},
ensure_ascii=False,
),
),
)
storage._conn.commit()
loaded = storage.load(thread_id)
assert loaded is not None
assert loaded["ownerId"] == "owner-1"
assert loaded["user"]["workContext"]["summary"] == "Frontend engineer"
assert loaded["facts"] == []
def test_sqlite_thread_memory_backfill_is_noop_after_migration(tmp_path):
db_path = tmp_path / "thread-memory.db"
storage = SqliteThreadMemoryStorage(str(db_path))
assert storage.count_legacy_rows() == 0
stats = storage.backfill_legacy_rows()
assert stats["scanned"] == 0
assert stats["updated"] == 0
assert stats["failed"] == 0
assert storage.count_legacy_rows() == 0

View File

@ -0,0 +1,103 @@
from unittest.mock import patch
import pytest
from deerflow.agents.memory.thread_summary import (
ThreadMemoryConflictError,
_extract_json_object,
apply_thread_memory_summary,
render_thread_memory_summary,
)
def test_render_thread_memory_summary_returns_text():
fake_storage = type(
"S",
(),
{"load": lambda self, tid: {"threadId": tid, "user": {}, "history": {}, "facts": [], "memoryVersion": 2}},
)()
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "用户总结"})()})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=fake_storage),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
):
result = render_thread_memory_summary("t1")
assert result["threadId"] == "t1"
assert result["memoryVersion"] == 2
assert result["summary"] == "用户总结"
def test_apply_thread_memory_summary_raises_conflict_on_cas_failure():
class _Storage:
def load(self, _tid):
return {"threadId": "t1", "ownerId": None, "user": {}, "history": {}, "facts": [], "memoryVersion": 1}
def save(self, _tid, _data, expected_version=None):
return False
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "{}"})()})()
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=_Storage()),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
):
with pytest.raises(ThreadMemoryConflictError):
apply_thread_memory_summary("t1", "更新内容", 1)
def test_apply_thread_memory_summary_falls_back_when_model_output_is_not_json():
class _Storage:
def __init__(self):
self.saved = None
def load(self, _tid):
if self.saved is not None:
return {"threadId": "t1", "memoryVersion": 2, **self.saved}
return {
"threadId": "t1",
"ownerId": None,
"user": {"topOfMind": {"summary": ""}},
"history": {},
"facts": [],
"memoryVersion": 1,
}
def save(self, _tid, data, expected_version=None):
self.saved = data
return True
storage = _Storage()
fake_model = type("M", (), {"invoke": lambda self, prompt: type("R", (), {"content": "这是自然语言不是JSON"})()})()
fake_updater = type("U", (), {"_scrub_sensitive": lambda self, data, _thread_id: data})()
with (
patch("deerflow.agents.memory.thread_summary.get_thread_memory_storage", return_value=storage),
patch("deerflow.agents.memory.thread_summary._get_summary_model", return_value=fake_model),
patch("deerflow.agents.memory.thread_summary.ThreadMemoryUpdater", return_value=fake_updater),
):
result = apply_thread_memory_summary("t1", "我最近在做线程记忆功能", 1)
assert storage.saved is not None
assert storage.saved["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
assert result["user"]["topOfMind"]["summary"] == "我最近在做线程记忆功能"
def test_extract_json_object_repairs_inner_unescaped_quotes():
raw = """
{
"user": {
"topOfMind": {
"summary": "反感“作为 AI"这种句式认为回答不用寒暄直接说重点"
}
},
"history": {},
"facts": []
}
""".strip()
parsed = _extract_json_object(raw)
assert parsed is not None
assert parsed["user"]["topOfMind"]["summary"].startswith("反感“作为 AI")

View File

@ -0,0 +1,20 @@
from deerflow.agents.memory.thread_updater import ThreadMemoryUpdater
def test_scrub_sensitive_tolerates_non_numeric_confidence():
updater = ThreadMemoryUpdater()
cleaned = updater._scrub_sensitive(
{
"user": {},
"history": {},
"facts": [
{"content": "Uses React", "category": "knowledge", "confidence": "high"},
{"content": "Uses TypeScript", "category": "knowledge", "confidence": None},
],
},
"thread-test",
)
assert len(cleaned["facts"]) == 2
assert cleaned["facts"][0]["confidence"] == 0.5
assert cleaned["facts"][1]["confidence"] == 0.5

View File

@ -0,0 +1,760 @@
# Per-Thread Memory Brainstorm
Date: 2026-05-07
## Background
Deerflow 现有的记忆功能是单租户的——不同会话都属于同一个用户,所有对话共享一份全局 `memory.json`
要做一个新的记忆功能:不同对话属于不同用户,每个会话都有一个长期记忆,内容包括用户的使用习惯、个人信息、个人喜好和偏好语气。
## 现有记忆系统
- **存储**:单一全局 `backend/.deer-flow/memory.json`,所有会话共享
- **认证**没有用户认证没有用户隔离better-auth 已搭建但未启用)
- **结构**
- `user`: workContext / personalContext / topOfMind
- `history`: recentMonths / earlierContext / longTermBackground
- `facts[]`: id, content, category, confidence, source
- **读路径**system prompt 生成时注入 `<memory>...</memory>` XML 标签
- **写路径**MemoryMiddleware 在对话后过滤消息 → MemoryUpdateQueue debounce 30s → MemoryUpdater 调 LLM 提取更新 → 原子写入
- **配置**`config.yaml > memory`enabled, debounce_seconds, max_facts, max_injection_tokens 等)
---
## 决策记录
### 存储方式: 数据库
~~文件存储 `threads/{thread_id}/profile-memory.json`~~ → **改为数据库表**,通过 `thread_id` 区分用户。
### 数据库: SQLite本地/测试) + MySQL生产环境
### 表结构: 单表 + JSON 列Option A
### 依赖: 最小化,不引入 SQLAlchemy
SQLite 用标准库 `sqlite3`MySQL 用 `pymysql`(纯 Python轻量
### 与全局记忆关系: 策略 Bfallback
Per-thread 有记忆就用 per-thread 的,没有就 fallback 到全局记忆。
### 首次对话: 不主动询问用户偏好
---
## 1. 数据库表设计
```sql
-- SQLite
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY,
profile TEXT NOT NULL DEFAULT '{}',
preferences TEXT NOT NULL DEFAULT '{}',
facts TEXT NOT NULL DEFAULT '[]',
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
);
-- MySQL
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY,
profile JSON NOT NULL,
preferences JSON NOT NULL,
facts JSON NOT NULL,
last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
```
**profile** ({})
| 字段 | 类型 | 说明 |
|------|------|------|
| `name` | `string \| null` | 用户称呼 |
| `role` | `string \| null` | 职业/角色 |
| `expertise` | `string[]` | 技术栈/专业领域 |
| `language` | `"zh-CN" \| "en-US" \| null` | 使用的语言 |
| `context` | `string \| null` | 其他上下文(自由文本) |
**preferences** ({})
| 字段 | 类型 | 说明 |
|------|------|------|
| `tone` | `"casual" \| "formal" \| "technical" \| "friendly" \| null` | 语气偏好 |
| `verbosity` | `"concise" \| "detailed" \| null` | 回答详细程度 |
| `codeStyle` | `string \| null` | 代码风格偏好 |
| `other` | `string \| null` | 其他偏好(自由文本) |
**facts** ([]):复用现有全局记忆的 fact 结构
```json
{
"id": "fact_abc123",
"content": "用户在使用 React + TypeScript",
"category": "tech_stack | preference | personal | context | goal",
"confidence": 0.9,
"createdAt": "2026-05-07T...",
"source": "thread_id"
}
```
**说明**:三个 JSON 字段在 SQLite 中存为 TEXTsqlite3 标准库没有原生 JSON 类型),在 MySQL 中存为 JSON。代码层面读写时做 `json.dumps` / `json.loads`,对上层透明。
## 2. config.yaml 新增配置段
```yaml
thread_memory:
enabled: true
debounce_seconds: 30
model_name: null # null = 使用默认模型
max_facts: 100
fact_confidence_threshold: 0.7
injection_enabled: true
max_injection_tokens: 2000
database:
type: sqlite # sqlite | mysql
sqlite:
path: "thread_memory.db"
mysql:
host: "localhost"
port: 3306
user: "root"
password: "$MYSQL_PASSWORD"
database: "deerflow"
```
大部分字段和现有 `memory` 配置段语义相同,可以在两个配置段之间复用。`database` 段按 type 取子段,工厂函数只读自己需要的部分。
## 3. 存储层设计
### 3.1 抽象接口
```python
# deerflow/agents/memory/thread_storage.py
import abc
import json
import sqlite3
from datetime import datetime
from typing import Any
class ThreadMemoryStorage(abc.ABC):
@abc.abstractmethod
def load(self, thread_id: str) -> dict[str, Any] | None:
"""加载指定 thread 的记忆,不存在返回 None。"""
...
@abc.abstractmethod
def save(self, thread_id: str, data: dict[str, Any]) -> bool:
"""保存指定 thread 的记忆upsert。"""
...
@abc.abstractmethod
def delete(self, thread_id: str) -> bool:
"""删除指定 thread 的记忆thread 被删除时联动)。"""
...
def _create_empty_memory() -> dict[str, Any]:
"""Per-thread 记忆的初始空结构。"""
return {
"profile": {
"name": None,
"role": None,
"expertise": [],
"language": None,
"context": None,
},
"preferences": {
"tone": None,
"verbosity": None,
"codeStyle": None,
"other": None,
},
"facts": [],
}
def _row_to_memory(row: tuple) -> dict[str, Any]:
"""将数据库行转为 memory dict。SQLite 的 JSON 列存的是 TEXT需要 parse。"""
return {
"threadId": row[0],
"profile": json.loads(row[1]),
"preferences": json.loads(row[2]),
"facts": json.loads(row[3]),
"lastUpdated": row[4],
}
```
### 3.2 SQLite 实现(本地测试)
```python
class SqliteThreadMemoryStorage(ThreadMemoryStorage):
def __init__(self, db_path: str):
self._conn = sqlite3.connect(db_path)
self._conn.execute("""
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY,
profile TEXT NOT NULL DEFAULT '{}',
preferences TEXT NOT NULL DEFAULT '{}',
facts TEXT NOT NULL DEFAULT '[]',
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
)
""")
self._conn.commit()
def load(self, thread_id: str) -> dict | None:
row = self._conn.execute(
"SELECT thread_id, profile, preferences, facts, last_updated "
"FROM thread_memory WHERE thread_id = ?",
(thread_id,)
).fetchone()
return _row_to_memory(row) if row else None
def save(self, thread_id: str, data: dict) -> bool:
now = datetime.utcnow().isoformat() + "Z"
self._conn.execute("""
INSERT INTO thread_memory (thread_id, profile, preferences, facts, last_updated)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(thread_id) DO UPDATE SET
profile = excluded.profile,
preferences = excluded.preferences,
facts = excluded.facts,
last_updated = excluded.last_updated
""", (
thread_id,
json.dumps(data["profile"], ensure_ascii=False),
json.dumps(data["preferences"], ensure_ascii=False),
json.dumps(data["facts"], ensure_ascii=False),
now,
))
self._conn.commit()
return True
def delete(self, thread_id: str) -> bool:
self._conn.execute("DELETE FROM thread_memory WHERE thread_id = ?", (thread_id,))
self._conn.commit()
return True
```
### 3.3 MySQL 实现(生产环境)
```python
class MysqlThreadMemoryStorage(ThreadMemoryStorage):
def __init__(self, host: str, port: int, user: str, password: str, database: str):
import pymysql
self._conn = pymysql.connect(
host=host, port=port, user=user, password=password, database=database,
charset="utf8mb4",
)
with self._conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY,
profile JSON NOT NULL,
preferences JSON NOT NULL,
facts JSON NOT NULL,
last_updated TIMESTAMP NOT NULL
DEFAULT CURRENT_TIMESTAMP
ON UPDATE CURRENT_TIMESTAMP
)
""")
self._conn.commit()
def load(self, thread_id: str) -> dict | None:
with self._conn.cursor() as cur:
cur.execute(
"SELECT thread_id, profile, preferences, facts, last_updated "
"FROM thread_memory WHERE thread_id = %s",
(thread_id,)
)
row = cur.fetchone()
return _row_to_memory(row) if row else None
def save(self, thread_id: str, data: dict) -> bool:
now = datetime.utcnow()
with self._conn.cursor() as cur:
cur.execute("""
INSERT INTO thread_memory (thread_id, profile, preferences, facts, last_updated)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
profile = VALUES(profile),
preferences = VALUES(preferences),
facts = VALUES(facts),
last_updated = VALUES(last_updated)
""", (
thread_id,
json.dumps(data["profile"], ensure_ascii=False),
json.dumps(data["preferences"], ensure_ascii=False),
json.dumps(data["facts"], ensure_ascii=False),
now,
))
self._conn.commit()
return True
def delete(self, thread_id: str) -> bool:
with self._conn.cursor() as cur:
cur.execute("DELETE FROM thread_memory WHERE thread_id = %s", (thread_id,))
self._conn.commit()
return True
```
### 3.4 工厂函数
```python
def get_thread_memory_storage() -> ThreadMemoryStorage:
"""从 config 读取 database 配置,构建对应的 storage 实例(单例)。"""
config = get_thread_memory_config()
db = config.database
if db.type == "sqlite":
return SqliteThreadMemoryStorage(db.sqlite.path)
elif db.type == "mysql":
return MysqlThreadMemoryStorage(
host=db.mysql.host,
port=db.mysql.port,
user=db.mysql.user,
password=db.mysql.password,
database=db.mysql.database,
)
else:
raise ValueError(f"Unknown thread_memory database type: {db.type}")
```
### 3.5 注意事项
- **JSON 在 SQLite 中存为 TEXT**`sqlite3` 标准库没有 JSON 类型,用 TEXT 存储 `json.dumps` 的结果。读写时做序列化/反序列化。MySQL 用原生 JSON 列,`pymysql` 自动处理。
- **upsert 语法差异**SQLite 用 `ON CONFLICT ... DO UPDATE SET`MySQL 用 `ON DUPLICATE KEY UPDATE`,语义等价。
- **连接管理**:两个实现都在 `__init__` 创建连接并持有。单线程场景没问题。如果将来需要并发,可以加连接池或改为每次操作创建连接。
---
## 4. upsert 语义:全量替换 vs 合并更新
### 两种模式
**模式 A — 增量合并**LLM 出 delta应用层合并
```
LLM 输入: 现有记忆 + 新对话
LLM 输出: { profile: { name: "新值", shouldUpdate: true }, newFacts: [...], factsToRemove: [...] }
应用层: 读取现有记忆 → 按 delta 逐字段合并 → 写入
```
现有全局记忆用的就是这个模式。LLM 输出里带 `shouldUpdate` 标记和 `factsToRemove` 列表,应用代码做合并。
**模式 B — 全量替换**LLM 出完整状态,应用层直接覆盖):
```
LLM 输入: 现有记忆 + 新对话
LLM 输出: { profile: { name: "...", role: "...", ... }, preferences: {...}, facts: [...] }
应用层: INSERT ... ON CONFLICT DO UPDATE整行覆盖
```
### 选择模式 B 的理由
1. **profile 和 preferences 本身很小**。每个对象 5-6 个字段,全部输出最多几十个 token增量节省的 token 可以忽略。
2. **去重和淘汰由 LLM 负责,应用层零逻辑**。LLM 看到了完整的现有记忆,在 prompt 中就能决定哪些 facts 要保留、哪些过时了要删、哪些要合并。应用代码只需要 `json.dumps` + upsert。
3. **避免字段删除的尴尬**。如果 LLM 想把 `profile.context``"前端开发者"` 改成 `null`(表示不再确定这个信息),增量模式需要额外表达"显式置 null"还是"不变",全量替换没有歧义。
4. **和现有全局记忆的模式不同是合理的**。全局记忆的 `history` 有大量的对话摘要文本不适合全量替换。Per-thread 记忆的 profile/preferences 是结构化的元数据,全量输出成本低。
### 具体流程
```
用户对话结束
MemoryMiddleware.after_agent() 提取 user + final AI 消息
queue.add(thread_id, messages) # debounce 30s
ThreadMemoryUpdater.update()
1. 从 DB 读取现有记忆(不存在就用 _create_empty_memory()
2. 构建 prompt: "以下是用户的现有画像和偏好:{existing_memory},以下是新的对话:{conversation},请更新用户画像。"
3. LLM 返回完整的 profile + preferences + facts
4. storage.save(thread_id, data) # upsert 整行覆盖
```
**关键点**LLM 的 prompt 里放了**现有记忆**LLM 看到之后自己决定:
- 保留哪些 facts
- 更新哪些 profile 字段
- 新增什么偏好
- 删除过时的信息(不输出就是删除)
应用代码不做任何合并判断,只负责把 LLM 输出写入数据库。
---
## 5. 更新路径
### 5.1 MemoryMiddleware 改造(最小改动)
在现有 `MemoryMiddleware.after_agent()` 中加一段逻辑,当 `thread_id` 存在时,同时向 per-thread 记忆的 queue 推一条:
```python
# 现有逻辑:全局记忆
queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=filtered_messages, ...)
# 新增per-thread 记忆
if thread_id:
thread_queue = get_thread_memory_queue()
thread_queue.add(thread_id=thread_id, messages=filtered_messages)
```
### 5.2 ThreadMemoryUpdater
新类,结构类似现有的 `MemoryUpdater`,但使用不同的 prompt 和存储后端:
```python
class ThreadMemoryUpdater:
def update(self, messages, thread_id):
storage = get_thread_memory_storage()
existing = storage.load(thread_id) or _create_empty_memory()
prompt = THREAD_MEMORY_UPDATE_PROMPT.format(
existing_memory=json.dumps(existing, ensure_ascii=False),
conversation=format_conversation(messages),
)
response = model.invoke(prompt)
new_memory = parse_llm_output(response) # { profile, preferences, facts }
storage.save(thread_id, new_memory)
```
### 5.3 Prompt 设计要点
与全局记忆 prompt 的关键区别:
| | 全局记忆 prompt | Per-thread 记忆 prompt |
|---|---|---|
| **目标** | "对话中发生了什么" | "这个人是谁、喜欢什么" |
| **输出** | user context 摘要 + history 摘要 + facts | profile + preferences + facts |
| **侧重** | 保留对话内容的事实性信息 | 推断用户的身份、偏好、风格 |
| **语气影响** | 无 | 输出 `preferences.tone` 直接影响后续回复风格 |
---
## 6. 读取路径(注入 System Prompt
```python
def inject_thread_memory(system_prompt: str, thread_id: str) -> str:
storage = get_thread_memory_storage()
memory = storage.load(thread_id)
if memory is None:
# fallback 到全局记忆
return inject_global_memory(system_prompt)
# 生成 <memory profile="..."> 标签注入 system prompt
profile_xml = _format_profile_xml(memory)
return system_prompt + "\n" + profile_xml
```
注入内容的 XML 结构示例:
```xml
<memory>
<profile>
<name>张三</name>
<role>全栈工程师</role>
<expertise>React, TypeScript, Python</expertise>
<language>zh-CN</language>
<context>在做一个电商项目</context>
</profile>
<preferences>
<tone>casual</tone>
<verbosity>detailed</verbosity>
<codeStyle>prefers functional components with hooks</codeStyle>
</preferences>
</memory>
```
语气偏好(`preferences.tone`)不直接改 system prompt 模板,而是放在 `<preferences>` XML 里让 LLM 自己理解。方式简单,不用维护 prompt 模板的分支逻辑。如果发现 LLM 不遵循,再考虑动态改写 prompt 模板。
---
## 7. Thread 删除时的联动
Gateway 已有 `DELETE /api/threads/{id}`。在现有 handler 中加一行:
```python
# app/gateway/routers/threads.py
@router.delete("/api/threads/{thread_id}")
async def delete_thread(thread_id: str):
# ... 现有清理逻辑 ...
# 新增:删除 per-thread 记忆
get_thread_memory_storage().delete(thread_id)
```
---
## 8. 实施步骤
1. **新增配置模型**`thread_memory_config.py`(参考现有 `memory_config.py`
2. **新增存储层**`thread_storage.py``ThreadMemoryStorage` + `SqliteThreadMemoryStorage` + `MysqlThreadMemoryStorage`
3. **新增 prompt**`thread_memory_prompt.py`(用于 LLM 提取用户画像)
4. **新增 updater** — 或扩展现有 `MemoryUpdater`,根据 `thread_id` 参数路由到不同逻辑
5. **改造 middleware**`MemoryMiddleware` 中加 per-thread 记忆的 queue 逻辑
6. **改造注入** — system prompt 生成时注入 `<memory>` 标签
7. **扩展 thread 删除 handler** — 联动删除 DB 记录
8. **写入测试**`test_thread_memory_storage.py`, `test_thread_memory_updater.py`
## 9. 待确认事项
- [ ] pymysql 作为新依赖是否 OK
- [ ] `database` 配置段结构是否合适?
- [ ] upsert 使用全量替换模式(模式 B是否认同
## 10. 第二轮脑暴(风险前置)
下面这轮不是改大方向,而是把容易在落地时踩坑的点先钉住。
### 10.1 隔离键:`thread_id` 是否足够?
当前设计用 `thread_id` 作为主键隔离用户记忆,简单可行。但有一个隐含前提:
- 一个 thread 永远只对应一个真实用户
如果未来支持“同一用户多 thread 共享画像”或“thread 可能转移 owner”只用 `thread_id` 会限制扩展。
可选路径:
- 路径 A维持现状推荐短期主键 `thread_id`,最快上线。
- 路径 B兼容未来增加 `owner_id`(可空),并加索引 `(owner_id, thread_id)`
建议:
- 第一版继续 `thread_id`,但在表里预留 `owner_id` nullable 字段,避免后续大迁移。
### 10.2 并发一致性:同一 thread 的并发写覆盖问题
场景:同一 thread 在短时间内触发多次 update后到达的旧结果可能覆盖先到达的新结果。
可选保护:
- 方案 A`last_updated` 乐观锁(更新时带 where 条件)
- 方案 B`memory_version` 整数版本号(推荐)
- 方案 C严格串行队列单 thread 单 worker
建议:
- 加 `memory_version`(默认 0。`save` 时做 compare-and-swap 语义:
- 读取 version = n
- 写入时要求 version 仍为 n成功后 version = n+1
- 失败则重试一次(重新 load + merge prompt 再写)
这样不需要分布式锁,也能规避“旧结果回写”。
### 10.3 记忆质量控制:防止噪声和幻觉固化
LLM 抽取用户画像时,最大风险是把一次性表达当长期偏好。
建议加三道门:
1. 事实类别阈值
- `preference` 类阈值可略低(如 0.7
- `personal` 类阈值更高(如 0.85
2. 稳定性规则
- 同类偏好至少被 2 次独立对话支持,才提升为 profile/preference 的强字段
3. 冲突降级
- 新旧事实冲突时,不立刻删旧值
- 先把旧值降权并标记 `supersededBy`,下一轮再淘汰
### 10.4 隐私与合规:先定义“不能记”的边界
建议在 prompt 与代码都加 denylist双保险
- 默认不写入:身份证号、手机号、邮箱、住址、银行卡、密码/API Key 等敏感信息
- 允许写入:技术偏好、工作语境、沟通风格、项目目标
实现上:
- 在 `ThreadMemoryUpdater` parse 后做一次 server-side scrub
- 命中敏感模式就丢弃并打审计日志(不落库原文)
### 10.5 注入预算:避免 memory 挤爆上下文
当前有 `max_injection_tokens`,但还缺“裁剪策略”。
建议固定优先级:
1. profile最高
2. preferences
3. facts按 confidence + recency 排序后截断)
当超预算时:
- 永远保留 profile/preference
- 只裁剪 facts
### 10.6 可观测性:上线后如何判断有效
建议最小指标集:
- `thread_memory_update_total{status=ok|error}`
- `thread_memory_injection_tokens`
- `thread_memory_fact_count`
- `thread_memory_update_latency_ms`
- `thread_memory_conflict_retry_total`
加两条抽样日志:
- 更新前后摘要 diff脱敏后
- 注入片段长度与截断原因
### 10.7 迁移与回滚策略(从全局记忆过渡)
你已选 fallback 策略,这很好。建议再补两个机制:
- 冷启动导入(可选)
- 首次访问 thread 且无 per-thread 记录时,从全局记忆抽取一份“弱画像”写入
- 打 `bootstrapped_from_global=true`
- 一键回滚
- 配置开关 `thread_memory.injection_enabled=false` 时,立刻只走全局注入
- 更新链路可继续跑,便于回滚期间保留数据
### 10.8 API 语义建议(便于后续运维)
即使第一版 UI 不暴露,也建议预留内部接口:
- `GET /internal/thread-memory/{thread_id}`(脱敏视图)
- `DELETE /internal/thread-memory/{thread_id}`
- `POST /internal/thread-memory/{thread_id}/rebuild`
这样排障时不用直接查库。
---
## 11. 第三轮决策清单(进入实现前最后拍板)
- [ ] 表结构是否预留 `owner_id``memory_version`
- [ ] 是否采用 `memory_version` 方案处理并发覆盖?
- [ ] 敏感信息 denylist 范围是否按 10.4 执行?
- [ ] 注入裁剪优先级是否固定为 profile > preferences > facts
- [ ] 是否需要“冷启动导入”全局记忆到 per-thread
- [ ] 是否要在首版就加内部运维接口?
如果以上 6 项确定,基本就能把实现风险压到可控范围内。
## 12. 默认拍板方案(建议直接采用)
目标:在不显著增加复杂度的前提下,拿到“可上线 + 可回滚 + 可演进”的第一版。
### 12.1 表结构默认值
采用:**预留 `owner_id` + 引入 `memory_version`**。
SQLite
```sql
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id TEXT PRIMARY KEY,
owner_id TEXT NULL,
profile TEXT NOT NULL DEFAULT '{}',
preferences TEXT NOT NULL DEFAULT '{}',
facts TEXT NOT NULL DEFAULT '[]',
memory_version INTEGER NOT NULL DEFAULT 0,
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_thread_memory_owner_id ON thread_memory(owner_id);
```
MySQL
```sql
CREATE TABLE IF NOT EXISTS thread_memory (
thread_id VARCHAR(64) PRIMARY KEY,
owner_id VARCHAR(64) NULL,
profile JSON NOT NULL,
preferences JSON NOT NULL,
facts JSON NOT NULL,
memory_version INT NOT NULL DEFAULT 0,
last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_owner_id (owner_id)
);
```
### 12.2 并发一致性默认值
采用:**`memory_version` 乐观并发控制 + 失败重试 1 次**。
保存逻辑:
- `load()` 读出 `memory_version=n`
- `save()` 时执行条件更新(`WHERE thread_id=? AND memory_version=n`
- 成功则 `memory_version=n+1`
- 如果受影响行数为 0说明被并发写抢先重读并重试一次
这能防止“旧更新覆盖新更新”,同时实现复杂度可控。
### 12.3 隐私策略默认值
采用:**默认拒绝敏感信息入库(代码层 hard filter**。
默认 denylist
- 手机号
- 邮箱
- 身份证号/护照号
- 银行卡号
- 密码/API Key/Token
- 详细住址
规则:
- 命中则从 `profile/preferences/facts` 中删除该片段
- 仅记录脱敏审计信息(类型 + 时间 + thread_id不记录原文
### 12.4 注入裁剪默认值
采用固定优先级:**`profile > preferences > facts`**。
当超过 `max_injection_tokens`
- 必保留:`profile`、`preferences`
- 裁剪:`facts`(按 `confidence DESC, createdAt DESC` 排序后截断)
这能保证人格与风格信息稳定注入,不被历史 facts 挤掉。
### 12.5 冷启动策略默认值
采用:**首版不开启自动冷启动导入**`bootstrap_from_global=false`)。
理由:
- 降低“全局脏数据复制到 thread”风险
- 逻辑更清晰,便于观察 per-thread 记忆真实质量
补充:
- 保留 fallback你当前已定
- 后续若需要可加后台任务做可控回填
### 12.6 内部运维接口默认值
采用:**首版只加读接口,写接口延后**。
第一版建议:
- `GET /internal/thread-memory/{thread_id}`(脱敏后返回)
暂不做:
- `DELETE /internal/thread-memory/{thread_id}`(已有 thread delete 联动可覆盖主场景)
- `POST /internal/thread-memory/{thread_id}/rebuild`(二期再加)
这样可以先满足排障可见性,避免过早扩大运维面。
---
## 13. 实施前冻结版 Checklist可直接转开发
- [ ] DDL 按 12.1 落地(含 `owner_id`, `memory_version`, index
- [ ] Storage `save()` 改为 compare-and-swap 语义
- [ ] Updater 增加一次冲突重试
- [ ] parse 后执行敏感信息 scrub
- [ ] 注入模块按 `profile > preferences > facts` 裁剪
- [ ] fallback 保持开启,冷启动导入保持关闭
- [ ] 增加最小指标与脱敏 diff 日志
- [ ] 增加内部只读排障接口
到这一步,方案已经可以进入实现,不需要再做大改。

View File

@ -0,0 +1,213 @@
# Thread Memory 手动测试清单
日期:`2026-05-08`
测试人:`__________`
---
## 0. 前置检查
- [ ] 已拉取包含以下修复的最新代码并重启后端进程
- `memory.enabled=false` 时仍允许 `thread_memory` 更新
- `thread_prompt` 的 JSON 模板转义修复(避免 `KeyError: "profile"`
- `thread_updater` 使用非流式安全参数(避免 `stream_options` 400
- [ ] `config.yaml` 中已启用 `thread_memory.enabled: true`
- [ ] 确认使用的是预期配置文件(当前项目根目录 `config.yaml`
---
## 1. 基础写入与读取
前置条件:
- 选择一个新的 `thread_id`(例:`1f571481-e3ae-42b5-a513-945bf8f1cbef`
步骤:
1. 在该线程发送 2-3 轮消息,包含姓名、角色、偏好语气等信息
2. 等待 `debounce_seconds`(默认 30 秒)
3. 查询 `thread_memory`
期望:
- 出现该 `thread_id` 记录
- `profile/preferences/facts` 有对应内容
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 2. Per-Thread 隔离
前置条件:
- 准备两个线程 `thread_A`、`thread_B`
步骤:
1. 在 A 中输入“前端背景”信息
2. 在 B 中输入“后端背景”信息
3. 分别等待写入完成后查看两条记录
期望:
- A 仅保存 A 的画像B 仅保存 B 的画像
- 两个线程不串数据
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 3. 全局记忆 Fallback
前置条件:
- 全局 memory 有内容
- 新建一个尚无 per-thread 记录的线程
步骤:
1. 先在该新线程发一轮普通消息
2. 观察回复是否体现全局记忆
3. 再继续对话触发 per-thread 写入后观察注入变化
期望:
- 无 per-thread 时可 fallback 到全局
- 有 per-thread 后优先使用 per-thread
结果:
- [ ] 通过
- [ ] 失败(备注:`未执行N/A当前环境 memory.enabled=false全局记忆关闭本用例不适用`
---
## 4. 注入裁剪优先级Profile > Preferences > Facts
前置条件:
- 某线程已有大量 facts
步骤:
1. 人为积累 facts 到接近/超过注入预算
2. 保持 profile/preferences 有值
3. 观察注入后的表现
期望:
- 超预算时保留 profile + preferences
- 优先裁剪 facts
结果:
- [1 ] 通过
- [ ] 失败(备注:`________________`
---
## 5. 敏感信息过滤
步骤:
1. 在对话中输入邮箱、手机号、token/password 等敏感样例
2. 等待写入后查库
期望:
- 敏感信息不应落入 `profile/preferences/facts`
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 6. 并发覆盖保护CAS + version
步骤:
1. 同一 `thread_id` 短时间内触发两次更新(尽量并发)
2. 观察最终数据与日志
期望:
- 不出现明显“旧数据覆盖新数据”
- 冲突时可见重试行为(日志)
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 7. Debounce 生效
步骤:
1. 在 30 秒内连续发送多条消息
2. 观察写库频率
期望:
- 多条输入被合并处理,不是每条都立即写库
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 8. 线程删除联动清理
步骤:
1. 对已有 per-thread 记录的线程调用 `DELETE /api/threads/{thread_id}`
2. 查询 `thread_memory`
期望:
- 对应 `thread_id` 记录被删除
结果:
- [ ] 通过
- [ ] 失败(备注:`未执行:当前产品决策不接受“删线程即删记忆”,需改为用户显式触发清除后再复测`
---
## 9. SQLite 自动建表与路径
步骤:
1. 删除现有 `thread_memory.db`(测试环境)
2. 重启服务并触发一轮写入
3. 检查 DB 文件和表结构
期望:
- 自动创建 DB 文件与 `thread_memory`
- 索引 `idx_thread_memory_owner_id` 存在
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 10. 配置开关验证
步骤:
1. 关闭 `thread_memory.enabled`,重启并测试写入
2. 开启 `thread_memory.enabled`,关闭 `thread_memory.injection_enabled`,重启并测试注入
期望:
- `enabled=false`:不更新 per-thread
- `injection_enabled=false`:不注入 per-thread可 fallback
结果:
- [1] 通过
- [ ] 失败(备注:`________________`
---
## 11. 已知错误回归验证
### 11.1 `KeyError: "profile"` 回归
- [ 1] 未再出现 `thread_prompt.py``KeyError` 报错
### 11.2 `stream_options` 400 回归
- [ 1] 未再出现 `"'stream_options' only set this when you set stream: true"` 报错
备注:`________________`
---
## 测试总结
- 总用例数:`11`
- 通过数:`____`
- 失败数:`____`
- 结论:
- [ ] 可上线
- [ ] 需修复后复测

BIN
frontend/public/coxwork.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

View File

@ -1,7 +1,7 @@
"use client"; "use client";
import { Ticker } from "@tombcato/smart-ticker"; import { Ticker } from "@tombcato/smart-ticker";
import { FilesIcon, ListTodoIcon, XIcon } from "lucide-react"; import { FilesIcon, XIcon } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
@ -23,20 +23,21 @@ import {
} from "@/components/workspace/artifacts"; } from "@/components/workspace/artifacts";
import { useThreadChat } from "@/components/workspace/chats"; import { useThreadChat } from "@/components/workspace/chats";
// import { DevTodoList } from "@/components/workspace/dev-todo-list"; // import { DevTodoList } from "@/components/workspace/dev-todo-list";
import { IframeTestPanel } from "@/components/workspace/iframe-test-panel";
import { InputBox } from "@/components/workspace/input-box"; import { InputBox } from "@/components/workspace/input-box";
import { MessageList } from "@/components/workspace/messages"; import { MessageList } from "@/components/workspace/messages";
import { ThreadContext } from "@/components/workspace/messages/context"; import { ThreadContext } from "@/components/workspace/messages/context";
import { ThreadTitle } from "@/components/workspace/thread-title";
import { Tooltip } from "@/components/workspace/tooltip"; import { Tooltip } from "@/components/workspace/tooltip";
import { useSpecificChatMode } from "@/components/workspace/use-chat-mode"; import { useSpecificChatMode } from "@/components/workspace/use-chat-mode";
import { useBrand } from "@/core/brand/provider";
import { Welcome } from "@/components/workspace/welcome"; import { Welcome } from "@/components/workspace/welcome";
import { getAPIClient } from "@/core/api"; import { getAPIClient } from "@/core/api";
import { sanitizeArtifactPaths } from "@/core/artifacts/utils"; import { sanitizeArtifactPaths } from "@/core/artifacts/utils";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages"; import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages";
import { useNotification } from "@/core/notification/hooks"; import { useNotification } from "@/core/notification/hooks";
import { useLocalSettings } from "@/core/settings"; import { useLocalSettings } from "@/core/settings";
import { clearThreadMemoryOnExit } from "@/core/threads/exit-thread-memory";
import { useThreadStream } from "@/core/threads/hooks"; import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils"; import { textOfMessage } from "@/core/threads/utils";
import { env } from "@/env"; import { env } from "@/env";
@ -48,6 +49,7 @@ import motivationSlogans from "./motivation-slogans.json";
export default function ChatPage() { export default function ChatPage() {
const { t } = useI18n(); const { t } = useI18n();
const { brand } = useBrand();
useSpecificChatMode(); useSpecificChatMode();
const [sloganIndex, setSloganIndex] = useState(0); const [sloganIndex, setSloganIndex] = useState(0);
const [settings, setSettings] = useLocalSettings(); const [settings, setSettings] = useLocalSettings();
@ -60,8 +62,6 @@ export default function ChatPage() {
setArtifacts, setArtifacts,
select: selectArtifact, select: selectArtifact,
selectedArtifact, selectedArtifact,
deselect: deselectArtifact,
setFullscreen: setArtifactsFullscreen,
fullscreen, fullscreen,
} = useArtifacts(); } = useArtifacts();
const { threadId, isNewThread, setIsNewThread, isMock, showWelcomeStyle } = const { threadId, isNewThread, setIsNewThread, isMock, showWelcomeStyle } =
@ -303,6 +303,8 @@ export default function ChatPage() {
const todoListCollapsed = true; const todoListCollapsed = true;
const [showExitDialog, setShowExitDialog] = useState(false); const [showExitDialog, setShowExitDialog] = useState(false);
const [clearMemoryOnExit, setClearMemoryOnExit] = useState(false);
const [isConfirmingExit, setIsConfirmingExit] = useState(false);
const isStreaming = isUploading || thread.isLoading; const isStreaming = isUploading || thread.isLoading;
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (message: Parameters<typeof sendMessage>[1]) => { async (message: Parameters<typeof sendMessage>[1]) => {
@ -345,6 +347,7 @@ export default function ChatPage() {
className={cn( className={cn(
"m-auto flex h-screen min-h-svh overflow-hidden rounded-t-[20px] transition-[width] duration-300 ease-in-out", "m-auto flex h-screen min-h-svh overflow-hidden rounded-t-[20px] transition-[width] duration-300 ease-in-out",
artifactsOpen ? "w-full" : "w-[70%]", artifactsOpen ? "w-full" : "w-[70%]",
brand === "sxwz" && artifactsOpen === false && "translate-x-[-172px]",
)} )}
> >
<div className="relative flex size-full min-h-0 justify-between rounded-t-[20px]"> <div className="relative flex size-full min-h-0 justify-between rounded-t-[20px]">
@ -374,6 +377,7 @@ export default function ChatPage() {
isChatting: false, isChatting: false,
}); });
router.replace(`/workspace/chats/${threadId}?is_chatting=false`) router.replace(`/workspace/chats/${threadId}?is_chatting=false`)
setArtifactsOpen(false);
} }
} }
> >
@ -440,7 +444,7 @@ export default function ChatPage() {
onClick={() => setShowExitDialog(true)} onClick={() => setShowExitDialog(true)}
> >
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none"> <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
<path d="M2 4H6M16 4H12M6 4H12M6 4C6 2.89543 6.89543 2 8 2H10C11.1046 2 12 2.89543 12 4M4 6V14C4 15.1046 4.89543 16 6 16H12C13.1046 16 14 15.1046 14 14V6M7 8V13M11 8V13" stroke="#150033" stroke-linecap="round" /> <path d="M2 4H6M16 4H12M6 4H12M6 4C6 2.89543 6.89543 2 8 2H10C11.1046 2 12 2.89543 12 4M4 6V14C4 15.1046 4.89543 16 6 16H12C13.1046 16 14 15.1046 14 14V6M7 8V13M11 8V13" stroke="#150033" strokeLinecap="round" />
</svg> </svg>
{t.common.resetThread} {t.common.resetThread}
</Button> </Button>
@ -458,10 +462,10 @@ export default function ChatPage() {
}} }}
> >
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none"> <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
<path d="M16 7V4C16 2.89543 15.1046 2 14 2H4C2.89543 2 2 2.89543 2 4V14C2 15.1046 2.89543 16 4 16H9" stroke="#150033" stroke-linecap="round" /> <path d="M16 7V4C16 2.89543 15.1046 2 14 2H4C2.89543 2 2 2.89543 2 4V14C2 15.1046 2.89543 16 4 16H9" stroke="#150033" strokeLinecap="round" />
<path d="M5 5H9M5 8H7" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" /> <path d="M5 5H9M5 8H7" stroke="#150033" strokeLinecap="round" strokeLinejoin="round" />
<circle cx="11.5" cy="10.5" r="3" stroke="#150033" /> <circle cx="11.5" cy="10.5" r="3" stroke="#150033" />
<path d="M15.5 14.5L14 13" stroke="#150033" stroke-linecap="round" stroke-linejoin="round" /> <path d="M15.5 14.5L14 13" stroke="#150033" strokeLinecap="round" strokeLinejoin="round" />
</svg> </svg>
{t.common.artifacts} {t.common.artifacts}
</Button> </Button>
@ -575,6 +579,7 @@ export default function ChatPage() {
className={cn( className={cn(
"pointer-events-auto relative w-full max-w-[720px]", "pointer-events-auto relative w-full max-w-[720px]",
showWelcomeStyle && "-translate-y-[calc(50vh-96px)]", showWelcomeStyle && "-translate-y-[calc(50vh-96px)]",
brand === "sxwz" && "-translate-x-[172px]"
)} )}
> >
{!(showWelcomeStyle && thread.isThreadLoading) ? ( {!(showWelcomeStyle && thread.isThreadLoading) ? (
@ -627,7 +632,16 @@ export default function ChatPage() {
</div> </div>
{/* 退出确认对话框 */} {/* 退出确认对话框 */}
<DevDialog open={showExitDialog} onOpenChange={setShowExitDialog}> <DevDialog
open={showExitDialog}
onOpenChange={(open) => {
setShowExitDialog(open);
if (!open) {
setClearMemoryOnExit(false);
setIsConfirmingExit(false);
}
}}
>
<DevDialogContent> <DevDialogContent>
<DevDialogHeader> <DevDialogHeader>
<DevDialogTitle>{t.chatPage.exitDialogTitle}</DevDialogTitle> <DevDialogTitle>{t.chatPage.exitDialogTitle}</DevDialogTitle>
@ -635,11 +649,22 @@ export default function ChatPage() {
<p className="text-muted-foreground text-sm"> <p className="text-muted-foreground text-sm">
{t.chatPage.exitDialogDescription} {t.chatPage.exitDialogDescription}
</p> </p>
<label className="flex cursor-pointer items-center gap-2 text-sm text-ws-fg-primary">
<input
type="checkbox"
className="h-4 w-4 rounded border-ws-divider accent-ws-interactive-primary"
checked={clearMemoryOnExit}
onChange={(e) => setClearMemoryOnExit(e.target.checked)}
disabled={isConfirmingExit}
/>
<span>{t.chatPage.exitDialogClearMemory}</span>
</label>
<DevDialogFooter> <DevDialogFooter>
<Button <Button
className="w-full bg-ws-surface-subtle hover:bg-ws-interactive-primary hover:text-primary-foreground" className="w-full bg-ws-surface-subtle hover:bg-ws-interactive-primary hover:text-primary-foreground"
variant="ghost" variant="ghost"
onClick={() => setShowExitDialog(false)} onClick={() => setShowExitDialog(false)}
disabled={isConfirmingExit}
> >
{t.common.cancel} {t.common.cancel}
</Button> </Button>
@ -647,25 +672,31 @@ export default function ChatPage() {
className="w-full bg-ws-surface-subtle hover:bg-ws-interactive-primary hover:text-primary-foreground" className="w-full bg-ws-surface-subtle hover:bg-ws-interactive-primary hover:text-primary-foreground"
variant="ghost" variant="ghost"
onClick={async () => { onClick={async () => {
// 如果正在生成,先终止再退出 setIsConfirmingExit(true);
if (thread.isLoading) { try {
await handleStop(); if (thread.isLoading) {
await handleStop();
}
await clearThreadMemoryOnExit({
backendBaseURL: getBackendBaseURL(),
threadId: safeThreadId,
shouldClearMemory: clearMemoryOnExit,
});
setShowExitDialog(false);
sendToParent({
type: POST_MESSAGE_TYPES.IS_CHATTING,
isChatting: false,
});
router.replace(`/workspace/chats/new?thread_id=${threadId}`);
} catch {
toast.error(t.threadMemoryPanel.toastDeleteFailed);
} finally {
setIsConfirmingExit(false);
} }
setShowExitDialog(false);
sendToParent({
type: POST_MESSAGE_TYPES.IS_CHATTING,
isChatting: false,
});
// 始终复用 query 中的 thread_id。
const nextQuery = new URLSearchParams();
if (threadId && threadId !== "new") {
nextQuery.set("thread_id", threadId);
}
// /workspace/chats/${threadId}?is_chatting=false
router.replace(
`/workspace/chats/new?thread_id=${threadId}`,
);
}} }}
disabled={isConfirmingExit}
> >
{t.chatPage.exitDialogConfirm} {t.chatPage.exitDialogConfirm}
</Button> </Button>

View File

@ -14,12 +14,25 @@ import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
import { Toaster } from "@/components/ui/sonner"; import { Toaster } from "@/components/ui/sonner";
import { CommandPalette } from "@/components/workspace/command-palette"; import { CommandPalette } from "@/components/workspace/command-palette";
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar"; import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
import { BrandProvider, useBrand } from "@/core/brand/provider";
import { BrandSessionInitializer } from "@/core/brand/provider-client";
import { getLocalSettings, useLocalSettings } from "@/core/settings"; import { getLocalSettings, useLocalSettings } from "@/core/settings";
import { cn } from "@/lib/utils";
const queryClient = new QueryClient(); const queryClient = new QueryClient();
export default function WorkspaceLayout({ export default function WorkspaceLayout({
children, children,
}: Readonly<{ children: React.ReactNode }>) {
return (
<BrandProvider>
<WorkspaceBrandShell>{children}</WorkspaceBrandShell>
</BrandProvider>
);
}
function WorkspaceBrandShell({
children,
}: Readonly<{ children: React.ReactNode }>) { }: Readonly<{ children: React.ReactNode }>) {
const [settings, setSettings] = useLocalSettings(); const [settings, setSettings] = useLocalSettings();
const [open, setOpen] = useState(false); // SSR default: open (matches server render) const [open, setOpen] = useState(false); // SSR default: open (matches server render)
@ -27,6 +40,7 @@ export default function WorkspaceLayout({
const pressedKeysRef = useRef<Set<string>>(new Set()); const pressedKeysRef = useRef<Set<string>>(new Set());
const comboTriggeredRef = useRef(false); const comboTriggeredRef = useRef(false);
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const { rootClassName } = useBrand();
// iframe 技能模式mode=skill时隐藏侧边栏 // iframe 技能模式mode=skill时隐藏侧边栏
const isSkillMode = searchParams.get("mode") === "skill"; const isSkillMode = searchParams.get("mode") === "skill";
@ -110,8 +124,9 @@ export default function WorkspaceLayout({
); );
return ( return (
<QueryClientProvider client={queryClient}> <QueryClientProvider client={queryClient}>
<BrandSessionInitializer />
<SidebarProvider <SidebarProvider
className="h-screen" className={cn("h-screen", rootClassName)}
open={open} open={open}
onOpenChange={handleOpenChange} onOpenChange={handleOpenChange}
> >

View File

@ -1157,7 +1157,11 @@ export const PromptInputSubmit = ({
? !!disabled ? !!disabled
: disabled || !hasContent || isSubmitted; : disabled || !hasContent || isSubmitted;
let Icon = <ArrowUpIcon className="size-4" />; // let Icon = <ArrowUpIcon className="size-4" />;
let Icon = <svg width="12" height="16" viewBox="0 0 12 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.75 14.75V0.75M0.75 5.75L5.75 0.75L10.75 5.75" stroke="white" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round"/>
</svg>;
let text: string = t.inputBox.submit; let text: string = t.inputBox.submit;
@ -1165,11 +1169,13 @@ export const PromptInputSubmit = ({
Icon = <Loader2Icon className="size-4 animate-spin" />; Icon = <Loader2Icon className="size-4 animate-spin" />;
text = t.inputBox.submitting; text = t.inputBox.submitting;
} else if (status === "streaming") { } else if (status === "streaming") {
Icon = <SquareIcon className="size-4" />; Icon = <svg className="!w-[12px] !h-[12px]" width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="12" height="12" rx="2" fill="white"/>
</svg>;
text = t.inputBox.stop; text = t.inputBox.stop;
} else if (status === "error") { } else if (status === "error") {
// 没有报错状态先用error状态代替 // 没有报错状态先用error状态代替
Icon = <XIcon className="size-4" />; // Icon = <XIcon className="size-4" />;
// MARK: 这里后端没有返回错误信息,先写死一个文本 // MARK: 这里后端没有返回错误信息,先写死一个文本
text = t.inputBox.submit; text = t.inputBox.submit;
} }
@ -1180,10 +1186,10 @@ export const PromptInputSubmit = ({
aria-label="Submit" aria-label="Submit"
// 被button{bgc:#fff}覆盖了,只能加"!" // 被button{bgc:#fff}覆盖了,只能加"!"
className={cn( className={cn(
"h-[40px] w-[140px] rounded-[10px] border-0 font-bold transition-all", "h-[36px] w-[36px] rounded-[50%] border-0 font-bold transition-all ",
isDisabled isDisabled
? "cursor-not-allowed !bg-gray-200 text-gray-400" ? "cursor-not-allowed !bg-[#15003399] text-gray-400"
: "!bg-[#F0E8FB] text-[#8E47F0] hover:!bg-[#8E47F0] hover:text-[#FFFFFF]", : "!bg-[#150033] text-[#8E47F0] hover:text-[#FFFFFF]",
className, className,
)} )}
size={size} size={size}
@ -1192,8 +1198,8 @@ export const PromptInputSubmit = ({
disabled={isDisabled} disabled={isDisabled}
{...props} {...props}
> >
{/* {children ?? Icon} */} {children ?? Icon}
{text} {/* {text} */}
</InputGroupButton> </InputGroupButton>
</Tooltip> </Tooltip>
); );

View File

@ -17,8 +17,13 @@ export const Suggestions = ({
children, children,
...props ...props
}: SuggestionsProps) => ( }: SuggestionsProps) => (
<ScrollArea className="overflow-x-auto whitespace-nowrap" {...props}> <ScrollArea className="overflow-x-auto" {...props}>
<div className={cn("flex w-max flex-nowrap items-center gap-2", className)}> <div
className={cn(
"flex w-max flex-nowrap items-center gap-2 whitespace-nowrap",
className,
)}
>
{Children.map(children, (child, index) => {Children.map(children, (child, index) =>
child != null ? ( child != null ? (
<span <span
@ -61,7 +66,7 @@ export const Suggestion = ({
return ( return (
<Button <Button
className={cn( className={cn(
"cursor-pointer rounded-full px-[20px] py-[15px] text-sm font-normal", "cursor-pointer w-[216px] rounded-full px-[20px] py-[15px] text-sm font-normal",
"border-none bg-ws-surface-subtle text-ws-text-muted", "border-none bg-ws-surface-subtle text-ws-text-muted",
"hover:bg-ws-surface-elevated hover:text-ws-base-1", "hover:bg-ws-surface-elevated hover:text-ws-base-1",
className, className,

View File

@ -14,6 +14,7 @@ import {
useState, useState,
type CSSProperties, type CSSProperties,
type ComponentProps, type ComponentProps,
type ComponentPropsWithoutRef,
type HTMLAttributes, type HTMLAttributes,
} from "react"; } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
@ -40,6 +41,7 @@ import { CodeEditor } from "@/components/workspace/code-editor";
import { useArtifactContent } from "@/core/artifacts/hooks"; import { useArtifactContent } from "@/core/artifacts/hooks";
import { resolveArtifactURL, urlOfArtifact } from "@/core/artifacts/utils"; import { resolveArtifactURL, urlOfArtifact } from "@/core/artifacts/utils";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { MarkdownTable } from "@/components/workspace/messages/markdown-content";
import { streamdownPlugins } from "@/core/streamdown"; import { streamdownPlugins } from "@/core/streamdown";
import { checkCodeFile, getFileName } from "@/core/utils/files"; import { checkCodeFile, getFileName } from "@/core/utils/files";
import { useMarkdownDownload } from "@/core/utils/markdown-download"; import { useMarkdownDownload } from "@/core/utils/markdown-download";
@ -909,11 +911,26 @@ export function ArtifactFilePreview({
threadId: string; threadId: string;
filepath?: string; filepath?: string;
}) { }) {
const { t } = useI18n();
const zoomScale = zoom / 100; const zoomScale = zoom / 100;
const normalizedContent = useMemo(() => { const normalizedContent = useMemo(() => {
return rewriteArtifactImagePaths(content ?? "", threadId, filepath); return rewriteArtifactImagePaths(content ?? "", threadId, filepath);
}, [content, threadId, filepath]); }, [content, threadId, filepath]);
const streamdownComponents = useMemo(
() => ({
a: CitationLink,
table: (props: ComponentPropsWithoutRef<"table">) => (
<MarkdownTable
copyLabel={t.clipboard.copyToClipboard}
downloadLabel={t.common.download}
{...props}
/>
),
}),
[t.clipboard.copyToClipboard, t.common.download],
);
if (language === "markdown") { if (language === "markdown") {
return ( return (
<div <div
@ -923,7 +940,7 @@ export function ArtifactFilePreview({
<Streamdown <Streamdown
className="w-full" className="w-full"
{...streamdownPlugins} {...streamdownPlugins}
components={{ a: CitationLink }} components={streamdownComponents}
> >
{normalizedContent} {normalizedContent}
</Streamdown> </Streamdown>

View File

@ -0,0 +1,39 @@
import assert from "node:assert/strict";
import test from "node:test";
const { canSubmitInputBoxMessage } = await import(
new URL("./input-box-submit.ts", import.meta.url).href
);
void test("rejects empty submits without new or existing attachments", () => {
assert.equal(
canSubmitInputBoxMessage({
text: " ",
attachmentCount: 0,
referenceCount: 0,
}),
false,
);
});
void test("allows empty-text submits when new attachments are present", () => {
assert.equal(
canSubmitInputBoxMessage({
text: " ",
attachmentCount: 1,
referenceCount: 0,
}),
true,
);
});
void test("allows empty-text submits when existing references are present", () => {
assert.equal(
canSubmitInputBoxMessage({
text: " ",
attachmentCount: 0,
referenceCount: 1,
}),
true,
);
});

View File

@ -0,0 +1,15 @@
type CanSubmitInputBoxMessageOptions = {
text: string;
attachmentCount: number;
referenceCount: number;
};
export function canSubmitInputBoxMessage({
text,
attachmentCount,
referenceCount,
}: CanSubmitInputBoxMessageOptions) {
return (
text.trim().length > 0 || attachmentCount > 0 || referenceCount > 0
);
}

View File

@ -4,6 +4,7 @@ import type { ChatStatus } from "ai";
import { Tour } from "antd"; import { Tour } from "antd";
import { import {
CheckIcon, CheckIcon,
BrainIcon,
GraduationCapIcon, GraduationCapIcon,
LightbulbIcon, LightbulbIcon,
Loader2Icon, Loader2Icon,
@ -69,6 +70,7 @@ import {
DropdownMenuSeparator, DropdownMenuSeparator,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { Input } from "@/components/ui/input";
import { Tag } from "@/components/ui/tag"; import { Tag } from "@/components/ui/tag";
import { useReferenceFiles } from "@/core/artifacts/references"; import { useReferenceFiles } from "@/core/artifacts/references";
import { urlOfArtifact } from "@/core/artifacts/utils"; import { urlOfArtifact } from "@/core/artifacts/utils";
@ -97,7 +99,9 @@ import { Suggestion, Suggestions } from "../ai-elements/suggestion";
import { ScrollArea } from "../ui/scroll-area"; import { ScrollArea } from "../ui/scroll-area";
import { ModeHoverGuide } from "./mode-hover-guide"; import { ModeHoverGuide } from "./mode-hover-guide";
import { ThreadMemoryPanel } from "./thread-memory-panel";
import { Tooltip } from "./tooltip"; import { Tooltip } from "./tooltip";
import { canSubmitInputBoxMessage } from "./input-box-submit";
const MAX_REFERENCES_PER_MESSAGE = 10; const MAX_REFERENCES_PER_MESSAGE = 10;
@ -148,7 +152,7 @@ function WorkspaceToolButton({
return ( return (
<PromptInputButton <PromptInputButton
className={cn( className={cn(
"group h-full rounded-[10px] p-[10px]! hover:bg-ws-surface-subtle hover:text-ws-interactive-primary", "group h-full rounded-[10px] p-[10px]! hover:bg-ws-interactive-hover hover:text-ws-interactive-primary",
className, className,
)} )}
{...props} {...props}
@ -280,8 +284,10 @@ export function InputBox({
null, null,
); );
const [isFocused, setIsFocused] = useState(false); const [isFocused, setIsFocused] = useState(false);
const [memoryPanelOpen, setMemoryPanelOpen] = useState(false);
const [references, setReferences] = useState<PromptInputReference[]>([]); const [references, setReferences] = useState<PromptInputReference[]>([]);
const [mentionQuery, setMentionQuery] = useState(""); const [mentionQuery, setMentionQuery] = useState("");
const [mentionSearchText, setMentionSearchText] = useState("");
const [mentionOpen, setMentionOpen] = useState(false); const [mentionOpen, setMentionOpen] = useState(false);
const [activeMentionIndex, setActiveMentionIndex] = useState(0); const [activeMentionIndex, setActiveMentionIndex] = useState(0);
const [mentionRange, setMentionRange] = useState<{ const [mentionRange, setMentionRange] = useState<{
@ -290,10 +296,19 @@ export function InputBox({
} | null>(null); } | null>(null);
const [isInputToolsTourOpen, setIsInputToolsTourOpen] = useState(false); const [isInputToolsTourOpen, setIsInputToolsTourOpen] = useState(false);
const [isInputToolsTourReady, setIsInputToolsTourReady] = useState(false); const [isInputToolsTourReady, setIsInputToolsTourReady] = useState(false);
const { data: referenceFilesData } = useReferenceFiles(threadIdFromProps); const { data: referenceFilesData, refetch: refetchReferenceFiles } =
useReferenceFiles(threadIdFromProps);
// 打开附件引用弹窗时刷新数据
useEffect(() => {
if (mentionOpen) {
refetchReferenceFiles();
}
}, [mentionOpen, refetchReferenceFiles]);
// Welcome 态下禁用收缩,始终保持展开 // Welcome 态下禁用收缩,始终保持展开
const effectiveIsFocused = (showWelcomeStyle ?? false) || isFocused; const effectiveIsFocused =
(showWelcomeStyle ?? false) || isFocused || memoryPanelOpen;
const shouldShowSuggestionList = const shouldShowSuggestionList =
showWelcomeStyle && searchParams.get("mode") !== "skill"; showWelcomeStyle && searchParams.get("mode") !== "skill";
@ -473,15 +488,24 @@ export function InputBox({
const filteredMentionCandidates = useMemo(() => { const filteredMentionCandidates = useMemo(() => {
const query = mentionQuery.trim().toLowerCase(); const query = mentionQuery.trim().toLowerCase();
if (!query) { const search = mentionSearchText.trim().toLowerCase();
return mentionCandidates; let result = mentionCandidates;
if (query) {
result = result.filter((candidate) =>
`${candidate.filename} ${candidate.typeLabel} ${candidate.pathTail}`
.toLowerCase()
.includes(query),
);
} }
return mentionCandidates.filter((candidate) => if (search) {
`${candidate.filename} ${candidate.typeLabel} ${candidate.pathTail}` result = result.filter((candidate) =>
.toLowerCase() `${candidate.filename} ${candidate.typeLabel} ${candidate.pathTail}`
.includes(query), .toLowerCase()
); .includes(search),
}, [mentionCandidates, mentionQuery]); );
}
return result;
}, [mentionCandidates, mentionQuery, mentionSearchText]);
const handleModelSelect = useCallback( const handleModelSelect = useCallback(
(model_name: string) => { (model_name: string) => {
onContextChange?.({ onContextChange?.({
@ -507,7 +531,13 @@ export function InputBox({
onStop?.(); onStop?.();
return; return;
} }
if (!message.text && references.length === 0) { if (
!canSubmitInputBoxMessage({
text: message.text,
attachmentCount: message.files?.length ?? 0,
referenceCount: references.length,
})
) {
return; return;
} }
setIsFocused(false); setIsFocused(false);
@ -583,6 +613,7 @@ export function InputBox({
}); });
} }
setMentionQuery(""); setMentionQuery("");
setMentionSearchText("");
setMentionOpen(false); setMentionOpen(false);
setActiveMentionIndex(0); setActiveMentionIndex(0);
setMentionRange(null); setMentionRange(null);
@ -623,6 +654,7 @@ export function InputBox({
if (!token) { if (!token) {
setMentionOpen(false); setMentionOpen(false);
setMentionQuery(""); setMentionQuery("");
setMentionSearchText("");
setActiveMentionIndex(0); setActiveMentionIndex(0);
setMentionRange(null); setMentionRange(null);
return; return;
@ -672,6 +704,7 @@ export function InputBox({
} }
} else if (event.key === "Escape") { } else if (event.key === "Escape") {
event.preventDefault(); event.preventDefault();
setMentionSearchText("");
setMentionOpen(false); setMentionOpen(false);
setMentionRange(null); setMentionRange(null);
} }
@ -850,6 +883,7 @@ export function InputBox({
onOpenChange={(open) => { onOpenChange={(open) => {
setMentionOpen(open); setMentionOpen(open);
if (!open) { if (!open) {
setMentionSearchText("");
setMentionRange(null); setMentionRange(null);
} }
}} }}
@ -877,7 +911,30 @@ export function InputBox({
<DropdownMenuLabel className="p-0 text-sm text-ws-fg-primary"> <DropdownMenuLabel className="p-0 text-sm text-ws-fg-primary">
{t.inputBox.addReference} {t.inputBox.addReference}
</DropdownMenuLabel> </DropdownMenuLabel>
<DropdownMenuSeparator className="mx-0 mt-[20px] mb-0" /> <Input
className="mt-3 h-8 text-sm"
placeholder="搜索文件..."
value={mentionSearchText}
autoFocus
onChange={(e) => {
setMentionSearchText(e.target.value);
setActiveMentionIndex(0);
}}
onKeyDown={(e) => {
if (e.key === "ArrowDown" || e.key === "ArrowUp") {
e.preventDefault();
// 将焦点交还给 dropdown让现有的键盘导航逻辑处理
const items =
document.querySelectorAll<HTMLElement>(
'[data-testid="mention-candidate-item"]',
);
if (items.length > 0) {
(items[0] as HTMLElement).focus();
}
}
}}
/>
<DropdownMenuSeparator className="mx-0 mt-3 mb-0" />
<DropdownMenuGroup className="flex min-h-0 flex-col gap-[10px] px-0"> <DropdownMenuGroup className="flex min-h-0 flex-col gap-[10px] px-0">
<ScrollArea className="h-[320px] pt-[20px]" hideScrollbar={false}> <ScrollArea className="h-[320px] pt-[20px]" hideScrollbar={false}>
{filteredMentionCandidates.map((candidate, index) => { {filteredMentionCandidates.map((candidate, index) => {
@ -965,6 +1022,7 @@ export function InputBox({
/> />
</div> </div>
)} )}
{/* {!showWelcomeStyle && ( {/* {!showWelcomeStyle && (
<div className="shrink-0 h-full"> <div className="shrink-0 h-full">
<ExitChattingButton <ExitChattingButton
@ -976,6 +1034,22 @@ export function InputBox({
<div ref={attachmentsButtonTourRef} className="shrink-0 h-full"> <div ref={attachmentsButtonTourRef} className="shrink-0 h-full">
<AddAttachmentsButton /> <AddAttachmentsButton />
</div> </div>
{/* 记忆按钮 */}
{/* <div className="shrink-0 h-full">
<DropdownMenu open={memoryPanelOpen} onOpenChange={setMemoryPanelOpen}>
<DropdownMenuTrigger asChild>
<WorkspaceToolButton
className="h-full"
disabled={!threadIdFromProps || threadIdFromProps === "new"}
>
<BrainIcon className="size-4" />
</WorkspaceToolButton>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="w-auto p-0">
<ThreadMemoryPanel threadId={threadIdFromProps} />
</DropdownMenuContent>
</DropdownMenu>
</div> */}
<div className="min-w-0 grow basis-0 h-full"> <div className="min-w-0 grow basis-0 h-full">
<IframeSkillDialogButton <IframeSkillDialogButton
skillButtonRef={skillButtonTourRef} skillButtonRef={skillButtonTourRef}
@ -1022,7 +1096,7 @@ export function InputBox({
</ModelSelector> */} </ModelSelector> */}
<PromptInputTools> <PromptInputTools>
{/* 占位符 */} {/* 占位符 */}
<div className="w-[150px] h-[40px]"></div> <div className="w-[36px] h-[36px]"></div>
</PromptInputTools> </PromptInputTools>
</PromptInputFooter> </PromptInputFooter>
<PromptInputSubmit <PromptInputSubmit
@ -1206,7 +1280,7 @@ function SuggestionList({
); );
return ( return (
<Suggestions <Suggestions
className="w-fit items-start" className="mx-auto grid w-fit grid-cols-2 justify-center gap-[16px] whitespace-normal md:grid-cols-3"
data-testid="welcome-suggestions" data-testid="welcome-suggestions"
> >
{promptSuggestions.map((suggestion) => ( {promptSuggestions.map((suggestion) => (

View File

@ -1,7 +1,7 @@
"use client"; "use client";
import { CheckIcon, CopyIcon, DownloadIcon } from "lucide-react"; import { DownloadIcon } from "lucide-react";
import { useCallback, useMemo, useState, type MouseEvent } from "react"; import { useCallback, useLayoutEffect, useMemo, useRef, useState } from "react";
import type { import type {
AnchorHTMLAttributes, AnchorHTMLAttributes,
ComponentPropsWithoutRef, ComponentPropsWithoutRef,
@ -14,7 +14,9 @@ import {
} from "@/components/ai-elements/message"; } from "@/components/ai-elements/message";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { streamdownPlugins } from "@/core/streamdown"; import { streamdownPlugins } from "@/core/streamdown";
import { cn, copyToClipboard } from "@/lib/utils"; import { CopyButton } from "@/components/workspace/copy-button";
import { Tooltip } from "@/components/workspace/tooltip";
import { cn } from "@/lib/utils";
import { CitationLink } from "../citations/citation-link"; import { CitationLink } from "../citations/citation-link";
@ -56,21 +58,9 @@ function toMarkdownTable(data: TableData): string {
return [headerLine, dividerLine, ...rowLines].join("\n"); return [headerLine, dividerLine, ...rowLines].join("\n");
} }
function escapeCsvCell(value: string): string { function downloadMarkdownFile(content: string, filename: string) {
if (!/[",\n\r]/.test(value)) return value;
return `"${value.replaceAll('"', '""')}"`;
}
function toCsvTable(data: TableData): string {
if (data.headers.length === 0) return "";
return [data.headers, ...data.rows]
.map((row) => row.map(escapeCsvCell).join(","))
.join("\n");
}
function downloadCsvFile(content: string, filename: string) {
const blob = new Blob(["\uFEFF", content], { const blob = new Blob(["\uFEFF", content], {
type: "text/csv;charset=utf-8", type: "text/markdown;charset=utf-8",
}); });
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const anchor = document.createElement("a"); const anchor = document.createElement("a");
@ -80,58 +70,43 @@ function downloadCsvFile(content: string, filename: string) {
URL.revokeObjectURL(url); URL.revokeObjectURL(url);
} }
function MarkdownTable({ export function MarkdownTable({
className, className,
children, children,
copyLabel, copyLabel: _copyLabel,
downloadLabel, downloadLabel,
...props ...props
}: ComponentPropsWithoutRef<"table"> & { }: ComponentPropsWithoutRef<"table"> & {
copyLabel: string; copyLabel: string;
downloadLabel: string; downloadLabel: string;
}) { }) {
const [copied, setCopied] = useState(false); const tableRef = useRef<HTMLTableElement>(null);
const [, forceUpdate] = useState(0);
const getTableData = useCallback((event: MouseEvent<HTMLButtonElement>) => { // 首次 mount 后 tableRef 才被赋值,用 useLayoutEffect 在 paint 前强制刷新
const wrapper = event.currentTarget.closest( useLayoutEffect(() => {
'[data-streamdown="table-wrapper"]', forceUpdate((n) => n + 1);
);
const table = wrapper?.querySelector("table");
if (!(table instanceof HTMLTableElement)) return null;
return parseTableData(table);
}, []); }, []);
const handleCopy = useCallback( // 在 render 阶段直接从 DOM ref 计算,不依赖 effect 异步更新
async (event: MouseEvent<HTMLButtonElement>) => { // tableRef 在上一次渲染的 commit 阶段已设置,本次渲染可用
const data = getTableData(event); const clipboardData = (() => {
if (!data) return; const table = tableRef.current;
if (!table) return "";
const data = parseTableData(table);
if (!data) return "";
return toMarkdownTable(data);
})();
const markdown = toMarkdownTable(data); const handleDownload = useCallback(() => {
if (!markdown) return; const table = tableRef.current;
if (!table) return;
try { const data = parseTableData(table);
await copyToClipboard(markdown); if (!data) return;
setCopied(true); const markdown = toMarkdownTable(data);
window.setTimeout(() => setCopied(false), 2000); if (!markdown) return;
} catch { downloadMarkdownFile(markdown, "table.md");
// no-op }, []);
}
},
[getTableData],
);
const handleDownload = useCallback(
(event: MouseEvent<HTMLButtonElement>) => {
const data = getTableData(event);
if (!data) return;
const csv = toCsvTable(data);
if (!csv) return;
downloadCsvFile(csv, "table.csv");
},
[getTableData],
);
return ( return (
<div <div
@ -139,25 +114,20 @@ function MarkdownTable({
data-streamdown="table-wrapper" data-streamdown="table-wrapper"
> >
<div className="flex items-center justify-end gap-1"> <div className="flex items-center justify-end gap-1">
<button <CopyButton className="text-muted-foreground hover:bg-transparent hover:text-foreground cursor-pointer p-1 transition-all" clipboardData={clipboardData} />
className="text-muted-foreground hover:text-foreground cursor-pointer p-1 transition-all" <Tooltip content={downloadLabel}>
onClick={handleCopy} <button
title={copyLabel} className="h-[32px] w-[32px] text-muted-foreground hover:text-foreground cursor-pointer p-1 transition-all"
type="button" onClick={handleDownload}
> type="button"
{copied ? <CheckIcon size={14} /> : <CopyIcon size={14} />} >
</button> <DownloadIcon size={16} />
<button </button>
className="text-muted-foreground hover:text-foreground cursor-pointer p-1 transition-all" </Tooltip>
onClick={handleDownload}
title={downloadLabel}
type="button"
>
<DownloadIcon size={14} />
</button>
</div> </div>
<div className="overflow-x-auto"> <div className="overflow-x-auto">
<table <table
ref={tableRef}
className={cn( className={cn(
"border-border w-full border-collapse border", "border-border w-full border-collapse border",
className, className,

View File

@ -53,6 +53,22 @@ function localizeAssistantFixedCopy(content: string, localized: string): string
return content; return content;
} }
function buildClipboardData(message: Message): string {
const raw =
extractContentFromMessage(message) ??
extractReasoningContentFromMessage(message) ??
"";
if (!raw) {
return "";
}
const cleaned = stripUploadedFilesTag(raw);
if (message.type === "human") {
return normalizeHumanMessageDisplayText(stripPriorityHintSuffix(cleaned));
}
return cleaned;
}
export function MessageListItem({ export function MessageListItem({
className, className,
message, message,
@ -90,13 +106,7 @@ export function MessageListItem({
)} )}
> >
<div className="flex gap-1"> <div className="flex gap-1">
<CopyButton <CopyButton clipboardData={buildClipboardData(message)} />
clipboardData={
extractContentFromMessage(message) ??
extractReasoningContentFromMessage(message) ??
""
}
/>
</div> </div>
</MessageToolbar> </MessageToolbar>
)} )}

View File

@ -0,0 +1,138 @@
"use client";
import { useState } from "react";
import { toast } from "sonner";
import { Button } from "@/components/ui/button";
import { Textarea } from "@/components/ui/textarea";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks";
type ThreadMemoryPanelProps = {
threadId?: string;
};
export function ThreadMemoryPanel({ threadId }: ThreadMemoryPanelProps) {
const [memorySummary, setMemorySummary] = useState("");
const [memoryVersion, setMemoryVersion] = useState<number | null>(null);
const [loadingSummary, setLoadingSummary] = useState(false);
const [savingSummary, setSavingSummary] = useState(false);
const [deletingMemory, setDeletingMemory] = useState(false);
const { t } = useI18n();
if (!threadId || threadId === "new") return null;
const handleLoadMemorySummary = async () => {
setLoadingSummary(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`,
);
if (!res.ok) throw new Error(`HTTP ${res.status}`);
const data = (await res.json()) as { summary: string; memoryVersion: number };
setMemorySummary(data.summary ?? "");
setMemoryVersion(data.memoryVersion ?? 0);
toast.success(t.threadMemoryPanel.toastLoadSuccess);
} catch {
toast.error(t.threadMemoryPanel.toastLoadFailed);
} finally {
setLoadingSummary(false);
}
};
const handleSaveMemorySummary = async () => {
if (memoryVersion == null) return;
setSavingSummary(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory-summary`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ summary: memorySummary, memoryVersion }),
},
);
if (res.status === 409) {
toast.error(t.threadMemoryPanel.toastConflict);
return;
}
if (!res.ok) throw new Error(`HTTP ${res.status}`);
const data = (await res.json()) as { memoryVersion?: number };
if (typeof data.memoryVersion === "number") setMemoryVersion(data.memoryVersion);
toast.success(t.threadMemoryPanel.toastSaveSuccess);
} catch {
toast.error(t.threadMemoryPanel.toastSaveFailed);
} finally {
setSavingSummary(false);
}
};
const handleDeleteMemory = async () => {
setDeletingMemory(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/memory`,
{ method: "DELETE" },
);
if (!res.ok) throw new Error(`HTTP ${res.status}`);
setMemorySummary("");
setMemoryVersion(0);
toast.success(t.threadMemoryPanel.toastDeleteSuccess);
} catch {
toast.error(t.threadMemoryPanel.toastDeleteFailed);
} finally {
setDeletingMemory(false);
}
};
return (
<div className="w-[380px] space-y-2 rounded-lg border border-ws-divider bg-ws-surface-elevated p-3 shadow-lg">
<div className="text-sm font-semibold">
<span className="hidden sm:inline">{t.threadMemoryPanel.title}</span>
</div>
<div className="space-y-2">
<div className="flex items-center gap-2">
<Button
size="sm"
variant="outline"
onClick={() => {
void handleLoadMemorySummary();
}}
disabled={loadingSummary}
>
{loadingSummary ? t.threadMemoryPanel.loading : t.threadMemoryPanel.load}
</Button>
<Button
size="sm"
onClick={() => {
void handleSaveMemorySummary();
}}
disabled={savingSummary || memoryVersion == null}
>
{savingSummary ? t.threadMemoryPanel.saving : t.threadMemoryPanel.save}
</Button>
<Button
size="sm"
variant="destructive"
onClick={() => {
void handleDeleteMemory();
}}
disabled={deletingMemory}
>
{deletingMemory ? t.threadMemoryPanel.removing : t.threadMemoryPanel.remove}
</Button>
</div>
<div className="text-xs text-ws-text-subtle-strong">
{t.threadMemoryPanel.threadId}: {threadId.slice(0, 8)}... |{" "}
{t.threadMemoryPanel.version}:{" "}
{memoryVersion == null ? t.threadMemoryPanel.unavailableVersion : memoryVersion}
</div>
<Textarea
value={memorySummary}
onChange={(e) => setMemorySummary(e.target.value)}
placeholder={t.threadMemoryPanel.summaryPlaceholder}
className="min-h-32 bg-white/80"
/>
</div>
</div>
);
}

View File

@ -1,8 +1,10 @@
"use client"; "use client";
import Image from "next/image";
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
import { useMemo } from "react"; import { useMemo } from "react";
import { useBrand } from "@/core/brand/provider";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@ -16,6 +18,7 @@ export function Welcome({
mode?: "ultra" | "pro" | "thinking" | "flash"; mode?: "ultra" | "pro" | "thinking" | "flash";
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
const { copy } = useBrand();
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const isUltra = useMemo(() => mode === "ultra", [mode]); const isUltra = useMemo(() => mode === "ultra", [mode]);
const colors = useMemo(() => { const colors = useMemo(() => {
@ -39,12 +42,37 @@ export function Welcome({
className="flex items-center gap-2" className="flex items-center gap-2"
style={{ fontFamily: '"Microsoft YaHei"' }} style={{ fontFamily: '"Microsoft YaHei"' }}
> >
<AuroraText {/* <AuroraText
className="text-center text-[18px] leading-normal font-normal" className="text-center text-[18px] leading-normal font-normal"
colors={colors} colors={colors}
> >
{t.welcome.greeting} {copy.productLabel}
</AuroraText> </AuroraText> */}
<span className="text-[18px] font-normal text-foreground/70">
{copy.productLabel}
</span>
<span className="text-[18px] font-normal text-foreground/70">
·
</span>
{copy.appLogoSrc ? (
<Image
src={copy.appLogoSrc}
alt={copy.appLogoAlt ?? copy.appName}
width={104}
height={16}
draggable={false}
// className="h-[16px] w-auto"
priority
/>
) : (
<AuroraText
className="text-center text-[18px] leading-normal font-normal"
colors={colors}
>
{copy.appName}
</AuroraText>
)}
</div> </div>
)} )}
</div> </div>
@ -59,7 +87,8 @@ export function Welcome({
)} )}
</div> </div>
) : ( ) : (
<div> </div> // <div> </div>
<></>
)} )}
</div> </div>
); );

View File

@ -1,6 +1,7 @@
"use client"; "use client";
import { MessageSquarePlus } from "lucide-react"; import { MessageSquarePlus } from "lucide-react";
import Image from "next/image";
import Link from "next/link"; import Link from "next/link";
import { usePathname } from "next/navigation"; import { usePathname } from "next/navigation";
import { toast } from "sonner"; import { toast } from "sonner";
@ -13,6 +14,7 @@ import {
useSidebar, useSidebar,
} from "@/components/ui/sidebar"; } from "@/components/ui/sidebar";
import { useThreadChat } from "@/components/workspace/chats"; import { useThreadChat } from "@/components/workspace/chats";
import { useBrand } from "@/core/brand/provider";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages"; import { POST_MESSAGE_TYPES, sendToParent } from "@/core/iframe-messages";
import { env } from "@/env"; import { env } from "@/env";
@ -21,10 +23,14 @@ import { copyToClipboard } from "@/lib/utils";
export function WorkspaceHeader({ className }: { className?: string }) { export function WorkspaceHeader({ className }: { className?: string }) {
const { t } = useI18n(); const { t } = useI18n();
const { copy } = useBrand();
const { state } = useSidebar(); const { state } = useSidebar();
const pathname = usePathname(); const pathname = usePathname();
const { threadId } = useThreadChat(); const { threadId } = useThreadChat();
const threadUrl = threadId ? `/workspace/chats/${threadId}` : ""; const threadUrl = threadId ? `/workspace/chats/${threadId}` : "";
const compactAppName = copy.appName.startsWith("cox")
? `C${copy.appName.charAt(3).toUpperCase()}`
: copy.appName.slice(0, 2).toUpperCase();
const handleCopyThreadId = async () => { const handleCopyThreadId = async () => {
if (!threadId) return; if (!threadId) return;
@ -51,7 +57,7 @@ export function WorkspaceHeader({ className }: { className?: string }) {
{state === "collapsed" ? ( {state === "collapsed" ? (
<div className="group-has-data-[collapsible=icon]/sidebar-wrapper:-translate-y flex w-full cursor-pointer items-center justify-center"> <div className="group-has-data-[collapsible=icon]/sidebar-wrapper:-translate-y flex w-full cursor-pointer items-center justify-center">
<div className="text-primary block pt-1 font-serif group-hover/workspace-header:hidden"> <div className="text-primary block pt-1 font-serif group-hover/workspace-header:hidden">
XC {compactAppName}
</div> </div>
<SidebarTrigger className="hidden pl-2 group-hover/workspace-header:block" /> <SidebarTrigger className="hidden pl-2 group-hover/workspace-header:block" />
</div> </div>
@ -62,10 +68,20 @@ export function WorkspaceHeader({ className }: { className?: string }) {
{t.workspaceHeader.sidebarTitle} {t.workspaceHeader.sidebarTitle}
</Link> </Link>
) : ( ) : (
<div className="text-primary ml-2 cursor-default font-serif"> <div className="text-primary ml-2 flex cursor-default items-center gap-2 font-serif">
{/* TODO: 测试标识 */} {copy.appLogoSrc ? (
XClaw{" "} <Image
<span className="text-sm text-ws-text-subtle-strong">v3.2.9 </span>{" "} src={copy.appLogoSrc}
alt={copy.appLogoAlt ?? copy.appName}
width={104}
height={16}
className="h-4 w-auto"
priority
/>
) : (
copy.appName
)}
<span className="text-sm text-ws-text-subtle-strong">v3.3.0 </span>{" "}
<span <span
className={cn( className={cn(
"text-xs font-mono", "text-xs font-mono",
@ -80,7 +96,6 @@ export function WorkspaceHeader({ className }: { className?: string }) {
> >
id:{threadId ? threadId.slice(0, 5) : "-"} id:{threadId ? threadId.slice(0, 5) : "-"}
</span> </span>
{" "}
{threadId && ( {threadId && (
<a <a
href={threadUrl} href={threadUrl}

View File

@ -0,0 +1,39 @@
import assert from "node:assert/strict";
import test from "node:test";
const {
BRAND_SESSION_STORAGE_KEY,
getBrandRootClassName,
parseBrandFromSearchParams,
resolveBrandSession,
} = await import(new URL("./index.ts", import.meta.url).href);
void test("parseBrandFromSearchParams returns correct brand per param value", () => {
assert.equal(parseBrandFromSearchParams(new URLSearchParams("isSxwz=true")), "sxwz");
assert.equal(parseBrandFromSearchParams(new URLSearchParams("isSxwz=false")), "default");
assert.equal(parseBrandFromSearchParams(new URLSearchParams("")), null);
});
void test("resolveBrandSession falls back to default without url or storage", () => {
assert.equal(resolveBrandSession({ urlBrand: null, storedBrand: null }), "default");
});
void test("resolveBrandSession keeps stored sxwz when later url omits the flag", () => {
assert.equal(resolveBrandSession({ urlBrand: null, storedBrand: "sxwz" }), "sxwz");
});
void test("resolveBrandSession downgrades stored sxwz when url explicitly sets isSxwz=false", () => {
const urlBrand = parseBrandFromSearchParams(new URLSearchParams("isSxwz=false"));
assert.equal(resolveBrandSession({ urlBrand, storedBrand: "sxwz" }), "default");
});
void test("resolveBrandSession upgrades to sxwz when url flag is true", () => {
const urlBrand = parseBrandFromSearchParams(new URLSearchParams("isSxwz=true"));
assert.equal(resolveBrandSession({ urlBrand, storedBrand: "default" }), "sxwz");
});
void test("getBrandRootClassName returns stable workspace hook classes", () => {
assert.equal(getBrandRootClassName("default"), "brand-default");
assert.equal(getBrandRootClassName("sxwz"), "brand-sxwz");
assert.equal(BRAND_SESSION_STORAGE_KEY, "deerflow.brand-session");
});

View File

@ -0,0 +1,76 @@
export const BRAND_SESSION_STORAGE_KEY = "deerflow.brand-session";
export const DEFAULT_BRAND = "default" as const;
const SXWZ_BRAND = "sxwz" as const;
export type Brand = typeof DEFAULT_BRAND | typeof SXWZ_BRAND;
export type BrandCopy = {
productLabel: string;
appName: string;
appLogoSrc?: string;
appLogoAlt?: string;
};
export const BRAND_COPY: Record<Brand, BrandCopy> = {
default: {
productLabel: "轻办公",
appName: "coxworker",
appLogoSrc: "/coxwork.png",
appLogoAlt: "coxworker",
},
sxwz: {
productLabel: "在线教育智能体",
appName: "coxstudy",
},
};
export function isBrand(value: string | null): value is Brand {
return value === DEFAULT_BRAND || value === SXWZ_BRAND;
}
export function parseBrandFromSearchParams(
searchParams: URLSearchParams,
): Brand | null {
const value = searchParams.get("isSxwz");
if (value === "true") return SXWZ_BRAND;
if (value === "false") return DEFAULT_BRAND;
return null;
}
export function resolveBrandSession({
urlBrand,
storedBrand,
}: {
urlBrand: Brand | null;
storedBrand: Brand | null;
}): Brand {
if (urlBrand === SXWZ_BRAND) {
return SXWZ_BRAND;
}
if (urlBrand === DEFAULT_BRAND) {
return DEFAULT_BRAND;
}
if (storedBrand === SXWZ_BRAND) {
return SXWZ_BRAND;
}
return DEFAULT_BRAND;
}
export function getBrandRootClassName(brand: Brand): string {
return brand === SXWZ_BRAND ? "brand-sxwz" : "brand-default";
}
export function readStoredBrand(storage: Pick<Storage, "getItem">): Brand | null {
const value = storage.getItem(BRAND_SESSION_STORAGE_KEY);
return isBrand(value) ? value : null;
}
export function writeStoredBrand(
storage: Pick<Storage, "setItem">,
brand: Brand,
): void {
storage.setItem(BRAND_SESSION_STORAGE_KEY, brand);
}

View File

@ -0,0 +1,41 @@
"use client";
import { useSearchParams } from "next/navigation";
import { useEffect, useRef } from "react";
import { useBrand } from "./provider";
import {
parseBrandFromSearchParams,
readStoredBrand,
resolveBrandSession,
writeStoredBrand,
type Brand,
} from "./index";
export function BrandSessionInitializer() {
const searchParams = useSearchParams();
const { brand, setBrand } = useBrand();
const prevBrandRef = useRef<Brand>(brand);
// 在 render 阶段同步计算,确保每次 searchParams 变化都能感知
const storedBrand =
typeof window !== "undefined"
? readStoredBrand(window.sessionStorage)
: null;
const urlBrand = parseBrandFromSearchParams(
new URLSearchParams(searchParams.toString()),
);
const resolvedBrand = resolveBrandSession({ urlBrand, storedBrand });
// 只在品牌确实变化时才写入 sessionStorage 并更新 context
useEffect(() => {
if (resolvedBrand !== prevBrandRef.current) {
prevBrandRef.current = resolvedBrand;
writeStoredBrand(window.sessionStorage, resolvedBrand);
setBrand(resolvedBrand);
}
}, [resolvedBrand, setBrand]);
return null;
}

View File

@ -0,0 +1,55 @@
"use client";
import { createContext, useContext, useState, type ReactNode } from "react";
import {
BRAND_COPY,
DEFAULT_BRAND,
getBrandRootClassName,
type Brand,
} from "./index";
type BrandContextValue = {
brand: Brand;
copy: (typeof BRAND_COPY)[Brand];
rootClassName: string;
setBrand: (brand: Brand) => void;
};
const BrandContext = createContext<BrandContextValue | null>(null);
function getInitialBrand(): Brand {
if (typeof window === "undefined") {
return DEFAULT_BRAND;
}
const storedBrand = window.sessionStorage.getItem("deerflow.brand-session");
return storedBrand === "sxwz" ? "sxwz" : DEFAULT_BRAND;
}
export function BrandProvider({ children }: { children: ReactNode }) {
const [brand, setBrand] = useState<Brand>(getInitialBrand);
return (
<BrandContext.Provider
value={{
brand,
copy: BRAND_COPY[brand],
rootClassName: getBrandRootClassName(brand),
setBrand,
}}
>
{children}
</BrandContext.Provider>
);
}
export function useBrand() {
const context = useContext(BrandContext);
if (!context) {
throw new Error("useBrand must be used within BrandProvider");
}
return context;
}

View File

@ -264,6 +264,27 @@ export const enUS: Translations = {
scrollToBottom: "Scroll to bottom", scrollToBottom: "Scroll to bottom",
}, },
threadMemoryPanel: {
title: "Thread Memory",
load: "Load memory",
loading: "Loading...",
save: "Save memory",
saving: "Saving...",
remove: "Delete memory",
removing: "Deleting...",
threadId: "Thread ID",
version: "Version",
unavailableVersion: "-",
summaryPlaceholder: "Thread memory summary is shown here. Edit it and save.",
toastLoadSuccess: "Thread memory loaded",
toastLoadFailed: "Failed to load thread memory",
toastConflict: "Memory changed. Please reload before saving.",
toastSaveSuccess: "Thread memory saved",
toastSaveFailed: "Failed to save thread memory",
toastDeleteSuccess: "Thread memory deleted",
toastDeleteFailed: "Failed to delete thread memory",
},
// Workspace Chat Page // Workspace Chat Page
chatPage: { chatPage: {
defaultSlogan: "Let's study and work together", defaultSlogan: "Let's study and work together",
@ -277,6 +298,7 @@ export const enUS: Translations = {
exitDialogTitle: "Notice", exitDialogTitle: "Notice",
exitDialogDescription: exitDialogDescription:
"Chat history is automatically deleted every seven days. You will return to the welcome page now. Continue?", "Chat history is automatically deleted every seven days. You will return to the welcome page now. Continue?",
exitDialogClearMemory: "Also clear memory for this thread",
exitDialogConfirm: "Confirm", exitDialogConfirm: "Confirm",
selectedSkillLoadFailed: "Failed to load skill", selectedSkillLoadFailed: "Failed to load skill",
unknownErrorRetry: "An unknown error occurred. Please try again later.", unknownErrorRetry: "An unknown error occurred. Please try again later.",

View File

@ -194,6 +194,27 @@ export interface Translations {
scrollToBottom: string; scrollToBottom: string;
}; };
threadMemoryPanel: {
title: string;
load: string;
loading: string;
save: string;
saving: string;
remove: string;
removing: string;
threadId: string;
version: string;
unavailableVersion: string;
summaryPlaceholder: string;
toastLoadSuccess: string;
toastLoadFailed: string;
toastConflict: string;
toastSaveSuccess: string;
toastSaveFailed: string;
toastDeleteSuccess: string;
toastDeleteFailed: string;
};
// Workspace Chat Page // Workspace Chat Page
chatPage: { chatPage: {
defaultSlogan: string; defaultSlogan: string;
@ -206,6 +227,7 @@ export interface Translations {
noArtifactSelectedDescription: string; noArtifactSelectedDescription: string;
exitDialogTitle: string; exitDialogTitle: string;
exitDialogDescription: string; exitDialogDescription: string;
exitDialogClearMemory: string;
exitDialogConfirm: string; exitDialogConfirm: string;
selectedSkillLoadFailed: string; selectedSkillLoadFailed: string;
unknownErrorRetry: string; unknownErrorRetry: string;

View File

@ -130,32 +130,69 @@ export const zhCN: Translations = {
prompt: prompt:
"为[主题/产品]撰写吸引人的自媒体文案,包括标题、正文和话题标签。", "为[主题/产品]撰写吸引人的自媒体文案,包括标题、正文和话题标签。",
icon: PenLineIcon, icon: PenLineIcon,
children: [{ id: "6057", name: "生辰解语" }], children: [{ id: "6057", name: "八字命理" }],
},
{
suggestion: "张雪峰・升学就业心智",
prompt: "编写[项目/功能]的需求文档,包含功能描述、用户故事和验收标准。",
icon: CompassIcon,
children: [{ id: "6094", name: "张雪峰・升学就业心智" }],
},
{
suggestion: "塔罗牌",
prompt: "编写[产品/功能]的使用指南,包含操作步骤、注意事项和常见问题。",
icon: GraduationCapIcon,
children: [{ id: "6109", name: "塔罗牌" }],
}, },
{ {
suggestion: "GPT-Image-2", suggestion: "GPT-Image-2",
prompt: "编写[项目/功能]的需求文档,包含功能描述、用户故事和验收标准。", prompt: "对[Excel文件/数据]进行分析,生成数据洞察和可视化建议。",
icon: CompassIcon, icon: MicroscopeIcon,
children: [{ id: "6130", name: "GPT-Image-2" }], children: [{ id: "6130", name: "GPT-Image-2" }],
}, },
{ {
suggestion: "音乐生成", suggestion: "音乐生成",
prompt: "编写[产品/功能]的使用指南,包含操作步骤、注意事项和常见问题。", prompt: "对[word文件/数据]进行分析",
icon: GraduationCapIcon,
children: [{ id: "6133", name: "音乐生成器" }],
},
{
suggestion: "excel数据处理",
prompt: "对[Excel文件/数据]进行分析,生成数据洞察和可视化建议。",
icon: MicroscopeIcon, icon: MicroscopeIcon,
children: [{ id: "17", name: "Excel处理" }], children: [{ id: "6133", name: "音乐生成" }],
}, },
{ {
suggestion: "微信文章撰写", suggestion: "微信公众号攥写",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。", prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon, icon: ShapesIcon,
children: [{ id: "6134", name: "微信文章撰写" }], children: [{ id: "6134", name: "微信公众号攥写" }],
}, },
{
suggestion: "word填表神器",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon,
children: [{ id: "6195", name: "word填表神器" }],
},
{
suggestion: "excel填表神器",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon,
children: [{ id: "17", name: "excel填表神器" }],
},
{
suggestion: "精美ppt生成",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon,
children: [{ id: "6129 ", name: "精美ppt生成" }],
},
{
suggestion: "论文撰写",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon,
children: [{ id: "6250", name: "论文撰写" }],
},
{
suggestion: "专利交底书编写",
prompt: "针对[行业/产品]进行市场调研,分析市场规模、竞品和趋势。",
icon: ShapesIcon,
children: [{ id: "6000", name: "专利交底书编写" }],
},
], ],
suggestionsCreate: [ suggestionsCreate: [
{ {
@ -252,6 +289,27 @@ export const zhCN: Translations = {
scrollToBottom: "滚动到底部", scrollToBottom: "滚动到底部",
}, },
threadMemoryPanel: {
title: "会话记忆",
load: "查看记忆",
loading: "加载中...",
save: "保存记忆",
saving: "保存中...",
remove: "删除记忆",
removing: "删除中...",
threadId: "threadId",
version: "版本",
unavailableVersion: "-",
summaryPlaceholder: "这里显示会话记忆总结,可编辑后保存",
toastLoadSuccess: "已加载会话记忆",
toastLoadFailed: "加载会话记忆失败",
toastConflict: "记忆已更新,请先重新加载再保存",
toastSaveSuccess: "会话记忆已保存",
toastSaveFailed: "保存会话记忆失败",
toastDeleteSuccess: "当前会话记忆已删除",
toastDeleteFailed: "删除会话记忆失败",
},
// Workspace Chat Page // Workspace Chat Page
chatPage: { chatPage: {
defaultSlogan: "来,一起学习工作吧", defaultSlogan: "来,一起学习工作吧",
@ -265,6 +323,7 @@ export const zhCN: Translations = {
exitDialogTitle: "提示", exitDialogTitle: "提示",
exitDialogDescription: exitDialogDescription:
"每七天自动删除。现在将返回欢迎页且清空聊天消息,是否继续?", "每七天自动删除。现在将返回欢迎页且清空聊天消息,是否继续?",
exitDialogClearMemory: "同时清除当前会话的记忆",
exitDialogConfirm: "确定", exitDialogConfirm: "确定",
selectedSkillLoadFailed: "技能加载失败", selectedSkillLoadFailed: "技能加载失败",
unknownErrorRetry: "发生了未知错误,请稍后重试。", unknownErrorRetry: "发生了未知错误,请稍后重试。",

View File

@ -32,32 +32,16 @@ export interface LocalSettings {
}; };
} }
function clearLocalSettingsStorage() {
localStorage.removeItem(LOCAL_SETTINGS_KEY);
}
export function getLocalSettings(): LocalSettings { export function getLocalSettings(): LocalSettings {
if (typeof window === "undefined") { if (typeof window === "undefined") {
return DEFAULT_LOCAL_SETTINGS; return DEFAULT_LOCAL_SETTINGS;
} }
const json = localStorage.getItem(LOCAL_SETTINGS_KEY);
try { clearLocalSettingsStorage();
if (json) {
const settings = JSON.parse(json);
const mergedSettings = {
...DEFAULT_LOCAL_SETTINGS,
context: {
...DEFAULT_LOCAL_SETTINGS.context,
...settings.context,
},
layout: {
...DEFAULT_LOCAL_SETTINGS.layout,
...settings.layout,
},
notification: {
...DEFAULT_LOCAL_SETTINGS.notification,
...settings.notification,
},
};
return mergedSettings;
}
} catch {}
return DEFAULT_LOCAL_SETTINGS; return DEFAULT_LOCAL_SETTINGS;
} }
@ -65,6 +49,6 @@ export function saveLocalSettings(settings: LocalSettings) {
void settings; void settings;
// 注释了,因为本地存储会污染模型配置 // 注释了,因为本地存储会污染模型配置
console.log("localStorage设置已经注释"); console.log("localStorage设置已经注释");
localStorage.removeItem(LOCAL_SETTINGS_KEY); clearLocalSettingsStorage();
// localStorage.setItem(LOCAL_SETTINGS_KEY, JSON.stringify(settings)); // localStorage.setItem(LOCAL_SETTINGS_KEY, JSON.stringify(settings));
} }

View File

@ -0,0 +1,45 @@
import assert from "node:assert/strict";
import test from "node:test";
const { clearThreadMemoryOnExit } = await import(
new URL("./exit-thread-memory.ts", import.meta.url).href
);
void test("clears thread memory when checkbox is enabled", async () => {
const calls: Array<{ input: RequestInfo | URL; init?: RequestInit }> = [];
globalThis.fetch = (async (input, init) => {
calls.push({ input, init });
return new Response(null, { status: 204 });
}) as typeof fetch;
await clearThreadMemoryOnExit({
backendBaseURL: "http://localhost:3000",
threadId: "thread-123",
shouldClearMemory: true,
});
assert.equal(calls.length, 1);
assert.equal(
calls[0]?.input,
"http://localhost:3000/api/threads/thread-123/memory",
);
assert.equal(calls[0]?.init?.method, "DELETE");
});
void test("skips clearing thread memory when checkbox is disabled", async () => {
let called = false;
globalThis.fetch = (async () => {
called = true;
return new Response(null, { status: 204 });
}) as typeof fetch;
await clearThreadMemoryOnExit({
backendBaseURL: "http://localhost:3000",
threadId: "thread-123",
shouldClearMemory: false,
});
assert.equal(called, false);
});

View File

@ -0,0 +1,24 @@
type ClearThreadMemoryOnExitParams = {
backendBaseURL?: string;
threadId?: string;
shouldClearMemory: boolean;
};
export async function clearThreadMemoryOnExit({
backendBaseURL = "",
threadId,
shouldClearMemory,
}: ClearThreadMemoryOnExitParams) {
if (!threadId || !shouldClearMemory) {
return;
}
const res = await fetch(
`${backendBaseURL}/api/threads/${encodeURIComponent(threadId)}/memory`,
{ method: "DELETE" },
);
if (!res.ok) {
throw new Error(`Failed to clear thread memory: HTTP ${res.status}`);
}
}

View File

@ -18,29 +18,27 @@ export const externalLinkClassNoUnderline = "text-primary hover:underline";
* In iframe context, sends message to parent window to handle clipboard operation. * In iframe context, sends message to parent window to handle clipboard operation.
*/ */
export async function copyToClipboard(text: string): Promise<void> { export async function copyToClipboard(text: string): Promise<void> {
const isInIframe = window.self !== window.top;
const message = { const message = {
type: POST_MESSAGE_TYPES.COPY_TO_CLIPBOARD, type: POST_MESSAGE_TYPES.COPY_TO_CLIPBOARD,
text, text,
} as const; } as const;
if (isInIframe) { console.log("[copyToClipboard] called, text length:", text.length);
try {
// Request parent window to copy // 始终发送 postMessage由 sendToParent 内部判断是否为 iframe 环境
sendToParent(message); // 与 openSkillDialog 等其他 iframe 通信保持一致
console.log( try {
"[copyToClipboard] iframe mode → postMessage to parent", sendToParent(message);
message, } catch {
); // no-op
return;
} catch (error) {
console.warn("[copyToClipboard] iframe postMessage failed", error);
}
} }
// Direct clipboard access when not in iframe // 同时也尝试直接写剪贴板(非 iframe 场景兜底)
console.log("[copyToClipboard] direct mode", message); try {
await navigator.clipboard.writeText(text); await navigator.clipboard.writeText(text);
} catch {
// no-op: 在 iframe 环境下由父窗口处理
}
} }
/** /**

View File

@ -68,6 +68,12 @@
@source inline("bg-{background,muted,primary,secondary,accent}"); @source inline("bg-{background,muted,primary,secondary,accent}");
@source inline("border-{border,input}"); @source inline("border-{border,input}");
.brand-default {
}
.brand-sxwz {
}
@custom-variant dark (&:is(.dark *)); @custom-variant dark (&:is(.dark *));
@theme { @theme {
@ -206,6 +212,7 @@
--color-ws-surface-subtle: var(--ws-color-surface-subtle); --color-ws-surface-subtle: var(--ws-color-surface-subtle);
--color-ws-surface-elevated: var(--ws-color-surface-elevated); --color-ws-surface-elevated: var(--ws-color-surface-elevated);
--color-ws-interactive-primary: var(--ws-color-interactive-primary); --color-ws-interactive-primary: var(--ws-color-interactive-primary);
--color-ws-interactive-hover: var(--ws-color-interactive-hover);
--color-ws-line-default: var(--ws-color-line-default); --color-ws-line-default: var(--ws-color-line-default);
--color-ws-text-muted: var(--ws-color-text-muted); --color-ws-text-muted: var(--ws-color-text-muted);
--color-ws-icon-muted: var(--ws-color-icon-muted); --color-ws-icon-muted: var(--ws-color-icon-muted);
@ -311,7 +318,8 @@
--ws-color-fg-primary: #333333; --ws-color-fg-primary: #333333;
--ws-color-surface-subtle: #f9f8fa; --ws-color-surface-subtle: #f9f8fa;
--ws-color-surface-elevated: #fbfafc; --ws-color-surface-elevated: #fbfafc;
--ws-color-interactive-primary: #8e47f0; --ws-color-interactive-hover: #1500331A;
--ws-color-interactive-primary: #150033;
--ws-color-line-default: #e4e7ec; --ws-color-line-default: #e4e7ec;
--ws-color-text-muted: #667085; --ws-color-text-muted: #667085;
--ws-color-icon-muted: #a3a1a1; --ws-color-icon-muted: #a3a1a1;
@ -364,7 +372,7 @@
--ws-color-fg-primary: #f5f5f5; --ws-color-fg-primary: #f5f5f5;
--ws-color-surface-subtle: #1f1f1f; --ws-color-surface-subtle: #1f1f1f;
--ws-color-surface-elevated: #24222a; --ws-color-surface-elevated: #24222a;
--ws-color-interactive-primary: #b987ff; --ws-color-interactive-primary: #150033;
--ws-color-line-default: #3b3f48; --ws-color-line-default: #3b3f48;
--ws-color-text-muted: #98a2b3; --ws-color-text-muted: #98a2b3;
--ws-color-icon-muted: #d0d0d0; --ws-color-icon-muted: #d0d0d0;

View File

@ -0,0 +1,43 @@
import { expect, test } from "@playwright/test";
import { invalidNewChatUrl } from "./support/chat-helpers";
const LOCAL_SETTINGS_KEY = "deerflow.local-settings";
test.describe("本地设置清理", () => {
test("禁用持久化后会在进入工作台时清除历史 localStorage 设置", async ({
page,
}) => {
await page.addInitScript(
({ key, value }: { key: string; value: string }) => {
window.localStorage.setItem(key, value);
},
{
key: LOCAL_SETTINGS_KEY,
value: JSON.stringify({
context: {
model_name: "gpt-5",
mode: "pro",
reasoning_effort: "high",
},
layout: {
sidebar_collapsed: true,
},
notification: {
enabled: false,
},
}),
},
);
await page.goto(invalidNewChatUrl());
await expect(page.locator("textarea[name='message']")).toBeVisible();
await expect
.poll(
() => page.evaluate((key) => window.localStorage.getItem(key), LOCAL_SETTINGS_KEY),
{ message: "expected deprecated local settings storage to be cleared" },
)
.toBeNull();
});
});

16
scripts/git/install-hooks.sh Executable file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env bash
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "$REPO_ROOT"
if ! git rev-parse --is-inside-work-tree >/dev/null 2>&1; then
echo "✗ Not inside a Git repository."
exit 1
fi
git config --local core.hooksPath .githooks
echo "✓ Git hooks installed"
echo " core.hooksPath = .githooks"
echo " pre-push will rebase the current branch before push"

42
scripts/git/pre-push-rebase.sh Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env bash
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "$REPO_ROOT"
current_branch="$(git symbolic-ref -q --short HEAD || true)"
if [ -z "$current_branch" ]; then
echo "✗ pre-push rebase: detached HEAD detected."
echo " Checkout a branch before pushing."
exit 1
fi
if ! git rev-parse --verify --quiet HEAD >/dev/null; then
echo "✗ pre-push rebase: current branch has no commits yet."
exit 1
fi
if [ -n "$(git status --porcelain)" ]; then
echo "✗ pre-push rebase: working tree is not clean."
echo " Commit or stash changes before pushing."
exit 1
fi
if [ -d .git/rebase-merge ] || [ -d .git/rebase-apply ]; then
echo "✗ pre-push rebase: a rebase is already in progress."
exit 1
fi
echo "Fetching origin/git-main..."
git fetch origin git-main
if ! git rev-parse --verify --quiet origin/git-main >/dev/null; then
echo "✗ pre-push rebase: origin/git-main is not available locally."
echo " Fetch origin first, then push again."
exit 1
fi
echo "Rebasing '$current_branch' onto origin/git-main before push..."
git rebase origin/git-main
echo "✓ Rebase completed for '$current_branch'"