517 lines
19 KiB
Python
517 lines
19 KiB
Python
"""Run lifecycle service layer.
|
||
|
||
Centralizes the business logic for creating runs, formatting SSE
|
||
frames, and consuming stream bridge events. Router modules
|
||
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import time
|
||
from typing import Any
|
||
|
||
from fastapi import HTTPException, Request
|
||
from langchain_core.messages import HumanMessage
|
||
from openai import AsyncOpenAI
|
||
|
||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||
from deerflow.runtime import (
|
||
END_SENTINEL,
|
||
HEARTBEAT_SENTINEL,
|
||
ConflictError,
|
||
DisconnectMode,
|
||
RunManager,
|
||
RunRecord,
|
||
RunStatus,
|
||
StreamBridge,
|
||
UnsupportedStrategyError,
|
||
run_agent,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
# 预处理提示词的大模型
|
||
|
||
PPT_INSUFFICIENT_INFO_FORWARD = "用户想生成ppt,但是没有输入足够多的信息,所以先向用户询问更多信息"
|
||
PPT_SELECTOR_SYSTEM_PROMPT = """#PPT
|
||
你是 PPT 技能选择器,严格执行以下流程:
|
||
用户输入生成 PPT 相关指令后,询问:你需要使用哪个生成 PPT 的技能?可选技能:1. ppt_gen_html(生成 HTML 形式 PPT)2. ppt_gen_reference(根据文档生成 PPT)
|
||
记住用户最初的 PPT 指令。
|
||
用户选择技能后,仅输出固定语句,无任何多余内容:
|
||
选 ppt_gen_html:{user_input},使用 ppt_gen_html 这个 skill 来完成
|
||
选 ppt_gen_reference:{user_input},使用 ppt_gen_reference 这个 skill 来完成
|
||
注:“{user_input}” 特指用户最初输入的 PPT 制作指令,非选择回复。"""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SSE formatting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
||
"""Format a single SSE frame.
|
||
|
||
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
||
This matches the LangGraph Platform wire format consumed by the
|
||
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
||
"""
|
||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||
parts = [f"event: {event}", f"data: {payload}"]
|
||
if event_id:
|
||
parts.append(f"id: {event_id}")
|
||
parts.append("")
|
||
parts.append("")
|
||
return "\n".join(parts)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Input / config helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
||
"""Normalize the stream_mode parameter to a list.
|
||
|
||
Default matches what ``useStream`` expects: values + messages-tuple.
|
||
"""
|
||
if raw is None:
|
||
return ["values"]
|
||
if isinstance(raw, str):
|
||
return [raw]
|
||
return raw if raw else ["values"]
|
||
|
||
|
||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
||
if raw_input is None:
|
||
return {}
|
||
messages = raw_input.get("messages")
|
||
if messages and isinstance(messages, list):
|
||
converted = []
|
||
for msg in messages:
|
||
if isinstance(msg, dict):
|
||
role = msg.get("role", msg.get("type", "user"))
|
||
content = msg.get("content", "")
|
||
if role in ("user", "human"):
|
||
converted.append(HumanMessage(content=content))
|
||
else:
|
||
# TODO: handle other message types (system, ai, tool)
|
||
converted.append(HumanMessage(content=content))
|
||
else:
|
||
converted.append(msg)
|
||
return {**raw_input, "messages": converted}
|
||
return raw_input
|
||
|
||
|
||
def _extract_text_content(content: Any) -> str:
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
text = item.get("text")
|
||
if isinstance(text, str) and text.strip():
|
||
parts.append(text.strip())
|
||
elif isinstance(item, str) and item.strip():
|
||
parts.append(item.strip())
|
||
return "\n".join(parts)
|
||
return str(content or "")
|
||
|
||
|
||
def _extract_last_human_text(graph_input: dict[str, Any]) -> str:
|
||
messages = graph_input.get("messages")
|
||
if not isinstance(messages, list):
|
||
return ""
|
||
for msg in reversed(messages):
|
||
if isinstance(msg, HumanMessage):
|
||
return _extract_text_content(msg.content).strip()
|
||
if isinstance(msg, dict):
|
||
role = str(msg.get("role", msg.get("type", ""))).lower()
|
||
if role in {"user", "human"}:
|
||
return _extract_text_content(msg.get("content")).strip()
|
||
return ""
|
||
|
||
|
||
def _is_ppt_request(text: str) -> bool:
|
||
lowered = text.lower()
|
||
return any(token in lowered for token in ("ppt", "slides", "powerpoint", "幻灯片", "演示文稿"))
|
||
|
||
|
||
def _heuristic_has_enough_ppt_info(text: str) -> bool:
|
||
lowered = text.lower()
|
||
if len(lowered.strip()) < 12:
|
||
return False
|
||
|
||
score = 0
|
||
if len(lowered) >= 24:
|
||
score += 1
|
||
if re.search(r"(关于|主题|topic|题目|on\s+)", lowered):
|
||
score += 1
|
||
if re.search(r"(面向|给|用于|目的|audience|for\s+)", lowered):
|
||
score += 1
|
||
if re.search(r"(\d+\s*(页|p|slides?)|大纲|目录|章节|结构)", lowered):
|
||
score += 1
|
||
if re.search(r"(风格|配色|模板|视觉|语气|style|tone)", lowered):
|
||
score += 1
|
||
if re.search(r"(根据|参考|数据|附件|文档|material|reference)", lowered):
|
||
score += 1
|
||
return score >= 2
|
||
|
||
|
||
async def _deepseek_ppt_info_check(user_text: str) -> bool:
|
||
enabled = os.getenv("PPT_PRECHECK_ENABLED", "true").strip().lower()
|
||
if enabled in {"0", "false", "off", "no"}:
|
||
return True
|
||
|
||
base_url = os.getenv("PPT_PRECHECK_BASE_URL", "").strip()
|
||
api_key = os.getenv("PPT_PRECHECK_API_KEY", "").strip()
|
||
model = os.getenv("PPT_PRECHECK_MODEL", "deepseek-chat").strip()
|
||
timeout_s = float(os.getenv("PPT_PRECHECK_TIMEOUT_SECONDS", "10").strip() or "10")
|
||
|
||
if not base_url or not api_key:
|
||
return _heuristic_has_enough_ppt_info(user_text)
|
||
|
||
check_instruction = (
|
||
"你现在只做“PPT信息是否足够”的判断,不做技能追问。"
|
||
"判断标准:至少包含主题 + 另一个关键信息(受众/用途/页数或结构/风格/参考资料)。"
|
||
"仅输出一个词:ENOUGH 或 INSUFFICIENT。"
|
||
)
|
||
system_prompt = f"{PPT_SELECTOR_SYSTEM_PROMPT}\n\n{check_instruction}"
|
||
|
||
try:
|
||
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s)
|
||
resp = await client.chat.completions.create(
|
||
model=model,
|
||
temperature=0,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_text},
|
||
],
|
||
)
|
||
content = (resp.choices[0].message.content or "").strip().upper()
|
||
if "INSUFFICIENT" in content:
|
||
return False
|
||
if "ENOUGH" in content:
|
||
return True
|
||
logger.warning("PPT precheck unexpected output: %r; fallback to heuristic", content)
|
||
except Exception:
|
||
logger.warning("PPT precheck via DeepSeek failed; fallback to heuristic", exc_info=True)
|
||
|
||
return _heuristic_has_enough_ppt_info(user_text)
|
||
|
||
|
||
def _overwrite_last_human_message(graph_input: dict[str, Any], text: str) -> None:
|
||
messages = graph_input.get("messages")
|
||
if not isinstance(messages, list):
|
||
graph_input["messages"] = [HumanMessage(content=text)]
|
||
return
|
||
|
||
for idx in range(len(messages) - 1, -1, -1):
|
||
msg = messages[idx]
|
||
if isinstance(msg, HumanMessage):
|
||
msg.content = text
|
||
return
|
||
if isinstance(msg, dict):
|
||
role = str(msg.get("role", msg.get("type", ""))).lower()
|
||
if role in {"user", "human"}:
|
||
msg["content"] = text
|
||
return
|
||
|
||
messages.append(HumanMessage(content=text))
|
||
|
||
|
||
async def _maybe_apply_ppt_precheck(graph_input: dict[str, Any]) -> None:
|
||
user_text = _extract_last_human_text(graph_input)
|
||
if not user_text or not _is_ppt_request(user_text):
|
||
return
|
||
|
||
enough = await _deepseek_ppt_info_check(user_text)
|
||
if enough:
|
||
return
|
||
|
||
_overwrite_last_human_message(graph_input, PPT_INSUFFICIENT_INFO_FORWARD)
|
||
logger.info("PPT precheck flagged insufficient info; forwarded clarification instruction")
|
||
|
||
|
||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||
|
||
|
||
def resolve_agent_factory(assistant_id: str | None):
|
||
"""Resolve the agent factory callable from config.
|
||
|
||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
||
injected into ``configurable`` — see :func:`build_run_config`. All
|
||
``assistant_id`` values therefore map to the same factory; the routing
|
||
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
||
"""
|
||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||
|
||
return make_lead_agent
|
||
|
||
|
||
def build_run_config(
|
||
thread_id: str,
|
||
request_config: dict[str, Any] | None,
|
||
metadata: dict[str, Any] | None,
|
||
*,
|
||
assistant_id: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Build a RunnableConfig dict for the agent.
|
||
|
||
When *assistant_id* refers to a custom agent (anything other than
|
||
``"lead_agent"`` / ``None``), the name is forwarded as
|
||
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
||
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||
without it the agent silently runs as the default lead agent.
|
||
|
||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||
identically.
|
||
"""
|
||
config: dict[str, Any] = {"recursion_limit": 100}
|
||
if request_config:
|
||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
||
# pass thread-level data and rejects requests that include both
|
||
# ``configurable`` and ``context``. If the caller already sends
|
||
# ``context``, honour it and skip our own ``configurable`` dict.
|
||
if "context" in request_config:
|
||
if "configurable" in request_config:
|
||
logger.warning(
|
||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
||
thread_id,
|
||
list(request_config.get("configurable", {}).keys()),
|
||
)
|
||
config["context"] = request_config["context"]
|
||
else:
|
||
configurable = {"thread_id": thread_id}
|
||
configurable.update(request_config.get("configurable", {}))
|
||
config["configurable"] = configurable
|
||
for k, v in request_config.items():
|
||
if k not in ("configurable", "context"):
|
||
config[k] = v
|
||
else:
|
||
config["configurable"] = {"thread_id": thread_id}
|
||
|
||
# Inject custom agent name when the caller specified a non-default assistant.
|
||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||
if "agent_name" not in config["configurable"]:
|
||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||
config["configurable"]["agent_name"] = normalized
|
||
if metadata:
|
||
config.setdefault("metadata", {}).update(metadata)
|
||
return config
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Run lifecycle
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
|
||
"""Create or refresh the thread record in the Store.
|
||
|
||
Called from :func:`start_run` so that threads created via the stateless
|
||
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
|
||
appear in ``/threads/search`` results.
|
||
"""
|
||
# Deferred import to avoid circular import with the threads router module.
|
||
from app.gateway.routers.threads import _store_upsert
|
||
|
||
try:
|
||
await _store_upsert(store, thread_id, metadata=metadata)
|
||
except Exception:
|
||
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
|
||
|
||
|
||
async def _sync_thread_title_after_run(
|
||
run_task: asyncio.Task,
|
||
thread_id: str,
|
||
checkpointer: Any,
|
||
store: Any,
|
||
) -> None:
|
||
"""Wait for *run_task* to finish, then persist the generated title to the Store.
|
||
|
||
TitleMiddleware writes the generated title to the LangGraph agent state
|
||
(checkpointer) but the Gateway's Store record is not updated automatically.
|
||
This coroutine closes that gap by reading the final checkpoint after the
|
||
run completes and syncing ``values.title`` into the Store record so that
|
||
subsequent ``/threads/search`` responses include the correct title.
|
||
|
||
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
|
||
logged at DEBUG level and never propagate.
|
||
"""
|
||
# Wait for the background run task to complete (any outcome).
|
||
# asyncio.wait does not propagate task exceptions — it just returns
|
||
# when the task is done, cancelled, or failed.
|
||
await asyncio.wait({run_task})
|
||
|
||
# Deferred import to avoid circular import with the threads router module.
|
||
from app.gateway.routers.threads import _store_get, _store_put
|
||
|
||
try:
|
||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||
if ckpt_tuple is None:
|
||
return
|
||
|
||
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
|
||
title = channel_values.get("title")
|
||
if not title:
|
||
return
|
||
|
||
existing = await _store_get(store, thread_id)
|
||
if existing is None:
|
||
return
|
||
|
||
updated = dict(existing)
|
||
updated.setdefault("values", {})["title"] = title
|
||
updated["updated_at"] = time.time()
|
||
await _store_put(store, updated)
|
||
logger.debug("Synced title %r for thread %s", title, thread_id)
|
||
except Exception:
|
||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
|
||
|
||
|
||
async def start_run(
|
||
body: Any,
|
||
thread_id: str,
|
||
request: Request,
|
||
) -> RunRecord:
|
||
"""Create a RunRecord and launch the background agent task.
|
||
|
||
Parameters
|
||
----------
|
||
body : RunCreateRequest
|
||
The validated request body (typed as Any to avoid circular import
|
||
with the router module that defines the Pydantic model).
|
||
thread_id : str
|
||
Target thread.
|
||
request : Request
|
||
FastAPI request — used to retrieve singletons from ``app.state``.
|
||
"""
|
||
bridge = get_stream_bridge(request)
|
||
run_mgr = get_run_manager(request)
|
||
checkpointer = get_checkpointer(request)
|
||
store = get_store(request)
|
||
|
||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||
|
||
try:
|
||
record = await run_mgr.create_or_reject(
|
||
thread_id,
|
||
body.assistant_id,
|
||
on_disconnect=disconnect,
|
||
metadata=body.metadata or {},
|
||
kwargs={"input": body.input, "config": body.config},
|
||
multitask_strategy=body.multitask_strategy,
|
||
)
|
||
except ConflictError as exc:
|
||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||
except UnsupportedStrategyError as exc:
|
||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||
|
||
# Ensure the thread is visible in /threads/search, even for threads that
|
||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
||
store = get_store(request)
|
||
if store is not None:
|
||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
||
|
||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||
graph_input = normalize_input(body.input)
|
||
await _maybe_apply_ppt_precheck(graph_input)
|
||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||
|
||
if "configurable" in config and isinstance(config["configurable"], dict):
|
||
config["configurable"].setdefault("run_id", record.run_id)
|
||
if "context" in config and isinstance(config["context"], dict):
|
||
config["context"].setdefault("run_id", record.run_id)
|
||
|
||
# Merge DeerFlow-specific context overrides into configurable.
|
||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||
context = getattr(body, "context", None)
|
||
if context:
|
||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||
"model_name",
|
||
"mode",
|
||
"thinking_enabled",
|
||
"reasoning_effort",
|
||
"is_plan_mode",
|
||
"subagent_enabled",
|
||
"max_concurrent_subagents",
|
||
}
|
||
configurable = config.setdefault("configurable", {})
|
||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||
if key in context:
|
||
configurable.setdefault(key, context[key])
|
||
|
||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||
|
||
task = asyncio.create_task(
|
||
run_agent(
|
||
bridge,
|
||
run_mgr,
|
||
record,
|
||
checkpointer=checkpointer,
|
||
store=store,
|
||
agent_factory=agent_factory,
|
||
graph_input=graph_input,
|
||
config=config,
|
||
stream_modes=stream_modes,
|
||
stream_subgraphs=body.stream_subgraphs,
|
||
interrupt_before=body.interrupt_before,
|
||
interrupt_after=body.interrupt_after,
|
||
)
|
||
)
|
||
record.task = task
|
||
|
||
# After the run completes, sync the title generated by TitleMiddleware from
|
||
# the checkpointer into the Store record so that /threads/search returns the
|
||
# correct title instead of an empty values dict.
|
||
if store is not None:
|
||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
||
|
||
return record
|
||
|
||
|
||
async def sse_consumer(
|
||
bridge: StreamBridge,
|
||
record: RunRecord,
|
||
request: Request,
|
||
run_mgr: RunManager,
|
||
):
|
||
"""Async generator that yields SSE frames from the bridge.
|
||
|
||
The ``finally`` block implements ``on_disconnect`` semantics:
|
||
- ``cancel``: abort the background task on client disconnect.
|
||
- ``continue``: let the task run; events are discarded.
|
||
"""
|
||
try:
|
||
async for entry in bridge.subscribe(record.run_id):
|
||
if await request.is_disconnected():
|
||
break
|
||
|
||
if entry is HEARTBEAT_SENTINEL:
|
||
yield ": heartbeat\n\n"
|
||
continue
|
||
|
||
if entry is END_SENTINEL:
|
||
yield format_sse("end", None, event_id=entry.id or None)
|
||
return
|
||
|
||
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
||
|
||
finally:
|
||
if record.status in (RunStatus.pending, RunStatus.running):
|
||
if record.on_disconnect == DisconnectMode.cancel:
|
||
await run_mgr.cancel(record.run_id)
|