470 lines
19 KiB
Python
470 lines
19 KiB
Python
"""Middleware to inject uploaded files information into agent context."""
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import NotRequired, override
|
||
|
||
from langchain.agents import AgentState
|
||
from langchain.agents.middleware import AgentMiddleware
|
||
from langchain_core.messages import HumanMessage
|
||
from langgraph.runtime import Runtime
|
||
|
||
from deerflow.config.paths import Paths, get_paths
|
||
from deerflow.utils.file_conversion import extract_outline
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
_OUTLINE_PREVIEW_LINES = 5
|
||
|
||
|
||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||
"""Return the document outline and fallback preview for *file_path*.
|
||
|
||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||
pipeline.
|
||
|
||
Returns:
|
||
(outline, preview) where:
|
||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||
Empty when no headings are found or no .md exists.
|
||
- preview: first few non-empty lines of the .md, used as a content
|
||
anchor when outline is empty so the agent has some context.
|
||
Empty when outline is non-empty (no fallback needed).
|
||
"""
|
||
md_path = file_path.with_suffix(".md")
|
||
if not md_path.is_file():
|
||
return [], []
|
||
|
||
outline = extract_outline(md_path)
|
||
if outline:
|
||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||
return outline, []
|
||
|
||
# outline is empty — read the first few non-empty lines as a content preview
|
||
preview: list[str] = []
|
||
try:
|
||
with md_path.open(encoding="utf-8") as f:
|
||
for line in f:
|
||
stripped = line.strip()
|
||
if stripped:
|
||
preview.append(stripped)
|
||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||
break
|
||
except Exception:
|
||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||
return [], preview
|
||
|
||
|
||
class UploadsMiddlewareState(AgentState):
|
||
"""State schema for uploads middleware."""
|
||
|
||
uploaded_files: NotRequired[list[dict] | None]
|
||
|
||
|
||
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||
"""Middleware to inject uploaded files information into the agent context.
|
||
|
||
Reads file metadata from the current message's additional_kwargs.files
|
||
(set by the frontend after upload) and prepends an <uploaded_files> block
|
||
to the last human message so the model knows which files are available.
|
||
"""
|
||
|
||
state_schema = UploadsMiddlewareState
|
||
|
||
def __init__(self, base_dir: str | None = None):
|
||
"""Initialize the middleware.
|
||
|
||
Args:
|
||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||
"""
|
||
super().__init__()
|
||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||
|
||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||
size_kb = file["size"] / 1024
|
||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||
lines.append(f"- {file['filename']} ({size_str})")
|
||
lines.append(f" Path: {file['path']}")
|
||
outline = file.get("outline") or []
|
||
if outline:
|
||
truncated = outline[-1].get("truncated", False)
|
||
visible = [e for e in outline if not e.get("truncated")]
|
||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||
for entry in visible:
|
||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||
if truncated:
|
||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||
else:
|
||
preview = file.get("outline_preview") or []
|
||
if preview:
|
||
lines.append(" No structural headings detected. Document begins with:")
|
||
for text in preview:
|
||
lines.append(f" > {text}")
|
||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||
lines.append("")
|
||
|
||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||
"""Create a formatted message listing uploaded files.
|
||
|
||
Args:
|
||
new_files: Files uploaded in the current message.
|
||
historical_files: Files uploaded in previous messages.
|
||
Each file dict may contain an optional ``outline`` key — a list of
|
||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||
|
||
Returns:
|
||
Formatted string inside <uploaded_files> tags.
|
||
"""
|
||
lines = ["<uploaded_files>"]
|
||
|
||
lines.append("The following files were uploaded in this message:")
|
||
lines.append("")
|
||
if new_files:
|
||
for file in new_files:
|
||
self._format_file_entry(file, lines)
|
||
else:
|
||
lines.append("(empty)")
|
||
lines.append("")
|
||
|
||
if historical_files:
|
||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||
lines.append("")
|
||
for file in historical_files:
|
||
self._format_file_entry(file, lines)
|
||
|
||
lines.append("To work with these files:")
|
||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||
lines.append("- Use `glob` to find files by name pattern")
|
||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||
lines.append("</uploaded_files>")
|
||
|
||
return "\n".join(lines)
|
||
|
||
def _merge_sent_files(self, uploaded_files: list[dict], mention_files: list[dict]) -> list[dict]:
|
||
"""Build conversation-level sent-files view (uploads ∪ mentions, deduped by path)."""
|
||
|
||
merged: dict[str, dict] = {}
|
||
|
||
def _upsert(file: dict, source: str) -> None:
|
||
path = file.get("path") or ""
|
||
if not path:
|
||
return
|
||
entry = merged.get(path)
|
||
if entry is None:
|
||
entry = {
|
||
"filename": file.get("filename") or Path(path).name,
|
||
"path": path,
|
||
"size": int(file.get("size") or 0),
|
||
"sent_sources": set(),
|
||
}
|
||
merged[path] = entry
|
||
entry["sent_sources"].add(source)
|
||
entry["size"] = max(entry["size"], int(file.get("size") or 0))
|
||
if source == "mention" and file.get("ref_source"):
|
||
entry["ref_source"] = file["ref_source"]
|
||
|
||
for file in uploaded_files:
|
||
_upsert(file, "upload")
|
||
for file in mention_files:
|
||
_upsert(file, "mention")
|
||
|
||
ordered = sorted(
|
||
merged.values(),
|
||
key=lambda f: (str(f.get("filename", "")).lower(), str(f.get("path", "")).lower()),
|
||
)
|
||
for file in ordered:
|
||
sources = file.get("sent_sources") or set()
|
||
if "upload" in sources and "mention" in sources:
|
||
file["sent_source_label"] = "upload+mention"
|
||
elif "upload" in sources:
|
||
file["sent_source_label"] = "upload"
|
||
else:
|
||
file["sent_source_label"] = "mention"
|
||
return ordered
|
||
|
||
def _create_sent_files_summary(
|
||
self,
|
||
sent_files: list[dict],
|
||
current_turn_mentions: list[dict] | None = None,
|
||
) -> str:
|
||
"""Create policy block describing unified 'sent files' semantics."""
|
||
current_turn_mentions = current_turn_mentions or []
|
||
lines = [
|
||
"<sent_files_semantics>",
|
||
"Conversation attachment semantics:",
|
||
"- Treat uploaded files and mentioned files as one unified concept of files the user has sent.",
|
||
"- For questions like 'what files did I send' or 'how many files did I send', use the conversation-level union of uploaded + mentioned files.",
|
||
"- Count unique files by path (deduplicated).",
|
||
]
|
||
if current_turn_mentions:
|
||
lines.extend(
|
||
[
|
||
"- Current-turn mention priority: if the user says deictic references like 'this image/file' (e.g. '这张图', '这个文件'), bind to files mentioned in the current message first.",
|
||
"- Only ask for clarification when the current message itself mentions multiple files.",
|
||
"",
|
||
"Current message mentioned files (highest priority for deictic references):",
|
||
]
|
||
)
|
||
for file in current_turn_mentions:
|
||
size_kb = file["size"] / 1024
|
||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||
lines.append(
|
||
f"- {file['filename']} ({size_str}, source: mention)"
|
||
)
|
||
lines.append(f" Path: {file['path']}")
|
||
lines.extend(
|
||
[
|
||
"",
|
||
"Conversation-level sent files (deduplicated):",
|
||
]
|
||
)
|
||
else:
|
||
lines.extend(
|
||
[
|
||
"",
|
||
"Conversation-level sent files (deduplicated):",
|
||
]
|
||
)
|
||
if sent_files:
|
||
for file in sent_files:
|
||
size_kb = file["size"] / 1024
|
||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||
lines.append(
|
||
f"- {file['filename']} ({size_str}, source: {file['sent_source_label']})"
|
||
)
|
||
lines.append(f" Path: {file['path']}")
|
||
else:
|
||
lines.append("- (none)")
|
||
lines.append("</sent_files_semantics>")
|
||
return "\n".join(lines)
|
||
|
||
def _mentioned_files_from_kwargs(self, message: HumanMessage) -> list[dict]:
|
||
"""Extract mention references from additional_kwargs.files.
|
||
|
||
Mention entries are context references (not uploads) and should be
|
||
surfaced to the model so it can read them directly by path.
|
||
"""
|
||
kwargs_files = (message.additional_kwargs or {}).get("files")
|
||
if not isinstance(kwargs_files, list) or not kwargs_files:
|
||
return []
|
||
|
||
references: list[dict] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
for item in kwargs_files:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if item.get("ref_kind") != "mention":
|
||
continue
|
||
|
||
filename = item.get("filename") or ""
|
||
path = item.get("path") or ""
|
||
if not filename or Path(filename).name != filename:
|
||
continue
|
||
if not isinstance(path, str) or not path.startswith("/mnt/user-data/"):
|
||
continue
|
||
|
||
key = (filename, path)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
|
||
references.append(
|
||
{
|
||
"filename": filename,
|
||
"size": int(item.get("size") or 0),
|
||
"path": path,
|
||
"ref_source": item.get("ref_source") or "unknown",
|
||
}
|
||
)
|
||
return references
|
||
|
||
def _create_mentions_message(self, mention_files: list[dict]) -> str:
|
||
lines = ["<mentioned_files>", "The following files were referenced by the user in this conversation:", ""]
|
||
for file in mention_files:
|
||
size_kb = file["size"] / 1024
|
||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||
lines.append(
|
||
f"- {file['filename']} ({size_str}, source: {file['ref_source']})"
|
||
)
|
||
lines.append(f" Path: {file['path']}")
|
||
lines.append("")
|
||
lines.append("Use `read_file` with these paths directly. Do not re-upload them.")
|
||
lines.append("</mentioned_files>")
|
||
return "\n".join(lines)
|
||
|
||
def _mentioned_files_from_messages(self, messages: list) -> list[dict]:
|
||
"""Extract mention references across conversation messages."""
|
||
references: list[dict] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
for message in messages:
|
||
if not isinstance(message, HumanMessage):
|
||
continue
|
||
for file in self._mentioned_files_from_kwargs(message):
|
||
key = (file["filename"], file["path"])
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
references.append(file)
|
||
return references
|
||
|
||
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
|
||
"""Extract file info from message additional_kwargs.files.
|
||
|
||
The frontend sends uploaded file metadata in additional_kwargs.files
|
||
after a successful upload. Each entry has: filename, size (bytes),
|
||
path (virtual path), status.
|
||
|
||
Args:
|
||
message: The human message to inspect.
|
||
uploads_dir: Physical uploads directory used to verify file existence.
|
||
When provided, entries whose files no longer exist are skipped.
|
||
|
||
Returns:
|
||
List of file dicts with virtual paths, or None if the field is absent or empty.
|
||
"""
|
||
kwargs_files = (message.additional_kwargs or {}).get("files")
|
||
if not isinstance(kwargs_files, list) or not kwargs_files:
|
||
return None
|
||
|
||
files = []
|
||
for f in kwargs_files:
|
||
if not isinstance(f, dict):
|
||
continue
|
||
# Mention references are context pointers, not newly uploaded files.
|
||
if f.get("ref_kind") == "mention":
|
||
continue
|
||
filename = f.get("filename") or ""
|
||
if not filename or Path(filename).name != filename:
|
||
continue
|
||
if uploads_dir is not None and not (uploads_dir / filename).is_file():
|
||
continue
|
||
files.append(
|
||
{
|
||
"filename": filename,
|
||
"size": int(f.get("size") or 0),
|
||
"path": f"/mnt/user-data/uploads/{filename}",
|
||
"extension": Path(filename).suffix,
|
||
}
|
||
)
|
||
return files if files else None
|
||
|
||
@override
|
||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||
"""Inject uploaded files information before agent execution.
|
||
|
||
New files come from the current message's additional_kwargs.files.
|
||
Historical files are scanned from the thread's uploads directory,
|
||
excluding the new ones.
|
||
|
||
Prepends <uploaded_files> context to the last human message content.
|
||
The original additional_kwargs (including files metadata) is preserved
|
||
on the updated message so the frontend can read it from the stream.
|
||
|
||
Args:
|
||
state: Current agent state.
|
||
runtime: Runtime context containing thread_id.
|
||
|
||
Returns:
|
||
State updates including uploaded files list.
|
||
"""
|
||
messages = list(state.get("messages", []))
|
||
if not messages:
|
||
return None
|
||
|
||
last_message_index = len(messages) - 1
|
||
last_message = messages[last_message_index]
|
||
|
||
if not isinstance(last_message, HumanMessage):
|
||
return None
|
||
|
||
# Resolve uploads directory for existence checks
|
||
thread_id = (runtime.context or {}).get("thread_id")
|
||
if thread_id is None:
|
||
try:
|
||
from langgraph.config import get_config
|
||
|
||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||
except RuntimeError:
|
||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||
|
||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||
mention_files = self._mentioned_files_from_messages(messages)
|
||
current_turn_mentions = self._mentioned_files_from_kwargs(last_message)
|
||
|
||
# Collect historical files from the uploads directory (all except the new ones)
|
||
new_filenames = {f["filename"] for f in new_files}
|
||
historical_files: list[dict] = []
|
||
if uploads_dir and uploads_dir.exists():
|
||
for file_path in sorted(uploads_dir.iterdir()):
|
||
if file_path.is_file() and file_path.name not in new_filenames:
|
||
stat = file_path.stat()
|
||
outline, preview = _extract_outline_for_file(file_path)
|
||
historical_files.append(
|
||
{
|
||
"filename": file_path.name,
|
||
"size": stat.st_size,
|
||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||
"extension": file_path.suffix,
|
||
"outline": outline,
|
||
"outline_preview": preview,
|
||
}
|
||
)
|
||
|
||
# Attach outlines to new files as well
|
||
if uploads_dir:
|
||
for file in new_files:
|
||
phys_path = uploads_dir / file["filename"]
|
||
outline, preview = _extract_outline_for_file(phys_path)
|
||
file["outline"] = outline
|
||
file["outline_preview"] = preview
|
||
|
||
sent_files = self._merge_sent_files(new_files + historical_files, mention_files)
|
||
|
||
if not new_files and not historical_files and not mention_files and not sent_files:
|
||
return None
|
||
|
||
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
|
||
|
||
# Create context message(s) and prepend to the last human message content.
|
||
message_parts = [
|
||
self._create_files_message(new_files, historical_files),
|
||
self._create_sent_files_summary(sent_files, current_turn_mentions),
|
||
]
|
||
if mention_files:
|
||
message_parts.append(self._create_mentions_message(mention_files))
|
||
files_message = "\n\n".join(message_parts)
|
||
|
||
# Extract original content - handle both string and list formats
|
||
original_content = ""
|
||
if isinstance(last_message.content, str):
|
||
original_content = last_message.content
|
||
elif isinstance(last_message.content, list):
|
||
text_parts = []
|
||
for block in last_message.content:
|
||
if isinstance(block, dict) and block.get("type") == "text":
|
||
text_parts.append(block.get("text", ""))
|
||
original_content = "\n".join(text_parts)
|
||
|
||
# Create new message with combined content.
|
||
# Preserve additional_kwargs (including files metadata) so the frontend
|
||
# can read structured file info from the streamed message.
|
||
updated_message = HumanMessage(
|
||
content=f"{files_message}\n\n{original_content}",
|
||
id=last_message.id,
|
||
additional_kwargs=last_message.additional_kwargs,
|
||
)
|
||
|
||
messages[last_message_index] = updated_message
|
||
|
||
return {
|
||
"uploaded_files": new_files,
|
||
"messages": messages,
|
||
}
|