Clawith/backend/app/services/text_extractor.py

170 lines
5.2 KiB
Python

"""Extract text from common office file formats.
Supports: PDF, DOCX, XLSX, PPTX, CSV
Saves extracted text as a companion .txt file alongside the original.
"""
import io
from pathlib import Path
from loguru import logger
# File extensions that need text extraction
EXTRACTABLE_EXTS = {".pdf", ".docx", ".xlsx", ".pptx"}
# Text extensions that don't need extraction
TEXT_EXTS = {".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml",
".js", ".ts", ".py", ".html", ".css", ".sh", ".log", ".env"}
def needs_extraction(filename: str) -> bool:
"""Check if a file needs text extraction."""
ext = Path(filename).suffix.lower()
return ext in EXTRACTABLE_EXTS
def extract_text(file_bytes: bytes, filename: str) -> str | None:
"""Extract text from a binary file.
Returns extracted text string, or None if extraction fails.
"""
ext = Path(filename).suffix.lower()
try:
if ext == ".pdf":
return _extract_pdf(file_bytes)
elif ext == ".docx":
return _extract_docx(file_bytes)
elif ext == ".xlsx":
return _extract_xlsx(file_bytes)
elif ext == ".pptx":
return _extract_pptx(file_bytes)
except Exception as e:
logger.error(f"[TextExtractor] Failed to extract from {filename}: {e}")
return None
return None
def save_extracted_text(save_path: Path, file_bytes: bytes, filename: str) -> Path | None:
"""Extract text and save as a companion .txt file.
For example: report.pdf → report.txt
Returns the path to the text file, or None if extraction failed.
"""
text = extract_text(file_bytes, filename)
if not text or not text.strip():
return None
txt_path = save_path.parent / f"{save_path.stem}.txt"
txt_path.write_text(text, encoding="utf-8")
logger.info(f"[TextExtractor] Extracted {len(text)} chars from {filename}{txt_path.name}")
return txt_path
def _extract_pdf(data: bytes) -> str:
"""Extract text from PDF using pdfplumber."""
import pdfplumber
pages = []
with pdfplumber.open(io.BytesIO(data)) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text and text.strip():
pages.append(f"--- 第{i+1}页 ---\n{text.strip()}")
# Also extract tables
tables = page.extract_tables()
for table in tables:
if table:
rows = []
for row in table:
cells = [str(c or "").strip() for c in row]
rows.append(" | ".join(cells))
if rows:
pages.append("表格:\n" + "\n".join(rows))
return "\n\n".join(pages)
def _extract_docx(data: bytes) -> str:
"""Extract text from DOCX using python-docx."""
from docx import Document
doc = Document(io.BytesIO(data))
parts = []
for para in doc.paragraphs:
text = para.text.strip()
if text:
# Preserve heading hierarchy
if para.style and para.style.name and para.style.name.startswith("Heading"):
level = para.style.name.replace("Heading", "").strip()
try:
level = int(level)
except ValueError:
level = 1
parts.append(f"{'#' * level} {text}")
else:
parts.append(text)
# Extract tables
for table in doc.tables:
rows = []
for row in table.rows:
cells = [cell.text.strip() for cell in row.cells]
rows.append(" | ".join(cells))
if rows:
parts.append("\n表格:\n" + "\n".join(rows))
return "\n\n".join(parts)
def _extract_xlsx(data: bytes) -> str:
"""Extract text from XLSX using openpyxl."""
from openpyxl import load_workbook
wb = load_workbook(io.BytesIO(data), read_only=True, data_only=True)
parts = []
for sheet in wb.sheetnames:
ws = wb[sheet]
rows = []
for row in ws.iter_rows(values_only=True):
cells = [str(c) if c is not None else "" for c in row]
if any(c.strip() for c in cells):
rows.append(" | ".join(cells))
if rows:
parts.append(f"## 工作表: {sheet}\n" + "\n".join(rows))
wb.close()
return "\n\n".join(parts)
def _extract_pptx(data: bytes) -> str:
"""Extract text from PPTX using python-pptx."""
from pptx import Presentation
prs = Presentation(io.BytesIO(data))
parts = []
for i, slide in enumerate(prs.slides):
texts = []
for shape in slide.shapes:
if shape.has_text_frame:
for para in shape.text_frame.paragraphs:
text = para.text.strip()
if text:
texts.append(text)
if shape.has_table:
for row in shape.table.rows:
cells = [cell.text.strip() for cell in row.cells]
texts.append(" | ".join(cells))
if texts:
parts.append(f"--- 幻灯片 {i+1} ---\n" + "\n".join(texts))
return "\n\n".join(parts)