Compare commits

..

No commits in common. "6b900ccb6067d2b791677faba6c946203b091ec3" and "0ffe5a73c1440f2d61d04bc3a16529942d62300e" have entirely different histories.

23 changed files with 53 additions and 1320 deletions

View File

@ -30,7 +30,7 @@ class SlackChannel(Channel):
self._socket_client = None self._socket_client = None
self._web_client = None self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])} self._allowed_users: set[str] = set(config.get("allowed_users", []))
async def start(self) -> None: async def start(self) -> None:
if self._running: if self._running:

View File

@ -23,11 +23,9 @@ from app.gateway.routers import (
) )
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
# Configure logging with env override # Configure logging
import os
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
logging.basicConfig( logging.basicConfig(
level=getattr(logging, log_level, logging.INFO), level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )

View File

@ -9,10 +9,6 @@ class GatewayConfig(BaseModel):
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server") host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server") port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins") cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
skill_content_api_url: str = Field(
default="https://skills.xueai.art/api/cmsContent/getContent",
description="Remote API URL used to fetch skill YAML content by content ID",
)
_gateway_config: GatewayConfig | None = None _gateway_config: GatewayConfig | None = None
@ -27,9 +23,5 @@ def get_gateway_config() -> GatewayConfig:
host=os.getenv("GATEWAY_HOST", "0.0.0.0"), host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")), port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","), cors_origins=cors_origins_str.split(","),
skill_content_api_url=os.getenv(
"SKILL_CONTENT_API_URL",
"https://skills.xueai.art/api/cmsContent/getContent",
),
) )
return _gateway_config return _gateway_config

View File

@ -1,15 +1,11 @@
import json import json
import logging import logging
import shutil
from pathlib import Path from pathlib import Path
import httpx
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.config import get_gateway_config
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
from app.gateway.skill_yaml_importer import materialize_skill_tree, parse_skill_yaml_spec
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
@ -56,38 +52,6 @@ class SkillInstallResponse(BaseModel):
message: str = Field(..., description="Installation result message") message: str = Field(..., description="Installation result message")
class RemoteSkillBootstrapRequest(BaseModel):
"""Request model for bootstrapping skill files from remote content API."""
thread_id: str = Field(..., description="Thread ID used for user-data path binding")
content_ids: list[int] = Field(
...,
min_length=1,
description="Remote content ID sequence (maps from frontend query param skill_id)",
)
language_type: int = Field(default=0, description="Language type for remote API request body")
target_dir: str = Field(
default="/mnt/user-data/uploads/skill",
description="Virtual base directory where each skill-{id} subdirectory is created",
)
clear_target: bool = Field(
default=True,
description="Whether to clear target directory before writing parsed files",
)
class RemoteSkillBootstrapResponse(BaseModel):
"""Response model for remote bootstrap endpoint."""
success: bool = Field(..., description="Whether bootstrap succeeded")
target_dir: str = Field(..., description="Virtual base target directory")
content_ids: list[int] = Field(..., description="Bootstrapped content IDs")
created_directories: int = Field(..., description="Number of created directories")
created_files: int = Field(..., description="Number of created files")
sandbox_id: str | None = Field(default=None, description="Acquired sandbox ID (null when sandbox is not acquired)")
message: str = Field(..., description="Operation result message")
def _skill_to_response(skill: Skill) -> SkillResponse: def _skill_to_response(skill: Skill) -> SkillResponse:
"""Convert a Skill object to a SkillResponse.""" """Convert a Skill object to a SkillResponse."""
return SkillResponse( return SkillResponse(
@ -207,107 +171,3 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
except Exception as e: except Exception as e:
logger.error(f"Failed to install skill: {e}", exc_info=True) logger.error(f"Failed to install skill: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
@router.post(
"/skills/bootstrap-remote",
response_model=RemoteSkillBootstrapResponse,
summary="Bootstrap Skill Files From Remote API",
description=(
"Fetch YAML text from configured remote API by content_ids/language_type and "
"materialize files into skill-{id} directories under /mnt/user-data/uploads/skill "
"before first thread submit."
),
)
async def bootstrap_skill_from_remote(request: RemoteSkillBootstrapRequest) -> RemoteSkillBootstrapResponse:
"""Initialize thread skill directory from remote YAML content service."""
try:
cfg = get_gateway_config()
api_url = cfg.skill_content_api_url
created_directories_total = 0
created_files_total = 0
base_target_path = resolve_thread_virtual_path(request.thread_id, request.target_dir)
if request.clear_target and base_target_path.exists():
if base_target_path.is_dir():
shutil.rmtree(base_target_path)
else:
base_target_path.unlink()
base_target_path.mkdir(parents=True, exist_ok=True)
async with httpx.AsyncClient(timeout=20.0) as client:
for content_id in request.content_ids:
payload = {
"contentId": content_id,
"languageType": request.language_type,
}
response = await client.post(api_url, json=payload)
if response.status_code >= 400:
raise HTTPException(
status_code=502,
detail=(
"Remote skill content API failed with HTTP "
f"{response.status_code} for content_id={content_id}"
),
)
try:
response_json = response.json()
except ValueError as e:
raise HTTPException(status_code=502, detail=f"Remote API did not return valid JSON: {e}") from e
status = response_json.get("status")
if status != 1000:
raise HTTPException(
status_code=502,
detail=(
"Remote API returned non-success status: "
f"{status}, message: {response_json.get('message')}, content_id={content_id}"
),
)
yaml_text = response_json.get("data")
if not isinstance(yaml_text, str) or not yaml_text.strip():
raise HTTPException(
status_code=502,
detail=f"Remote API returned empty or invalid YAML content for content_id={content_id}",
)
target_dir = f"{request.target_dir.rstrip('/')}/skill-{content_id}"
target_path = resolve_thread_virtual_path(request.thread_id, target_dir)
parsed = parse_skill_yaml_spec(yaml_text)
materialize_skill_tree(parsed, target_path, clear_target=False)
created_directories_total += len(parsed.directories)
created_files_total += len(parsed.files)
logger.info(
"Bootstrapped remote skill YAML for thread %s (content_id=%s, language_type=%s) to %s: dirs=%d files=%d",
request.thread_id,
content_id,
request.language_type,
target_dir,
len(parsed.directories),
len(parsed.files),
)
return RemoteSkillBootstrapResponse(
success=True,
target_dir=request.target_dir,
content_ids=request.content_ids,
created_directories=created_directories_total,
created_files=created_files_total,
sandbox_id=None,
message=(
f"Bootstrapped {created_files_total} files and {created_directories_total} directories "
f"for {len(request.content_ids)} skills under '{request.target_dir}'"
),
)
except HTTPException:
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to bootstrap skill from remote API: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to bootstrap skill from remote API: {str(e)}")

View File

@ -1,475 +0,0 @@
"""Utilities for parsing YAML-defined skill package structures.
This module supports turning a YAML document describing files/directories into
real filesystem content under a thread's virtual path (for example,
``/mnt/user-data/uploads/skill``).
"""
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import dataclass
from pathlib import Path
import yaml # type: ignore[import-not-found]
@dataclass(frozen=True)
class ParsedSkillTree:
"""Normalized parsed structure from YAML spec."""
directories: set[str]
files: dict[str, str]
def _pick_first_existing(data: dict, keys: tuple[str, ...]):
for key in keys:
if key in data:
return data[key]
return None
def _extract_spec_root(data: dict) -> dict:
"""Extract the effective spec root.
Supports nested wrappers like:
- skill: { ... }
- package: { ... }
- spec: { ... }
"""
if not isinstance(data, dict):
raise ValueError("YAML root must be an object")
known_keys = {
"entries",
"files",
"directories",
"dirs",
"tree",
"structure",
"file_tree",
"fileTree",
"file_structure",
"paths",
}
if any(k in data for k in known_keys):
return data
wrapper_candidates = ("skill", "package", "spec", "data", "content", "payload")
for wrapper in wrapper_candidates:
candidate = data.get(wrapper)
if isinstance(candidate, dict) and any(k in candidate for k in known_keys):
return candidate
# Fallback: if exactly one nested object exists, try it as spec root.
nested_dicts = [v for v in data.values() if isinstance(v, dict)]
if len(nested_dicts) == 1:
return nested_dicts[0]
return data
def _normalize_relative_path(path: str) -> str:
"""Normalize and validate a relative path.
Raises:
ValueError: If path is unsafe or invalid.
"""
if not isinstance(path, str):
raise ValueError("Path must be a string")
normalized = path.strip().replace("\\", "/")
if normalized in {"/", ".", "./"}:
return ""
if not normalized:
raise ValueError("Path cannot be empty")
if normalized.startswith("/"):
raise ValueError(f"Path must be relative, got absolute path: {path}")
if ":" in normalized:
raise ValueError(f"Path cannot contain ':' (possible drive path): {path}")
parts = [part for part in normalized.split("/") if part]
if not parts:
raise ValueError("Path cannot be empty")
if any(part in {".", ".."} for part in parts):
raise ValueError(f"Path traversal is not allowed: {path}")
return "/".join(parts)
def _add_directory(path: str, directories: set[str]) -> None:
normalized = _normalize_relative_path(path)
if not normalized:
return
directories.add(normalized)
def _add_file(path: str, content: str, files: dict[str, str], directories: set[str]) -> None:
normalized = _normalize_relative_path(path)
if not normalized:
raise ValueError("File path cannot be root ('/')")
if not isinstance(content, str):
raise ValueError(f"File content must be a string for '{normalized}'")
parent = Path(normalized).parent
if str(parent) != ".":
directories.add(str(parent).replace("\\", "/"))
files[normalized] = content
def _walk_tree_dict(tree: dict, base: str, files: dict[str, str], directories: set[str]) -> None:
for name, value in tree.items():
if not isinstance(name, str):
raise ValueError("Tree keys must be strings")
if name.strip() in {"/", ".", "./"}:
if isinstance(value, dict):
_walk_tree_dict(value, base, files, directories)
continue
raise ValueError("Root sentinel '/' can only be used for directory/object nodes")
node_path = f"{base}/{name}" if base else name
if isinstance(value, dict):
_add_directory(node_path, directories)
_walk_tree_dict(value, _normalize_relative_path(node_path), files, directories)
elif isinstance(value, str):
_add_file(node_path, value, files, directories)
else:
raise ValueError(
f"Unsupported tree node type for '{node_path}': {type(value).__name__}. "
"Use object (directory) or string (file content)."
)
def _parse_entries_node(
node: dict,
base: str,
files: dict[str, str],
directories: set[str],
) -> None:
raw_path = node.get("path")
raw_name = node.get("name")
if raw_path is None and raw_name is None:
raise ValueError("Each entry must have at least one of: 'path' or 'name'")
if raw_path is not None and not isinstance(raw_path, str):
raise ValueError("Entry 'path' must be a string")
if raw_name is not None and not isinstance(raw_name, str):
raise ValueError("Entry 'name' must be a string")
# Common schema compatibility:
# - `path` is parent directory (e.g. "/")
# - `name` is current node name (e.g. "README.md")
# Build parent then append name when both are present.
parent = base
if isinstance(raw_path, str) and raw_path.strip():
rp = raw_path.strip()
if rp not in {"/", ".", "./"}:
parent = _normalize_relative_path(f"{base}/{rp}" if base else rp)
if isinstance(raw_name, str) and raw_name.strip():
if parent:
node_path = _normalize_relative_path(f"{parent}/{raw_name.strip()}")
else:
node_path = _normalize_relative_path(raw_name.strip())
else:
# Fallback: only path provided
if not isinstance(raw_path, str) or not raw_path.strip():
raise ValueError("Each entry must have a non-empty 'path' or 'name'")
rp = raw_path.strip()
if rp in {"/", ".", "./"}:
node_path = base
else:
node_path = _normalize_relative_path(f"{base}/{rp}" if base else rp)
node_type = node.get("type")
content = node.get("content")
children = node.get("children")
inferred_type = "directory" if isinstance(children, list) else "file" if content is not None else None
final_type = node_type or inferred_type
if final_type == "directory":
_add_directory(node_path, directories)
if children is None:
return
if not isinstance(children, list):
raise ValueError(f"Entry '{node_path}' children must be a list")
for child in children:
if not isinstance(child, dict):
raise ValueError(f"Entry '{node_path}' children must be objects")
_parse_entries_node(child, node_path, files, directories)
return
if final_type == "file":
if content is None:
raise ValueError(f"File entry '{node_path}' is missing 'content'")
_add_file(node_path, content, files, directories)
return
raise ValueError(
f"Unable to infer entry type for '{node_path}'. Set 'type' to 'file' or 'directory'."
)
def parse_skill_yaml_spec(yaml_text: str) -> ParsedSkillTree:
"""Parse YAML text into normalized directories and files.
Supported forms:
- entries: [{type,path/content/children}, ...]
- files: {"path/to/file": "text"} + optional directories/dirs
- tree/structure: nested dict where dict=directory and string=file content
"""
try:
data = yaml.safe_load(yaml_text)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML: {e}") from e
if data is None:
raise ValueError("YAML is empty")
if not isinstance(data, dict):
raise ValueError("YAML root must be an object")
data = _extract_spec_root(data)
directories: set[str] = set()
files: dict[str, str] = {}
# Form 1: explicit entries list
entries = _pick_first_existing(data, ("entries", "nodes", "items"))
if entries is not None:
if not isinstance(entries, list):
raise ValueError("'entries' must be a list")
for entry in entries:
if not isinstance(entry, dict):
raise ValueError("Each item in 'entries' must be an object")
_parse_entries_node(entry, "", files, directories)
# Form 2: files + directories
file_map = _pick_first_existing(data, ("files", "paths", "file_map", "fileMap", "documents"))
if file_map is not None:
if isinstance(file_map, dict):
for path, content in file_map.items():
_add_file(path, content, files, directories)
elif isinstance(file_map, list):
for item in file_map:
if not isinstance(item, dict):
raise ValueError("Each item in 'files' list must be an object")
path = item.get("path") or item.get("name") or item.get("file")
content = item.get("content")
if content is None:
content = item.get("text")
if content is None:
content = item.get("body")
if path is None or content is None:
raise ValueError("Each file item needs 'path' and 'content'")
_add_file(path, content, files, directories)
else:
raise ValueError("'files' must be a map or list")
directory_list = _pick_first_existing(data, ("directories", "dirs", "folders", "folder_paths"))
if directory_list is not None:
if not isinstance(directory_list, list):
raise ValueError("'directories'/'dirs' must be a list")
for path in directory_list:
_add_directory(path, directories)
# Form 3: nested tree
tree = _pick_first_existing(data, ("tree", "structure", "file_tree", "fileTree", "file_structure"))
if tree is not None:
if isinstance(tree, dict):
_walk_tree_dict(tree, "", files, directories)
elif isinstance(tree, list):
for item in tree:
if not isinstance(item, dict):
raise ValueError("Items in 'tree' list must be objects")
_parse_entries_node(item, "", files, directories)
else:
raise ValueError("'tree'/'structure' must be an object or list")
# Heuristic fallback: treat root as path->content map when possible.
if not files and not directories:
candidate_keys = [k for k in data.keys() if isinstance(k, str)]
if candidate_keys and all(isinstance(data[k], str) for k in candidate_keys):
for path, content in data.items():
_add_file(path, content, files, directories)
if not files and not directories:
raise ValueError(
"No content found. Provide at least one of: entries, files, directories/dirs, tree/structure"
)
# Ensure parent directories exist for every file
for rel_file in files:
parent = Path(rel_file).parent
if str(parent) != ".":
directories.add(str(parent).replace("\\", "/"))
return ParsedSkillTree(directories=directories, files=files)
def materialize_skill_tree(parsed: ParsedSkillTree, target_root: Path, clear_target: bool = True) -> None:
"""Create parsed directories/files under target root."""
if clear_target and target_root.exists():
import shutil
shutil.rmtree(target_root)
target_root.mkdir(parents=True, exist_ok=True)
for rel_dir in sorted(parsed.directories, key=lambda p: (p.count("/"), p)):
(target_root / rel_dir).mkdir(parents=True, exist_ok=True)
for rel_file, content in parsed.files.items():
file_path = target_root / rel_file
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8")
def _build_cli_parser() -> argparse.ArgumentParser:
"""Build command-line argument parser.
CLI usage:
python skill_yaml_importer.py <input_path> [options]
"""
parser = argparse.ArgumentParser(description="Parse and validate a skill YAML spec file")
parser.add_argument("input_path", help="Path to a YAML file or a directory containing YAML files")
parser.add_argument(
"--show-files",
action="store_true",
help="Print sorted parsed file paths",
)
parser.add_argument(
"--show-directories",
action="store_true",
help="Print sorted parsed directory paths",
)
parser.add_argument(
"--json",
action="store_true",
help="Print parsed summary as JSON",
)
parser.add_argument(
"--recursive",
action="store_true",
help="When input path is a directory, scan YAML files recursively",
)
parser.add_argument(
"--log-file",
default=None,
help="Optional path to save full execution results and summary as JSON",
)
return parser
def _collect_yaml_files(input_path: Path, recursive: bool) -> list[Path]:
if input_path.is_file():
return [input_path]
if not input_path.is_dir():
return []
patterns = ("*.yaml", "*.yml")
files: list[Path] = []
for pattern in patterns:
iterator = input_path.rglob(pattern) if recursive else input_path.glob(pattern)
files.extend(iterator)
# Stable order for deterministic output
return sorted({p.resolve() for p in files})
def _parse_one_yaml_file(yaml_path: Path, show_files: bool, show_directories: bool) -> dict:
yaml_text = yaml_path.read_text(encoding="utf-8")
parsed = parse_skill_yaml_spec(yaml_text)
directories = sorted(parsed.directories)
files = sorted(parsed.files.keys())
return {
"yaml_file": str(yaml_path),
"directories_count": len(directories),
"files_count": len(files),
"directories": directories if show_directories else None,
"files": files if show_files else None,
}
def _main() -> int:
"""CLI entrypoint for parsing one YAML file or a batch of YAML files.
Exit codes:
0: all files parsed successfully
1: invalid input path or no YAML files found
2: processed completed with one or more parse failures
"""
args = _build_cli_parser().parse_args()
input_path = Path(args.input_path)
if not input_path.exists():
print(f"Input path not found: {input_path}", file=sys.stderr)
return 1
yaml_files = _collect_yaml_files(input_path, recursive=args.recursive)
if not yaml_files:
print(f"No YAML files found under: {input_path}", file=sys.stderr)
return 1
successes: list[dict] = []
failures: list[dict[str, str]] = []
for yaml_path in yaml_files:
try:
result = _parse_one_yaml_file(
yaml_path,
show_files=args.show_files,
show_directories=args.show_directories,
)
successes.append(result)
if not args.json:
print(f"OK: {yaml_path}")
print(f" Directories: {result['directories_count']}")
print(f" Files: {result['files_count']}")
except Exception as e: # noqa: BLE001
failures.append({"yaml_file": str(yaml_path), "error": str(e)})
print(f"ERROR: {yaml_path}: {e}", file=sys.stderr)
summary = {
"input_path": str(input_path),
"total": len(yaml_files),
"success": len(successes),
"failed": len(failures),
}
report = {"summary": summary, "successes": successes, "failures": failures}
if args.log_file:
try:
log_path = Path(args.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
log_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"Log saved: {log_path}")
except Exception as e: # noqa: BLE001
print(f"Failed to write log file '{args.log_file}': {e}", file=sys.stderr)
if args.json:
print(json.dumps(report, ensure_ascii=False, indent=2))
else:
print("\n[Summary]")
print(f"Input: {summary['input_path']}")
print(f"Total: {summary['total']}")
print(f"Success: {summary['success']}")
print(f"Failed: {summary['failed']}")
return 0 if not failures else 2
if __name__ == "__main__":
raise SystemExit(_main())

View File

@ -394,17 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
Returns the <skill_system>...</skill_system> block listing all enabled skills, Returns the <skill_system>...</skill_system> block listing all enabled skills,
suitable for injection into any agent's system prompt. suitable for injection into any agent's system prompt.
""" """
thread_id = None skills = _get_enabled_skills()
try:
from langgraph.config import get_config
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
except Exception:
pass
# Keep regular enable/disable behavior while letting uploads be default-enabled in loader.
skills = load_skills(enabled_only=True, thread_id=thread_id)
try: try:
from deerflow.config import get_app_config from deerflow.config import get_app_config
@ -571,6 +561,4 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
acp_section=acp_and_mounts_section, acp_section=acp_and_mounts_section,
) )
logger.debug("Generated full system prompt:\n%s", prompt)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>" return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"

View File

@ -21,7 +21,6 @@ class ConversationContext:
timestamp: datetime = field(default_factory=datetime.utcnow) timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None agent_name: str | None = None
correction_detected: bool = False correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue: class MemoryUpdateQueue:
@ -45,7 +44,6 @@ class MemoryUpdateQueue:
messages: list[Any], messages: list[Any],
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None: ) -> None:
"""Add a conversation to the update queue. """Add a conversation to the update queue.
@ -54,7 +52,6 @@ class MemoryUpdateQueue:
messages: The conversation messages. messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory. agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
""" """
config = get_memory_config() config = get_memory_config()
if not config.enabled: if not config.enabled:
@ -66,13 +63,11 @@ class MemoryUpdateQueue:
None, None,
) )
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext( context = ConversationContext(
thread_id=thread_id, thread_id=thread_id,
messages=messages, messages=messages,
agent_name=agent_name, agent_name=agent_name,
correction_detected=merged_correction_detected, correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
) )
# Check if this thread already has a pending update # Check if this thread already has a pending update
@ -135,7 +130,6 @@ class MemoryUpdateQueue:
thread_id=context.thread_id, thread_id=context.thread_id,
agent_name=context.agent_name, agent_name=context.agent_name,
correction_detected=context.correction_detected, correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
) )
if success: if success:
logger.info("Memory updated successfully for thread %s", context.thread_id) logger.info("Memory updated successfully for thread %s", context.thread_id)

View File

@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
stripped = content.strip() stripped = content.strip()
if not stripped: if not stripped:
return None return None
return stripped.casefold() return stripped
class MemoryUpdater: class MemoryUpdater:
@ -272,7 +272,6 @@ class MemoryUpdater:
thread_id: str | None = None, thread_id: str | None = None,
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool: ) -> bool:
"""Update memory based on conversation messages. """Update memory based on conversation messages.
@ -281,7 +280,6 @@ class MemoryUpdater:
thread_id: Optional thread ID for tracking source. thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
@ -312,14 +310,6 @@ class MemoryUpdater:
"and record the correct approach as a fact with category " "and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.' '"correction" and confidence >= 0.95 when appropriate.'
) )
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format( prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2), current_memory=json.dumps(current_memory, indent=2),
@ -451,7 +441,6 @@ def update_memory_from_conversation(
thread_id: str | None = None, thread_id: str | None = None,
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool: ) -> bool:
"""Convenience function to update memory from a conversation. """Convenience function to update memory from a conversation.
@ -460,10 +449,9 @@ def update_memory_from_conversation(
thread_id: Optional thread ID. thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns: Returns:
True if successful, False otherwise. True if successful, False otherwise.
""" """
updater = MemoryUpdater() updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected) return updater.update_memory(messages, thread_id, agent_name, correction_detected)

View File

@ -29,22 +29,6 @@ _CORRECTION_PATTERNS = (
re.compile(r"改用"), re.compile(r"改用"),
) )
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState): class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@ -148,29 +132,6 @@ def detect_correction(messages: list[Any]) -> bool:
return False return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution. """Middleware that queues conversation for memory update after agent execution.
@ -235,14 +196,12 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update # Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages) correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue() queue = get_memory_queue()
queue.add( queue.add(
thread_id=thread_id, thread_id=thread_id,
messages=filtered_messages, messages=filtered_messages,
agent_name=self._agent_name, agent_name=self._agent_name,
correction_detected=correction_detected, correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
) )
return None return None

View File

@ -192,8 +192,8 @@ class ExtensionsConfig(BaseModel):
""" """
skill_config = self.skills.get(skill_name) skill_config = self.skills.get(skill_name)
if skill_config is None: if skill_config is None:
# Default to enable for public/custom/uploads skills. # Default to enable for public & custom skill
return skill_category in ("public", "custom", "uploads") return skill_category in ("public", "custom")
return skill_config.enabled return skill_config.enabled

View File

@ -366,17 +366,12 @@ def _path_variants(path: str) -> set[str]:
return {path, path.replace("\\", "/"), path.replace("/", "\\")} return {path, path.replace("\\", "/"), path.replace("/", "\\")}
def _path_separator_for_style(path: str) -> str:
return "\\" if "\\" in path and "/" not in path else "/"
def _join_path_preserving_style(base: str, relative: str) -> str: def _join_path_preserving_style(base: str, relative: str) -> str:
if not relative: if not relative:
return base return base
separator = _path_separator_for_style(base) if "/" in base and "\\" not in base:
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\") return f"{base.rstrip('/')}/{relative}"
stripped_base = base.rstrip("/\\") return str(Path(base) / relative)
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str: def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
@ -421,10 +416,7 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
return actual_base return actual_base
if path.startswith(f"{virtual_base}/"): if path.startswith(f"{virtual_base}/"):
rest = path[len(virtual_base) :].lstrip("/") rest = path[len(virtual_base) :].lstrip("/")
result = _join_path_preserving_style(actual_base, rest) return _join_path_preserving_style(actual_base, rest)
if path.endswith("/") and not result.endswith(("/", "\\")):
result += _path_separator_for_style(actual_base)
return result
return path return path

View File

@ -8,27 +8,6 @@ from .types import Skill
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
UPLOADS_SKILLS_PATH = Path("/mnt/user-data/uploads")
def get_uploads_skills_path(thread_id: str | None = None) -> Path:
"""Resolve the uploads skills root for the current execution context.
When called from the LangGraph process, uploaded skills live under the
host-side per-thread data directory rather than the sandbox mount path.
"""
if not thread_id:
return UPLOADS_SKILLS_PATH
try:
from deerflow.config.paths import get_paths
return get_paths().sandbox_uploads_dir(thread_id)
except Exception as exc:
logger.warning("Failed to resolve uploads skills path for thread %s: %s", thread_id, exc)
return UPLOADS_SKILLS_PATH
def get_skills_root_path() -> Path: def get_skills_root_path() -> Path:
""" """
Get the root path of the skills directory. Get the root path of the skills directory.
@ -43,19 +22,12 @@ def get_skills_root_path() -> Path:
return skills_dir return skills_dir
def load_skills( def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
skills_path: Path | None = None,
use_config: bool = True,
enabled_only: bool = False,
thread_id: str | None = None,
) -> list[Skill]:
""" """
Load all skills from the skills directory. Load all skills from the skills directory.
Scans public/custom skill directories under the configured skills root, Scans both public and custom skill directories, parsing SKILL.md files
and also scans uploaded skills under /mnt/user-data/uploads. to extract metadata. The enabled state is determined by the skills_state_config.json file.
SKILL.md metadata is parsed and enabled state is derived from
skills_state_config.json.
Args: Args:
skills_path: Optional custom path to skills directory. skills_path: Optional custom path to skills directory.
@ -63,8 +35,6 @@ def load_skills(
Otherwise defaults to deer-flow/skills Otherwise defaults to deer-flow/skills
use_config: Whether to load skills path from config (default: True) use_config: Whether to load skills path from config (default: True)
enabled_only: If True, only return enabled skills (default: False) enabled_only: If True, only return enabled skills (default: False)
thread_id: Optional thread ID used to resolve per-thread uploads skills
from the LangGraph host process
Returns: Returns:
List of Skill objects, sorted by name List of Skill objects, sorted by name
@ -87,22 +57,12 @@ def load_skills(
skills = [] skills = []
# Scan built-in roots and uploaded skills mounted in personal workspace. # Scan public and custom directories
scan_targets: list[tuple[str, Path]] = [ for category in ["public", "custom"]:
("public", skills_path / "public"), category_path = skills_path / category
("custom", skills_path / "custom"),
("uploads", get_uploads_skills_path(thread_id)),
]
for category, category_path in scan_targets:
logger.debug("Scanning %s skills under %s", category, category_path)
if not category_path.exists() or not category_path.is_dir(): if not category_path.exists() or not category_path.is_dir():
logger.debug("Skip %s scan: directory not found or not a directory (%s)", category, category_path)
continue continue
scanned_skill_dirs: list[str] = []
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True): for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
# Keep traversal deterministic and skip hidden directories. # Keep traversal deterministic and skip hidden directories.
dir_names[:] = sorted(name for name in dir_names if not name.startswith(".")) dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
@ -111,22 +71,11 @@ def load_skills(
skill_file = Path(current_root) / "SKILL.md" skill_file = Path(current_root) / "SKILL.md"
relative_path = skill_file.parent.relative_to(category_path) relative_path = skill_file.parent.relative_to(category_path)
scanned_skill_dirs.append(relative_path.as_posix())
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path) skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
if skill: if skill:
skills.append(skill) skills.append(skill)
if scanned_skill_dirs:
logger.debug(
"%s scan found %d skill directories: %s",
category,
len(scanned_skill_dirs),
", ".join(sorted(scanned_skill_dirs)),
)
else:
logger.debug("%s scan found no skill directories", category)
# Load skills state configuration and update enabled status # Load skills state configuration and update enabled status
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config() # NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
# to always read the latest configuration from disk. This ensures that changes # to always read the latest configuration from disk. This ensures that changes
@ -137,10 +86,6 @@ def load_skills(
extensions_config = ExtensionsConfig.from_file() extensions_config = ExtensionsConfig.from_file()
for skill in skills: for skill in skills:
if skill.category == "uploads":
# Uploaded skills should be available by default for the current thread.
skill.enabled = True
continue
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category) skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
except Exception as e: except Exception as e:
# If config loading fails, default to all enabled # If config loading fails, default to all enabled

View File

@ -13,7 +13,7 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
Args: Args:
skill_file: Path to the SKILL.md file skill_file: Path to the SKILL.md file
category: Category of the skill ('public', 'custom', or 'uploads') category: Category of the skill ('public' or 'custom')
Returns: Returns:
Skill object if parsing succeeds, None otherwise Skill object if parsing succeeds, None otherwise

View File

@ -12,7 +12,7 @@ class Skill:
skill_dir: Path skill_dir: Path
skill_file: Path skill_file: Path
relative_path: Path # Relative path from category root to skill directory relative_path: Path # Relative path from category root to skill directory
category: str # 'public', 'custom', or 'uploads' category: str # 'public' or 'custom'
enabled: bool = False # Whether this skill is enabled enabled: bool = False # Whether this skill is enabled
@property @property
@ -31,10 +31,7 @@ class Skill:
Returns: Returns:
Full container path to the skill directory Full container path to the skill directory
""" """
if self.category == "uploads": category_base = f"{container_base_path}/{self.category}"
category_base = "/mnt/user-data/uploads"
else:
category_base = f"{container_base_path}/{self.category}"
skill_path = self.skill_path skill_path = self.skill_path
if skill_path: if skill_path:
return f"{category_base}/{skill_path}" return f"{category_base}/{skill_path}"

View File

@ -7,7 +7,7 @@ import json
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -1988,47 +1988,6 @@ class TestSlackSendRetry:
_run(go()) _run(go())
class TestSlackAllowedUsers:
def test_numeric_allowed_users_match_string_event_user_id(self):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = SlackChannel(
bus=bus,
config={"allowed_users": [123456]},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "123456",
"text": "hello from slack",
"channel": "C123",
"ts": "1710000000.000100",
}
def submit_coro(coro, loop):
coro.close()
return MagicMock()
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=submit_coro,
) as submit:
channel._handle_message_event(event)
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
submit.assert_called_once()
inbound = bus.publish_inbound.call_args.args[0]
assert inbound.user_id == "123456"
assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack"
def test_raises_after_all_retries_exhausted(self): def test_raises_after_all_retries_exhausted(self):
from app.channels.slack import SlackChannel from app.channels.slack import SlackChannel

View File

@ -47,45 +47,4 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
thread_id="thread-1", thread_id="thread-1",
agent_name="lead_agent", agent_name="lead_agent",
correction_detected=True, correction_detected=True,
reinforcement_detected=False,
)
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].reinforcement_detected is True
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
reinforcement_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=False,
reinforcement_detected=True,
) )

View File

@ -619,156 +619,3 @@ class TestUpdateMemoryStructuredResponse:
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt assert "Explicit correction signals were detected" not in prompt
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
# Same fact with different casing should be treated as duplicate
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
class TestReinforcementHint:
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
@staticmethod
def _make_mock_model(json_response: str):
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Yes, exactly! That's what I needed."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Great to hear!"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Tell me more."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No wait, that's wrong. Actually yes, exactly right."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Got it."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt

View File

@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@ -270,73 +270,3 @@ class TestStripUploadMentionsFromMemory:
mem = {"user": {}, "history": {}, "facts": []} mem = {"user": {}, "history": {}, "facts": []}
result = _strip_upload_mentions_from_memory(mem) result = _strip_upload_mentions_from_memory(mem)
assert result == {"user": {}, "history": {}, "facts": []} assert result == {"user": {}, "history": {}, "facts": []}
# ===========================================================================
# detect_reinforcement
# ===========================================================================
class TestDetectReinforcement:
def test_detects_english_reinforcement_signal(self):
msgs = [
_human("Can you summarise it in bullet points?"),
_ai("Here are the key points: ..."),
_human("Yes, exactly! That's what I needed."),
_ai("Glad it helped."),
]
assert detect_reinforcement(msgs) is True
def test_detects_perfect_signal(self):
msgs = [
_human("Write it more concisely."),
_ai("Here is the concise version."),
_human("Perfect."),
_ai("Great!"),
]
assert detect_reinforcement(msgs) is True
def test_detects_chinese_reinforcement_signal(self):
msgs = [
_human("帮我用要点来总结"),
_ai("好的,要点如下:..."),
_human("完全正确,就是这个意思"),
_ai("很高兴能帮到你"),
]
assert detect_reinforcement(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("What does this function do?"),
_ai("It processes the input data."),
_human("Can you show me an example?"),
]
assert detect_reinforcement(msgs) is False
def test_only_checks_recent_messages(self):
# Reinforcement signal buried beyond the -6 window should not trigger
msgs = [
_human("Yes, exactly right."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_reinforcement(msgs) is False
def test_does_not_conflict_with_correction(self):
# A message can trigger correction but not reinforcement
msgs = [
_human("That's wrong, try again."),
_ai("Corrected."),
]
assert detect_reinforcement(msgs) is False

View File

@ -42,53 +42,6 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data" assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
def test_replace_virtual_path_preserves_trailing_slash() -> None:
"""Trailing slash must survive virtual-to-actual path replacement.
Regression: '/mnt/user-data/workspace/' was previously returned without
the trailing slash, causing string concatenations like
output_dir + 'file.txt' to produce a missing-separator path.
"""
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
If actual_base uses backslash separators, appending '/' would produce a
mixed-separator path. The separator must match the style of actual_base.
"""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
"""Nested Windows-style subdirectories must keep backslashes throughout."""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
"""Trailing slash on a virtual path inside a command must be preserved."""
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
# ---------- mask_local_paths_in_output ---------- # ---------- mask_local_paths_in_output ----------
@ -304,22 +257,6 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA) validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
"""HTTP URLs must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None: def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
validate_local_bash_command_paths( validate_local_bash_command_paths(
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null", "/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",

View File

@ -1,7 +1,6 @@
"use client"; "use client";
import { useSearchParams } from "next/navigation"; import { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
import { ArtifactTrigger } from "@/components/workspace/artifacts"; import { ArtifactTrigger } from "@/components/workspace/artifacts";
@ -25,35 +24,15 @@ import { Welcome } from "@/components/workspace/welcome";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useNotification } from "@/core/notification/hooks"; import { useNotification } from "@/core/notification/hooks";
import { useThreadSettings } from "@/core/settings"; import { useThreadSettings } from "@/core/settings";
import { bootstrapRemoteSkill } from "@/core/skills";
import { useThreadStream } from "@/core/threads/hooks"; import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils"; import { textOfMessage } from "@/core/threads/utils";
import { uuid } from "@/core/utils/uuid";
import { env } from "@/env"; import { env } from "@/env";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
const UUID_REGEX =
/^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i;
export default function ChatPage() { export default function ChatPage() {
const { t } = useI18n(); const { t } = useI18n();
const [showFollowups, setShowFollowups] = useState(false); const [showFollowups, setShowFollowups] = useState(false);
const searchParams = useSearchParams(); const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat();
const generatedThreadIdRef = useRef<string>("");
if (!generatedThreadIdRef.current) {
const queryThreadId = searchParams.get("thread_id")?.trim();
generatedThreadIdRef.current =
queryThreadId && UUID_REGEX.test(queryThreadId) ? queryThreadId : uuid();
}
// 检查 xclaw_used 参数,仅用于界面风格控制,不影响线程创建逻辑
const xclawUsedParam = searchParams.get("xclaw_used");
const initialForceNewStyle = xclawUsedParam === "false";
const [forceNewStyle, setForceNewStyle] = useState(initialForceNewStyle);
const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat({
newThreadId: generatedThreadIdRef.current,
});
const [settings, setSettings] = useThreadSettings(threadId); const [settings, setSettings] = useThreadSettings(threadId);
const [mounted, setMounted] = useState(false); const [mounted, setMounted] = useState(false);
useSpecificChatMode(); useSpecificChatMode();
@ -63,34 +42,6 @@ export default function ChatPage() {
}, []); }, []);
const { showNotification } = useNotification(); const { showNotification } = useNotification();
const skillBootstrappedKeysRef = useRef<Set<string>>(new Set());
const skillBootstrappingKeysRef = useRef<Set<string>>(new Set());
const skillBootstrap = useMemo(() => {
const skillIdRaw = searchParams.get("skill_id")?.trim();
if (!skillIdRaw) return undefined;
const contentIds = skillIdRaw
.split(",")
.map((value) => value.trim())
.filter((value) => value.length > 0)
.map((value) => Number(value))
.filter((value) => Number.isFinite(value));
// Deduplicate while preserving incoming order.
const uniqueContentIds = Array.from(new Set(contentIds));
if (uniqueContentIds.length === 0) return undefined;
const languageTypeRaw =
searchParams.get("languageType")?.trim() ??
searchParams.get("language_type")?.trim();
const languageType = languageTypeRaw ? Number(languageTypeRaw) : 0;
return {
contentIds: uniqueContentIds,
languageType: Number.isFinite(languageType) ? languageType : 0,
};
}, [searchParams]);
const [thread, sendMessage, isUploading] = useThreadStream({ const [thread, sendMessage, isUploading] = useThreadStream({
threadId: isNewThread ? undefined : threadId, threadId: isNewThread ? undefined : threadId,
@ -119,54 +70,11 @@ export default function ChatPage() {
}, },
}); });
useEffect(() => {
if (!threadId || !skillBootstrap?.contentIds?.length) {
return;
}
const languageType = skillBootstrap.languageType ?? 0;
const initKey = `${threadId}:${skillBootstrap.contentIds.join(",")}:${languageType}`;
if (
skillBootstrappedKeysRef.current.has(initKey) ||
skillBootstrappingKeysRef.current.has(initKey)
) {
return;
}
skillBootstrappingKeysRef.current.add(initKey);
const runBootstrap = async () => {
try {
await bootstrapRemoteSkill({
thread_id: threadId,
content_ids: skillBootstrap.contentIds,
language_type: languageType,
target_dir: "/mnt/user-data/uploads/skill",
clear_target: true,
});
skillBootstrappedKeysRef.current.add(initKey);
} catch (error) {
const message =
error instanceof Error ? error.message : "Skill initialization failed";
showNotification("Skill initialization failed", { body: message });
} finally {
skillBootstrappingKeysRef.current.delete(initKey);
}
};
void runBootstrap();
}, [threadId, skillBootstrap, showNotification]);
const handleSubmit = useCallback( const handleSubmit = useCallback(
(message: PromptInputMessage) => { (message: PromptInputMessage) => {
void sendMessage(threadId, message); void sendMessage(threadId, message);
// 仅切换界面风格,不影响线程状态
if (forceNewStyle) {
setForceNewStyle(false);
}
}, },
[sendMessage, threadId, forceNewStyle], [sendMessage, threadId],
); );
const handleStop = useCallback(async () => { const handleStop = useCallback(async () => {
await thread.stop(); await thread.stop();
@ -184,7 +92,7 @@ export default function ChatPage() {
<header <header
className={cn( className={cn(
"absolute top-0 right-0 left-0 z-30 flex h-12 shrink-0 items-center px-4", "absolute top-0 right-0 left-0 z-30 flex h-12 shrink-0 items-center px-4",
(forceNewStyle || isNewThread) isNewThread
? "bg-background/0 backdrop-blur-none" ? "bg-background/0 backdrop-blur-none"
: "bg-background/80 shadow-xs backdrop-blur", : "bg-background/80 shadow-xs backdrop-blur",
)} )}
@ -200,22 +108,19 @@ export default function ChatPage() {
</header> </header>
<main className="flex min-h-0 max-w-full grow flex-col"> <main className="flex min-h-0 max-w-full grow flex-col">
<div className="flex size-full justify-center"> <div className="flex size-full justify-center">
{/* forceNewStyle 时隐藏消息列表,提交后再显示 */} <MessageList
{!(forceNewStyle) && ( className={cn("size-full", !isNewThread && "pt-10")}
<MessageList threadId={threadId}
className={cn("size-full", !isNewThread && "pt-10")} thread={thread}
threadId={threadId} paddingBottom={messageListPaddingBottom}
thread={thread} />
paddingBottom={messageListPaddingBottom}
/>
)}
</div> </div>
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4"> <div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
<div <div
className={cn( className={cn(
"relative w-full", "relative w-full",
(forceNewStyle || isNewThread) && "-translate-y-[calc(50vh-96px)]", isNewThread && "-translate-y-[calc(50vh-96px)]",
(forceNewStyle || isNewThread) isNewThread
? "max-w-(--container-width-sm)" ? "max-w-(--container-width-sm)"
: "max-w-(--container-width-md)", : "max-w-(--container-width-md)",
)} )}
@ -234,9 +139,9 @@ export default function ChatPage() {
{mounted ? ( {mounted ? (
<InputBox <InputBox
className={cn("bg-background/5 w-full -translate-y-4")} className={cn("bg-background/5 w-full -translate-y-4")}
isNewThread={forceNewStyle || isNewThread} isNewThread={isNewThread}
threadId={threadId} threadId={threadId}
autoFocus={forceNewStyle || isNewThread} autoFocus={isNewThread}
status={ status={
thread.error thread.error
? "error" ? "error"
@ -246,7 +151,7 @@ export default function ChatPage() {
} }
context={settings.context} context={settings.context}
extraHeader={ extraHeader={
(forceNewStyle || isNewThread) && <Welcome mode={settings.context.mode} /> isNewThread && <Welcome mode={settings.context.mode} />
} }
disabled={ disabled={
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" || env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||

View File

@ -1,23 +1,17 @@
"use client"; "use client";
import { useParams, usePathname, useSearchParams } from "next/navigation"; import { useParams, usePathname, useSearchParams } from "next/navigation";
import { useEffect, useRef, useState } from "react"; import { useEffect, useState } from "react";
import { uuid } from "@/core/utils/uuid"; import { uuid } from "@/core/utils/uuid";
type UseThreadChatOptions = { export function useThreadChat() {
newThreadId?: string;
};
export function useThreadChat(options?: UseThreadChatOptions) {
const { thread_id: threadIdFromPath } = useParams<{ thread_id: string }>(); const { thread_id: threadIdFromPath } = useParams<{ thread_id: string }>();
const pathname = usePathname(); const pathname = usePathname();
const fallbackNewThreadIdRef = useRef<string>(options?.newThreadId ?? uuid());
const fallbackNewThreadId = options?.newThreadId ?? fallbackNewThreadIdRef.current;
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const [threadId, setThreadId] = useState(() => { const [threadId, setThreadId] = useState(() => {
return threadIdFromPath === "new" ? fallbackNewThreadId : threadIdFromPath; return threadIdFromPath === "new" ? uuid() : threadIdFromPath;
}); });
const [isNewThread, setIsNewThread] = useState( const [isNewThread, setIsNewThread] = useState(
@ -27,9 +21,9 @@ export function useThreadChat(options?: UseThreadChatOptions) {
useEffect(() => { useEffect(() => {
if (pathname.endsWith("/new")) { if (pathname.endsWith("/new")) {
setIsNewThread(true); setIsNewThread(true);
setThreadId(fallbackNewThreadId); setThreadId(uuid());
} }
}, [pathname, fallbackNewThreadId]); }, [pathname]);
const isMock = searchParams.get("mock") === "true"; const isMock = searchParams.get("mock") === "true";
return { threadId, isNewThread, setIsNewThread, isMock }; return { threadId, isNewThread, setIsNewThread, isMock };
} }

View File

@ -32,10 +32,14 @@ import { SettingsDialog } from "./settings";
export function CommandPalette() { export function CommandPalette() {
const { t } = useI18n(); const { t } = useI18n();
const router = useRouter(); const router = useRouter();
const [mounted, setMounted] = useState(false);
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
const [shortcutsOpen, setShortcutsOpen] = useState(false); const [shortcutsOpen, setShortcutsOpen] = useState(false);
const [settingsOpen, setSettingsOpen] = useState(false); const [settingsOpen, setSettingsOpen] = useState(false);
const [isMac, setIsMac] = useState(false);
useEffect(() => {
setMounted(true);
}, []);
const handleNewChat = useCallback(() => { const handleNewChat = useCallback(() => {
router.push("/workspace/chats/new"); router.push("/workspace/chats/new");
@ -64,12 +68,14 @@ export function CommandPalette() {
useGlobalShortcuts(shortcuts); useGlobalShortcuts(shortcuts);
useEffect(() => { const isMac = mounted && navigator.userAgent.includes("Mac");
setIsMac(navigator.userAgent.includes("Mac"));
}, []);
const metaKey = isMac ? "⌘" : "Ctrl+"; const metaKey = isMac ? "⌘" : "Ctrl+";
const shiftKey = isMac ? "⇧" : "Shift+"; const shiftKey = isMac ? "⇧" : "Shift+";
if (!mounted) {
return null;
}
return ( return (
<> <>
<SettingsDialog open={settingsOpen} onOpenChange={setSettingsOpen} /> <SettingsDialog open={settingsOpen} onOpenChange={setSettingsOpen} />

View File

@ -35,24 +35,6 @@ export interface InstallSkillResponse {
message: string; message: string;
} }
export interface BootstrapRemoteSkillRequest {
thread_id: string;
content_ids: number[];
language_type?: number;
target_dir?: string;
clear_target?: boolean;
}
export interface BootstrapRemoteSkillResponse {
success: boolean;
target_dir: string;
content_ids: number[];
created_directories: number;
created_files: number;
sandbox_id: string | null;
message: string;
}
export async function installSkill( export async function installSkill(
request: InstallSkillRequest, request: InstallSkillRequest,
): Promise<InstallSkillResponse> { ): Promise<InstallSkillResponse> {
@ -78,27 +60,3 @@ export async function installSkill(
return response.json(); return response.json();
} }
export async function bootstrapRemoteSkill(
request: BootstrapRemoteSkillRequest,
): Promise<BootstrapRemoteSkillResponse> {
const response = await fetch(
`${getBackendBaseURL()}/api/skills/bootstrap-remote`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(request),
},
);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
const errorMessage =
errorData.detail ?? `HTTP ${response.status}: ${response.statusText}`;
throw new Error(errorMessage);
}
return response.json();
}