feat(backend): 生成PPT时预处理用户提示词

This commit is contained in:
肖应宇 2026-04-20 14:55:19 +08:00
parent 040b107647
commit f0d93ab342
2 changed files with 197 additions and 0 deletions

View File

@ -10,12 +10,14 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from openai import AsyncOpenAI
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import (
@ -32,6 +34,17 @@ from deerflow.runtime import (
)
logger = logging.getLogger(__name__)
# 预处理提示词的大模型
PPT_INSUFFICIENT_INFO_FORWARD = "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
PPT_SELECTOR_SYSTEM_PROMPT = """#PPT
你是 PPT 技能选择器严格执行以下流程
用户输入生成 PPT 相关指令后询问你需要使用哪个生成 PPT 的技能可选技能1. ppt_gen_html生成 HTML 形式 PPT2. ppt_gen_reference根据文档生成 PPT
记住用户最初的 PPT 指令
用户选择技能后仅输出固定语句无任何多余内容
ppt_gen_html{user_input}使用 ppt_gen_html 这个 skill 来完成
ppt_gen_reference{user_input}使用 ppt_gen_reference 这个 skill 来完成
{user_input} 特指用户最初输入的 PPT 制作指令非选择回复"""
# ---------------------------------------------------------------------------
@ -94,6 +107,137 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
return raw_input
def _extract_text_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
elif isinstance(item, str) and item.strip():
parts.append(item.strip())
return "\n".join(parts)
return str(content or "")
def _extract_last_human_text(graph_input: dict[str, Any]) -> str:
messages = graph_input.get("messages")
if not isinstance(messages, list):
return ""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return _extract_text_content(msg.content).strip()
if isinstance(msg, dict):
role = str(msg.get("role", msg.get("type", ""))).lower()
if role in {"user", "human"}:
return _extract_text_content(msg.get("content")).strip()
return ""
def _is_ppt_request(text: str) -> bool:
lowered = text.lower()
return any(token in lowered for token in ("ppt", "slides", "powerpoint", "幻灯片", "演示文稿"))
def _heuristic_has_enough_ppt_info(text: str) -> bool:
lowered = text.lower()
if len(lowered.strip()) < 12:
return False
score = 0
if len(lowered) >= 24:
score += 1
if re.search(r"(关于|主题|topic|题目|on\s+)", lowered):
score += 1
if re.search(r"(面向|给|用于|目的|audience|for\s+)", lowered):
score += 1
if re.search(r"(\d+\s*(页|p|slides?)|大纲|目录|章节|结构)", lowered):
score += 1
if re.search(r"(风格|配色|模板|视觉|语气|style|tone)", lowered):
score += 1
if re.search(r"(根据|参考|数据|附件|文档|material|reference)", lowered):
score += 1
return score >= 2
async def _deepseek_ppt_info_check(user_text: str) -> bool:
enabled = os.getenv("PPT_PRECHECK_ENABLED", "true").strip().lower()
if enabled in {"0", "false", "off", "no"}:
return True
base_url = os.getenv("PPT_PRECHECK_BASE_URL", "").strip()
api_key = os.getenv("PPT_PRECHECK_API_KEY", "").strip()
model = os.getenv("PPT_PRECHECK_MODEL", "deepseek-chat").strip()
timeout_s = float(os.getenv("PPT_PRECHECK_TIMEOUT_SECONDS", "10").strip() or "10")
if not base_url or not api_key:
return _heuristic_has_enough_ppt_info(user_text)
check_instruction = (
"你现在只做“PPT信息是否足够”的判断不做技能追问。"
"判断标准:至少包含主题 + 另一个关键信息(受众/用途/页数或结构/风格/参考资料)。"
"仅输出一个词ENOUGH 或 INSUFFICIENT。"
)
system_prompt = f"{PPT_SELECTOR_SYSTEM_PROMPT}\n\n{check_instruction}"
try:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s)
resp = await client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_text},
],
)
content = (resp.choices[0].message.content or "").strip().upper()
if "INSUFFICIENT" in content:
return False
if "ENOUGH" in content:
return True
logger.warning("PPT precheck unexpected output: %r; fallback to heuristic", content)
except Exception:
logger.warning("PPT precheck via DeepSeek failed; fallback to heuristic", exc_info=True)
return _heuristic_has_enough_ppt_info(user_text)
def _overwrite_last_human_message(graph_input: dict[str, Any], text: str) -> None:
messages = graph_input.get("messages")
if not isinstance(messages, list):
graph_input["messages"] = [HumanMessage(content=text)]
return
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if isinstance(msg, HumanMessage):
msg.content = text
return
if isinstance(msg, dict):
role = str(msg.get("role", msg.get("type", ""))).lower()
if role in {"user", "human"}:
msg["content"] = text
return
messages.append(HumanMessage(content=text))
async def _maybe_apply_ppt_precheck(graph_input: dict[str, Any]) -> None:
user_text = _extract_last_human_text(graph_input)
if not user_text or not _is_ppt_request(user_text):
return
enough = await _deepseek_ppt_info_check(user_text)
if enough:
return
_overwrite_last_human_message(graph_input, PPT_INSUFFICIENT_INFO_FORWARD)
logger.info("PPT precheck flagged insufficient info; forwarded clarification instruction")
_DEFAULT_ASSISTANT_ID = "lead_agent"
@ -282,6 +426,7 @@ async def start_run(
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
await _maybe_apply_ppt_precheck(graph_input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
if "configurable" in config and isinstance(config["configurable"], dict):

View File

@ -3,6 +3,9 @@
from __future__ import annotations
import json
from unittest.mock import AsyncMock, patch
from langchain_core.messages import HumanMessage
def test_format_sse_basic():
@ -81,6 +84,55 @@ def test_normalize_input_passthrough():
assert result == {"custom_key": "value"}
def test_extract_last_human_text_from_human_message():
from app.gateway.services import _extract_last_human_text
graph_input = {
"messages": [
HumanMessage(content="第一条"),
HumanMessage(content=[{"type": "text", "text": "我要做一个产品发布会PPT"}]),
]
}
assert _extract_last_human_text(graph_input) == "我要做一个产品发布会PPT"
def test_is_ppt_request():
from app.gateway.services import _is_ppt_request
assert _is_ppt_request("帮我做个PPT")
assert _is_ppt_request("Please generate slides for roadmap")
assert not _is_ppt_request("帮我写一段 SQL")
def test_heuristic_has_enough_ppt_info():
from app.gateway.services import _heuristic_has_enough_ppt_info
assert not _heuristic_has_enough_ppt_info("做个ppt")
assert _heuristic_has_enough_ppt_info("做一个关于Q2复盘的PPT面向管理层10页简洁风格")
def test_overwrite_last_human_message():
from app.gateway.services import _overwrite_last_human_message
graph_input = {"messages": [HumanMessage(content="请生成PPT")]}
_overwrite_last_human_message(graph_input, "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息")
assert graph_input["messages"][-1].content == "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
def test_maybe_apply_ppt_precheck_rewrites_when_insufficient():
from app.gateway.services import _maybe_apply_ppt_precheck
graph_input = {"messages": [HumanMessage(content="帮我做个PPT")]}
with patch(
"app.gateway.services._deepseek_ppt_info_check",
new=AsyncMock(return_value=False),
):
import asyncio
asyncio.run(_maybe_apply_ppt_precheck(graph_input))
assert graph_input["messages"][-1].content == "用户想生成ppt但是没有输入足够多的信息所以先向用户询问更多信息"
def test_build_run_config_basic():
from app.gateway.services import build_run_config