Files
JARVIS/backend/app/services/document_service.py

683 lines
27 KiB
Python
Raw Permalink Normal View History

2026-03-21 10:13:29 +08:00
"""
文档服务 - 上传解析分块存储
支持多种文档格式 + LlamaIndex 智能分块
"""
from pathlib import Path
import tempfile
import shutil
2026-03-21 10:13:29 +08:00
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import UploadFile
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
from app.services.brain_service import BrainService
import csv
import io
import json
2026-03-21 10:13:29 +08:00
import os
import re
2026-03-21 10:13:29 +08:00
import aiofiles
from dataclasses import dataclass, field
2026-03-21 10:13:29 +08:00
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc", ".csv", ".xlsx"}
PARSER_VERSION = "v2"
INDEX_VERSION = "v2"
@dataclass
class ParsedNode:
node_type: str
text: str
metadata: dict = field(default_factory=dict)
section_path: list[str] = field(default_factory=list)
@dataclass
class ParsedDocument:
summary: str
nodes: list[ParsedNode]
structured_markdown: str = ""
2026-03-21 10:13:29 +08:00
class DocumentService:
def __init__(self, db: AsyncSession, user_id: str = None):
self.db = db
self.user_id = user_id
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: {ext}")
folder_path = await self._get_storage_directory(user_id, folder_id)
folder_path.mkdir(parents=True, exist_ok=True)
file_path = self._resolve_unique_file_path(folder_path, file.filename)
2026-03-21 10:13:29 +08:00
content = await file.read()
file_size = len(content)
if file_size > settings.MAX_UPLOAD_SIZE:
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
parsed = await self._parse_document(str(file_path), ext)
parsed.structured_markdown = self._render_structured_markdown(parsed)
2026-03-21 10:13:29 +08:00
doc = Document(
user_id=user_id,
title=file.filename.rsplit('.', 1)[0],
filename=file.filename,
file_type=ext[1:],
file_size=file_size,
file_path=str(file_path),
summary=parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary,
2026-03-21 10:13:29 +08:00
folder_id=folder_id,
ingestion_status="uploaded",
ingestion_error=None,
parser_version=PARSER_VERSION,
index_version=INDEX_VERSION,
normalized_content=parsed.structured_markdown,
normalized_format="structured_markdown",
2026-03-21 10:13:29 +08:00
)
self.db.add(doc)
await self.db.flush()
2026-03-21 10:13:29 +08:00
chunks = self._build_chunks(parsed)
for i, chunk_data in enumerate(chunks):
2026-03-21 10:13:29 +08:00
chunk = DocumentChunk(
document_id=doc.id,
chunk_index=i,
content=chunk_data["content"],
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
2026-03-21 10:13:29 +08:00
)
self.db.add(chunk)
doc.chunk_count = len(chunks)
brain_service = BrainService(self.db)
await brain_service.create_event(
user_id,
source_type="document",
source_id=doc.id,
event_type="document_uploaded",
title=doc.filename,
content_summary=doc.summary,
raw_excerpt=(doc.normalized_content or "")[:1000] or None,
metadata_={
"document_id": doc.id,
"file_type": doc.file_type,
"ingestion_status": doc.ingestion_status,
},
importance_signal=1.0,
)
2026-03-21 10:13:29 +08:00
await self.db.commit()
await self.db.refresh(doc)
2026-03-21 10:13:29 +08:00
return doc
async def rebuild_document(self, document: Document) -> Document:
ext = os.path.splitext(document.filename)[1].lower()
parsed = await self._parse_document(document.file_path, ext)
parsed.structured_markdown = self._render_structured_markdown(parsed)
chunk_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document.id)
.order_by(DocumentChunk.chunk_index)
)
existing_chunks = list(chunk_result.scalars().all())
for chunk in existing_chunks:
await self.db.delete(chunk)
await self.db.flush()
chunks = self._build_chunks(parsed)
for i, chunk_data in enumerate(chunks):
self.db.add(DocumentChunk(
document_id=document.id,
chunk_index=i,
content=chunk_data["content"],
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
))
document.summary = parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary
document.chunk_count = len(chunks)
document.ingestion_status = "indexing"
document.ingestion_error = None
document.parser_version = PARSER_VERSION
document.index_version = INDEX_VERSION
document.normalized_content = parsed.structured_markdown
document.normalized_format = "structured_markdown"
await self.db.commit()
await self.db.refresh(document)
return document
2026-03-21 10:13:29 +08:00
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
folders = await self.db.execute(
select(Folder).where(Folder.user_id == self.user_id)
)
folder_map = {f.id: f for f in folders.scalars().all()}
path_parts = []
current_id = folder_id
while current_id:
folder = folder_map.get(current_id)
if not folder:
break
path_parts.insert(0, folder.name)
current_id = folder.parent_id
return "/" + "/".join(path_parts) if path_parts else None
async def ensure_folder_directory(self, user_id: str, folder_id: str | None) -> Path:
folder_path = await self._get_storage_directory(user_id, folder_id)
folder_path.mkdir(parents=True, exist_ok=True)
return folder_path
async def delete_folder_directory(self, user_id: str, folder_id: str) -> None:
folder_path = await self._get_storage_directory(user_id, folder_id)
if folder_path.exists():
shutil.rmtree(folder_path, ignore_errors=True)
async def rename_folder_directory(self, user_id: str, folder_id: str, old_name: str, new_name: str) -> None:
folder = await self.db.get(Folder, folder_id)
if folder is None:
return
parent_path = await self._get_storage_directory(user_id, folder.parent_id)
old_path = parent_path / self._sanitize_storage_name(old_name)
new_path = parent_path / self._sanitize_storage_name(new_name)
if old_path != new_path:
parent_path.mkdir(parents=True, exist_ok=True)
if old_path.exists():
old_path.rename(new_path)
else:
new_path.mkdir(parents=True, exist_ok=True)
else:
new_path.mkdir(parents=True, exist_ok=True)
document_result = await self.db.execute(
select(Document).where(Document.user_id == user_id)
)
for document in document_result.scalars().all():
try:
relative_path = Path(document.file_path).relative_to(old_path)
except ValueError:
continue
document.file_path = str(new_path / relative_path)
async def _get_storage_directory(self, user_id: str, folder_id: str | None) -> Path:
base_path = Path(settings.UPLOAD_DIR) / user_id
if not folder_id:
return base_path
folders = await self.db.execute(
select(Folder).where(Folder.user_id == user_id)
)
folder_map = {folder.id: folder for folder in folders.scalars().all()}
path_segments: list[str] = []
current_id = folder_id
while current_id:
folder = folder_map.get(current_id)
if folder is None:
raise ValueError("鐖舵枃浠跺す涓嶅瓨鍦?")
path_segments.insert(0, self._sanitize_storage_name(folder.name))
current_id = folder.parent_id
return base_path.joinpath(*path_segments)
def _resolve_unique_file_path(self, directory: Path, original_name: str) -> Path:
safe_name = self._sanitize_storage_name(Path(original_name).name, is_file=True)
stem = Path(safe_name).stem
suffix = Path(safe_name).suffix
candidate = directory / safe_name
counter = 2
while candidate.exists():
candidate = directory / f"{stem}-{counter}{suffix}"
counter += 1
return candidate
def _sanitize_storage_name(self, name: str, is_file: bool = False) -> str:
invalid_chars = '<>:"/\\|?*'
sanitized = ''.join('_' if char in invalid_chars or ord(char) < 32 else char for char in name).strip().rstrip('.')
if not sanitized:
return 'untitled' if is_file else 'folder'
return sanitized
2026-03-21 10:13:29 +08:00
async def delete_document(self, user_id: str, document_id: str):
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError("文档不存在")
if os.path.exists(doc.file_path):
os.remove(doc.file_path)
await self.db.delete(doc)
await self.db.commit()
async def _extract_text(self, file_path: str, ext: str) -> str:
if ext in (".md", ".txt"):
2026-03-21 10:13:29 +08:00
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read()
if ext in (".docx", ".doc"):
2026-03-21 10:13:29 +08:00
try:
from docx import Document as DocxDocument
doc = DocxDocument(file_path)
parts = [p.text for p in doc.paragraphs if p.text.strip()]
for table in doc.tables:
for row in table.rows:
row_values = [cell.text.strip() for cell in row.cells]
if any(row_values):
parts.append(" | ".join(row_values))
return "\n".join(parts)
2026-03-21 10:13:29 +08:00
except ImportError:
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
return "[暂不支持此格式]"
async def _parse_document(self, file_path: str, ext: str) -> ParsedDocument:
if ext == ".csv":
return await self._parse_csv(file_path)
if ext == ".xlsx":
return await self._parse_xlsx(file_path)
if ext == ".md":
content = await self._extract_text(file_path, ext)
return self._parse_markdown(content)
if ext == ".txt":
content = await self._extract_text(file_path, ext)
return self._parse_text(content)
if ext == ".docx":
return await self._parse_docx(file_path)
if ext == ".doc":
content = await self._extract_text(file_path, ext)
return self._parse_text(content)
if ext == ".pdf":
return await self._parse_pdf(file_path)
content = await self._extract_text(file_path, ext)
return self._parse_text(content)
async def _parse_csv(self, file_path: str) -> ParsedDocument:
async with aiofiles.open(file_path, "r", encoding="utf-8-sig") as f:
content = await f.read()
reader = list(csv.reader(io.StringIO(content)))
headers = reader[0] if reader else []
rows = reader[1:] if len(reader) > 1 else []
nodes = [
ParsedNode(
node_type="table_schema",
text=f"CSV columns: {', '.join(headers)} | rows: {len(rows)}",
metadata={"headers": headers, "row_count": len(rows), "table_name": "csv"},
section_path=["csv"],
)
]
for start in range(0, len(rows), 50):
batch = rows[start:start + 50]
serialized_rows = []
for row in batch:
serialized = ", ".join(
f"{header}={value}" for header, value in zip(headers, row)
)
serialized_rows.append(serialized)
nodes.append(
ParsedNode(
node_type="table_rows",
text="\n".join(serialized_rows),
metadata={
"headers": headers,
"row_start": start + 1,
"row_end": start + len(batch),
"table_name": "csv",
},
section_path=["csv"],
)
)
summary = f"CSV with columns {', '.join(headers)}" if headers else "CSV document"
return ParsedDocument(summary=summary, nodes=nodes)
async def _parse_xlsx(self, file_path: str) -> ParsedDocument:
try:
from openpyxl import load_workbook
except ModuleNotFoundError as error:
raise ValueError("XLSX 解析依赖缺失: openpyxl") from error
workbook = load_workbook(file_path, data_only=True)
nodes: list[ParsedNode] = []
summaries: list[str] = []
for sheet in workbook.worksheets:
rows = list(sheet.iter_rows(values_only=True))
if not rows:
continue
headers = [str(cell).strip() if cell is not None else "" for cell in rows[0]]
data_rows = rows[1:]
summaries.append(sheet.title)
nodes.append(
ParsedNode(
node_type="table_schema",
text=f"Sheet {sheet.title} columns: {', '.join(headers)} | rows: {len(data_rows)}",
metadata={"headers": headers, "row_count": len(data_rows), "sheet_name": sheet.title},
section_path=[sheet.title],
)
)
for start in range(0, len(data_rows), 50):
batch = data_rows[start:start + 50]
serialized_rows = []
for row in batch:
normalized = ["" if value is None else str(value) for value in row]
serialized_rows.append(", ".join(f"{header}={value}" for header, value in zip(headers, normalized)))
nodes.append(
ParsedNode(
node_type="table_rows",
text="\n".join(serialized_rows),
metadata={
"headers": headers,
"row_start": start + 1,
"row_end": start + len(batch),
"sheet_name": sheet.title,
},
section_path=[sheet.title],
)
)
summary = f"Workbook sheets: {', '.join(summaries)}" if summaries else "Workbook"
return ParsedDocument(summary=summary, nodes=nodes)
async def _parse_docx(self, file_path: str) -> ParsedDocument:
try:
from docx import Document as DocxDocument
except ModuleNotFoundError as error:
raise ValueError("DOCX 解析依赖缺失: python-docx") from error
doc = DocxDocument(file_path)
nodes: list[ParsedNode] = []
section_path: list[str] = []
summary_parts: list[str] = []
for paragraph in doc.paragraphs:
text = paragraph.text.strip()
if not text:
2026-03-21 10:13:29 +08:00
continue
style_name = getattr(paragraph.style, "name", "") or ""
if style_name.startswith("Heading"):
level_match = re.search(r"(\d+)", style_name)
level = int(level_match.group(1)) if level_match else 1
section_path = section_path[: level - 1] + [text]
nodes.append(ParsedNode("heading", text, {"level": level}, list(section_path)))
2026-03-21 10:13:29 +08:00
else:
if not section_path:
section_path = [doc.core_properties.title or "Document"]
summary_parts.append(text)
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
for table in doc.tables:
rows = [[cell.text.strip() for cell in row.cells] for row in table.rows]
if not rows:
continue
headers = rows[0]
nodes.append(
ParsedNode(
"table_schema",
f"DOCX table columns: {', '.join(headers)} | rows: {max(len(rows) - 1, 0)}",
{"headers": headers, "row_count": max(len(rows) - 1, 0), "table_name": "docx_table"},
list(section_path),
)
)
for start in range(1, len(rows), 50):
batch = rows[start:start + 50]
serialized_rows = [", ".join(f"{header}={value}" for header, value in zip(headers, row)) for row in batch]
nodes.append(
ParsedNode(
"table_rows",
"\n".join(serialized_rows),
{
"headers": headers,
"row_start": start,
"row_end": start + len(batch) - 1,
"table_name": "docx_table",
},
list(section_path),
)
)
summary = " ".join(summary_parts[:3]) if summary_parts else doc.core_properties.title or "Document"
return ParsedDocument(summary=summary, nodes=nodes)
async def _parse_pdf_with_mineru(self, file_path: str) -> str:
try:
import mineru
except ModuleNotFoundError as error:
raise ValueError("PDF 解析依赖缺失: mineru") from error
if hasattr(mineru, "to_markdown"):
return mineru.to_markdown(file_path)
if hasattr(mineru, "parse_to_markdown"):
return mineru.parse_to_markdown(file_path)
try:
from mineru.cli.common import do_parse, read_fn
from mineru.utils.enum_class import MakeMode
except Exception as error:
raise ValueError(
"PDF 解析失败: 当前安装的 MinerU 版本接口不兼容,请确认支持 to_markdown / parse_to_markdown或提供 cli.common.do_parse 能力"
) from error
with tempfile.TemporaryDirectory(prefix="mineru-") as output_dir:
pdf_name = Path(file_path).stem
pdf_bytes = read_fn(Path(file_path))
try:
do_parse(
output_dir,
[pdf_name],
[pdf_bytes],
["zh"],
f_draw_layout_bbox=False,
f_draw_span_bbox=False,
f_dump_md=True,
f_dump_middle_json=False,
f_dump_model_output=False,
f_dump_orig_pdf=False,
f_dump_content_list=False,
f_make_md_mode=MakeMode.MM_MD,
)
except ModuleNotFoundError as error:
dependency = getattr(error, "name", None) or str(error).split("'")[-2] if "'" in str(error) else str(error)
raise ValueError(f"PDF 解析依赖缺失: MinerU 运行时依赖 {dependency}") from error
markdown_path = Path(output_dir) / pdf_name / "pipeline" / f"{pdf_name}.md"
if markdown_path.exists():
return markdown_path.read_text(encoding="utf-8")
raise ValueError(
"PDF 解析失败: 当前安装的 MinerU 版本接口不兼容,请确认支持 to_markdown / parse_to_markdown或提供 cli.common.do_parse 能力"
)
async def _parse_pdf(self, file_path: str) -> ParsedDocument:
markdown = await self._parse_pdf_with_mineru(file_path)
return self._parse_markdown(markdown)
def _parse_markdown(self, content: str) -> ParsedDocument:
nodes: list[ParsedNode] = []
section_path: list[str] = []
summary_parts: list[str] = []
buffer: list[str] = []
def flush_buffer():
if not buffer:
return
text = "\n".join(buffer).strip()
buffer.clear()
if not text:
return
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
summary_parts.append(text)
for line in content.splitlines():
heading_match = re.match(r"^(#{1,6})\s+(.+)$", line.strip())
if heading_match:
flush_buffer()
level = len(heading_match.group(1))
title = heading_match.group(2).strip()
section_path = section_path[: level - 1] + [title]
nodes.append(ParsedNode("heading", title, {"level": level}, list(section_path)))
2026-03-21 10:13:29 +08:00
continue
if line.strip():
buffer.append(line.strip())
2026-03-21 10:13:29 +08:00
else:
flush_buffer()
flush_buffer()
summary = " ".join(summary_parts[:3]) if summary_parts else content[:200]
return ParsedDocument(summary=summary, nodes=nodes)
def _parse_text(self, content: str) -> ParsedDocument:
paragraphs = [part.strip() for part in content.split("\n\n") if part.strip()]
nodes = [ParsedNode("text", paragraph, {}, []) for paragraph in paragraphs]
summary = " ".join(paragraphs[:3]) if paragraphs else content[:200]
return ParsedDocument(summary=summary, nodes=nodes)
def _build_chunks(self, parsed: ParsedDocument) -> list[dict]:
chunks: list[dict] = []
for source_order, node in enumerate(parsed.nodes):
section_path = node.section_path or []
metadata = {
"content_type": node.node_type,
"section_path": section_path,
"section_title": section_path[-1] if section_path else None,
"chunk_level": len(section_path),
"parent_key": "/".join(section_path[:-1]) or None,
"block_key": "/".join(section_path) or None,
"parser_version": PARSER_VERSION,
"index_version": INDEX_VERSION,
"source_order": source_order,
**node.metadata,
}
chunks.append({"content": node.text, "metadata": metadata})
if not chunks:
chunks.append({
"content": parsed.summary,
"metadata": {
"content_type": "text",
"section_path": [],
"section_title": None,
"chunk_level": 0,
"parent_key": None,
"block_key": None,
"parser_version": PARSER_VERSION,
"index_version": INDEX_VERSION,
"source_order": 0,
},
})
2026-03-21 10:13:29 +08:00
return chunks
def _render_structured_markdown(self, parsed: ParsedDocument) -> str:
blocks: list[str] = []
for node in parsed.nodes:
if node.node_type == "heading":
level = max(1, min(int(node.metadata.get("level", 1)), 6))
blocks.append(f"{'#' * level} {node.text}")
continue
if node.node_type == "table_schema":
headers = node.metadata.get("headers") or []
if headers:
header_row = "| " + " | ".join(headers) + " |"
divider_row = "| " + " | ".join(["---"] * len(headers)) + " |"
blocks.append("\n".join([header_row, divider_row]))
else:
blocks.append(node.text)
continue
if node.node_type == "table_rows":
headers = node.metadata.get("headers") or []
if headers:
rows = []
for line in node.text.splitlines():
values_by_header = {}
for part in line.split(", "):
if "=" not in part:
continue
key, value = part.split("=", 1)
values_by_header[key] = value
rows.append("| " + " | ".join(values_by_header.get(header, "") for header in headers) + " |")
if rows:
blocks.append("\n".join(rows))
continue
blocks.append(node.text)
continue
blocks.append(node.text)
return "\n\n".join(block for block in blocks if block).strip() or parsed.summary
2026-03-21 10:13:29 +08:00
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return list(result.scalars().all())
async def update_document_chunk(self, user_id: str, document_id: str, chunk_id: str, content: str) -> DocumentChunk:
document_result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
document = document_result.scalar_one_or_none()
if not document:
raise ValueError("文档不存在")
chunk_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.id == chunk_id,
DocumentChunk.document_id == document_id,
)
)
chunk = chunk_result.scalar_one_or_none()
if not chunk:
raise ValueError("切片不存在")
chunk.content = content
document.ingestion_status = "indexing"
document.ingestion_error = None
await self.db.commit()
await self.db.refresh(chunk)
return chunk
2026-03-21 10:13:29 +08:00
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
"""获取文档的文本内容"""
import os
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
return None
if doc.normalized_content:
return doc.normalized_content
2026-03-21 10:13:29 +08:00
file_path = doc.file_path
if not os.path.exists(file_path):
return None
# 根据文件类型读取内容
ext = doc.filename.split('.')[-1].lower()
try:
if ext == 'txt':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'md':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
else:
return f"[文档] {doc.filename}"
except Exception:
return f"[文档] {doc.filename}"