Clawith/backend/app/api/upload.py

179 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""File upload API for chat — saves files to agent workspace and extracts text."""
import base64
import os
import uuid
from pathlib import Path
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Form
from loguru import logger
from app.core.security import get_current_user
from app.models.user import User
from app.config import get_settings
router = APIRouter(prefix="/chat", tags=["chat"])
_settings = get_settings()
WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR)
# Supported extensions and their text extraction method
TEXT_EXTENSIONS = {
".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml",
".py", ".js", ".ts", ".html", ".css", ".sql", ".sh", ".log",
".ini", ".cfg", ".conf", ".env", ".toml",
}
OFFICE_EXTENSIONS = {".pdf", ".docx", ".doc", ".xlsx", ".xls", ".pptx", ".ppt"}
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
EXTRACTABLE = TEXT_EXTENSIONS | OFFICE_EXTENSIONS
MIME_MAP = {
".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp",
}
def extract_text(file_path: Path, extension: str) -> str:
"""Extract text content from a file."""
if extension in TEXT_EXTENSIONS:
try:
return file_path.read_text(encoding="utf-8", errors="replace")
except Exception:
return file_path.read_text(encoding="gbk", errors="replace")
if extension == ".pdf":
try:
import subprocess
result = subprocess.run(
["python3", "-c", f"""
import sys
try:
import PyPDF2
reader = PyPDF2.PdfReader('{file_path}')
text = '\\n'.join(page.extract_text() or '' for page in reader.pages)
print(text[:8000])
except ImportError:
# Fallback: use pdftotext if available
import subprocess as sp
r = sp.run(['pdftotext', '{file_path}', '-'], capture_output=True, text=True)
print(r.stdout[:8000] if r.returncode == 0 else '[无法解析PDF]')
"""],
capture_output=True, text=True, timeout=30,
)
return result.stdout.strip() or "[PDF内容提取失败]"
except Exception as e:
return f"[PDF解析错误: {e}]"
if extension == ".docx":
try:
import subprocess
result = subprocess.run(
["python3", "-c", f"""
try:
from docx import Document
doc = Document('{file_path}')
text = '\\n'.join(p.text for p in doc.paragraphs)
print(text[:8000])
except ImportError:
print('[需要安装 python-docx 库]')
"""],
capture_output=True, text=True, timeout=30,
)
return result.stdout.strip() or "[DOCX内容提取失败]"
except Exception as e:
return f"[DOCX解析错误: {e}]"
if extension in (".xlsx", ".xls"):
try:
import subprocess
result = subprocess.run(
["python3", "-c", f"""
try:
import openpyxl
wb = openpyxl.load_workbook('{file_path}', read_only=True)
lines = []
for ws in wb.worksheets[:3]:
lines.append(f'## Sheet: {{ws.title}}')
for row in ws.iter_rows(max_row=50, values_only=True):
lines.append('\\t'.join(str(c) if c is not None else '' for c in row))
print('\\n'.join(lines)[:8000])
except ImportError:
print('[需要安装 openpyxl 库]')
"""],
capture_output=True, text=True, timeout=30,
)
return result.stdout.strip() or "[Excel内容提取失败]"
except Exception as e:
return f"[Excel解析错误: {e}]"
return f"[不支持的文件格式: {extension}]"
@router.post("/upload")
async def upload_file(
file: UploadFile = File(...),
agent_id: str = Form(""),
current_user: User = Depends(get_current_user),
):
"""Upload a file for chat context. Saves to agent workspace/uploads/ and returns extracted text."""
if not file.filename:
raise HTTPException(status_code=400, detail="No filename")
ext = os.path.splitext(file.filename)[1].lower()
content = await file.read()
# Determine save directory
workspace_path = ""
if agent_id:
# Save to agent's workspace/uploads/
uploads_dir = WORKSPACE_ROOT / agent_id / "workspace" / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
save_path = uploads_dir / file.filename
# Avoid overwriting: add suffix if file exists
if save_path.exists():
stem = save_path.stem
suffix = save_path.suffix
counter = 1
while save_path.exists():
save_path = uploads_dir / f"{stem}_{counter}{suffix}"
counter += 1
save_path.write_bytes(content)
workspace_path = f"workspace/uploads/{save_path.name}"
else:
# Fallback: save to /tmp (legacy behavior)
fallback_dir = Path("/tmp/clawith_uploads")
fallback_dir.mkdir(exist_ok=True)
file_id = str(uuid.uuid4())[:8]
save_path = fallback_dir / f"{file_id}_{file.filename}"
save_path.write_bytes(content)
# Extract text (only for known formats)
is_image = ext in IMAGE_EXTENSIONS
image_data_url = ""
if is_image:
# For images: generate base64 data URL for vision models
if len(content) > 10 * 1024 * 1024: # 10MB limit
raise HTTPException(status_code=400, detail="Image too large (max 10MB)")
mime = MIME_MAP.get(ext, "image/png")
b64 = base64.b64encode(content).decode("ascii")
image_data_url = f"data:{mime};base64,{b64}"
extracted = f"[图片文件: {file.filename},需要视觉模型分析]"
elif ext in EXTRACTABLE:
extracted = extract_text(save_path, ext)
else:
extracted = f"[文件已保存,格式 {ext} 暂不支持文本提取Agent 可通过 read_document 工具读取]"
# Truncate if too long
if len(extracted) > 6000:
extracted = extracted[:6000] + "\n\n...[内容已截断,共 " + str(len(extracted)) + " 字]"
return {
"filename": file.filename,
"saved_filename": save_path.name,
"size": len(content),
"extracted_text": extracted,
"workspace_path": workspace_path,
"is_image": is_image,
"image_data_url": image_data_url,
}