feat(backend): 生成PPT时预处理用户提示词
This commit is contained in:
parent
040b107647
commit
f0d93ab342
|
|
@ -10,12 +10,14 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||||
from deerflow.runtime import (
|
||||
|
|
@ -32,6 +34,17 @@ from deerflow.runtime import (
|
|||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# 预处理提示词的大模型
|
||||
|
||||
PPT_INSUFFICIENT_INFO_FORWARD = "用户想生成ppt,但是没有输入足够多的信息,所以先向用户询问更多信息"
|
||||
PPT_SELECTOR_SYSTEM_PROMPT = """#PPT
|
||||
你是 PPT 技能选择器,严格执行以下流程:
|
||||
用户输入生成 PPT 相关指令后,询问:你需要使用哪个生成 PPT 的技能?可选技能:1. ppt_gen_html(生成 HTML 形式 PPT)2. ppt_gen_reference(根据文档生成 PPT)
|
||||
记住用户最初的 PPT 指令。
|
||||
用户选择技能后,仅输出固定语句,无任何多余内容:
|
||||
选 ppt_gen_html:{user_input},使用 ppt_gen_html 这个 skill 来完成
|
||||
选 ppt_gen_reference:{user_input},使用 ppt_gen_reference 这个 skill 来完成
|
||||
注:“{user_input}” 特指用户最初输入的 PPT 制作指令,非选择回复。"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -94,6 +107,137 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|||
return raw_input
|
||||
|
||||
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str) and text.strip():
|
||||
parts.append(text.strip())
|
||||
elif isinstance(item, str) and item.strip():
|
||||
parts.append(item.strip())
|
||||
return "\n".join(parts)
|
||||
return str(content or "")
|
||||
|
||||
|
||||
def _extract_last_human_text(graph_input: dict[str, Any]) -> str:
|
||||
messages = graph_input.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return ""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, HumanMessage):
|
||||
return _extract_text_content(msg.content).strip()
|
||||
if isinstance(msg, dict):
|
||||
role = str(msg.get("role", msg.get("type", ""))).lower()
|
||||
if role in {"user", "human"}:
|
||||
return _extract_text_content(msg.get("content")).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _is_ppt_request(text: str) -> bool:
|
||||
lowered = text.lower()
|
||||
return any(token in lowered for token in ("ppt", "slides", "powerpoint", "幻灯片", "演示文稿"))
|
||||
|
||||
|
||||
def _heuristic_has_enough_ppt_info(text: str) -> bool:
|
||||
lowered = text.lower()
|
||||
if len(lowered.strip()) < 12:
|
||||
return False
|
||||
|
||||
score = 0
|
||||
if len(lowered) >= 24:
|
||||
score += 1
|
||||
if re.search(r"(关于|主题|topic|题目|on\s+)", lowered):
|
||||
score += 1
|
||||
if re.search(r"(面向|给|用于|目的|audience|for\s+)", lowered):
|
||||
score += 1
|
||||
if re.search(r"(\d+\s*(页|p|slides?)|大纲|目录|章节|结构)", lowered):
|
||||
score += 1
|
||||
if re.search(r"(风格|配色|模板|视觉|语气|style|tone)", lowered):
|
||||
score += 1
|
||||
if re.search(r"(根据|参考|数据|附件|文档|material|reference)", lowered):
|
||||
score += 1
|
||||
return score >= 2
|
||||
|
||||
|
||||
async def _deepseek_ppt_info_check(user_text: str) -> bool:
|
||||
enabled = os.getenv("PPT_PRECHECK_ENABLED", "true").strip().lower()
|
||||
if enabled in {"0", "false", "off", "no"}:
|
||||
return True
|
||||
|
||||
base_url = os.getenv("PPT_PRECHECK_BASE_URL", "").strip()
|
||||
api_key = os.getenv("PPT_PRECHECK_API_KEY", "").strip()
|
||||
model = os.getenv("PPT_PRECHECK_MODEL", "deepseek-chat").strip()
|
||||
timeout_s = float(os.getenv("PPT_PRECHECK_TIMEOUT_SECONDS", "10").strip() or "10")
|
||||
|
||||
if not base_url or not api_key:
|
||||
return _heuristic_has_enough_ppt_info(user_text)
|
||||
|
||||
check_instruction = (
|
||||
"你现在只做“PPT信息是否足够”的判断,不做技能追问。"
|
||||
"判断标准:至少包含主题 + 另一个关键信息(受众/用途/页数或结构/风格/参考资料)。"
|
||||
"仅输出一个词:ENOUGH 或 INSUFFICIENT。"
|
||||
)
|
||||
system_prompt = f"{PPT_SELECTOR_SYSTEM_PROMPT}\n\n{check_instruction}"
|
||||
|
||||
try:
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key, timeout=timeout_s)
|
||||
resp = await client.chat.completions.create(
|
||||
model=model,
|
||||
temperature=0,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_text},
|
||||
],
|
||||
)
|
||||
content = (resp.choices[0].message.content or "").strip().upper()
|
||||
if "INSUFFICIENT" in content:
|
||||
return False
|
||||
if "ENOUGH" in content:
|
||||
return True
|
||||
logger.warning("PPT precheck unexpected output: %r; fallback to heuristic", content)
|
||||
except Exception:
|
||||
logger.warning("PPT precheck via DeepSeek failed; fallback to heuristic", exc_info=True)
|
||||
|
||||
return _heuristic_has_enough_ppt_info(user_text)
|
||||
|
||||
|
||||
def _overwrite_last_human_message(graph_input: dict[str, Any], text: str) -> None:
|
||||
messages = graph_input.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
graph_input["messages"] = [HumanMessage(content=text)]
|
||||
return
|
||||
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[idx]
|
||||
if isinstance(msg, HumanMessage):
|
||||
msg.content = text
|
||||
return
|
||||
if isinstance(msg, dict):
|
||||
role = str(msg.get("role", msg.get("type", ""))).lower()
|
||||
if role in {"user", "human"}:
|
||||
msg["content"] = text
|
||||
return
|
||||
|
||||
messages.append(HumanMessage(content=text))
|
||||
|
||||
|
||||
async def _maybe_apply_ppt_precheck(graph_input: dict[str, Any]) -> None:
|
||||
user_text = _extract_last_human_text(graph_input)
|
||||
if not user_text or not _is_ppt_request(user_text):
|
||||
return
|
||||
|
||||
enough = await _deepseek_ppt_info_check(user_text)
|
||||
if enough:
|
||||
return
|
||||
|
||||
_overwrite_last_human_message(graph_input, PPT_INSUFFICIENT_INFO_FORWARD)
|
||||
logger.info("PPT precheck flagged insufficient info; forwarded clarification instruction")
|
||||
|
||||
|
||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
|
||||
|
|
@ -282,6 +426,7 @@ async def start_run(
|
|||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
await _maybe_apply_ppt_precheck(graph_input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
if "configurable" in config and isinstance(config["configurable"], dict):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
def test_format_sse_basic():
|
||||
|
|
@ -81,6 +84,55 @@ def test_normalize_input_passthrough():
|
|||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_extract_last_human_text_from_human_message():
|
||||
from app.gateway.services import _extract_last_human_text
|
||||
|
||||
graph_input = {
|
||||
"messages": [
|
||||
HumanMessage(content="第一条"),
|
||||
HumanMessage(content=[{"type": "text", "text": "我要做一个产品发布会PPT"}]),
|
||||
]
|
||||
}
|
||||
assert _extract_last_human_text(graph_input) == "我要做一个产品发布会PPT"
|
||||
|
||||
|
||||
def test_is_ppt_request():
|
||||
from app.gateway.services import _is_ppt_request
|
||||
|
||||
assert _is_ppt_request("帮我做个PPT")
|
||||
assert _is_ppt_request("Please generate slides for roadmap")
|
||||
assert not _is_ppt_request("帮我写一段 SQL")
|
||||
|
||||
|
||||
def test_heuristic_has_enough_ppt_info():
|
||||
from app.gateway.services import _heuristic_has_enough_ppt_info
|
||||
|
||||
assert not _heuristic_has_enough_ppt_info("做个ppt")
|
||||
assert _heuristic_has_enough_ppt_info("做一个关于Q2复盘的PPT,面向管理层,10页,简洁风格")
|
||||
|
||||
|
||||
def test_overwrite_last_human_message():
|
||||
from app.gateway.services import _overwrite_last_human_message
|
||||
|
||||
graph_input = {"messages": [HumanMessage(content="请生成PPT")]}
|
||||
_overwrite_last_human_message(graph_input, "用户想生成ppt,但是没有输入足够多的信息,所以先向用户询问更多信息")
|
||||
assert graph_input["messages"][-1].content == "用户想生成ppt,但是没有输入足够多的信息,所以先向用户询问更多信息"
|
||||
|
||||
|
||||
def test_maybe_apply_ppt_precheck_rewrites_when_insufficient():
|
||||
from app.gateway.services import _maybe_apply_ppt_precheck
|
||||
|
||||
graph_input = {"messages": [HumanMessage(content="帮我做个PPT")]}
|
||||
with patch(
|
||||
"app.gateway.services._deepseek_ppt_info_check",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_maybe_apply_ppt_precheck(graph_input))
|
||||
assert graph_input["messages"][-1].content == "用户想生成ppt,但是没有输入足够多的信息,所以先向用户询问更多信息"
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue