""" 文档服务 - 上传、解析、分块、存储 支持多种文档格式 + LlamaIndex 智能分块 """ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from fastapi import UploadFile from app.models.document import Document, DocumentChunk from app.models.folder import Folder from app.config import settings from app.services.brain_service import BrainService import csv import io import json import os import re import aiofiles import uuid from dataclasses import dataclass, field ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc", ".csv", ".xlsx"} PARSER_VERSION = "v2" INDEX_VERSION = "v2" @dataclass class ParsedNode: node_type: str text: str metadata: dict = field(default_factory=dict) section_path: list[str] = field(default_factory=list) @dataclass class ParsedDocument: summary: str nodes: list[ParsedNode] structured_markdown: str = "" class DocumentService: def __init__(self, db: AsyncSession, user_id: str = None): self.db = db self.user_id = user_id async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document: ext = os.path.splitext(file.filename)[1].lower() if ext not in ALLOWED_EXTENSIONS: raise ValueError(f"不支持的文件类型: {ext}") os.makedirs(settings.UPLOAD_DIR, exist_ok=True) file_id = str(uuid.uuid4()) file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}") content = await file.read() file_size = len(content) if file_size > settings.MAX_UPLOAD_SIZE: raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB") async with aiofiles.open(file_path, "wb") as f: await f.write(content) parsed = await self._parse_document(file_path, ext) parsed.structured_markdown = self._render_structured_markdown(parsed) doc = Document( user_id=user_id, title=file.filename.rsplit('.', 1)[0], filename=file.filename, file_type=ext[1:], file_size=file_size, file_path=file_path, summary=parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary, folder_id=folder_id, ingestion_status="uploaded", ingestion_error=None, parser_version=PARSER_VERSION, index_version=INDEX_VERSION, normalized_content=parsed.structured_markdown, normalized_format="structured_markdown", ) self.db.add(doc) await self.db.flush() chunks = self._build_chunks(parsed) for i, chunk_data in enumerate(chunks): chunk = DocumentChunk( document_id=doc.id, chunk_index=i, content=chunk_data["content"], metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False), ) self.db.add(chunk) doc.chunk_count = len(chunks) brain_service = BrainService(self.db) await brain_service.create_event( user_id, source_type="document", source_id=doc.id, event_type="document_uploaded", title=doc.filename, content_summary=doc.summary, raw_excerpt=(doc.normalized_content or "")[:1000] or None, metadata_={ "document_id": doc.id, "file_type": doc.file_type, "ingestion_status": doc.ingestion_status, }, importance_signal=1.0, ) await self.db.commit() await self.db.refresh(doc) return doc async def rebuild_document(self, document: Document) -> Document: ext = os.path.splitext(document.filename)[1].lower() parsed = await self._parse_document(document.file_path, ext) parsed.structured_markdown = self._render_structured_markdown(parsed) chunk_result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document.id) .order_by(DocumentChunk.chunk_index) ) existing_chunks = list(chunk_result.scalars().all()) for chunk in existing_chunks: await self.db.delete(chunk) await self.db.flush() chunks = self._build_chunks(parsed) for i, chunk_data in enumerate(chunks): self.db.add(DocumentChunk( document_id=document.id, chunk_index=i, content=chunk_data["content"], metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False), )) document.summary = parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary document.chunk_count = len(chunks) document.ingestion_status = "indexing" document.ingestion_error = None document.parser_version = PARSER_VERSION document.index_version = INDEX_VERSION document.normalized_content = parsed.structured_markdown document.normalized_format = "structured_markdown" await self.db.commit() await self.db.refresh(document) return document async def _get_folder_path(self, folder_id: str) -> str | None: """获取文件夹的完整路径""" folders = await self.db.execute( select(Folder).where(Folder.user_id == self.user_id) ) folder_map = {f.id: f for f in folders.scalars().all()} path_parts = [] current_id = folder_id while current_id: folder = folder_map.get(current_id) if not folder: break path_parts.insert(0, folder.name) current_id = folder.parent_id return "/" + "/".join(path_parts) if path_parts else None async def delete_document(self, user_id: str, document_id: str): result = await self.db.execute( select(Document).where( Document.id == document_id, Document.user_id == user_id, ) ) doc = result.scalar_one_or_none() if not doc: raise ValueError("文档不存在") if os.path.exists(doc.file_path): os.remove(doc.file_path) await self.db.delete(doc) await self.db.commit() async def _extract_text(self, file_path: str, ext: str) -> str: if ext in (".md", ".txt"): async with aiofiles.open(file_path, "r", encoding="utf-8") as f: return await f.read() if ext in (".docx", ".doc"): try: from docx import Document as DocxDocument doc = DocxDocument(file_path) parts = [p.text for p in doc.paragraphs if p.text.strip()] for table in doc.tables: for row in table.rows: row_values = [cell.text.strip() for cell in row.cells] if any(row_values): parts.append(" | ".join(row_values)) return "\n".join(parts) except ImportError: return "[Word 内容需要安装 python-docx: uv pip install python-docx]" return "[暂不支持此格式]" async def _parse_document(self, file_path: str, ext: str) -> ParsedDocument: if ext == ".csv": return await self._parse_csv(file_path) if ext == ".xlsx": return await self._parse_xlsx(file_path) if ext == ".md": content = await self._extract_text(file_path, ext) return self._parse_markdown(content) if ext == ".txt": content = await self._extract_text(file_path, ext) return self._parse_text(content) if ext == ".docx": return await self._parse_docx(file_path) if ext == ".doc": content = await self._extract_text(file_path, ext) return self._parse_text(content) if ext == ".pdf": return await self._parse_pdf(file_path) content = await self._extract_text(file_path, ext) return self._parse_text(content) async def _parse_csv(self, file_path: str) -> ParsedDocument: async with aiofiles.open(file_path, "r", encoding="utf-8-sig") as f: content = await f.read() reader = list(csv.reader(io.StringIO(content))) headers = reader[0] if reader else [] rows = reader[1:] if len(reader) > 1 else [] nodes = [ ParsedNode( node_type="table_schema", text=f"CSV columns: {', '.join(headers)} | rows: {len(rows)}", metadata={"headers": headers, "row_count": len(rows), "table_name": "csv"}, section_path=["csv"], ) ] for start in range(0, len(rows), 50): batch = rows[start:start + 50] serialized_rows = [] for row in batch: serialized = ", ".join( f"{header}={value}" for header, value in zip(headers, row) ) serialized_rows.append(serialized) nodes.append( ParsedNode( node_type="table_rows", text="\n".join(serialized_rows), metadata={ "headers": headers, "row_start": start + 1, "row_end": start + len(batch), "table_name": "csv", }, section_path=["csv"], ) ) summary = f"CSV with columns {', '.join(headers)}" if headers else "CSV document" return ParsedDocument(summary=summary, nodes=nodes) async def _parse_xlsx(self, file_path: str) -> ParsedDocument: try: from openpyxl import load_workbook except ModuleNotFoundError as error: raise ValueError("XLSX 解析依赖缺失: openpyxl") from error workbook = load_workbook(file_path, data_only=True) nodes: list[ParsedNode] = [] summaries: list[str] = [] for sheet in workbook.worksheets: rows = list(sheet.iter_rows(values_only=True)) if not rows: continue headers = [str(cell).strip() if cell is not None else "" for cell in rows[0]] data_rows = rows[1:] summaries.append(sheet.title) nodes.append( ParsedNode( node_type="table_schema", text=f"Sheet {sheet.title} columns: {', '.join(headers)} | rows: {len(data_rows)}", metadata={"headers": headers, "row_count": len(data_rows), "sheet_name": sheet.title}, section_path=[sheet.title], ) ) for start in range(0, len(data_rows), 50): batch = data_rows[start:start + 50] serialized_rows = [] for row in batch: normalized = ["" if value is None else str(value) for value in row] serialized_rows.append(", ".join(f"{header}={value}" for header, value in zip(headers, normalized))) nodes.append( ParsedNode( node_type="table_rows", text="\n".join(serialized_rows), metadata={ "headers": headers, "row_start": start + 1, "row_end": start + len(batch), "sheet_name": sheet.title, }, section_path=[sheet.title], ) ) summary = f"Workbook sheets: {', '.join(summaries)}" if summaries else "Workbook" return ParsedDocument(summary=summary, nodes=nodes) async def _parse_docx(self, file_path: str) -> ParsedDocument: try: from docx import Document as DocxDocument except ModuleNotFoundError as error: raise ValueError("DOCX 解析依赖缺失: python-docx") from error doc = DocxDocument(file_path) nodes: list[ParsedNode] = [] section_path: list[str] = [] summary_parts: list[str] = [] for paragraph in doc.paragraphs: text = paragraph.text.strip() if not text: continue style_name = getattr(paragraph.style, "name", "") or "" if style_name.startswith("Heading"): level_match = re.search(r"(\d+)", style_name) level = int(level_match.group(1)) if level_match else 1 section_path = section_path[: level - 1] + [text] nodes.append(ParsedNode("heading", text, {"level": level}, list(section_path))) else: if not section_path: section_path = [doc.core_properties.title or "Document"] summary_parts.append(text) nodes.append(ParsedNode("paragraph", text, {}, list(section_path))) for table in doc.tables: rows = [[cell.text.strip() for cell in row.cells] for row in table.rows] if not rows: continue headers = rows[0] nodes.append( ParsedNode( "table_schema", f"DOCX table columns: {', '.join(headers)} | rows: {max(len(rows) - 1, 0)}", {"headers": headers, "row_count": max(len(rows) - 1, 0), "table_name": "docx_table"}, list(section_path), ) ) for start in range(1, len(rows), 50): batch = rows[start:start + 50] serialized_rows = [", ".join(f"{header}={value}" for header, value in zip(headers, row)) for row in batch] nodes.append( ParsedNode( "table_rows", "\n".join(serialized_rows), { "headers": headers, "row_start": start, "row_end": start + len(batch) - 1, "table_name": "docx_table", }, list(section_path), ) ) summary = " ".join(summary_parts[:3]) if summary_parts else doc.core_properties.title or "Document" return ParsedDocument(summary=summary, nodes=nodes) async def _parse_pdf_with_mineru(self, file_path: str) -> str: try: import mineru except ModuleNotFoundError as error: raise ValueError("PDF 解析依赖缺失: mineru") from error if hasattr(mineru, "to_markdown"): return mineru.to_markdown(file_path) if hasattr(mineru, "parse_to_markdown"): return mineru.parse_to_markdown(file_path) raise ValueError("PDF 解析失败: MinerU 不支持当前接口") async def _parse_pdf(self, file_path: str) -> ParsedDocument: markdown = await self._parse_pdf_with_mineru(file_path) return self._parse_markdown(markdown) def _parse_markdown(self, content: str) -> ParsedDocument: nodes: list[ParsedNode] = [] section_path: list[str] = [] summary_parts: list[str] = [] buffer: list[str] = [] def flush_buffer(): if not buffer: return text = "\n".join(buffer).strip() buffer.clear() if not text: return nodes.append(ParsedNode("paragraph", text, {}, list(section_path))) summary_parts.append(text) for line in content.splitlines(): heading_match = re.match(r"^(#{1,6})\s+(.+)$", line.strip()) if heading_match: flush_buffer() level = len(heading_match.group(1)) title = heading_match.group(2).strip() section_path = section_path[: level - 1] + [title] nodes.append(ParsedNode("heading", title, {"level": level}, list(section_path))) continue if line.strip(): buffer.append(line.strip()) else: flush_buffer() flush_buffer() summary = " ".join(summary_parts[:3]) if summary_parts else content[:200] return ParsedDocument(summary=summary, nodes=nodes) def _parse_text(self, content: str) -> ParsedDocument: paragraphs = [part.strip() for part in content.split("\n\n") if part.strip()] nodes = [ParsedNode("text", paragraph, {}, []) for paragraph in paragraphs] summary = " ".join(paragraphs[:3]) if paragraphs else content[:200] return ParsedDocument(summary=summary, nodes=nodes) def _build_chunks(self, parsed: ParsedDocument) -> list[dict]: chunks: list[dict] = [] for source_order, node in enumerate(parsed.nodes): section_path = node.section_path or [] metadata = { "content_type": node.node_type, "section_path": section_path, "section_title": section_path[-1] if section_path else None, "chunk_level": len(section_path), "parent_key": "/".join(section_path[:-1]) or None, "block_key": "/".join(section_path) or None, "parser_version": PARSER_VERSION, "index_version": INDEX_VERSION, "source_order": source_order, **node.metadata, } chunks.append({"content": node.text, "metadata": metadata}) if not chunks: chunks.append({ "content": parsed.summary, "metadata": { "content_type": "text", "section_path": [], "section_title": None, "chunk_level": 0, "parent_key": None, "block_key": None, "parser_version": PARSER_VERSION, "index_version": INDEX_VERSION, "source_order": 0, }, }) return chunks def _render_structured_markdown(self, parsed: ParsedDocument) -> str: blocks: list[str] = [] for node in parsed.nodes: if node.node_type == "heading": level = max(1, min(int(node.metadata.get("level", 1)), 6)) blocks.append(f"{'#' * level} {node.text}") continue if node.node_type == "table_schema": headers = node.metadata.get("headers") or [] if headers: header_row = "| " + " | ".join(headers) + " |" divider_row = "| " + " | ".join(["---"] * len(headers)) + " |" blocks.append("\n".join([header_row, divider_row])) else: blocks.append(node.text) continue if node.node_type == "table_rows": headers = node.metadata.get("headers") or [] if headers: rows = [] for line in node.text.splitlines(): values_by_header = {} for part in line.split(", "): if "=" not in part: continue key, value = part.split("=", 1) values_by_header[key] = value rows.append("| " + " | ".join(values_by_header.get(header, "") for header in headers) + " |") if rows: blocks.append("\n".join(rows)) continue blocks.append(node.text) continue blocks.append(node.text) return "\n\n".join(block for block in blocks if block).strip() or parsed.summary async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]: result = await self.db.execute( select(DocumentChunk) .where(DocumentChunk.document_id == document_id) .order_by(DocumentChunk.chunk_index) ) return list(result.scalars().all()) async def update_document_chunk(self, user_id: str, document_id: str, chunk_id: str, content: str) -> DocumentChunk: document_result = await self.db.execute( select(Document).where( Document.id == document_id, Document.user_id == user_id, ) ) document = document_result.scalar_one_or_none() if not document: raise ValueError("文档不存在") chunk_result = await self.db.execute( select(DocumentChunk).where( DocumentChunk.id == chunk_id, DocumentChunk.document_id == document_id, ) ) chunk = chunk_result.scalar_one_or_none() if not chunk: raise ValueError("切片不存在") chunk.content = content document.ingestion_status = "indexing" document.ingestion_error = None await self.db.commit() await self.db.refresh(chunk) return chunk async def get_document_content(self, user_id: str, document_id: str) -> str | None: """获取文档的文本内容""" import os result = await self.db.execute( select(Document).where( Document.id == document_id, Document.user_id == user_id, ) ) doc = result.scalar_one_or_none() if not doc: return None if doc.normalized_content: return doc.normalized_content file_path = doc.file_path if not os.path.exists(file_path): return None # 根据文件类型读取内容 ext = doc.filename.split('.')[-1].lower() try: if ext == 'txt': with open(file_path, 'r', encoding='utf-8') as f: return f.read() elif ext == 'md': with open(file_path, 'r', encoding='utf-8') as f: return f.read() else: return f"[文档] {doc.filename}" except Exception: return f"[文档] {doc.filename}"