Compare commits
No commits in common. "6b900ccb6067d2b791677faba6c946203b091ec3" and "0ffe5a73c1440f2d61d04bc3a16529942d62300e" have entirely different histories.
6b900ccb60
...
0ffe5a73c1
|
|
@ -30,7 +30,7 @@ class SlackChannel(Channel):
|
|||
self._socket_client = None
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||
self._allowed_users: set[str] = set(config.get("allowed_users", []))
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
|
|
|
|||
|
|
@ -23,11 +23,9 @@ from app.gateway.routers import (
|
|||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
# Configure logging with env override
|
||||
import os
|
||||
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level, logging.INFO),
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,6 @@ class GatewayConfig(BaseModel):
|
|||
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
||||
port: int = Field(default=8001, description="Port to bind the gateway server")
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
||||
skill_content_api_url: str = Field(
|
||||
default="https://skills.xueai.art/api/cmsContent/getContent",
|
||||
description="Remote API URL used to fetch skill YAML content by content ID",
|
||||
)
|
||||
|
||||
|
||||
_gateway_config: GatewayConfig | None = None
|
||||
|
|
@ -27,9 +23,5 @@ def get_gateway_config() -> GatewayConfig:
|
|||
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
||||
cors_origins=cors_origins_str.split(","),
|
||||
skill_content_api_url=os.getenv(
|
||||
"SKILL_CONTENT_API_URL",
|
||||
"https://skills.xueai.art/api/cmsContent/getContent",
|
||||
),
|
||||
)
|
||||
return _gateway_config
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from app.gateway.skill_yaml_importer import materialize_skill_tree, parse_skill_yaml_spec
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.skills import Skill, load_skills
|
||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||
|
|
@ -56,38 +52,6 @@ class SkillInstallResponse(BaseModel):
|
|||
message: str = Field(..., description="Installation result message")
|
||||
|
||||
|
||||
class RemoteSkillBootstrapRequest(BaseModel):
|
||||
"""Request model for bootstrapping skill files from remote content API."""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID used for user-data path binding")
|
||||
content_ids: list[int] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Remote content ID sequence (maps from frontend query param skill_id)",
|
||||
)
|
||||
language_type: int = Field(default=0, description="Language type for remote API request body")
|
||||
target_dir: str = Field(
|
||||
default="/mnt/user-data/uploads/skill",
|
||||
description="Virtual base directory where each skill-{id} subdirectory is created",
|
||||
)
|
||||
clear_target: bool = Field(
|
||||
default=True,
|
||||
description="Whether to clear target directory before writing parsed files",
|
||||
)
|
||||
|
||||
|
||||
class RemoteSkillBootstrapResponse(BaseModel):
|
||||
"""Response model for remote bootstrap endpoint."""
|
||||
|
||||
success: bool = Field(..., description="Whether bootstrap succeeded")
|
||||
target_dir: str = Field(..., description="Virtual base target directory")
|
||||
content_ids: list[int] = Field(..., description="Bootstrapped content IDs")
|
||||
created_directories: int = Field(..., description="Number of created directories")
|
||||
created_files: int = Field(..., description="Number of created files")
|
||||
sandbox_id: str | None = Field(default=None, description="Acquired sandbox ID (null when sandbox is not acquired)")
|
||||
message: str = Field(..., description="Operation result message")
|
||||
|
||||
|
||||
def _skill_to_response(skill: Skill) -> SkillResponse:
|
||||
"""Convert a Skill object to a SkillResponse."""
|
||||
return SkillResponse(
|
||||
|
|
@ -207,107 +171,3 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to install skill: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/skills/bootstrap-remote",
|
||||
response_model=RemoteSkillBootstrapResponse,
|
||||
summary="Bootstrap Skill Files From Remote API",
|
||||
description=(
|
||||
"Fetch YAML text from configured remote API by content_ids/language_type and "
|
||||
"materialize files into skill-{id} directories under /mnt/user-data/uploads/skill "
|
||||
"before first thread submit."
|
||||
),
|
||||
)
|
||||
async def bootstrap_skill_from_remote(request: RemoteSkillBootstrapRequest) -> RemoteSkillBootstrapResponse:
|
||||
"""Initialize thread skill directory from remote YAML content service."""
|
||||
try:
|
||||
cfg = get_gateway_config()
|
||||
api_url = cfg.skill_content_api_url
|
||||
created_directories_total = 0
|
||||
created_files_total = 0
|
||||
|
||||
base_target_path = resolve_thread_virtual_path(request.thread_id, request.target_dir)
|
||||
if request.clear_target and base_target_path.exists():
|
||||
if base_target_path.is_dir():
|
||||
shutil.rmtree(base_target_path)
|
||||
else:
|
||||
base_target_path.unlink()
|
||||
base_target_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for content_id in request.content_ids:
|
||||
payload = {
|
||||
"contentId": content_id,
|
||||
"languageType": request.language_type,
|
||||
}
|
||||
|
||||
response = await client.post(api_url, json=payload)
|
||||
if response.status_code >= 400:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Remote skill content API failed with HTTP "
|
||||
f"{response.status_code} for content_id={content_id}"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Remote API did not return valid JSON: {e}") from e
|
||||
|
||||
status = response_json.get("status")
|
||||
if status != 1000:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=(
|
||||
"Remote API returned non-success status: "
|
||||
f"{status}, message: {response_json.get('message')}, content_id={content_id}"
|
||||
),
|
||||
)
|
||||
|
||||
yaml_text = response_json.get("data")
|
||||
if not isinstance(yaml_text, str) or not yaml_text.strip():
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Remote API returned empty or invalid YAML content for content_id={content_id}",
|
||||
)
|
||||
|
||||
target_dir = f"{request.target_dir.rstrip('/')}/skill-{content_id}"
|
||||
target_path = resolve_thread_virtual_path(request.thread_id, target_dir)
|
||||
parsed = parse_skill_yaml_spec(yaml_text)
|
||||
materialize_skill_tree(parsed, target_path, clear_target=False)
|
||||
|
||||
created_directories_total += len(parsed.directories)
|
||||
created_files_total += len(parsed.files)
|
||||
|
||||
logger.info(
|
||||
"Bootstrapped remote skill YAML for thread %s (content_id=%s, language_type=%s) to %s: dirs=%d files=%d",
|
||||
request.thread_id,
|
||||
content_id,
|
||||
request.language_type,
|
||||
target_dir,
|
||||
len(parsed.directories),
|
||||
len(parsed.files),
|
||||
)
|
||||
|
||||
return RemoteSkillBootstrapResponse(
|
||||
success=True,
|
||||
target_dir=request.target_dir,
|
||||
content_ids=request.content_ids,
|
||||
created_directories=created_directories_total,
|
||||
created_files=created_files_total,
|
||||
sandbox_id=None,
|
||||
message=(
|
||||
f"Bootstrapped {created_files_total} files and {created_directories_total} directories "
|
||||
f"for {len(request.content_ids)} skills under '{request.target_dir}'"
|
||||
),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bootstrap skill from remote API: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to bootstrap skill from remote API: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -1,475 +0,0 @@
|
|||
"""Utilities for parsing YAML-defined skill package structures.
|
||||
|
||||
This module supports turning a YAML document describing files/directories into
|
||||
real filesystem content under a thread's virtual path (for example,
|
||||
``/mnt/user-data/uploads/skill``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import yaml # type: ignore[import-not-found]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedSkillTree:
|
||||
"""Normalized parsed structure from YAML spec."""
|
||||
|
||||
directories: set[str]
|
||||
files: dict[str, str]
|
||||
|
||||
|
||||
def _pick_first_existing(data: dict, keys: tuple[str, ...]):
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
|
||||
|
||||
def _extract_spec_root(data: dict) -> dict:
|
||||
"""Extract the effective spec root.
|
||||
|
||||
Supports nested wrappers like:
|
||||
- skill: { ... }
|
||||
- package: { ... }
|
||||
- spec: { ... }
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("YAML root must be an object")
|
||||
|
||||
known_keys = {
|
||||
"entries",
|
||||
"files",
|
||||
"directories",
|
||||
"dirs",
|
||||
"tree",
|
||||
"structure",
|
||||
"file_tree",
|
||||
"fileTree",
|
||||
"file_structure",
|
||||
"paths",
|
||||
}
|
||||
if any(k in data for k in known_keys):
|
||||
return data
|
||||
|
||||
wrapper_candidates = ("skill", "package", "spec", "data", "content", "payload")
|
||||
for wrapper in wrapper_candidates:
|
||||
candidate = data.get(wrapper)
|
||||
if isinstance(candidate, dict) and any(k in candidate for k in known_keys):
|
||||
return candidate
|
||||
|
||||
# Fallback: if exactly one nested object exists, try it as spec root.
|
||||
nested_dicts = [v for v in data.values() if isinstance(v, dict)]
|
||||
if len(nested_dicts) == 1:
|
||||
return nested_dicts[0]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _normalize_relative_path(path: str) -> str:
|
||||
"""Normalize and validate a relative path.
|
||||
|
||||
Raises:
|
||||
ValueError: If path is unsafe or invalid.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
raise ValueError("Path must be a string")
|
||||
|
||||
normalized = path.strip().replace("\\", "/")
|
||||
if normalized in {"/", ".", "./"}:
|
||||
return ""
|
||||
if not normalized:
|
||||
raise ValueError("Path cannot be empty")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError(f"Path must be relative, got absolute path: {path}")
|
||||
|
||||
if ":" in normalized:
|
||||
raise ValueError(f"Path cannot contain ':' (possible drive path): {path}")
|
||||
|
||||
parts = [part for part in normalized.split("/") if part]
|
||||
if not parts:
|
||||
raise ValueError("Path cannot be empty")
|
||||
|
||||
if any(part in {".", ".."} for part in parts):
|
||||
raise ValueError(f"Path traversal is not allowed: {path}")
|
||||
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def _add_directory(path: str, directories: set[str]) -> None:
|
||||
normalized = _normalize_relative_path(path)
|
||||
if not normalized:
|
||||
return
|
||||
directories.add(normalized)
|
||||
|
||||
|
||||
def _add_file(path: str, content: str, files: dict[str, str], directories: set[str]) -> None:
|
||||
normalized = _normalize_relative_path(path)
|
||||
if not normalized:
|
||||
raise ValueError("File path cannot be root ('/')")
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(f"File content must be a string for '{normalized}'")
|
||||
|
||||
parent = Path(normalized).parent
|
||||
if str(parent) != ".":
|
||||
directories.add(str(parent).replace("\\", "/"))
|
||||
|
||||
files[normalized] = content
|
||||
|
||||
|
||||
def _walk_tree_dict(tree: dict, base: str, files: dict[str, str], directories: set[str]) -> None:
|
||||
for name, value in tree.items():
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("Tree keys must be strings")
|
||||
|
||||
if name.strip() in {"/", ".", "./"}:
|
||||
if isinstance(value, dict):
|
||||
_walk_tree_dict(value, base, files, directories)
|
||||
continue
|
||||
raise ValueError("Root sentinel '/' can only be used for directory/object nodes")
|
||||
|
||||
node_path = f"{base}/{name}" if base else name
|
||||
|
||||
if isinstance(value, dict):
|
||||
_add_directory(node_path, directories)
|
||||
_walk_tree_dict(value, _normalize_relative_path(node_path), files, directories)
|
||||
elif isinstance(value, str):
|
||||
_add_file(node_path, value, files, directories)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tree node type for '{node_path}': {type(value).__name__}. "
|
||||
"Use object (directory) or string (file content)."
|
||||
)
|
||||
|
||||
|
||||
def _parse_entries_node(
|
||||
node: dict,
|
||||
base: str,
|
||||
files: dict[str, str],
|
||||
directories: set[str],
|
||||
) -> None:
|
||||
raw_path = node.get("path")
|
||||
raw_name = node.get("name")
|
||||
|
||||
if raw_path is None and raw_name is None:
|
||||
raise ValueError("Each entry must have at least one of: 'path' or 'name'")
|
||||
|
||||
if raw_path is not None and not isinstance(raw_path, str):
|
||||
raise ValueError("Entry 'path' must be a string")
|
||||
if raw_name is not None and not isinstance(raw_name, str):
|
||||
raise ValueError("Entry 'name' must be a string")
|
||||
|
||||
# Common schema compatibility:
|
||||
# - `path` is parent directory (e.g. "/")
|
||||
# - `name` is current node name (e.g. "README.md")
|
||||
# Build parent then append name when both are present.
|
||||
parent = base
|
||||
if isinstance(raw_path, str) and raw_path.strip():
|
||||
rp = raw_path.strip()
|
||||
if rp not in {"/", ".", "./"}:
|
||||
parent = _normalize_relative_path(f"{base}/{rp}" if base else rp)
|
||||
|
||||
if isinstance(raw_name, str) and raw_name.strip():
|
||||
if parent:
|
||||
node_path = _normalize_relative_path(f"{parent}/{raw_name.strip()}")
|
||||
else:
|
||||
node_path = _normalize_relative_path(raw_name.strip())
|
||||
else:
|
||||
# Fallback: only path provided
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
raise ValueError("Each entry must have a non-empty 'path' or 'name'")
|
||||
rp = raw_path.strip()
|
||||
if rp in {"/", ".", "./"}:
|
||||
node_path = base
|
||||
else:
|
||||
node_path = _normalize_relative_path(f"{base}/{rp}" if base else rp)
|
||||
|
||||
node_type = node.get("type")
|
||||
content = node.get("content")
|
||||
children = node.get("children")
|
||||
|
||||
inferred_type = "directory" if isinstance(children, list) else "file" if content is not None else None
|
||||
final_type = node_type or inferred_type
|
||||
|
||||
if final_type == "directory":
|
||||
_add_directory(node_path, directories)
|
||||
if children is None:
|
||||
return
|
||||
if not isinstance(children, list):
|
||||
raise ValueError(f"Entry '{node_path}' children must be a list")
|
||||
for child in children:
|
||||
if not isinstance(child, dict):
|
||||
raise ValueError(f"Entry '{node_path}' children must be objects")
|
||||
_parse_entries_node(child, node_path, files, directories)
|
||||
return
|
||||
|
||||
if final_type == "file":
|
||||
if content is None:
|
||||
raise ValueError(f"File entry '{node_path}' is missing 'content'")
|
||||
_add_file(node_path, content, files, directories)
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to infer entry type for '{node_path}'. Set 'type' to 'file' or 'directory'."
|
||||
)
|
||||
|
||||
|
||||
def parse_skill_yaml_spec(yaml_text: str) -> ParsedSkillTree:
|
||||
"""Parse YAML text into normalized directories and files.
|
||||
|
||||
Supported forms:
|
||||
- entries: [{type,path/content/children}, ...]
|
||||
- files: {"path/to/file": "text"} + optional directories/dirs
|
||||
- tree/structure: nested dict where dict=directory and string=file content
|
||||
"""
|
||||
try:
|
||||
data = yaml.safe_load(yaml_text)
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML: {e}") from e
|
||||
|
||||
if data is None:
|
||||
raise ValueError("YAML is empty")
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("YAML root must be an object")
|
||||
|
||||
data = _extract_spec_root(data)
|
||||
|
||||
directories: set[str] = set()
|
||||
files: dict[str, str] = {}
|
||||
|
||||
# Form 1: explicit entries list
|
||||
entries = _pick_first_existing(data, ("entries", "nodes", "items"))
|
||||
if entries is not None:
|
||||
if not isinstance(entries, list):
|
||||
raise ValueError("'entries' must be a list")
|
||||
for entry in entries:
|
||||
if not isinstance(entry, dict):
|
||||
raise ValueError("Each item in 'entries' must be an object")
|
||||
_parse_entries_node(entry, "", files, directories)
|
||||
|
||||
# Form 2: files + directories
|
||||
file_map = _pick_first_existing(data, ("files", "paths", "file_map", "fileMap", "documents"))
|
||||
if file_map is not None:
|
||||
if isinstance(file_map, dict):
|
||||
for path, content in file_map.items():
|
||||
_add_file(path, content, files, directories)
|
||||
elif isinstance(file_map, list):
|
||||
for item in file_map:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each item in 'files' list must be an object")
|
||||
path = item.get("path") or item.get("name") or item.get("file")
|
||||
content = item.get("content")
|
||||
if content is None:
|
||||
content = item.get("text")
|
||||
if content is None:
|
||||
content = item.get("body")
|
||||
if path is None or content is None:
|
||||
raise ValueError("Each file item needs 'path' and 'content'")
|
||||
_add_file(path, content, files, directories)
|
||||
else:
|
||||
raise ValueError("'files' must be a map or list")
|
||||
|
||||
directory_list = _pick_first_existing(data, ("directories", "dirs", "folders", "folder_paths"))
|
||||
if directory_list is not None:
|
||||
if not isinstance(directory_list, list):
|
||||
raise ValueError("'directories'/'dirs' must be a list")
|
||||
for path in directory_list:
|
||||
_add_directory(path, directories)
|
||||
|
||||
# Form 3: nested tree
|
||||
tree = _pick_first_existing(data, ("tree", "structure", "file_tree", "fileTree", "file_structure"))
|
||||
if tree is not None:
|
||||
if isinstance(tree, dict):
|
||||
_walk_tree_dict(tree, "", files, directories)
|
||||
elif isinstance(tree, list):
|
||||
for item in tree:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Items in 'tree' list must be objects")
|
||||
_parse_entries_node(item, "", files, directories)
|
||||
else:
|
||||
raise ValueError("'tree'/'structure' must be an object or list")
|
||||
|
||||
# Heuristic fallback: treat root as path->content map when possible.
|
||||
if not files and not directories:
|
||||
candidate_keys = [k for k in data.keys() if isinstance(k, str)]
|
||||
if candidate_keys and all(isinstance(data[k], str) for k in candidate_keys):
|
||||
for path, content in data.items():
|
||||
_add_file(path, content, files, directories)
|
||||
|
||||
if not files and not directories:
|
||||
raise ValueError(
|
||||
"No content found. Provide at least one of: entries, files, directories/dirs, tree/structure"
|
||||
)
|
||||
|
||||
# Ensure parent directories exist for every file
|
||||
for rel_file in files:
|
||||
parent = Path(rel_file).parent
|
||||
if str(parent) != ".":
|
||||
directories.add(str(parent).replace("\\", "/"))
|
||||
|
||||
return ParsedSkillTree(directories=directories, files=files)
|
||||
|
||||
|
||||
def materialize_skill_tree(parsed: ParsedSkillTree, target_root: Path, clear_target: bool = True) -> None:
|
||||
"""Create parsed directories/files under target root."""
|
||||
if clear_target and target_root.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(target_root)
|
||||
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for rel_dir in sorted(parsed.directories, key=lambda p: (p.count("/"), p)):
|
||||
(target_root / rel_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for rel_file, content in parsed.files.items():
|
||||
file_path = target_root / rel_file
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def _build_cli_parser() -> argparse.ArgumentParser:
|
||||
"""Build command-line argument parser.
|
||||
|
||||
CLI usage:
|
||||
python skill_yaml_importer.py <input_path> [options]
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Parse and validate a skill YAML spec file")
|
||||
parser.add_argument("input_path", help="Path to a YAML file or a directory containing YAML files")
|
||||
parser.add_argument(
|
||||
"--show-files",
|
||||
action="store_true",
|
||||
help="Print sorted parsed file paths",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-directories",
|
||||
action="store_true",
|
||||
help="Print sorted parsed directory paths",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print parsed summary as JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
help="When input path is a directory, scan YAML files recursively",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
default=None,
|
||||
help="Optional path to save full execution results and summary as JSON",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _collect_yaml_files(input_path: Path, recursive: bool) -> list[Path]:
|
||||
if input_path.is_file():
|
||||
return [input_path]
|
||||
|
||||
if not input_path.is_dir():
|
||||
return []
|
||||
|
||||
patterns = ("*.yaml", "*.yml")
|
||||
files: list[Path] = []
|
||||
for pattern in patterns:
|
||||
iterator = input_path.rglob(pattern) if recursive else input_path.glob(pattern)
|
||||
files.extend(iterator)
|
||||
|
||||
# Stable order for deterministic output
|
||||
return sorted({p.resolve() for p in files})
|
||||
|
||||
|
||||
def _parse_one_yaml_file(yaml_path: Path, show_files: bool, show_directories: bool) -> dict:
|
||||
yaml_text = yaml_path.read_text(encoding="utf-8")
|
||||
parsed = parse_skill_yaml_spec(yaml_text)
|
||||
directories = sorted(parsed.directories)
|
||||
files = sorted(parsed.files.keys())
|
||||
|
||||
return {
|
||||
"yaml_file": str(yaml_path),
|
||||
"directories_count": len(directories),
|
||||
"files_count": len(files),
|
||||
"directories": directories if show_directories else None,
|
||||
"files": files if show_files else None,
|
||||
}
|
||||
|
||||
|
||||
def _main() -> int:
|
||||
"""CLI entrypoint for parsing one YAML file or a batch of YAML files.
|
||||
|
||||
Exit codes:
|
||||
0: all files parsed successfully
|
||||
1: invalid input path or no YAML files found
|
||||
2: processed completed with one or more parse failures
|
||||
"""
|
||||
args = _build_cli_parser().parse_args()
|
||||
|
||||
input_path = Path(args.input_path)
|
||||
if not input_path.exists():
|
||||
print(f"Input path not found: {input_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
yaml_files = _collect_yaml_files(input_path, recursive=args.recursive)
|
||||
if not yaml_files:
|
||||
print(f"No YAML files found under: {input_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
successes: list[dict] = []
|
||||
failures: list[dict[str, str]] = []
|
||||
|
||||
for yaml_path in yaml_files:
|
||||
try:
|
||||
result = _parse_one_yaml_file(
|
||||
yaml_path,
|
||||
show_files=args.show_files,
|
||||
show_directories=args.show_directories,
|
||||
)
|
||||
successes.append(result)
|
||||
if not args.json:
|
||||
print(f"OK: {yaml_path}")
|
||||
print(f" Directories: {result['directories_count']}")
|
||||
print(f" Files: {result['files_count']}")
|
||||
except Exception as e: # noqa: BLE001
|
||||
failures.append({"yaml_file": str(yaml_path), "error": str(e)})
|
||||
print(f"ERROR: {yaml_path}: {e}", file=sys.stderr)
|
||||
|
||||
summary = {
|
||||
"input_path": str(input_path),
|
||||
"total": len(yaml_files),
|
||||
"success": len(successes),
|
||||
"failed": len(failures),
|
||||
}
|
||||
|
||||
report = {"summary": summary, "successes": successes, "failures": failures}
|
||||
|
||||
if args.log_file:
|
||||
try:
|
||||
log_path = Path(args.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"Log saved: {log_path}")
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Failed to write log file '{args.log_file}': {e}", file=sys.stderr)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(report, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print("\n[Summary]")
|
||||
print(f"Input: {summary['input_path']}")
|
||||
print(f"Total: {summary['total']}")
|
||||
print(f"Success: {summary['success']}")
|
||||
print(f"Failed: {summary['failed']}")
|
||||
|
||||
return 0 if not failures else 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(_main())
|
||||
|
|
@ -394,17 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
|||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
thread_id = None
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Keep regular enable/disable behavior while letting uploads be default-enabled in loader.
|
||||
skills = load_skills(enabled_only=True, thread_id=thread_id)
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
|
@ -571,6 +561,4 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
|||
acp_section=acp_and_mounts_section,
|
||||
)
|
||||
|
||||
logger.debug("Generated full system prompt:\n%s", prompt)
|
||||
|
||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class ConversationContext:
|
|||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
|
|
@ -45,7 +44,6 @@ class MemoryUpdateQueue:
|
|||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
|
|
@ -54,7 +52,6 @@ class MemoryUpdateQueue:
|
|||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
|
|
@ -66,13 +63,11 @@ class MemoryUpdateQueue:
|
|||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
|
|
@ -135,7 +130,6 @@ class MemoryUpdateQueue:
|
|||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
|
|||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
return stripped.casefold()
|
||||
return stripped
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
|
|
@ -272,7 +272,6 @@ class MemoryUpdater:
|
|||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
|
|
@ -281,7 +280,6 @@ class MemoryUpdater:
|
|||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
|
|
@ -312,14 +310,6 @@ class MemoryUpdater:
|
|||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
|
|
@ -451,7 +441,6 @@ def update_memory_from_conversation(
|
|||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
|
|
@ -460,10 +449,9 @@ def update_memory_from_conversation(
|
|||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
|
|
|
|||
|
|
@ -29,22 +29,6 @@ _CORRECTION_PATTERNS = (
|
|||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
|
@ -148,29 +132,6 @@ def detect_correction(messages: list[Any]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
|
|
@ -235,14 +196,12 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -192,8 +192,8 @@ class ExtensionsConfig(BaseModel):
|
|||
"""
|
||||
skill_config = self.skills.get(skill_name)
|
||||
if skill_config is None:
|
||||
# Default to enable for public/custom/uploads skills.
|
||||
return skill_category in ("public", "custom", "uploads")
|
||||
# Default to enable for public & custom skill
|
||||
return skill_category in ("public", "custom")
|
||||
return skill_config.enabled
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -366,17 +366,12 @@ def _path_variants(path: str) -> set[str]:
|
|||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
|
||||
def _path_separator_for_style(path: str) -> str:
|
||||
return "\\" if "\\" in path and "/" not in path else "/"
|
||||
|
||||
|
||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
if not relative:
|
||||
return base
|
||||
separator = _path_separator_for_style(base)
|
||||
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
|
||||
stripped_base = base.rstrip("/\\")
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
if "/" in base and "\\" not in base:
|
||||
return f"{base.rstrip('/')}/{relative}"
|
||||
return str(Path(base) / relative)
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
|
|
@ -421,10 +416,7 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
|||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
result = _join_path_preserving_style(actual_base, rest)
|
||||
if path.endswith("/") and not result.endswith(("/", "\\")):
|
||||
result += _path_separator_for_style(actual_base)
|
||||
return result
|
||||
return _join_path_preserving_style(actual_base, rest)
|
||||
|
||||
return path
|
||||
|
||||
|
|
|
|||
|
|
@ -8,27 +8,6 @@ from .types import Skill
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPLOADS_SKILLS_PATH = Path("/mnt/user-data/uploads")
|
||||
|
||||
|
||||
def get_uploads_skills_path(thread_id: str | None = None) -> Path:
|
||||
"""Resolve the uploads skills root for the current execution context.
|
||||
|
||||
When called from the LangGraph process, uploaded skills live under the
|
||||
host-side per-thread data directory rather than the sandbox mount path.
|
||||
"""
|
||||
if not thread_id:
|
||||
return UPLOADS_SKILLS_PATH
|
||||
|
||||
try:
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
return get_paths().sandbox_uploads_dir(thread_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to resolve uploads skills path for thread %s: %s", thread_id, exc)
|
||||
return UPLOADS_SKILLS_PATH
|
||||
|
||||
|
||||
def get_skills_root_path() -> Path:
|
||||
"""
|
||||
Get the root path of the skills directory.
|
||||
|
|
@ -43,19 +22,12 @@ def get_skills_root_path() -> Path:
|
|||
return skills_dir
|
||||
|
||||
|
||||
def load_skills(
|
||||
skills_path: Path | None = None,
|
||||
use_config: bool = True,
|
||||
enabled_only: bool = False,
|
||||
thread_id: str | None = None,
|
||||
) -> list[Skill]:
|
||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
|
||||
"""
|
||||
Load all skills from the skills directory.
|
||||
|
||||
Scans public/custom skill directories under the configured skills root,
|
||||
and also scans uploaded skills under /mnt/user-data/uploads.
|
||||
SKILL.md metadata is parsed and enabled state is derived from
|
||||
skills_state_config.json.
|
||||
Scans both public and custom skill directories, parsing SKILL.md files
|
||||
to extract metadata. The enabled state is determined by the skills_state_config.json file.
|
||||
|
||||
Args:
|
||||
skills_path: Optional custom path to skills directory.
|
||||
|
|
@ -63,8 +35,6 @@ def load_skills(
|
|||
Otherwise defaults to deer-flow/skills
|
||||
use_config: Whether to load skills path from config (default: True)
|
||||
enabled_only: If True, only return enabled skills (default: False)
|
||||
thread_id: Optional thread ID used to resolve per-thread uploads skills
|
||||
from the LangGraph host process
|
||||
|
||||
Returns:
|
||||
List of Skill objects, sorted by name
|
||||
|
|
@ -87,22 +57,12 @@ def load_skills(
|
|||
|
||||
skills = []
|
||||
|
||||
# Scan built-in roots and uploaded skills mounted in personal workspace.
|
||||
scan_targets: list[tuple[str, Path]] = [
|
||||
("public", skills_path / "public"),
|
||||
("custom", skills_path / "custom"),
|
||||
("uploads", get_uploads_skills_path(thread_id)),
|
||||
]
|
||||
|
||||
for category, category_path in scan_targets:
|
||||
logger.debug("Scanning %s skills under %s", category, category_path)
|
||||
|
||||
# Scan public and custom directories
|
||||
for category in ["public", "custom"]:
|
||||
category_path = skills_path / category
|
||||
if not category_path.exists() or not category_path.is_dir():
|
||||
logger.debug("Skip %s scan: directory not found or not a directory (%s)", category, category_path)
|
||||
continue
|
||||
|
||||
scanned_skill_dirs: list[str] = []
|
||||
|
||||
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
|
||||
# Keep traversal deterministic and skip hidden directories.
|
||||
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
|
||||
|
|
@ -111,22 +71,11 @@ def load_skills(
|
|||
|
||||
skill_file = Path(current_root) / "SKILL.md"
|
||||
relative_path = skill_file.parent.relative_to(category_path)
|
||||
scanned_skill_dirs.append(relative_path.as_posix())
|
||||
|
||||
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
|
||||
if skill:
|
||||
skills.append(skill)
|
||||
|
||||
if scanned_skill_dirs:
|
||||
logger.debug(
|
||||
"%s scan found %d skill directories: %s",
|
||||
category,
|
||||
len(scanned_skill_dirs),
|
||||
", ".join(sorted(scanned_skill_dirs)),
|
||||
)
|
||||
else:
|
||||
logger.debug("%s scan found no skill directories", category)
|
||||
|
||||
# Load skills state configuration and update enabled status
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
|
|
@ -137,10 +86,6 @@ def load_skills(
|
|||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
for skill in skills:
|
||||
if skill.category == "uploads":
|
||||
# Uploaded skills should be available by default for the current thread.
|
||||
skill.enabled = True
|
||||
continue
|
||||
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
|
||||
except Exception as e:
|
||||
# If config loading fails, default to all enabled
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
|
|||
|
||||
Args:
|
||||
skill_file: Path to the SKILL.md file
|
||||
category: Category of the skill ('public', 'custom', or 'uploads')
|
||||
category: Category of the skill ('public' or 'custom')
|
||||
|
||||
Returns:
|
||||
Skill object if parsing succeeds, None otherwise
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class Skill:
|
|||
skill_dir: Path
|
||||
skill_file: Path
|
||||
relative_path: Path # Relative path from category root to skill directory
|
||||
category: str # 'public', 'custom', or 'uploads'
|
||||
category: str # 'public' or 'custom'
|
||||
enabled: bool = False # Whether this skill is enabled
|
||||
|
||||
@property
|
||||
|
|
@ -31,10 +31,7 @@ class Skill:
|
|||
Returns:
|
||||
Full container path to the skill directory
|
||||
"""
|
||||
if self.category == "uploads":
|
||||
category_base = "/mnt/user-data/uploads"
|
||||
else:
|
||||
category_base = f"{container_base_path}/{self.category}"
|
||||
category_base = f"{container_base_path}/{self.category}"
|
||||
skill_path = self.skill_path
|
||||
if skill_path:
|
||||
return f"{category_base}/{skill_path}"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import json
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -1988,47 +1988,6 @@ class TestSlackSendRetry:
|
|||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestSlackAllowedUsers:
|
||||
def test_numeric_allowed_users_match_string_event_user_id(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = SlackChannel(
|
||||
bus=bus,
|
||||
config={"allowed_users": [123456]},
|
||||
)
|
||||
channel._loop = MagicMock()
|
||||
channel._loop.is_running.return_value = True
|
||||
channel._add_reaction = MagicMock()
|
||||
channel._send_running_reply = MagicMock()
|
||||
|
||||
event = {
|
||||
"user": "123456",
|
||||
"text": "hello from slack",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
}
|
||||
|
||||
def submit_coro(coro, loop):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.channels.slack.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=submit_coro,
|
||||
) as submit:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
|
||||
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
|
||||
submit.assert_called_once()
|
||||
inbound = bus.publish_inbound.call_args.args[0]
|
||||
assert inbound.user_id == "123456"
|
||||
assert inbound.chat_id == "C123"
|
||||
assert inbound.text == "hello from slack"
|
||||
|
||||
def test_raises_after_all_retries_exhausted(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
|
|
|
|||
|
|
@ -47,45 +47,4 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
|||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
assert len(queue._queue) == 1
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].reinforcement_detected is True
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
messages=["conversation"],
|
||||
agent_name="lead_agent",
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
]
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
queue._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once_with(
|
||||
messages=["conversation"],
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -619,156 +619,3 @@ class TestUpdateMemoryStructuredResponse:
|
|||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
# Same fact with different casing should be treated as duplicate
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
|
||||
class TestReinforcementHint:
|
||||
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(json_response: str):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Yes, exactly! That's what I needed."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Great to hear!"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Tell me more."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Sure."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "No wait, that's wrong. Actually yes, exactly right."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Got it."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ persisting in long-term memory:
|
|||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
|
@ -270,73 +270,3 @@ class TestStripUploadMentionsFromMemory:
|
|||
mem = {"user": {}, "history": {}, "facts": []}
|
||||
result = _strip_upload_mentions_from_memory(mem)
|
||||
assert result == {"user": {}, "history": {}, "facts": []}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# detect_reinforcement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDetectReinforcement:
|
||||
def test_detects_english_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("Can you summarise it in bullet points?"),
|
||||
_ai("Here are the key points: ..."),
|
||||
_human("Yes, exactly! That's what I needed."),
|
||||
_ai("Glad it helped."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_perfect_signal(self):
|
||||
msgs = [
|
||||
_human("Write it more concisely."),
|
||||
_ai("Here is the concise version."),
|
||||
_human("Perfect."),
|
||||
_ai("Great!"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_chinese_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("帮我用要点来总结"),
|
||||
_ai("好的,要点如下:..."),
|
||||
_human("完全正确,就是这个意思"),
|
||||
_ai("很高兴能帮到你"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_returns_false_without_signal(self):
|
||||
msgs = [
|
||||
_human("What does this function do?"),
|
||||
_ai("It processes the input data."),
|
||||
_human("Can you show me an example?"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_only_checks_recent_messages(self):
|
||||
# Reinforcement signal buried beyond the -6 window should not trigger
|
||||
msgs = [
|
||||
_human("Yes, exactly right."),
|
||||
_ai("Noted."),
|
||||
_human("Let's discuss tests."),
|
||||
_ai("Sure."),
|
||||
_human("What about linting?"),
|
||||
_ai("Use ruff."),
|
||||
_human("And formatting?"),
|
||||
_ai("Use make format."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_does_not_conflict_with_correction(self):
|
||||
# A message can trigger correction but not reinforcement
|
||||
msgs = [
|
||||
_human("That's wrong, try again."),
|
||||
_ai("Corrected."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
|
|
|||
|
|
@ -42,53 +42,6 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
|
|||
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash must survive virtual-to-actual path replacement.
|
||||
|
||||
Regression: '/mnt/user-data/workspace/' was previously returned without
|
||||
the trailing slash, causing string concatenations like
|
||||
output_dir + 'file.txt' to produce a missing-separator path.
|
||||
"""
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
|
||||
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
|
||||
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
|
||||
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
|
||||
|
||||
If actual_base uses backslash separators, appending '/' would produce a
|
||||
mixed-separator path. The separator must match the style of actual_base.
|
||||
"""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
|
||||
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
|
||||
"""Nested Windows-style subdirectories must keep backslashes throughout."""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
|
||||
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
# ---------- mask_local_paths_in_output ----------
|
||||
|
||||
|
||||
|
|
@ -304,22 +257,6 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
|||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
||||
|
|
@ -25,35 +24,15 @@ import { Welcome } from "@/components/workspace/welcome";
|
|||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useNotification } from "@/core/notification/hooks";
|
||||
import { useThreadSettings } from "@/core/settings";
|
||||
import { bootstrapRemoteSkill } from "@/core/skills";
|
||||
import { useThreadStream } from "@/core/threads/hooks";
|
||||
import { textOfMessage } from "@/core/threads/utils";
|
||||
import { uuid } from "@/core/utils/uuid";
|
||||
import { env } from "@/env";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const UUID_REGEX =
|
||||
/^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i;
|
||||
|
||||
export default function ChatPage() {
|
||||
const { t } = useI18n();
|
||||
const [showFollowups, setShowFollowups] = useState(false);
|
||||
const searchParams = useSearchParams();
|
||||
const generatedThreadIdRef = useRef<string>("");
|
||||
if (!generatedThreadIdRef.current) {
|
||||
const queryThreadId = searchParams.get("thread_id")?.trim();
|
||||
generatedThreadIdRef.current =
|
||||
queryThreadId && UUID_REGEX.test(queryThreadId) ? queryThreadId : uuid();
|
||||
}
|
||||
|
||||
// 检查 xclaw_used 参数,仅用于界面风格控制,不影响线程创建逻辑
|
||||
const xclawUsedParam = searchParams.get("xclaw_used");
|
||||
const initialForceNewStyle = xclawUsedParam === "false";
|
||||
const [forceNewStyle, setForceNewStyle] = useState(initialForceNewStyle);
|
||||
|
||||
const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat({
|
||||
newThreadId: generatedThreadIdRef.current,
|
||||
});
|
||||
const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat();
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [mounted, setMounted] = useState(false);
|
||||
useSpecificChatMode();
|
||||
|
|
@ -63,34 +42,6 @@ export default function ChatPage() {
|
|||
}, []);
|
||||
|
||||
const { showNotification } = useNotification();
|
||||
const skillBootstrappedKeysRef = useRef<Set<string>>(new Set());
|
||||
const skillBootstrappingKeysRef = useRef<Set<string>>(new Set());
|
||||
|
||||
const skillBootstrap = useMemo(() => {
|
||||
const skillIdRaw = searchParams.get("skill_id")?.trim();
|
||||
if (!skillIdRaw) return undefined;
|
||||
|
||||
const contentIds = skillIdRaw
|
||||
.split(",")
|
||||
.map((value) => value.trim())
|
||||
.filter((value) => value.length > 0)
|
||||
.map((value) => Number(value))
|
||||
.filter((value) => Number.isFinite(value));
|
||||
|
||||
// Deduplicate while preserving incoming order.
|
||||
const uniqueContentIds = Array.from(new Set(contentIds));
|
||||
if (uniqueContentIds.length === 0) return undefined;
|
||||
|
||||
const languageTypeRaw =
|
||||
searchParams.get("languageType")?.trim() ??
|
||||
searchParams.get("language_type")?.trim();
|
||||
const languageType = languageTypeRaw ? Number(languageTypeRaw) : 0;
|
||||
|
||||
return {
|
||||
contentIds: uniqueContentIds,
|
||||
languageType: Number.isFinite(languageType) ? languageType : 0,
|
||||
};
|
||||
}, [searchParams]);
|
||||
|
||||
const [thread, sendMessage, isUploading] = useThreadStream({
|
||||
threadId: isNewThread ? undefined : threadId,
|
||||
|
|
@ -119,54 +70,11 @@ export default function ChatPage() {
|
|||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!threadId || !skillBootstrap?.contentIds?.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const languageType = skillBootstrap.languageType ?? 0;
|
||||
const initKey = `${threadId}:${skillBootstrap.contentIds.join(",")}:${languageType}`;
|
||||
if (
|
||||
skillBootstrappedKeysRef.current.has(initKey) ||
|
||||
skillBootstrappingKeysRef.current.has(initKey)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
skillBootstrappingKeysRef.current.add(initKey);
|
||||
|
||||
const runBootstrap = async () => {
|
||||
try {
|
||||
await bootstrapRemoteSkill({
|
||||
thread_id: threadId,
|
||||
content_ids: skillBootstrap.contentIds,
|
||||
language_type: languageType,
|
||||
target_dir: "/mnt/user-data/uploads/skill",
|
||||
clear_target: true,
|
||||
});
|
||||
|
||||
skillBootstrappedKeysRef.current.add(initKey);
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Skill initialization failed";
|
||||
showNotification("Skill initialization failed", { body: message });
|
||||
} finally {
|
||||
skillBootstrappingKeysRef.current.delete(initKey);
|
||||
}
|
||||
};
|
||||
|
||||
void runBootstrap();
|
||||
}, [threadId, skillBootstrap, showNotification]);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
void sendMessage(threadId, message);
|
||||
// 仅切换界面风格,不影响线程状态
|
||||
if (forceNewStyle) {
|
||||
setForceNewStyle(false);
|
||||
}
|
||||
},
|
||||
[sendMessage, threadId, forceNewStyle],
|
||||
[sendMessage, threadId],
|
||||
);
|
||||
const handleStop = useCallback(async () => {
|
||||
await thread.stop();
|
||||
|
|
@ -184,7 +92,7 @@ export default function ChatPage() {
|
|||
<header
|
||||
className={cn(
|
||||
"absolute top-0 right-0 left-0 z-30 flex h-12 shrink-0 items-center px-4",
|
||||
(forceNewStyle || isNewThread)
|
||||
isNewThread
|
||||
? "bg-background/0 backdrop-blur-none"
|
||||
: "bg-background/80 shadow-xs backdrop-blur",
|
||||
)}
|
||||
|
|
@ -200,22 +108,19 @@ export default function ChatPage() {
|
|||
</header>
|
||||
<main className="flex min-h-0 max-w-full grow flex-col">
|
||||
<div className="flex size-full justify-center">
|
||||
{/* forceNewStyle 时隐藏消息列表,提交后再显示 */}
|
||||
{!(forceNewStyle) && (
|
||||
<MessageList
|
||||
className={cn("size-full", !isNewThread && "pt-10")}
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
/>
|
||||
)}
|
||||
<MessageList
|
||||
className={cn("size-full", !isNewThread && "pt-10")}
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
/>
|
||||
</div>
|
||||
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
||||
<div
|
||||
className={cn(
|
||||
"relative w-full",
|
||||
(forceNewStyle || isNewThread) && "-translate-y-[calc(50vh-96px)]",
|
||||
(forceNewStyle || isNewThread)
|
||||
isNewThread && "-translate-y-[calc(50vh-96px)]",
|
||||
isNewThread
|
||||
? "max-w-(--container-width-sm)"
|
||||
: "max-w-(--container-width-md)",
|
||||
)}
|
||||
|
|
@ -234,9 +139,9 @@ export default function ChatPage() {
|
|||
{mounted ? (
|
||||
<InputBox
|
||||
className={cn("bg-background/5 w-full -translate-y-4")}
|
||||
isNewThread={forceNewStyle || isNewThread}
|
||||
isNewThread={isNewThread}
|
||||
threadId={threadId}
|
||||
autoFocus={forceNewStyle || isNewThread}
|
||||
autoFocus={isNewThread}
|
||||
status={
|
||||
thread.error
|
||||
? "error"
|
||||
|
|
@ -246,7 +151,7 @@ export default function ChatPage() {
|
|||
}
|
||||
context={settings.context}
|
||||
extraHeader={
|
||||
(forceNewStyle || isNewThread) && <Welcome mode={settings.context.mode} />
|
||||
isNewThread && <Welcome mode={settings.context.mode} />
|
||||
}
|
||||
disabled={
|
||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,17 @@
|
|||
"use client";
|
||||
|
||||
import { useParams, usePathname, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { uuid } from "@/core/utils/uuid";
|
||||
|
||||
type UseThreadChatOptions = {
|
||||
newThreadId?: string;
|
||||
};
|
||||
|
||||
export function useThreadChat(options?: UseThreadChatOptions) {
|
||||
export function useThreadChat() {
|
||||
const { thread_id: threadIdFromPath } = useParams<{ thread_id: string }>();
|
||||
const pathname = usePathname();
|
||||
const fallbackNewThreadIdRef = useRef<string>(options?.newThreadId ?? uuid());
|
||||
const fallbackNewThreadId = options?.newThreadId ?? fallbackNewThreadIdRef.current;
|
||||
|
||||
const searchParams = useSearchParams();
|
||||
const [threadId, setThreadId] = useState(() => {
|
||||
return threadIdFromPath === "new" ? fallbackNewThreadId : threadIdFromPath;
|
||||
return threadIdFromPath === "new" ? uuid() : threadIdFromPath;
|
||||
});
|
||||
|
||||
const [isNewThread, setIsNewThread] = useState(
|
||||
|
|
@ -27,9 +21,9 @@ export function useThreadChat(options?: UseThreadChatOptions) {
|
|||
useEffect(() => {
|
||||
if (pathname.endsWith("/new")) {
|
||||
setIsNewThread(true);
|
||||
setThreadId(fallbackNewThreadId);
|
||||
setThreadId(uuid());
|
||||
}
|
||||
}, [pathname, fallbackNewThreadId]);
|
||||
}, [pathname]);
|
||||
const isMock = searchParams.get("mock") === "true";
|
||||
return { threadId, isNewThread, setIsNewThread, isMock };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,10 +32,14 @@ import { SettingsDialog } from "./settings";
|
|||
export function CommandPalette() {
|
||||
const { t } = useI18n();
|
||||
const router = useRouter();
|
||||
const [mounted, setMounted] = useState(false);
|
||||
const [open, setOpen] = useState(false);
|
||||
const [shortcutsOpen, setShortcutsOpen] = useState(false);
|
||||
const [settingsOpen, setSettingsOpen] = useState(false);
|
||||
const [isMac, setIsMac] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true);
|
||||
}, []);
|
||||
|
||||
const handleNewChat = useCallback(() => {
|
||||
router.push("/workspace/chats/new");
|
||||
|
|
@ -64,12 +68,14 @@ export function CommandPalette() {
|
|||
|
||||
useGlobalShortcuts(shortcuts);
|
||||
|
||||
useEffect(() => {
|
||||
setIsMac(navigator.userAgent.includes("Mac"));
|
||||
}, []);
|
||||
const isMac = mounted && navigator.userAgent.includes("Mac");
|
||||
const metaKey = isMac ? "⌘" : "Ctrl+";
|
||||
const shiftKey = isMac ? "⇧" : "Shift+";
|
||||
|
||||
if (!mounted) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsDialog open={settingsOpen} onOpenChange={setSettingsOpen} />
|
||||
|
|
|
|||
|
|
@ -35,24 +35,6 @@ export interface InstallSkillResponse {
|
|||
message: string;
|
||||
}
|
||||
|
||||
export interface BootstrapRemoteSkillRequest {
|
||||
thread_id: string;
|
||||
content_ids: number[];
|
||||
language_type?: number;
|
||||
target_dir?: string;
|
||||
clear_target?: boolean;
|
||||
}
|
||||
|
||||
export interface BootstrapRemoteSkillResponse {
|
||||
success: boolean;
|
||||
target_dir: string;
|
||||
content_ids: number[];
|
||||
created_directories: number;
|
||||
created_files: number;
|
||||
sandbox_id: string | null;
|
||||
message: string;
|
||||
}
|
||||
|
||||
export async function installSkill(
|
||||
request: InstallSkillRequest,
|
||||
): Promise<InstallSkillResponse> {
|
||||
|
|
@ -78,27 +60,3 @@ export async function installSkill(
|
|||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function bootstrapRemoteSkill(
|
||||
request: BootstrapRemoteSkillRequest,
|
||||
): Promise<BootstrapRemoteSkillResponse> {
|
||||
const response = await fetch(
|
||||
`${getBackendBaseURL()}/api/skills/bootstrap-remote`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
const errorMessage =
|
||||
errorData.detail ?? `HTTP ${response.status}: ${response.statusText}`;
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue