Compare commits
10 Commits
0ffe5a73c1
...
6b900ccb60
| Author | SHA1 | Date |
|---|---|---|
|
|
6b900ccb60 | |
|
|
0f607441c8 | |
|
|
b412b5193b | |
|
|
5de7a2ab46 | |
|
|
5aa38ee108 | |
|
|
66bdc951f8 | |
|
|
117fa9b05d | |
|
|
28474c47cb | |
|
|
8049785de6 | |
|
|
9ca68ffaaa |
|
|
@ -30,7 +30,7 @@ class SlackChannel(Channel):
|
||||||
self._socket_client = None
|
self._socket_client = None
|
||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users: set[str] = set(config.get("allowed_users", []))
|
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,11 @@ from app.gateway.routers import (
|
||||||
)
|
)
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging with env override
|
||||||
|
import os
|
||||||
|
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=getattr(logging, log_level, logging.INFO),
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,10 @@ class GatewayConfig(BaseModel):
|
||||||
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
||||||
port: int = Field(default=8001, description="Port to bind the gateway server")
|
port: int = Field(default=8001, description="Port to bind the gateway server")
|
||||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
||||||
|
skill_content_api_url: str = Field(
|
||||||
|
default="https://skills.xueai.art/api/cmsContent/getContent",
|
||||||
|
description="Remote API URL used to fetch skill YAML content by content ID",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_gateway_config: GatewayConfig | None = None
|
_gateway_config: GatewayConfig | None = None
|
||||||
|
|
@ -23,5 +27,9 @@ def get_gateway_config() -> GatewayConfig:
|
||||||
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
||||||
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
||||||
cors_origins=cors_origins_str.split(","),
|
cors_origins=cors_origins_str.split(","),
|
||||||
|
skill_content_api_url=os.getenv(
|
||||||
|
"SKILL_CONTENT_API_URL",
|
||||||
|
"https://skills.xueai.art/api/cmsContent/getContent",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return _gateway_config
|
return _gateway_config
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||||
|
from app.gateway.skill_yaml_importer import materialize_skill_tree, parse_skill_yaml_spec
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||||
from deerflow.skills import Skill, load_skills
|
from deerflow.skills import Skill, load_skills
|
||||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||||
|
|
@ -52,6 +56,38 @@ class SkillInstallResponse(BaseModel):
|
||||||
message: str = Field(..., description="Installation result message")
|
message: str = Field(..., description="Installation result message")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSkillBootstrapRequest(BaseModel):
|
||||||
|
"""Request model for bootstrapping skill files from remote content API."""
|
||||||
|
|
||||||
|
thread_id: str = Field(..., description="Thread ID used for user-data path binding")
|
||||||
|
content_ids: list[int] = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
description="Remote content ID sequence (maps from frontend query param skill_id)",
|
||||||
|
)
|
||||||
|
language_type: int = Field(default=0, description="Language type for remote API request body")
|
||||||
|
target_dir: str = Field(
|
||||||
|
default="/mnt/user-data/uploads/skill",
|
||||||
|
description="Virtual base directory where each skill-{id} subdirectory is created",
|
||||||
|
)
|
||||||
|
clear_target: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to clear target directory before writing parsed files",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSkillBootstrapResponse(BaseModel):
|
||||||
|
"""Response model for remote bootstrap endpoint."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether bootstrap succeeded")
|
||||||
|
target_dir: str = Field(..., description="Virtual base target directory")
|
||||||
|
content_ids: list[int] = Field(..., description="Bootstrapped content IDs")
|
||||||
|
created_directories: int = Field(..., description="Number of created directories")
|
||||||
|
created_files: int = Field(..., description="Number of created files")
|
||||||
|
sandbox_id: str | None = Field(default=None, description="Acquired sandbox ID (null when sandbox is not acquired)")
|
||||||
|
message: str = Field(..., description="Operation result message")
|
||||||
|
|
||||||
|
|
||||||
def _skill_to_response(skill: Skill) -> SkillResponse:
|
def _skill_to_response(skill: Skill) -> SkillResponse:
|
||||||
"""Convert a Skill object to a SkillResponse."""
|
"""Convert a Skill object to a SkillResponse."""
|
||||||
return SkillResponse(
|
return SkillResponse(
|
||||||
|
|
@ -171,3 +207,107 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to install skill: {e}", exc_info=True)
|
logger.error(f"Failed to install skill: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/skills/bootstrap-remote",
|
||||||
|
response_model=RemoteSkillBootstrapResponse,
|
||||||
|
summary="Bootstrap Skill Files From Remote API",
|
||||||
|
description=(
|
||||||
|
"Fetch YAML text from configured remote API by content_ids/language_type and "
|
||||||
|
"materialize files into skill-{id} directories under /mnt/user-data/uploads/skill "
|
||||||
|
"before first thread submit."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def bootstrap_skill_from_remote(request: RemoteSkillBootstrapRequest) -> RemoteSkillBootstrapResponse:
|
||||||
|
"""Initialize thread skill directory from remote YAML content service."""
|
||||||
|
try:
|
||||||
|
cfg = get_gateway_config()
|
||||||
|
api_url = cfg.skill_content_api_url
|
||||||
|
created_directories_total = 0
|
||||||
|
created_files_total = 0
|
||||||
|
|
||||||
|
base_target_path = resolve_thread_virtual_path(request.thread_id, request.target_dir)
|
||||||
|
if request.clear_target and base_target_path.exists():
|
||||||
|
if base_target_path.is_dir():
|
||||||
|
shutil.rmtree(base_target_path)
|
||||||
|
else:
|
||||||
|
base_target_path.unlink()
|
||||||
|
base_target_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||||
|
for content_id in request.content_ids:
|
||||||
|
payload = {
|
||||||
|
"contentId": content_id,
|
||||||
|
"languageType": request.language_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client.post(api_url, json=payload)
|
||||||
|
if response.status_code >= 400:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Remote skill content API failed with HTTP "
|
||||||
|
f"{response.status_code} for content_id={content_id}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_json = response.json()
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"Remote API did not return valid JSON: {e}") from e
|
||||||
|
|
||||||
|
status = response_json.get("status")
|
||||||
|
if status != 1000:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=(
|
||||||
|
"Remote API returned non-success status: "
|
||||||
|
f"{status}, message: {response_json.get('message')}, content_id={content_id}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yaml_text = response_json.get("data")
|
||||||
|
if not isinstance(yaml_text, str) or not yaml_text.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Remote API returned empty or invalid YAML content for content_id={content_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
target_dir = f"{request.target_dir.rstrip('/')}/skill-{content_id}"
|
||||||
|
target_path = resolve_thread_virtual_path(request.thread_id, target_dir)
|
||||||
|
parsed = parse_skill_yaml_spec(yaml_text)
|
||||||
|
materialize_skill_tree(parsed, target_path, clear_target=False)
|
||||||
|
|
||||||
|
created_directories_total += len(parsed.directories)
|
||||||
|
created_files_total += len(parsed.files)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Bootstrapped remote skill YAML for thread %s (content_id=%s, language_type=%s) to %s: dirs=%d files=%d",
|
||||||
|
request.thread_id,
|
||||||
|
content_id,
|
||||||
|
request.language_type,
|
||||||
|
target_dir,
|
||||||
|
len(parsed.directories),
|
||||||
|
len(parsed.files),
|
||||||
|
)
|
||||||
|
|
||||||
|
return RemoteSkillBootstrapResponse(
|
||||||
|
success=True,
|
||||||
|
target_dir=request.target_dir,
|
||||||
|
content_ids=request.content_ids,
|
||||||
|
created_directories=created_directories_total,
|
||||||
|
created_files=created_files_total,
|
||||||
|
sandbox_id=None,
|
||||||
|
message=(
|
||||||
|
f"Bootstrapped {created_files_total} files and {created_directories_total} directories "
|
||||||
|
f"for {len(request.content_ids)} skills under '{request.target_dir}'"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to bootstrap skill from remote API: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to bootstrap skill from remote API: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,475 @@
|
||||||
|
"""Utilities for parsing YAML-defined skill package structures.
|
||||||
|
|
||||||
|
This module supports turning a YAML document describing files/directories into
|
||||||
|
real filesystem content under a thread's virtual path (for example,
|
||||||
|
``/mnt/user-data/uploads/skill``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ParsedSkillTree:
|
||||||
|
"""Normalized parsed structure from YAML spec."""
|
||||||
|
|
||||||
|
directories: set[str]
|
||||||
|
files: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_first_existing(data: dict, keys: tuple[str, ...]):
|
||||||
|
for key in keys:
|
||||||
|
if key in data:
|
||||||
|
return data[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_spec_root(data: dict) -> dict:
|
||||||
|
"""Extract the effective spec root.
|
||||||
|
|
||||||
|
Supports nested wrappers like:
|
||||||
|
- skill: { ... }
|
||||||
|
- package: { ... }
|
||||||
|
- spec: { ... }
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("YAML root must be an object")
|
||||||
|
|
||||||
|
known_keys = {
|
||||||
|
"entries",
|
||||||
|
"files",
|
||||||
|
"directories",
|
||||||
|
"dirs",
|
||||||
|
"tree",
|
||||||
|
"structure",
|
||||||
|
"file_tree",
|
||||||
|
"fileTree",
|
||||||
|
"file_structure",
|
||||||
|
"paths",
|
||||||
|
}
|
||||||
|
if any(k in data for k in known_keys):
|
||||||
|
return data
|
||||||
|
|
||||||
|
wrapper_candidates = ("skill", "package", "spec", "data", "content", "payload")
|
||||||
|
for wrapper in wrapper_candidates:
|
||||||
|
candidate = data.get(wrapper)
|
||||||
|
if isinstance(candidate, dict) and any(k in candidate for k in known_keys):
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# Fallback: if exactly one nested object exists, try it as spec root.
|
||||||
|
nested_dicts = [v for v in data.values() if isinstance(v, dict)]
|
||||||
|
if len(nested_dicts) == 1:
|
||||||
|
return nested_dicts[0]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_relative_path(path: str) -> str:
|
||||||
|
"""Normalize and validate a relative path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path is unsafe or invalid.
|
||||||
|
"""
|
||||||
|
if not isinstance(path, str):
|
||||||
|
raise ValueError("Path must be a string")
|
||||||
|
|
||||||
|
normalized = path.strip().replace("\\", "/")
|
||||||
|
if normalized in {"/", ".", "./"}:
|
||||||
|
return ""
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Path cannot be empty")
|
||||||
|
|
||||||
|
if normalized.startswith("/"):
|
||||||
|
raise ValueError(f"Path must be relative, got absolute path: {path}")
|
||||||
|
|
||||||
|
if ":" in normalized:
|
||||||
|
raise ValueError(f"Path cannot contain ':' (possible drive path): {path}")
|
||||||
|
|
||||||
|
parts = [part for part in normalized.split("/") if part]
|
||||||
|
if not parts:
|
||||||
|
raise ValueError("Path cannot be empty")
|
||||||
|
|
||||||
|
if any(part in {".", ".."} for part in parts):
|
||||||
|
raise ValueError(f"Path traversal is not allowed: {path}")
|
||||||
|
|
||||||
|
return "/".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_directory(path: str, directories: set[str]) -> None:
|
||||||
|
normalized = _normalize_relative_path(path)
|
||||||
|
if not normalized:
|
||||||
|
return
|
||||||
|
directories.add(normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_file(path: str, content: str, files: dict[str, str], directories: set[str]) -> None:
|
||||||
|
normalized = _normalize_relative_path(path)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("File path cannot be root ('/')")
|
||||||
|
if not isinstance(content, str):
|
||||||
|
raise ValueError(f"File content must be a string for '{normalized}'")
|
||||||
|
|
||||||
|
parent = Path(normalized).parent
|
||||||
|
if str(parent) != ".":
|
||||||
|
directories.add(str(parent).replace("\\", "/"))
|
||||||
|
|
||||||
|
files[normalized] = content
|
||||||
|
|
||||||
|
|
||||||
|
def _walk_tree_dict(tree: dict, base: str, files: dict[str, str], directories: set[str]) -> None:
|
||||||
|
for name, value in tree.items():
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError("Tree keys must be strings")
|
||||||
|
|
||||||
|
if name.strip() in {"/", ".", "./"}:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
_walk_tree_dict(value, base, files, directories)
|
||||||
|
continue
|
||||||
|
raise ValueError("Root sentinel '/' can only be used for directory/object nodes")
|
||||||
|
|
||||||
|
node_path = f"{base}/{name}" if base else name
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
_add_directory(node_path, directories)
|
||||||
|
_walk_tree_dict(value, _normalize_relative_path(node_path), files, directories)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
_add_file(node_path, value, files, directories)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported tree node type for '{node_path}': {type(value).__name__}. "
|
||||||
|
"Use object (directory) or string (file content)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_entries_node(
|
||||||
|
node: dict,
|
||||||
|
base: str,
|
||||||
|
files: dict[str, str],
|
||||||
|
directories: set[str],
|
||||||
|
) -> None:
|
||||||
|
raw_path = node.get("path")
|
||||||
|
raw_name = node.get("name")
|
||||||
|
|
||||||
|
if raw_path is None and raw_name is None:
|
||||||
|
raise ValueError("Each entry must have at least one of: 'path' or 'name'")
|
||||||
|
|
||||||
|
if raw_path is not None and not isinstance(raw_path, str):
|
||||||
|
raise ValueError("Entry 'path' must be a string")
|
||||||
|
if raw_name is not None and not isinstance(raw_name, str):
|
||||||
|
raise ValueError("Entry 'name' must be a string")
|
||||||
|
|
||||||
|
# Common schema compatibility:
|
||||||
|
# - `path` is parent directory (e.g. "/")
|
||||||
|
# - `name` is current node name (e.g. "README.md")
|
||||||
|
# Build parent then append name when both are present.
|
||||||
|
parent = base
|
||||||
|
if isinstance(raw_path, str) and raw_path.strip():
|
||||||
|
rp = raw_path.strip()
|
||||||
|
if rp not in {"/", ".", "./"}:
|
||||||
|
parent = _normalize_relative_path(f"{base}/{rp}" if base else rp)
|
||||||
|
|
||||||
|
if isinstance(raw_name, str) and raw_name.strip():
|
||||||
|
if parent:
|
||||||
|
node_path = _normalize_relative_path(f"{parent}/{raw_name.strip()}")
|
||||||
|
else:
|
||||||
|
node_path = _normalize_relative_path(raw_name.strip())
|
||||||
|
else:
|
||||||
|
# Fallback: only path provided
|
||||||
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||||
|
raise ValueError("Each entry must have a non-empty 'path' or 'name'")
|
||||||
|
rp = raw_path.strip()
|
||||||
|
if rp in {"/", ".", "./"}:
|
||||||
|
node_path = base
|
||||||
|
else:
|
||||||
|
node_path = _normalize_relative_path(f"{base}/{rp}" if base else rp)
|
||||||
|
|
||||||
|
node_type = node.get("type")
|
||||||
|
content = node.get("content")
|
||||||
|
children = node.get("children")
|
||||||
|
|
||||||
|
inferred_type = "directory" if isinstance(children, list) else "file" if content is not None else None
|
||||||
|
final_type = node_type or inferred_type
|
||||||
|
|
||||||
|
if final_type == "directory":
|
||||||
|
_add_directory(node_path, directories)
|
||||||
|
if children is None:
|
||||||
|
return
|
||||||
|
if not isinstance(children, list):
|
||||||
|
raise ValueError(f"Entry '{node_path}' children must be a list")
|
||||||
|
for child in children:
|
||||||
|
if not isinstance(child, dict):
|
||||||
|
raise ValueError(f"Entry '{node_path}' children must be objects")
|
||||||
|
_parse_entries_node(child, node_path, files, directories)
|
||||||
|
return
|
||||||
|
|
||||||
|
if final_type == "file":
|
||||||
|
if content is None:
|
||||||
|
raise ValueError(f"File entry '{node_path}' is missing 'content'")
|
||||||
|
_add_file(node_path, content, files, directories)
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to infer entry type for '{node_path}'. Set 'type' to 'file' or 'directory'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_skill_yaml_spec(yaml_text: str) -> ParsedSkillTree:
|
||||||
|
"""Parse YAML text into normalized directories and files.
|
||||||
|
|
||||||
|
Supported forms:
|
||||||
|
- entries: [{type,path/content/children}, ...]
|
||||||
|
- files: {"path/to/file": "text"} + optional directories/dirs
|
||||||
|
- tree/structure: nested dict where dict=directory and string=file content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(yaml_text)
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML: {e}") from e
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("YAML is empty")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("YAML root must be an object")
|
||||||
|
|
||||||
|
data = _extract_spec_root(data)
|
||||||
|
|
||||||
|
directories: set[str] = set()
|
||||||
|
files: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Form 1: explicit entries list
|
||||||
|
entries = _pick_first_existing(data, ("entries", "nodes", "items"))
|
||||||
|
if entries is not None:
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
raise ValueError("'entries' must be a list")
|
||||||
|
for entry in entries:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
raise ValueError("Each item in 'entries' must be an object")
|
||||||
|
_parse_entries_node(entry, "", files, directories)
|
||||||
|
|
||||||
|
# Form 2: files + directories
|
||||||
|
file_map = _pick_first_existing(data, ("files", "paths", "file_map", "fileMap", "documents"))
|
||||||
|
if file_map is not None:
|
||||||
|
if isinstance(file_map, dict):
|
||||||
|
for path, content in file_map.items():
|
||||||
|
_add_file(path, content, files, directories)
|
||||||
|
elif isinstance(file_map, list):
|
||||||
|
for item in file_map:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise ValueError("Each item in 'files' list must be an object")
|
||||||
|
path = item.get("path") or item.get("name") or item.get("file")
|
||||||
|
content = item.get("content")
|
||||||
|
if content is None:
|
||||||
|
content = item.get("text")
|
||||||
|
if content is None:
|
||||||
|
content = item.get("body")
|
||||||
|
if path is None or content is None:
|
||||||
|
raise ValueError("Each file item needs 'path' and 'content'")
|
||||||
|
_add_file(path, content, files, directories)
|
||||||
|
else:
|
||||||
|
raise ValueError("'files' must be a map or list")
|
||||||
|
|
||||||
|
directory_list = _pick_first_existing(data, ("directories", "dirs", "folders", "folder_paths"))
|
||||||
|
if directory_list is not None:
|
||||||
|
if not isinstance(directory_list, list):
|
||||||
|
raise ValueError("'directories'/'dirs' must be a list")
|
||||||
|
for path in directory_list:
|
||||||
|
_add_directory(path, directories)
|
||||||
|
|
||||||
|
# Form 3: nested tree
|
||||||
|
tree = _pick_first_existing(data, ("tree", "structure", "file_tree", "fileTree", "file_structure"))
|
||||||
|
if tree is not None:
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
_walk_tree_dict(tree, "", files, directories)
|
||||||
|
elif isinstance(tree, list):
|
||||||
|
for item in tree:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise ValueError("Items in 'tree' list must be objects")
|
||||||
|
_parse_entries_node(item, "", files, directories)
|
||||||
|
else:
|
||||||
|
raise ValueError("'tree'/'structure' must be an object or list")
|
||||||
|
|
||||||
|
# Heuristic fallback: treat root as path->content map when possible.
|
||||||
|
if not files and not directories:
|
||||||
|
candidate_keys = [k for k in data.keys() if isinstance(k, str)]
|
||||||
|
if candidate_keys and all(isinstance(data[k], str) for k in candidate_keys):
|
||||||
|
for path, content in data.items():
|
||||||
|
_add_file(path, content, files, directories)
|
||||||
|
|
||||||
|
if not files and not directories:
|
||||||
|
raise ValueError(
|
||||||
|
"No content found. Provide at least one of: entries, files, directories/dirs, tree/structure"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure parent directories exist for every file
|
||||||
|
for rel_file in files:
|
||||||
|
parent = Path(rel_file).parent
|
||||||
|
if str(parent) != ".":
|
||||||
|
directories.add(str(parent).replace("\\", "/"))
|
||||||
|
|
||||||
|
return ParsedSkillTree(directories=directories, files=files)
|
||||||
|
|
||||||
|
|
||||||
|
def materialize_skill_tree(parsed: ParsedSkillTree, target_root: Path, clear_target: bool = True) -> None:
|
||||||
|
"""Create parsed directories/files under target root."""
|
||||||
|
if clear_target and target_root.exists():
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(target_root)
|
||||||
|
|
||||||
|
target_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for rel_dir in sorted(parsed.directories, key=lambda p: (p.count("/"), p)):
|
||||||
|
(target_root / rel_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for rel_file, content in parsed.files.items():
|
||||||
|
file_path = target_root / rel_file
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
file_path.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cli_parser() -> argparse.ArgumentParser:
|
||||||
|
"""Build command-line argument parser.
|
||||||
|
|
||||||
|
CLI usage:
|
||||||
|
python skill_yaml_importer.py <input_path> [options]
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="Parse and validate a skill YAML spec file")
|
||||||
|
parser.add_argument("input_path", help="Path to a YAML file or a directory containing YAML files")
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-files",
|
||||||
|
action="store_true",
|
||||||
|
help="Print sorted parsed file paths",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-directories",
|
||||||
|
action="store_true",
|
||||||
|
help="Print sorted parsed directory paths",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--json",
|
||||||
|
action="store_true",
|
||||||
|
help="Print parsed summary as JSON",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recursive",
|
||||||
|
action="store_true",
|
||||||
|
help="When input path is a directory, scan YAML files recursively",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-file",
|
||||||
|
default=None,
|
||||||
|
help="Optional path to save full execution results and summary as JSON",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_yaml_files(input_path: Path, recursive: bool) -> list[Path]:
|
||||||
|
if input_path.is_file():
|
||||||
|
return [input_path]
|
||||||
|
|
||||||
|
if not input_path.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
patterns = ("*.yaml", "*.yml")
|
||||||
|
files: list[Path] = []
|
||||||
|
for pattern in patterns:
|
||||||
|
iterator = input_path.rglob(pattern) if recursive else input_path.glob(pattern)
|
||||||
|
files.extend(iterator)
|
||||||
|
|
||||||
|
# Stable order for deterministic output
|
||||||
|
return sorted({p.resolve() for p in files})
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_one_yaml_file(yaml_path: Path, show_files: bool, show_directories: bool) -> dict:
|
||||||
|
yaml_text = yaml_path.read_text(encoding="utf-8")
|
||||||
|
parsed = parse_skill_yaml_spec(yaml_text)
|
||||||
|
directories = sorted(parsed.directories)
|
||||||
|
files = sorted(parsed.files.keys())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"yaml_file": str(yaml_path),
|
||||||
|
"directories_count": len(directories),
|
||||||
|
"files_count": len(files),
|
||||||
|
"directories": directories if show_directories else None,
|
||||||
|
"files": files if show_files else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _main() -> int:
|
||||||
|
"""CLI entrypoint for parsing one YAML file or a batch of YAML files.
|
||||||
|
|
||||||
|
Exit codes:
|
||||||
|
0: all files parsed successfully
|
||||||
|
1: invalid input path or no YAML files found
|
||||||
|
2: processed completed with one or more parse failures
|
||||||
|
"""
|
||||||
|
args = _build_cli_parser().parse_args()
|
||||||
|
|
||||||
|
input_path = Path(args.input_path)
|
||||||
|
if not input_path.exists():
|
||||||
|
print(f"Input path not found: {input_path}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
yaml_files = _collect_yaml_files(input_path, recursive=args.recursive)
|
||||||
|
if not yaml_files:
|
||||||
|
print(f"No YAML files found under: {input_path}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
successes: list[dict] = []
|
||||||
|
failures: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
for yaml_path in yaml_files:
|
||||||
|
try:
|
||||||
|
result = _parse_one_yaml_file(
|
||||||
|
yaml_path,
|
||||||
|
show_files=args.show_files,
|
||||||
|
show_directories=args.show_directories,
|
||||||
|
)
|
||||||
|
successes.append(result)
|
||||||
|
if not args.json:
|
||||||
|
print(f"OK: {yaml_path}")
|
||||||
|
print(f" Directories: {result['directories_count']}")
|
||||||
|
print(f" Files: {result['files_count']}")
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
failures.append({"yaml_file": str(yaml_path), "error": str(e)})
|
||||||
|
print(f"ERROR: {yaml_path}: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"input_path": str(input_path),
|
||||||
|
"total": len(yaml_files),
|
||||||
|
"success": len(successes),
|
||||||
|
"failed": len(failures),
|
||||||
|
}
|
||||||
|
|
||||||
|
report = {"summary": summary, "successes": successes, "failures": failures}
|
||||||
|
|
||||||
|
if args.log_file:
|
||||||
|
try:
|
||||||
|
log_path = Path(args.log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
log_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
print(f"Log saved: {log_path}")
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
print(f"Failed to write log file '{args.log_file}': {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(report, ensure_ascii=False, indent=2))
|
||||||
|
else:
|
||||||
|
print("\n[Summary]")
|
||||||
|
print(f"Input: {summary['input_path']}")
|
||||||
|
print(f"Total: {summary['total']}")
|
||||||
|
print(f"Success: {summary['success']}")
|
||||||
|
print(f"Failed: {summary['failed']}")
|
||||||
|
|
||||||
|
return 0 if not failures else 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(_main())
|
||||||
|
|
@ -394,7 +394,17 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||||
suitable for injection into any agent's system prompt.
|
suitable for injection into any agent's system prompt.
|
||||||
"""
|
"""
|
||||||
skills = _get_enabled_skills()
|
thread_id = None
|
||||||
|
try:
|
||||||
|
from langgraph.config import get_config
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
thread_id = config.get("configurable", {}).get("thread_id")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Keep regular enable/disable behavior while letting uploads be default-enabled in loader.
|
||||||
|
skills = load_skills(enabled_only=True, thread_id=thread_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
|
|
@ -561,4 +571,6 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
||||||
acp_section=acp_and_mounts_section,
|
acp_section=acp_and_mounts_section,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("Generated full system prompt:\n%s", prompt)
|
||||||
|
|
||||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ class ConversationContext:
|
||||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
correction_detected: bool = False
|
correction_detected: bool = False
|
||||||
|
reinforcement_detected: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MemoryUpdateQueue:
|
class MemoryUpdateQueue:
|
||||||
|
|
@ -44,6 +45,7 @@ class MemoryUpdateQueue:
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a conversation to the update queue.
|
"""Add a conversation to the update queue.
|
||||||
|
|
||||||
|
|
@ -52,6 +54,7 @@ class MemoryUpdateQueue:
|
||||||
messages: The conversation messages.
|
messages: The conversation messages.
|
||||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
|
|
@ -63,11 +66,13 @@ class MemoryUpdateQueue:
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
|
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||||
context = ConversationContext(
|
context = ConversationContext(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
correction_detected=merged_correction_detected,
|
correction_detected=merged_correction_detected,
|
||||||
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if this thread already has a pending update
|
# Check if this thread already has a pending update
|
||||||
|
|
@ -130,6 +135,7 @@ class MemoryUpdateQueue:
|
||||||
thread_id=context.thread_id,
|
thread_id=context.thread_id,
|
||||||
agent_name=context.agent_name,
|
agent_name=context.agent_name,
|
||||||
correction_detected=context.correction_detected,
|
correction_detected=context.correction_detected,
|
||||||
|
reinforcement_detected=context.reinforcement_detected,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
|
||||||
stripped = content.strip()
|
stripped = content.strip()
|
||||||
if not stripped:
|
if not stripped:
|
||||||
return None
|
return None
|
||||||
return stripped
|
return stripped.casefold()
|
||||||
|
|
||||||
|
|
||||||
class MemoryUpdater:
|
class MemoryUpdater:
|
||||||
|
|
@ -272,6 +272,7 @@ class MemoryUpdater:
|
||||||
thread_id: str | None = None,
|
thread_id: str | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
|
|
@ -280,6 +281,7 @@ class MemoryUpdater:
|
||||||
thread_id: Optional thread ID for tracking source.
|
thread_id: Optional thread ID for tracking source.
|
||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
|
|
@ -310,6 +312,14 @@ class MemoryUpdater:
|
||||||
"and record the correct approach as a fact with category "
|
"and record the correct approach as a fact with category "
|
||||||
'"correction" and confidence >= 0.95 when appropriate.'
|
'"correction" and confidence >= 0.95 when appropriate.'
|
||||||
)
|
)
|
||||||
|
if reinforcement_detected:
|
||||||
|
reinforcement_hint = (
|
||||||
|
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||||
|
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||||
|
"Record the confirmed approach, style, or preference as a fact with category "
|
||||||
|
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||||
|
)
|
||||||
|
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||||
|
|
||||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||||
current_memory=json.dumps(current_memory, indent=2),
|
current_memory=json.dumps(current_memory, indent=2),
|
||||||
|
|
@ -441,6 +451,7 @@ def update_memory_from_conversation(
|
||||||
thread_id: str | None = None,
|
thread_id: str | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Convenience function to update memory from a conversation.
|
"""Convenience function to update memory from a conversation.
|
||||||
|
|
||||||
|
|
@ -449,9 +460,10 @@ def update_memory_from_conversation(
|
||||||
thread_id: Optional thread ID.
|
thread_id: Optional thread ID.
|
||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise.
|
True if successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
|
||||||
re.compile(r"改用"),
|
re.compile(r"改用"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_REINFORCEMENT_PATTERNS = (
|
||||||
|
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddlewareState(AgentState):
|
class MemoryMiddlewareState(AgentState):
|
||||||
"""Compatible with the `ThreadState` schema."""
|
"""Compatible with the `ThreadState` schema."""
|
||||||
|
|
@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||||
|
|
||||||
|
Complements detect_correction() by identifying when the user confirms the
|
||||||
|
agent's approach was correct. This allows the memory system to record what
|
||||||
|
worked well, not just what went wrong.
|
||||||
|
|
||||||
|
The queue keeps only one pending context per thread, so callers pass the
|
||||||
|
latest filtered message list. Checking only recent user turns keeps signal
|
||||||
|
detection conservative while avoiding stale signals from long histories.
|
||||||
|
"""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = _extract_message_text(msg).strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
"""Middleware that queues conversation for memory update after agent execution.
|
"""Middleware that queues conversation for memory update after agent execution.
|
||||||
|
|
||||||
|
|
@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
|
|
||||||
# Queue the filtered conversation for memory update
|
# Queue the filtered conversation for memory update
|
||||||
correction_detected = detect_correction(filtered_messages)
|
correction_detected = detect_correction(filtered_messages)
|
||||||
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add(
|
queue.add(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
agent_name=self._agent_name,
|
agent_name=self._agent_name,
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -192,8 +192,8 @@ class ExtensionsConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
skill_config = self.skills.get(skill_name)
|
skill_config = self.skills.get(skill_name)
|
||||||
if skill_config is None:
|
if skill_config is None:
|
||||||
# Default to enable for public & custom skill
|
# Default to enable for public/custom/uploads skills.
|
||||||
return skill_category in ("public", "custom")
|
return skill_category in ("public", "custom", "uploads")
|
||||||
return skill_config.enabled
|
return skill_config.enabled
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -366,12 +366,17 @@ def _path_variants(path: str) -> set[str]:
|
||||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||||
|
|
||||||
|
|
||||||
|
def _path_separator_for_style(path: str) -> str:
|
||||||
|
return "\\" if "\\" in path and "/" not in path else "/"
|
||||||
|
|
||||||
|
|
||||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||||
if not relative:
|
if not relative:
|
||||||
return base
|
return base
|
||||||
if "/" in base and "\\" not in base:
|
separator = _path_separator_for_style(base)
|
||||||
return f"{base.rstrip('/')}/{relative}"
|
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
|
||||||
return str(Path(base) / relative)
|
stripped_base = base.rstrip("/\\")
|
||||||
|
return f"{stripped_base}{separator}{normalized_relative}"
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||||
|
|
@ -416,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||||
return actual_base
|
return actual_base
|
||||||
if path.startswith(f"{virtual_base}/"):
|
if path.startswith(f"{virtual_base}/"):
|
||||||
rest = path[len(virtual_base) :].lstrip("/")
|
rest = path[len(virtual_base) :].lstrip("/")
|
||||||
return _join_path_preserving_style(actual_base, rest)
|
result = _join_path_preserving_style(actual_base, rest)
|
||||||
|
if path.endswith("/") and not result.endswith(("/", "\\")):
|
||||||
|
result += _path_separator_for_style(actual_base)
|
||||||
|
return result
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,27 @@ from .types import Skill
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
UPLOADS_SKILLS_PATH = Path("/mnt/user-data/uploads")
|
||||||
|
|
||||||
|
|
||||||
|
def get_uploads_skills_path(thread_id: str | None = None) -> Path:
|
||||||
|
"""Resolve the uploads skills root for the current execution context.
|
||||||
|
|
||||||
|
When called from the LangGraph process, uploaded skills live under the
|
||||||
|
host-side per-thread data directory rather than the sandbox mount path.
|
||||||
|
"""
|
||||||
|
if not thread_id:
|
||||||
|
return UPLOADS_SKILLS_PATH
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
return get_paths().sandbox_uploads_dir(thread_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to resolve uploads skills path for thread %s: %s", thread_id, exc)
|
||||||
|
return UPLOADS_SKILLS_PATH
|
||||||
|
|
||||||
|
|
||||||
def get_skills_root_path() -> Path:
|
def get_skills_root_path() -> Path:
|
||||||
"""
|
"""
|
||||||
Get the root path of the skills directory.
|
Get the root path of the skills directory.
|
||||||
|
|
@ -22,12 +43,19 @@ def get_skills_root_path() -> Path:
|
||||||
return skills_dir
|
return skills_dir
|
||||||
|
|
||||||
|
|
||||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
|
def load_skills(
|
||||||
|
skills_path: Path | None = None,
|
||||||
|
use_config: bool = True,
|
||||||
|
enabled_only: bool = False,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
) -> list[Skill]:
|
||||||
"""
|
"""
|
||||||
Load all skills from the skills directory.
|
Load all skills from the skills directory.
|
||||||
|
|
||||||
Scans both public and custom skill directories, parsing SKILL.md files
|
Scans public/custom skill directories under the configured skills root,
|
||||||
to extract metadata. The enabled state is determined by the skills_state_config.json file.
|
and also scans uploaded skills under /mnt/user-data/uploads.
|
||||||
|
SKILL.md metadata is parsed and enabled state is derived from
|
||||||
|
skills_state_config.json.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skills_path: Optional custom path to skills directory.
|
skills_path: Optional custom path to skills directory.
|
||||||
|
|
@ -35,6 +63,8 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||||
Otherwise defaults to deer-flow/skills
|
Otherwise defaults to deer-flow/skills
|
||||||
use_config: Whether to load skills path from config (default: True)
|
use_config: Whether to load skills path from config (default: True)
|
||||||
enabled_only: If True, only return enabled skills (default: False)
|
enabled_only: If True, only return enabled skills (default: False)
|
||||||
|
thread_id: Optional thread ID used to resolve per-thread uploads skills
|
||||||
|
from the LangGraph host process
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Skill objects, sorted by name
|
List of Skill objects, sorted by name
|
||||||
|
|
@ -57,12 +87,22 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||||
|
|
||||||
skills = []
|
skills = []
|
||||||
|
|
||||||
# Scan public and custom directories
|
# Scan built-in roots and uploaded skills mounted in personal workspace.
|
||||||
for category in ["public", "custom"]:
|
scan_targets: list[tuple[str, Path]] = [
|
||||||
category_path = skills_path / category
|
("public", skills_path / "public"),
|
||||||
|
("custom", skills_path / "custom"),
|
||||||
|
("uploads", get_uploads_skills_path(thread_id)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for category, category_path in scan_targets:
|
||||||
|
logger.debug("Scanning %s skills under %s", category, category_path)
|
||||||
|
|
||||||
if not category_path.exists() or not category_path.is_dir():
|
if not category_path.exists() or not category_path.is_dir():
|
||||||
|
logger.debug("Skip %s scan: directory not found or not a directory (%s)", category, category_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
scanned_skill_dirs: list[str] = []
|
||||||
|
|
||||||
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
|
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
|
||||||
# Keep traversal deterministic and skip hidden directories.
|
# Keep traversal deterministic and skip hidden directories.
|
||||||
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
|
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
|
||||||
|
|
@ -71,11 +111,22 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||||
|
|
||||||
skill_file = Path(current_root) / "SKILL.md"
|
skill_file = Path(current_root) / "SKILL.md"
|
||||||
relative_path = skill_file.parent.relative_to(category_path)
|
relative_path = skill_file.parent.relative_to(category_path)
|
||||||
|
scanned_skill_dirs.append(relative_path.as_posix())
|
||||||
|
|
||||||
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
|
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
|
||||||
if skill:
|
if skill:
|
||||||
skills.append(skill)
|
skills.append(skill)
|
||||||
|
|
||||||
|
if scanned_skill_dirs:
|
||||||
|
logger.debug(
|
||||||
|
"%s scan found %d skill directories: %s",
|
||||||
|
category,
|
||||||
|
len(scanned_skill_dirs),
|
||||||
|
", ".join(sorted(scanned_skill_dirs)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("%s scan found no skill directories", category)
|
||||||
|
|
||||||
# Load skills state configuration and update enabled status
|
# Load skills state configuration and update enabled status
|
||||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||||
# to always read the latest configuration from disk. This ensures that changes
|
# to always read the latest configuration from disk. This ensures that changes
|
||||||
|
|
@ -86,6 +137,10 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||||
|
|
||||||
extensions_config = ExtensionsConfig.from_file()
|
extensions_config = ExtensionsConfig.from_file()
|
||||||
for skill in skills:
|
for skill in skills:
|
||||||
|
if skill.category == "uploads":
|
||||||
|
# Uploaded skills should be available by default for the current thread.
|
||||||
|
skill.enabled = True
|
||||||
|
continue
|
||||||
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
|
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If config loading fails, default to all enabled
|
# If config loading fails, default to all enabled
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skill_file: Path to the SKILL.md file
|
skill_file: Path to the SKILL.md file
|
||||||
category: Category of the skill ('public' or 'custom')
|
category: Category of the skill ('public', 'custom', or 'uploads')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Skill object if parsing succeeds, None otherwise
|
Skill object if parsing succeeds, None otherwise
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ class Skill:
|
||||||
skill_dir: Path
|
skill_dir: Path
|
||||||
skill_file: Path
|
skill_file: Path
|
||||||
relative_path: Path # Relative path from category root to skill directory
|
relative_path: Path # Relative path from category root to skill directory
|
||||||
category: str # 'public' or 'custom'
|
category: str # 'public', 'custom', or 'uploads'
|
||||||
enabled: bool = False # Whether this skill is enabled
|
enabled: bool = False # Whether this skill is enabled
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -31,6 +31,9 @@ class Skill:
|
||||||
Returns:
|
Returns:
|
||||||
Full container path to the skill directory
|
Full container path to the skill directory
|
||||||
"""
|
"""
|
||||||
|
if self.category == "uploads":
|
||||||
|
category_base = "/mnt/user-data/uploads"
|
||||||
|
else:
|
||||||
category_base = f"{container_base_path}/{self.category}"
|
category_base = f"{container_base_path}/{self.category}"
|
||||||
skill_path = self.skill_path
|
skill_path = self.skill_path
|
||||||
if skill_path:
|
if skill_path:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import json
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -1988,6 +1988,47 @@ class TestSlackSendRetry:
|
||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
|
|
||||||
|
class TestSlackAllowedUsers:
|
||||||
|
def test_numeric_allowed_users_match_string_event_user_id(self):
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
bus.publish_inbound = AsyncMock()
|
||||||
|
channel = SlackChannel(
|
||||||
|
bus=bus,
|
||||||
|
config={"allowed_users": [123456]},
|
||||||
|
)
|
||||||
|
channel._loop = MagicMock()
|
||||||
|
channel._loop.is_running.return_value = True
|
||||||
|
channel._add_reaction = MagicMock()
|
||||||
|
channel._send_running_reply = MagicMock()
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"user": "123456",
|
||||||
|
"text": "hello from slack",
|
||||||
|
"channel": "C123",
|
||||||
|
"ts": "1710000000.000100",
|
||||||
|
}
|
||||||
|
|
||||||
|
def submit_coro(coro, loop):
|
||||||
|
coro.close()
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.channels.slack.asyncio.run_coroutine_threadsafe",
|
||||||
|
side_effect=submit_coro,
|
||||||
|
) as submit:
|
||||||
|
channel._handle_message_event(event)
|
||||||
|
|
||||||
|
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
|
||||||
|
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
|
||||||
|
submit.assert_called_once()
|
||||||
|
inbound = bus.publish_inbound.call_args.args[0]
|
||||||
|
assert inbound.user_id == "123456"
|
||||||
|
assert inbound.chat_id == "C123"
|
||||||
|
assert inbound.text == "hello from slack"
|
||||||
|
|
||||||
def test_raises_after_all_retries_exhausted(self):
|
def test_raises_after_all_retries_exhausted(self):
|
||||||
from app.channels.slack import SlackChannel
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||||
thread_id="thread-1",
|
thread_id="thread-1",
|
||||||
agent_name="lead_agent",
|
agent_name="lead_agent",
|
||||||
correction_detected=True,
|
correction_detected=True,
|
||||||
|
reinforcement_detected=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch.object(queue, "_reset_timer"),
|
||||||
|
):
|
||||||
|
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||||
|
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||||
|
|
||||||
|
assert len(queue._queue) == 1
|
||||||
|
assert queue._queue[0].messages == ["second"]
|
||||||
|
assert queue._queue[0].reinforcement_detected is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
queue._queue = [
|
||||||
|
ConversationContext(
|
||||||
|
thread_id="thread-1",
|
||||||
|
messages=["conversation"],
|
||||||
|
agent_name="lead_agent",
|
||||||
|
reinforcement_detected=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
mock_updater = MagicMock()
|
||||||
|
mock_updater.update_memory.return_value = True
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||||
|
queue._process_queue()
|
||||||
|
|
||||||
|
mock_updater.update_memory.assert_called_once_with(
|
||||||
|
messages=["conversation"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name="lead_agent",
|
||||||
|
correction_detected=False,
|
||||||
|
reinforcement_detected=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.invoke.call_args[0][0]
|
||||||
assert "Explicit correction signals were detected" not in prompt
|
assert "Explicit correction signals were detected" not in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactDeduplicationCaseInsensitive:
|
||||||
|
"""Tests that fact deduplication is case-insensitive."""
|
||||||
|
|
||||||
|
def test_duplicate_fact_different_case_not_stored(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
current_memory = _make_memory(
|
||||||
|
facts=[
|
||||||
|
{
|
||||||
|
"id": "fact_1",
|
||||||
|
"content": "User prefers Python",
|
||||||
|
"category": "preference",
|
||||||
|
"confidence": 0.9,
|
||||||
|
"createdAt": "2026-01-01T00:00:00Z",
|
||||||
|
"source": "thread-a",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Same fact with different casing should be treated as duplicate
|
||||||
|
update_data = {
|
||||||
|
"factsToRemove": [],
|
||||||
|
"newFacts": [
|
||||||
|
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"deerflow.agents.memory.updater.get_memory_config",
|
||||||
|
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||||
|
):
|
||||||
|
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||||
|
|
||||||
|
# Should still have only 1 fact (duplicate rejected)
|
||||||
|
assert len(result["facts"]) == 1
|
||||||
|
assert result["facts"][0]["content"] == "User prefers Python"
|
||||||
|
|
||||||
|
def test_unique_fact_different_case_and_content_stored(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
current_memory = _make_memory(
|
||||||
|
facts=[
|
||||||
|
{
|
||||||
|
"id": "fact_1",
|
||||||
|
"content": "User prefers Python",
|
||||||
|
"category": "preference",
|
||||||
|
"confidence": 0.9,
|
||||||
|
"createdAt": "2026-01-01T00:00:00Z",
|
||||||
|
"source": "thread-a",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
update_data = {
|
||||||
|
"factsToRemove": [],
|
||||||
|
"newFacts": [
|
||||||
|
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"deerflow.agents.memory.updater.get_memory_config",
|
||||||
|
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||||
|
):
|
||||||
|
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||||
|
|
||||||
|
assert len(result["facts"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestReinforcementHint:
|
||||||
|
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mock_model(json_response: str):
|
||||||
|
model = MagicMock()
|
||||||
|
response = MagicMock()
|
||||||
|
response.content = f"```json\n{json_response}\n```"
|
||||||
|
model.invoke.return_value = response
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_reinforcement_hint_injected_when_detected(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Yes, exactly! That's what I needed."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Great to hear!"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
prompt = model.invoke.call_args[0][0]
|
||||||
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|
||||||
|
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Tell me more."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Sure."
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
prompt = model.invoke.call_args[0][0]
|
||||||
|
assert "Positive reinforcement signals were detected" not in prompt
|
||||||
|
|
||||||
|
def test_both_hints_present_when_both_detected(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "No wait, that's wrong. Actually yes, exactly right."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Got it."
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
prompt = model.invoke.call_args[0][0]
|
||||||
|
assert "Explicit correction signals were detected" in prompt
|
||||||
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
|
|
@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
|
||||||
mem = {"user": {}, "history": {}, "facts": []}
|
mem = {"user": {}, "history": {}, "facts": []}
|
||||||
result = _strip_upload_mentions_from_memory(mem)
|
result = _strip_upload_mentions_from_memory(mem)
|
||||||
assert result == {"user": {}, "history": {}, "facts": []}
|
assert result == {"user": {}, "history": {}, "facts": []}
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# detect_reinforcement
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectReinforcement:
|
||||||
|
def test_detects_english_reinforcement_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("Can you summarise it in bullet points?"),
|
||||||
|
_ai("Here are the key points: ..."),
|
||||||
|
_human("Yes, exactly! That's what I needed."),
|
||||||
|
_ai("Glad it helped."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is True
|
||||||
|
|
||||||
|
def test_detects_perfect_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("Write it more concisely."),
|
||||||
|
_ai("Here is the concise version."),
|
||||||
|
_human("Perfect."),
|
||||||
|
_ai("Great!"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is True
|
||||||
|
|
||||||
|
def test_detects_chinese_reinforcement_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("帮我用要点来总结"),
|
||||||
|
_ai("好的,要点如下:..."),
|
||||||
|
_human("完全正确,就是这个意思"),
|
||||||
|
_ai("很高兴能帮到你"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is True
|
||||||
|
|
||||||
|
def test_returns_false_without_signal(self):
|
||||||
|
msgs = [
|
||||||
|
_human("What does this function do?"),
|
||||||
|
_ai("It processes the input data."),
|
||||||
|
_human("Can you show me an example?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is False
|
||||||
|
|
||||||
|
def test_only_checks_recent_messages(self):
|
||||||
|
# Reinforcement signal buried beyond the -6 window should not trigger
|
||||||
|
msgs = [
|
||||||
|
_human("Yes, exactly right."),
|
||||||
|
_ai("Noted."),
|
||||||
|
_human("Let's discuss tests."),
|
||||||
|
_ai("Sure."),
|
||||||
|
_human("What about linting?"),
|
||||||
|
_ai("Use ruff."),
|
||||||
|
_human("And formatting?"),
|
||||||
|
_ai("Use make format."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is False
|
||||||
|
|
||||||
|
def test_does_not_conflict_with_correction(self):
|
||||||
|
# A message can trigger correction but not reinforcement
|
||||||
|
msgs = [
|
||||||
|
_human("That's wrong, try again."),
|
||||||
|
_ai("Corrected."),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert detect_reinforcement(msgs) is False
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,53 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
|
||||||
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
|
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_virtual_path_preserves_trailing_slash() -> None:
|
||||||
|
"""Trailing slash must survive virtual-to-actual path replacement.
|
||||||
|
|
||||||
|
Regression: '/mnt/user-data/workspace/' was previously returned without
|
||||||
|
the trailing slash, causing string concatenations like
|
||||||
|
output_dir + 'file.txt' to produce a missing-separator path.
|
||||||
|
"""
|
||||||
|
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
|
||||||
|
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
|
||||||
|
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
|
||||||
|
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
|
||||||
|
|
||||||
|
If actual_base uses backslash separators, appending '/' would produce a
|
||||||
|
mixed-separator path. The separator must match the style of actual_base.
|
||||||
|
"""
|
||||||
|
win_thread_data = {
|
||||||
|
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||||
|
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||||
|
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||||
|
}
|
||||||
|
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
|
||||||
|
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
|
||||||
|
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
|
||||||
|
"""Nested Windows-style subdirectories must keep backslashes throughout."""
|
||||||
|
win_thread_data = {
|
||||||
|
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||||
|
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||||
|
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||||
|
}
|
||||||
|
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
|
||||||
|
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
|
||||||
|
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||||
|
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||||
|
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||||
|
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||||
|
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||||
|
|
||||||
|
|
||||||
# ---------- mask_local_paths_in_output ----------
|
# ---------- mask_local_paths_in_output ----------
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -257,6 +304,22 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||||
|
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||||
|
validate_local_bash_command_paths(
|
||||||
|
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||||
|
_THREAD_DATA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||||
|
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||||
|
validate_local_bash_command_paths(
|
||||||
|
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||||
|
_THREAD_DATA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||||
validate_local_bash_command_paths(
|
validate_local_bash_command_paths(
|
||||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useCallback, useEffect, useState } from "react";
|
import { useSearchParams } from "next/navigation";
|
||||||
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
|
||||||
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||||
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
||||||
|
|
@ -24,15 +25,35 @@ import { Welcome } from "@/components/workspace/welcome";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
import { useNotification } from "@/core/notification/hooks";
|
import { useNotification } from "@/core/notification/hooks";
|
||||||
import { useThreadSettings } from "@/core/settings";
|
import { useThreadSettings } from "@/core/settings";
|
||||||
|
import { bootstrapRemoteSkill } from "@/core/skills";
|
||||||
import { useThreadStream } from "@/core/threads/hooks";
|
import { useThreadStream } from "@/core/threads/hooks";
|
||||||
import { textOfMessage } from "@/core/threads/utils";
|
import { textOfMessage } from "@/core/threads/utils";
|
||||||
|
import { uuid } from "@/core/utils/uuid";
|
||||||
import { env } from "@/env";
|
import { env } from "@/env";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
const UUID_REGEX =
|
||||||
|
/^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i;
|
||||||
|
|
||||||
export default function ChatPage() {
|
export default function ChatPage() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const [showFollowups, setShowFollowups] = useState(false);
|
const [showFollowups, setShowFollowups] = useState(false);
|
||||||
const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat();
|
const searchParams = useSearchParams();
|
||||||
|
const generatedThreadIdRef = useRef<string>("");
|
||||||
|
if (!generatedThreadIdRef.current) {
|
||||||
|
const queryThreadId = searchParams.get("thread_id")?.trim();
|
||||||
|
generatedThreadIdRef.current =
|
||||||
|
queryThreadId && UUID_REGEX.test(queryThreadId) ? queryThreadId : uuid();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 xclaw_used 参数,仅用于界面风格控制,不影响线程创建逻辑
|
||||||
|
const xclawUsedParam = searchParams.get("xclaw_used");
|
||||||
|
const initialForceNewStyle = xclawUsedParam === "false";
|
||||||
|
const [forceNewStyle, setForceNewStyle] = useState(initialForceNewStyle);
|
||||||
|
|
||||||
|
const { threadId, isNewThread, setIsNewThread, isMock } = useThreadChat({
|
||||||
|
newThreadId: generatedThreadIdRef.current,
|
||||||
|
});
|
||||||
const [settings, setSettings] = useThreadSettings(threadId);
|
const [settings, setSettings] = useThreadSettings(threadId);
|
||||||
const [mounted, setMounted] = useState(false);
|
const [mounted, setMounted] = useState(false);
|
||||||
useSpecificChatMode();
|
useSpecificChatMode();
|
||||||
|
|
@ -42,6 +63,34 @@ export default function ChatPage() {
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const { showNotification } = useNotification();
|
const { showNotification } = useNotification();
|
||||||
|
const skillBootstrappedKeysRef = useRef<Set<string>>(new Set());
|
||||||
|
const skillBootstrappingKeysRef = useRef<Set<string>>(new Set());
|
||||||
|
|
||||||
|
const skillBootstrap = useMemo(() => {
|
||||||
|
const skillIdRaw = searchParams.get("skill_id")?.trim();
|
||||||
|
if (!skillIdRaw) return undefined;
|
||||||
|
|
||||||
|
const contentIds = skillIdRaw
|
||||||
|
.split(",")
|
||||||
|
.map((value) => value.trim())
|
||||||
|
.filter((value) => value.length > 0)
|
||||||
|
.map((value) => Number(value))
|
||||||
|
.filter((value) => Number.isFinite(value));
|
||||||
|
|
||||||
|
// Deduplicate while preserving incoming order.
|
||||||
|
const uniqueContentIds = Array.from(new Set(contentIds));
|
||||||
|
if (uniqueContentIds.length === 0) return undefined;
|
||||||
|
|
||||||
|
const languageTypeRaw =
|
||||||
|
searchParams.get("languageType")?.trim() ??
|
||||||
|
searchParams.get("language_type")?.trim();
|
||||||
|
const languageType = languageTypeRaw ? Number(languageTypeRaw) : 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
contentIds: uniqueContentIds,
|
||||||
|
languageType: Number.isFinite(languageType) ? languageType : 0,
|
||||||
|
};
|
||||||
|
}, [searchParams]);
|
||||||
|
|
||||||
const [thread, sendMessage, isUploading] = useThreadStream({
|
const [thread, sendMessage, isUploading] = useThreadStream({
|
||||||
threadId: isNewThread ? undefined : threadId,
|
threadId: isNewThread ? undefined : threadId,
|
||||||
|
|
@ -70,11 +119,54 @@ export default function ChatPage() {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!threadId || !skillBootstrap?.contentIds?.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const languageType = skillBootstrap.languageType ?? 0;
|
||||||
|
const initKey = `${threadId}:${skillBootstrap.contentIds.join(",")}:${languageType}`;
|
||||||
|
if (
|
||||||
|
skillBootstrappedKeysRef.current.has(initKey) ||
|
||||||
|
skillBootstrappingKeysRef.current.has(initKey)
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
skillBootstrappingKeysRef.current.add(initKey);
|
||||||
|
|
||||||
|
const runBootstrap = async () => {
|
||||||
|
try {
|
||||||
|
await bootstrapRemoteSkill({
|
||||||
|
thread_id: threadId,
|
||||||
|
content_ids: skillBootstrap.contentIds,
|
||||||
|
language_type: languageType,
|
||||||
|
target_dir: "/mnt/user-data/uploads/skill",
|
||||||
|
clear_target: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
skillBootstrappedKeysRef.current.add(initKey);
|
||||||
|
} catch (error) {
|
||||||
|
const message =
|
||||||
|
error instanceof Error ? error.message : "Skill initialization failed";
|
||||||
|
showNotification("Skill initialization failed", { body: message });
|
||||||
|
} finally {
|
||||||
|
skillBootstrappingKeysRef.current.delete(initKey);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void runBootstrap();
|
||||||
|
}, [threadId, skillBootstrap, showNotification]);
|
||||||
|
|
||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
(message: PromptInputMessage) => {
|
(message: PromptInputMessage) => {
|
||||||
void sendMessage(threadId, message);
|
void sendMessage(threadId, message);
|
||||||
|
// 仅切换界面风格,不影响线程状态
|
||||||
|
if (forceNewStyle) {
|
||||||
|
setForceNewStyle(false);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[sendMessage, threadId],
|
[sendMessage, threadId, forceNewStyle],
|
||||||
);
|
);
|
||||||
const handleStop = useCallback(async () => {
|
const handleStop = useCallback(async () => {
|
||||||
await thread.stop();
|
await thread.stop();
|
||||||
|
|
@ -92,7 +184,7 @@ export default function ChatPage() {
|
||||||
<header
|
<header
|
||||||
className={cn(
|
className={cn(
|
||||||
"absolute top-0 right-0 left-0 z-30 flex h-12 shrink-0 items-center px-4",
|
"absolute top-0 right-0 left-0 z-30 flex h-12 shrink-0 items-center px-4",
|
||||||
isNewThread
|
(forceNewStyle || isNewThread)
|
||||||
? "bg-background/0 backdrop-blur-none"
|
? "bg-background/0 backdrop-blur-none"
|
||||||
: "bg-background/80 shadow-xs backdrop-blur",
|
: "bg-background/80 shadow-xs backdrop-blur",
|
||||||
)}
|
)}
|
||||||
|
|
@ -108,19 +200,22 @@ export default function ChatPage() {
|
||||||
</header>
|
</header>
|
||||||
<main className="flex min-h-0 max-w-full grow flex-col">
|
<main className="flex min-h-0 max-w-full grow flex-col">
|
||||||
<div className="flex size-full justify-center">
|
<div className="flex size-full justify-center">
|
||||||
|
{/* forceNewStyle 时隐藏消息列表,提交后再显示 */}
|
||||||
|
{!(forceNewStyle) && (
|
||||||
<MessageList
|
<MessageList
|
||||||
className={cn("size-full", !isNewThread && "pt-10")}
|
className={cn("size-full", !isNewThread && "pt-10")}
|
||||||
threadId={threadId}
|
threadId={threadId}
|
||||||
thread={thread}
|
thread={thread}
|
||||||
paddingBottom={messageListPaddingBottom}
|
paddingBottom={messageListPaddingBottom}
|
||||||
/>
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"relative w-full",
|
"relative w-full",
|
||||||
isNewThread && "-translate-y-[calc(50vh-96px)]",
|
(forceNewStyle || isNewThread) && "-translate-y-[calc(50vh-96px)]",
|
||||||
isNewThread
|
(forceNewStyle || isNewThread)
|
||||||
? "max-w-(--container-width-sm)"
|
? "max-w-(--container-width-sm)"
|
||||||
: "max-w-(--container-width-md)",
|
: "max-w-(--container-width-md)",
|
||||||
)}
|
)}
|
||||||
|
|
@ -139,9 +234,9 @@ export default function ChatPage() {
|
||||||
{mounted ? (
|
{mounted ? (
|
||||||
<InputBox
|
<InputBox
|
||||||
className={cn("bg-background/5 w-full -translate-y-4")}
|
className={cn("bg-background/5 w-full -translate-y-4")}
|
||||||
isNewThread={isNewThread}
|
isNewThread={forceNewStyle || isNewThread}
|
||||||
threadId={threadId}
|
threadId={threadId}
|
||||||
autoFocus={isNewThread}
|
autoFocus={forceNewStyle || isNewThread}
|
||||||
status={
|
status={
|
||||||
thread.error
|
thread.error
|
||||||
? "error"
|
? "error"
|
||||||
|
|
@ -151,7 +246,7 @@ export default function ChatPage() {
|
||||||
}
|
}
|
||||||
context={settings.context}
|
context={settings.context}
|
||||||
extraHeader={
|
extraHeader={
|
||||||
isNewThread && <Welcome mode={settings.context.mode} />
|
(forceNewStyle || isNewThread) && <Welcome mode={settings.context.mode} />
|
||||||
}
|
}
|
||||||
disabled={
|
disabled={
|
||||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,23 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useParams, usePathname, useSearchParams } from "next/navigation";
|
import { useParams, usePathname, useSearchParams } from "next/navigation";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import { uuid } from "@/core/utils/uuid";
|
import { uuid } from "@/core/utils/uuid";
|
||||||
|
|
||||||
export function useThreadChat() {
|
type UseThreadChatOptions = {
|
||||||
|
newThreadId?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useThreadChat(options?: UseThreadChatOptions) {
|
||||||
const { thread_id: threadIdFromPath } = useParams<{ thread_id: string }>();
|
const { thread_id: threadIdFromPath } = useParams<{ thread_id: string }>();
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
|
const fallbackNewThreadIdRef = useRef<string>(options?.newThreadId ?? uuid());
|
||||||
|
const fallbackNewThreadId = options?.newThreadId ?? fallbackNewThreadIdRef.current;
|
||||||
|
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const [threadId, setThreadId] = useState(() => {
|
const [threadId, setThreadId] = useState(() => {
|
||||||
return threadIdFromPath === "new" ? uuid() : threadIdFromPath;
|
return threadIdFromPath === "new" ? fallbackNewThreadId : threadIdFromPath;
|
||||||
});
|
});
|
||||||
|
|
||||||
const [isNewThread, setIsNewThread] = useState(
|
const [isNewThread, setIsNewThread] = useState(
|
||||||
|
|
@ -21,9 +27,9 @@ export function useThreadChat() {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (pathname.endsWith("/new")) {
|
if (pathname.endsWith("/new")) {
|
||||||
setIsNewThread(true);
|
setIsNewThread(true);
|
||||||
setThreadId(uuid());
|
setThreadId(fallbackNewThreadId);
|
||||||
}
|
}
|
||||||
}, [pathname]);
|
}, [pathname, fallbackNewThreadId]);
|
||||||
const isMock = searchParams.get("mock") === "true";
|
const isMock = searchParams.get("mock") === "true";
|
||||||
return { threadId, isNewThread, setIsNewThread, isMock };
|
return { threadId, isNewThread, setIsNewThread, isMock };
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,10 @@ import { SettingsDialog } from "./settings";
|
||||||
export function CommandPalette() {
|
export function CommandPalette() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const [mounted, setMounted] = useState(false);
|
|
||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
const [shortcutsOpen, setShortcutsOpen] = useState(false);
|
const [shortcutsOpen, setShortcutsOpen] = useState(false);
|
||||||
const [settingsOpen, setSettingsOpen] = useState(false);
|
const [settingsOpen, setSettingsOpen] = useState(false);
|
||||||
|
const [isMac, setIsMac] = useState(false);
|
||||||
useEffect(() => {
|
|
||||||
setMounted(true);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleNewChat = useCallback(() => {
|
const handleNewChat = useCallback(() => {
|
||||||
router.push("/workspace/chats/new");
|
router.push("/workspace/chats/new");
|
||||||
|
|
@ -68,14 +64,12 @@ export function CommandPalette() {
|
||||||
|
|
||||||
useGlobalShortcuts(shortcuts);
|
useGlobalShortcuts(shortcuts);
|
||||||
|
|
||||||
const isMac = mounted && navigator.userAgent.includes("Mac");
|
useEffect(() => {
|
||||||
|
setIsMac(navigator.userAgent.includes("Mac"));
|
||||||
|
}, []);
|
||||||
const metaKey = isMac ? "⌘" : "Ctrl+";
|
const metaKey = isMac ? "⌘" : "Ctrl+";
|
||||||
const shiftKey = isMac ? "⇧" : "Shift+";
|
const shiftKey = isMac ? "⇧" : "Shift+";
|
||||||
|
|
||||||
if (!mounted) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<SettingsDialog open={settingsOpen} onOpenChange={setSettingsOpen} />
|
<SettingsDialog open={settingsOpen} onOpenChange={setSettingsOpen} />
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,24 @@ export interface InstallSkillResponse {
|
||||||
message: string;
|
message: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface BootstrapRemoteSkillRequest {
|
||||||
|
thread_id: string;
|
||||||
|
content_ids: number[];
|
||||||
|
language_type?: number;
|
||||||
|
target_dir?: string;
|
||||||
|
clear_target?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BootstrapRemoteSkillResponse {
|
||||||
|
success: boolean;
|
||||||
|
target_dir: string;
|
||||||
|
content_ids: number[];
|
||||||
|
created_directories: number;
|
||||||
|
created_files: number;
|
||||||
|
sandbox_id: string | null;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
export async function installSkill(
|
export async function installSkill(
|
||||||
request: InstallSkillRequest,
|
request: InstallSkillRequest,
|
||||||
): Promise<InstallSkillResponse> {
|
): Promise<InstallSkillResponse> {
|
||||||
|
|
@ -60,3 +78,27 @@ export async function installSkill(
|
||||||
|
|
||||||
return response.json();
|
return response.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function bootstrapRemoteSkill(
|
||||||
|
request: BootstrapRemoteSkillRequest,
|
||||||
|
): Promise<BootstrapRemoteSkillResponse> {
|
||||||
|
const response = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/skills/bootstrap-remote`,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(request),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json().catch(() => ({}));
|
||||||
|
const errorMessage =
|
||||||
|
errorData.detail ?? `HTTP ${response.status}: ${response.statusText}`;
|
||||||
|
throw new Error(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue