feat(backend): 生成PPT时预处理用户提示词

This commit is contained in:
肖应宇 2026-04-20 14:55:19 +08:00
parent 040b107647
commit f0d93ab342
2 changed files with 197 additions and 0 deletions

View File

@ -10,12 +10,14 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import os
import re import re
import time import time
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from openai import AsyncOpenAI
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import ( from deerflow.runtime import (
@ -32,6 +34,17 @@ from deerflow.runtime import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 预处理提示词的大模型
PPT_INSUFFICIENT_INFO_FORWARD = "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
PPT_SELECTOR_SYSTEM_PROMPT = """#PPT
你是 PPT 技能选择器严格执行以下流程
用户输入生成 PPT 相关指令后询问你需要使用哪个生成 PPT 的技能可选技能1. ppt_gen_html生成 HTML 形式 PPT2. ppt_gen_reference根据文档生成 PPT
记住用户最初的 PPT 指令
用户选择技能后仅输出固定语句无任何多余内容
ppt_gen_html{user_input}使用 ppt_gen_html 这个 skill 来完成
ppt_gen_reference{user_input}使用 ppt_gen_reference 这个 skill 来完成
{user_input} 特指用户最初输入的 PPT 制作指令非选择回复"""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -94,6 +107,137 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
return raw_input return raw_input
def _extract_text_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
elif isinstance(item, str) and item.strip():
parts.append(item.strip())
return "\n".join(parts)
return str(content or "")
def _extract_last_human_text(graph_input: dict[str, Any]) -> str:
messages = graph_input.get("messages")
if not isinstance(messages, list):
return ""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return _extract_text_content(msg.content).strip()
if isinstance(msg, dict):
role = str(msg.get("role", msg.get("type", ""))).lower()
if role in {"user", "human"}:
return _extract_text_content(msg.get("content")).strip()
return ""
def _is_ppt_request(text: str) -> bool:
lowered = text.lower()
return any(token in lowered for token in ("ppt", "slides", "powerpoint", "幻灯片", "演示文稿"))
def _heuristic_has_enough_ppt_info(text: str) -> bool:
lowered = text.lower()
if len(lowered.strip()) < 12:
return False
score = 0
if len(lowered) >= 24:
score += 1
if re.search(r"(关于|主题|topic|题目|on\s+)", lowered):
score += 1
if re.search(r"(面向|给|用于|目的|audience|for\s+)", lowered):
score += 1
if re.search(r"(\d+\s*(页|p|slides?)|大纲|目录|章节|结构)", lowered):
score += 1
if re.search(r"(风格|配色|模板|视觉|语气|style|tone)", lowered):
score += 1
if re.search(r"(根据|参考|数据|附件|文档|material|reference)", lowered):
score += 1
return score >= 2
async def _deepseek_ppt_info_check(user_text: str) -> bool:
enabled = os.getenv("PPT_PRECHECK_ENABLED", "true").strip().lower()
if enabled in {"0", "false", "off", "no"}:
return True
base_url = os.getenv("PPT_PRECHECK_BASE_URL", "").strip()
api_key = os.getenv("PPT_PRECHECK_API_KEY", "").strip()
model = os.getenv("PPT_PRECHECK_MODEL", "deepseek-chat").strip()
timeout_s = float(os.getenv("PPT_PRECHECK_TIMEOUT_SECONDS", "10").strip() or "10")
if not base_url or not api_key:
return _heuristic_has_enough_ppt_info(user_text)
check_instruction = (
"你现在只做“PPT信息是否足够”的判断不做技能追问。"
"判断标准:至少包含主题 + 另一个关键信息(受众/用途/页数或结构/风格/参考资料)。"
"仅输出一个词ENOUGH 或 INSUFFICIENT。"
)
system_prompt = f"{PPT_SELECTOR_SYSTEM_PROMPT}\n\n{check_instruction}"
try:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s)
resp = await client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_text},
],
)
content = (resp.choices[0].message.content or "").strip().upper()
if "INSUFFICIENT" in content:
return False
if "ENOUGH" in content:
return True
logger.warning("PPT precheck unexpected output: %r; fallback to heuristic", content)
except Exception:
logger.warning("PPT precheck via DeepSeek failed; fallback to heuristic", exc_info=True)
return _heuristic_has_enough_ppt_info(user_text)
def _overwrite_last_human_message(graph_input: dict[str, Any], text: str) -> None:
messages = graph_input.get("messages")
if not isinstance(messages, list):
graph_input["messages"] = [HumanMessage(content=text)]
return
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if isinstance(msg, HumanMessage):
msg.content = text
return
if isinstance(msg, dict):
role = str(msg.get("role", msg.get("type", ""))).lower()
if role in {"user", "human"}:
msg["content"] = text
return
messages.append(HumanMessage(content=text))
async def _maybe_apply_ppt_precheck(graph_input: dict[str, Any]) -> None:
user_text = _extract_last_human_text(graph_input)
if not user_text or not _is_ppt_request(user_text):
return
enough = await _deepseek_ppt_info_check(user_text)
if enough:
return
_overwrite_last_human_message(graph_input, PPT_INSUFFICIENT_INFO_FORWARD)
logger.info("PPT precheck flagged insufficient info; forwarded clarification instruction")
_DEFAULT_ASSISTANT_ID = "lead_agent" _DEFAULT_ASSISTANT_ID = "lead_agent"
@ -282,6 +426,7 @@ async def start_run(
agent_factory = resolve_agent_factory(body.assistant_id) agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input) graph_input = normalize_input(body.input)
await _maybe_apply_ppt_precheck(graph_input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
if "configurable" in config and isinstance(config["configurable"], dict): if "configurable" in config and isinstance(config["configurable"], dict):

View File

@ -3,6 +3,9 @@
from __future__ import annotations from __future__ import annotations
import json import json
from unittest.mock import AsyncMock, patch
from langchain_core.messages import HumanMessage
def test_format_sse_basic(): def test_format_sse_basic():
@ -81,6 +84,55 @@ def test_normalize_input_passthrough():
assert result == {"custom_key": "value"} assert result == {"custom_key": "value"}
def test_extract_last_human_text_from_human_message():
from app.gateway.services import _extract_last_human_text
graph_input = {
"messages": [
HumanMessage(content="第一条"),
HumanMessage(content=[{"type": "text", "text": "我要做一个产品发布会PPT"}]),
]
}
assert _extract_last_human_text(graph_input) == "我要做一个产品发布会PPT"
def test_is_ppt_request():
from app.gateway.services import _is_ppt_request
assert _is_ppt_request("帮我做个PPT")
assert _is_ppt_request("Please generate slides for roadmap")
assert not _is_ppt_request("帮我写一段 SQL")
def test_heuristic_has_enough_ppt_info():
from app.gateway.services import _heuristic_has_enough_ppt_info
assert not _heuristic_has_enough_ppt_info("做个ppt")
assert _heuristic_has_enough_ppt_info("做一个关于Q2复盘的PPT面向管理层10页简洁风格")
def test_overwrite_last_human_message():
from app.gateway.services import _overwrite_last_human_message
graph_input = {"messages": [HumanMessage(content="请生成PPT")]}
_overwrite_last_human_message(graph_input, "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息")
assert graph_input["messages"][-1].content == "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
def test_maybe_apply_ppt_precheck_rewrites_when_insufficient():
from app.gateway.services import _maybe_apply_ppt_precheck
graph_input = {"messages": [HumanMessage(content="帮我做个PPT")]}
with patch(
"app.gateway.services._deepseek_ppt_info_check",
new=AsyncMock(return_value=False),
):
import asyncio
asyncio.run(_maybe_apply_ppt_precheck(graph_input))
assert graph_input["messages"][-1].content == "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
def test_build_run_config_basic(): def test_build_run_config_basic():
from app.gateway.services import build_run_config from app.gateway.services import build_run_config