Add sub-commander orchestration updates, align frontend integrations, and refine knowledge view behavior without including local data artifacts. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
606 lines
24 KiB
Python
606 lines
24 KiB
Python
"""
|
||
文档服务 - 上传、解析、分块、存储
|
||
支持多种文档格式 + LlamaIndex 智能分块
|
||
"""
|
||
|
||
from pathlib import Path
|
||
import tempfile
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from fastapi import UploadFile
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.folder import Folder
|
||
from app.config import settings
|
||
from app.services.brain_service import BrainService
|
||
import csv
|
||
import io
|
||
import json
|
||
import os
|
||
import re
|
||
import aiofiles
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
|
||
|
||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc", ".csv", ".xlsx"}
|
||
PARSER_VERSION = "v2"
|
||
INDEX_VERSION = "v2"
|
||
|
||
|
||
@dataclass
|
||
class ParsedNode:
|
||
node_type: str
|
||
text: str
|
||
metadata: dict = field(default_factory=dict)
|
||
section_path: list[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class ParsedDocument:
|
||
summary: str
|
||
nodes: list[ParsedNode]
|
||
structured_markdown: str = ""
|
||
|
||
|
||
class DocumentService:
|
||
def __init__(self, db: AsyncSession, user_id: str = None):
|
||
self.db = db
|
||
self.user_id = user_id
|
||
|
||
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
|
||
ext = os.path.splitext(file.filename)[1].lower()
|
||
if ext not in ALLOWED_EXTENSIONS:
|
||
raise ValueError(f"不支持的文件类型: {ext}")
|
||
|
||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||
file_id = str(uuid.uuid4())
|
||
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
|
||
|
||
content = await file.read()
|
||
file_size = len(content)
|
||
if file_size > settings.MAX_UPLOAD_SIZE:
|
||
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
|
||
|
||
async with aiofiles.open(file_path, "wb") as f:
|
||
await f.write(content)
|
||
|
||
parsed = await self._parse_document(file_path, ext)
|
||
parsed.structured_markdown = self._render_structured_markdown(parsed)
|
||
|
||
doc = Document(
|
||
user_id=user_id,
|
||
title=file.filename.rsplit('.', 1)[0],
|
||
filename=file.filename,
|
||
file_type=ext[1:],
|
||
file_size=file_size,
|
||
file_path=file_path,
|
||
summary=parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary,
|
||
folder_id=folder_id,
|
||
ingestion_status="uploaded",
|
||
ingestion_error=None,
|
||
parser_version=PARSER_VERSION,
|
||
index_version=INDEX_VERSION,
|
||
normalized_content=parsed.structured_markdown,
|
||
normalized_format="structured_markdown",
|
||
)
|
||
self.db.add(doc)
|
||
await self.db.flush()
|
||
|
||
chunks = self._build_chunks(parsed)
|
||
for i, chunk_data in enumerate(chunks):
|
||
chunk = DocumentChunk(
|
||
document_id=doc.id,
|
||
chunk_index=i,
|
||
content=chunk_data["content"],
|
||
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
|
||
)
|
||
self.db.add(chunk)
|
||
doc.chunk_count = len(chunks)
|
||
brain_service = BrainService(self.db)
|
||
await brain_service.create_event(
|
||
user_id,
|
||
source_type="document",
|
||
source_id=doc.id,
|
||
event_type="document_uploaded",
|
||
title=doc.filename,
|
||
content_summary=doc.summary,
|
||
raw_excerpt=(doc.normalized_content or "")[:1000] or None,
|
||
metadata_={
|
||
"document_id": doc.id,
|
||
"file_type": doc.file_type,
|
||
"ingestion_status": doc.ingestion_status,
|
||
},
|
||
importance_signal=1.0,
|
||
)
|
||
await self.db.commit()
|
||
await self.db.refresh(doc)
|
||
|
||
return doc
|
||
|
||
async def rebuild_document(self, document: Document) -> Document:
|
||
ext = os.path.splitext(document.filename)[1].lower()
|
||
parsed = await self._parse_document(document.file_path, ext)
|
||
parsed.structured_markdown = self._render_structured_markdown(parsed)
|
||
|
||
chunk_result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document.id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
existing_chunks = list(chunk_result.scalars().all())
|
||
for chunk in existing_chunks:
|
||
await self.db.delete(chunk)
|
||
await self.db.flush()
|
||
|
||
chunks = self._build_chunks(parsed)
|
||
for i, chunk_data in enumerate(chunks):
|
||
self.db.add(DocumentChunk(
|
||
document_id=document.id,
|
||
chunk_index=i,
|
||
content=chunk_data["content"],
|
||
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
|
||
))
|
||
|
||
document.summary = parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary
|
||
document.chunk_count = len(chunks)
|
||
document.ingestion_status = "indexing"
|
||
document.ingestion_error = None
|
||
document.parser_version = PARSER_VERSION
|
||
document.index_version = INDEX_VERSION
|
||
document.normalized_content = parsed.structured_markdown
|
||
document.normalized_format = "structured_markdown"
|
||
await self.db.commit()
|
||
await self.db.refresh(document)
|
||
return document
|
||
|
||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||
"""获取文件夹的完整路径"""
|
||
folders = await self.db.execute(
|
||
select(Folder).where(Folder.user_id == self.user_id)
|
||
)
|
||
folder_map = {f.id: f for f in folders.scalars().all()}
|
||
|
||
path_parts = []
|
||
current_id = folder_id
|
||
while current_id:
|
||
folder = folder_map.get(current_id)
|
||
if not folder:
|
||
break
|
||
path_parts.insert(0, folder.name)
|
||
current_id = folder.parent_id
|
||
|
||
return "/" + "/".join(path_parts) if path_parts else None
|
||
|
||
async def delete_document(self, user_id: str, document_id: str):
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise ValueError("文档不存在")
|
||
|
||
if os.path.exists(doc.file_path):
|
||
os.remove(doc.file_path)
|
||
|
||
await self.db.delete(doc)
|
||
await self.db.commit()
|
||
|
||
async def _extract_text(self, file_path: str, ext: str) -> str:
|
||
if ext in (".md", ".txt"):
|
||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||
return await f.read()
|
||
|
||
if ext in (".docx", ".doc"):
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
doc = DocxDocument(file_path)
|
||
parts = [p.text for p in doc.paragraphs if p.text.strip()]
|
||
for table in doc.tables:
|
||
for row in table.rows:
|
||
row_values = [cell.text.strip() for cell in row.cells]
|
||
if any(row_values):
|
||
parts.append(" | ".join(row_values))
|
||
return "\n".join(parts)
|
||
except ImportError:
|
||
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
||
|
||
return "[暂不支持此格式]"
|
||
|
||
async def _parse_document(self, file_path: str, ext: str) -> ParsedDocument:
|
||
if ext == ".csv":
|
||
return await self._parse_csv(file_path)
|
||
if ext == ".xlsx":
|
||
return await self._parse_xlsx(file_path)
|
||
if ext == ".md":
|
||
content = await self._extract_text(file_path, ext)
|
||
return self._parse_markdown(content)
|
||
if ext == ".txt":
|
||
content = await self._extract_text(file_path, ext)
|
||
return self._parse_text(content)
|
||
if ext == ".docx":
|
||
return await self._parse_docx(file_path)
|
||
if ext == ".doc":
|
||
content = await self._extract_text(file_path, ext)
|
||
return self._parse_text(content)
|
||
if ext == ".pdf":
|
||
return await self._parse_pdf(file_path)
|
||
content = await self._extract_text(file_path, ext)
|
||
return self._parse_text(content)
|
||
|
||
async def _parse_csv(self, file_path: str) -> ParsedDocument:
|
||
async with aiofiles.open(file_path, "r", encoding="utf-8-sig") as f:
|
||
content = await f.read()
|
||
reader = list(csv.reader(io.StringIO(content)))
|
||
headers = reader[0] if reader else []
|
||
rows = reader[1:] if len(reader) > 1 else []
|
||
nodes = [
|
||
ParsedNode(
|
||
node_type="table_schema",
|
||
text=f"CSV columns: {', '.join(headers)} | rows: {len(rows)}",
|
||
metadata={"headers": headers, "row_count": len(rows), "table_name": "csv"},
|
||
section_path=["csv"],
|
||
)
|
||
]
|
||
for start in range(0, len(rows), 50):
|
||
batch = rows[start:start + 50]
|
||
serialized_rows = []
|
||
for row in batch:
|
||
serialized = ", ".join(
|
||
f"{header}={value}" for header, value in zip(headers, row)
|
||
)
|
||
serialized_rows.append(serialized)
|
||
nodes.append(
|
||
ParsedNode(
|
||
node_type="table_rows",
|
||
text="\n".join(serialized_rows),
|
||
metadata={
|
||
"headers": headers,
|
||
"row_start": start + 1,
|
||
"row_end": start + len(batch),
|
||
"table_name": "csv",
|
||
},
|
||
section_path=["csv"],
|
||
)
|
||
)
|
||
summary = f"CSV with columns {', '.join(headers)}" if headers else "CSV document"
|
||
return ParsedDocument(summary=summary, nodes=nodes)
|
||
|
||
async def _parse_xlsx(self, file_path: str) -> ParsedDocument:
|
||
try:
|
||
from openpyxl import load_workbook
|
||
except ModuleNotFoundError as error:
|
||
raise ValueError("XLSX 解析依赖缺失: openpyxl") from error
|
||
|
||
workbook = load_workbook(file_path, data_only=True)
|
||
nodes: list[ParsedNode] = []
|
||
summaries: list[str] = []
|
||
for sheet in workbook.worksheets:
|
||
rows = list(sheet.iter_rows(values_only=True))
|
||
if not rows:
|
||
continue
|
||
headers = [str(cell).strip() if cell is not None else "" for cell in rows[0]]
|
||
data_rows = rows[1:]
|
||
summaries.append(sheet.title)
|
||
nodes.append(
|
||
ParsedNode(
|
||
node_type="table_schema",
|
||
text=f"Sheet {sheet.title} columns: {', '.join(headers)} | rows: {len(data_rows)}",
|
||
metadata={"headers": headers, "row_count": len(data_rows), "sheet_name": sheet.title},
|
||
section_path=[sheet.title],
|
||
)
|
||
)
|
||
for start in range(0, len(data_rows), 50):
|
||
batch = data_rows[start:start + 50]
|
||
serialized_rows = []
|
||
for row in batch:
|
||
normalized = ["" if value is None else str(value) for value in row]
|
||
serialized_rows.append(", ".join(f"{header}={value}" for header, value in zip(headers, normalized)))
|
||
nodes.append(
|
||
ParsedNode(
|
||
node_type="table_rows",
|
||
text="\n".join(serialized_rows),
|
||
metadata={
|
||
"headers": headers,
|
||
"row_start": start + 1,
|
||
"row_end": start + len(batch),
|
||
"sheet_name": sheet.title,
|
||
},
|
||
section_path=[sheet.title],
|
||
)
|
||
)
|
||
summary = f"Workbook sheets: {', '.join(summaries)}" if summaries else "Workbook"
|
||
return ParsedDocument(summary=summary, nodes=nodes)
|
||
|
||
async def _parse_docx(self, file_path: str) -> ParsedDocument:
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
except ModuleNotFoundError as error:
|
||
raise ValueError("DOCX 解析依赖缺失: python-docx") from error
|
||
|
||
doc = DocxDocument(file_path)
|
||
nodes: list[ParsedNode] = []
|
||
section_path: list[str] = []
|
||
summary_parts: list[str] = []
|
||
for paragraph in doc.paragraphs:
|
||
text = paragraph.text.strip()
|
||
if not text:
|
||
continue
|
||
style_name = getattr(paragraph.style, "name", "") or ""
|
||
if style_name.startswith("Heading"):
|
||
level_match = re.search(r"(\d+)", style_name)
|
||
level = int(level_match.group(1)) if level_match else 1
|
||
section_path = section_path[: level - 1] + [text]
|
||
nodes.append(ParsedNode("heading", text, {"level": level}, list(section_path)))
|
||
else:
|
||
if not section_path:
|
||
section_path = [doc.core_properties.title or "Document"]
|
||
summary_parts.append(text)
|
||
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
|
||
for table in doc.tables:
|
||
rows = [[cell.text.strip() for cell in row.cells] for row in table.rows]
|
||
if not rows:
|
||
continue
|
||
headers = rows[0]
|
||
nodes.append(
|
||
ParsedNode(
|
||
"table_schema",
|
||
f"DOCX table columns: {', '.join(headers)} | rows: {max(len(rows) - 1, 0)}",
|
||
{"headers": headers, "row_count": max(len(rows) - 1, 0), "table_name": "docx_table"},
|
||
list(section_path),
|
||
)
|
||
)
|
||
for start in range(1, len(rows), 50):
|
||
batch = rows[start:start + 50]
|
||
serialized_rows = [", ".join(f"{header}={value}" for header, value in zip(headers, row)) for row in batch]
|
||
nodes.append(
|
||
ParsedNode(
|
||
"table_rows",
|
||
"\n".join(serialized_rows),
|
||
{
|
||
"headers": headers,
|
||
"row_start": start,
|
||
"row_end": start + len(batch) - 1,
|
||
"table_name": "docx_table",
|
||
},
|
||
list(section_path),
|
||
)
|
||
)
|
||
summary = " ".join(summary_parts[:3]) if summary_parts else doc.core_properties.title or "Document"
|
||
return ParsedDocument(summary=summary, nodes=nodes)
|
||
|
||
async def _parse_pdf_with_mineru(self, file_path: str) -> str:
|
||
try:
|
||
import mineru
|
||
except ModuleNotFoundError as error:
|
||
raise ValueError("PDF 解析依赖缺失: mineru") from error
|
||
|
||
if hasattr(mineru, "to_markdown"):
|
||
return mineru.to_markdown(file_path)
|
||
if hasattr(mineru, "parse_to_markdown"):
|
||
return mineru.parse_to_markdown(file_path)
|
||
|
||
try:
|
||
from mineru.cli.common import do_parse, read_fn
|
||
from mineru.utils.enum_class import MakeMode
|
||
except Exception as error:
|
||
raise ValueError(
|
||
"PDF 解析失败: 当前安装的 MinerU 版本接口不兼容,请确认支持 to_markdown / parse_to_markdown,或提供 cli.common.do_parse 能力"
|
||
) from error
|
||
|
||
with tempfile.TemporaryDirectory(prefix="mineru-") as output_dir:
|
||
pdf_name = Path(file_path).stem
|
||
pdf_bytes = read_fn(Path(file_path))
|
||
try:
|
||
do_parse(
|
||
output_dir,
|
||
[pdf_name],
|
||
[pdf_bytes],
|
||
["zh"],
|
||
f_draw_layout_bbox=False,
|
||
f_draw_span_bbox=False,
|
||
f_dump_md=True,
|
||
f_dump_middle_json=False,
|
||
f_dump_model_output=False,
|
||
f_dump_orig_pdf=False,
|
||
f_dump_content_list=False,
|
||
f_make_md_mode=MakeMode.MM_MD,
|
||
)
|
||
except ModuleNotFoundError as error:
|
||
dependency = getattr(error, "name", None) or str(error).split("'")[-2] if "'" in str(error) else str(error)
|
||
raise ValueError(f"PDF 解析依赖缺失: MinerU 运行时依赖 {dependency}") from error
|
||
markdown_path = Path(output_dir) / pdf_name / "pipeline" / f"{pdf_name}.md"
|
||
if markdown_path.exists():
|
||
return markdown_path.read_text(encoding="utf-8")
|
||
|
||
raise ValueError(
|
||
"PDF 解析失败: 当前安装的 MinerU 版本接口不兼容,请确认支持 to_markdown / parse_to_markdown,或提供 cli.common.do_parse 能力"
|
||
)
|
||
|
||
async def _parse_pdf(self, file_path: str) -> ParsedDocument:
|
||
markdown = await self._parse_pdf_with_mineru(file_path)
|
||
return self._parse_markdown(markdown)
|
||
|
||
def _parse_markdown(self, content: str) -> ParsedDocument:
|
||
nodes: list[ParsedNode] = []
|
||
section_path: list[str] = []
|
||
summary_parts: list[str] = []
|
||
buffer: list[str] = []
|
||
|
||
def flush_buffer():
|
||
if not buffer:
|
||
return
|
||
text = "\n".join(buffer).strip()
|
||
buffer.clear()
|
||
if not text:
|
||
return
|
||
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
|
||
summary_parts.append(text)
|
||
|
||
for line in content.splitlines():
|
||
heading_match = re.match(r"^(#{1,6})\s+(.+)$", line.strip())
|
||
if heading_match:
|
||
flush_buffer()
|
||
level = len(heading_match.group(1))
|
||
title = heading_match.group(2).strip()
|
||
section_path = section_path[: level - 1] + [title]
|
||
nodes.append(ParsedNode("heading", title, {"level": level}, list(section_path)))
|
||
continue
|
||
if line.strip():
|
||
buffer.append(line.strip())
|
||
else:
|
||
flush_buffer()
|
||
flush_buffer()
|
||
summary = " ".join(summary_parts[:3]) if summary_parts else content[:200]
|
||
return ParsedDocument(summary=summary, nodes=nodes)
|
||
|
||
def _parse_text(self, content: str) -> ParsedDocument:
|
||
paragraphs = [part.strip() for part in content.split("\n\n") if part.strip()]
|
||
nodes = [ParsedNode("text", paragraph, {}, []) for paragraph in paragraphs]
|
||
summary = " ".join(paragraphs[:3]) if paragraphs else content[:200]
|
||
return ParsedDocument(summary=summary, nodes=nodes)
|
||
|
||
def _build_chunks(self, parsed: ParsedDocument) -> list[dict]:
|
||
chunks: list[dict] = []
|
||
for source_order, node in enumerate(parsed.nodes):
|
||
section_path = node.section_path or []
|
||
metadata = {
|
||
"content_type": node.node_type,
|
||
"section_path": section_path,
|
||
"section_title": section_path[-1] if section_path else None,
|
||
"chunk_level": len(section_path),
|
||
"parent_key": "/".join(section_path[:-1]) or None,
|
||
"block_key": "/".join(section_path) or None,
|
||
"parser_version": PARSER_VERSION,
|
||
"index_version": INDEX_VERSION,
|
||
"source_order": source_order,
|
||
**node.metadata,
|
||
}
|
||
chunks.append({"content": node.text, "metadata": metadata})
|
||
if not chunks:
|
||
chunks.append({
|
||
"content": parsed.summary,
|
||
"metadata": {
|
||
"content_type": "text",
|
||
"section_path": [],
|
||
"section_title": None,
|
||
"chunk_level": 0,
|
||
"parent_key": None,
|
||
"block_key": None,
|
||
"parser_version": PARSER_VERSION,
|
||
"index_version": INDEX_VERSION,
|
||
"source_order": 0,
|
||
},
|
||
})
|
||
return chunks
|
||
|
||
def _render_structured_markdown(self, parsed: ParsedDocument) -> str:
|
||
blocks: list[str] = []
|
||
for node in parsed.nodes:
|
||
if node.node_type == "heading":
|
||
level = max(1, min(int(node.metadata.get("level", 1)), 6))
|
||
blocks.append(f"{'#' * level} {node.text}")
|
||
continue
|
||
if node.node_type == "table_schema":
|
||
headers = node.metadata.get("headers") or []
|
||
if headers:
|
||
header_row = "| " + " | ".join(headers) + " |"
|
||
divider_row = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||
blocks.append("\n".join([header_row, divider_row]))
|
||
else:
|
||
blocks.append(node.text)
|
||
continue
|
||
if node.node_type == "table_rows":
|
||
headers = node.metadata.get("headers") or []
|
||
if headers:
|
||
rows = []
|
||
for line in node.text.splitlines():
|
||
values_by_header = {}
|
||
for part in line.split(", "):
|
||
if "=" not in part:
|
||
continue
|
||
key, value = part.split("=", 1)
|
||
values_by_header[key] = value
|
||
rows.append("| " + " | ".join(values_by_header.get(header, "") for header in headers) + " |")
|
||
if rows:
|
||
blocks.append("\n".join(rows))
|
||
continue
|
||
blocks.append(node.text)
|
||
continue
|
||
blocks.append(node.text)
|
||
return "\n\n".join(block for block in blocks if block).strip() or parsed.summary
|
||
|
||
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
||
result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
return list(result.scalars().all())
|
||
|
||
async def update_document_chunk(self, user_id: str, document_id: str, chunk_id: str, content: str) -> DocumentChunk:
|
||
document_result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
document = document_result.scalar_one_or_none()
|
||
if not document:
|
||
raise ValueError("文档不存在")
|
||
|
||
chunk_result = await self.db.execute(
|
||
select(DocumentChunk).where(
|
||
DocumentChunk.id == chunk_id,
|
||
DocumentChunk.document_id == document_id,
|
||
)
|
||
)
|
||
chunk = chunk_result.scalar_one_or_none()
|
||
if not chunk:
|
||
raise ValueError("切片不存在")
|
||
|
||
chunk.content = content
|
||
document.ingestion_status = "indexing"
|
||
document.ingestion_error = None
|
||
await self.db.commit()
|
||
await self.db.refresh(chunk)
|
||
return chunk
|
||
|
||
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
||
"""获取文档的文本内容"""
|
||
import os
|
||
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
return None
|
||
|
||
if doc.normalized_content:
|
||
return doc.normalized_content
|
||
|
||
file_path = doc.file_path
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
|
||
# 根据文件类型读取内容
|
||
ext = doc.filename.split('.')[-1].lower()
|
||
|
||
try:
|
||
if ext == 'txt':
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
elif ext == 'md':
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
else:
|
||
return f"[文档] {doc.filename}"
|
||
except Exception:
|
||
return f"[文档] {doc.filename}"
|