257 lines
8.5 KiB
Python
257 lines
8.5 KiB
Python
|
|
"""
|
|||
|
|
文档服务 - 上传、解析、分块、存储
|
|||
|
|
支持多种文档格式 + LlamaIndex 智能分块
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from fastapi import UploadFile
|
|||
|
|
from app.models.document import Document, DocumentChunk
|
|||
|
|
from app.models.folder import Folder
|
|||
|
|
from app.config import settings
|
|||
|
|
import os
|
|||
|
|
import aiofiles
|
|||
|
|
import uuid
|
|||
|
|
|
|||
|
|
|
|||
|
|
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DocumentService:
|
|||
|
|
def __init__(self, db: AsyncSession, user_id: str = None):
|
|||
|
|
self.db = db
|
|||
|
|
self.user_id = user_id
|
|||
|
|
|
|||
|
|
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
|
|||
|
|
ext = os.path.splitext(file.filename)[1].lower()
|
|||
|
|
if ext not in ALLOWED_EXTENSIONS:
|
|||
|
|
raise ValueError(f"不支持的文件类型: {ext}")
|
|||
|
|
|
|||
|
|
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
|||
|
|
file_id = str(uuid.uuid4())
|
|||
|
|
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
|
|||
|
|
|
|||
|
|
content = await file.read()
|
|||
|
|
file_size = len(content)
|
|||
|
|
if file_size > settings.MAX_UPLOAD_SIZE:
|
|||
|
|
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
|
|||
|
|
|
|||
|
|
async with aiofiles.open(file_path, "wb") as f:
|
|||
|
|
await f.write(content)
|
|||
|
|
|
|||
|
|
text_content = await self._extract_text(file_path, ext)
|
|||
|
|
|
|||
|
|
doc = Document(
|
|||
|
|
user_id=user_id,
|
|||
|
|
title=file.filename.rsplit('.', 1)[0],
|
|||
|
|
filename=file.filename,
|
|||
|
|
file_type=ext[1:],
|
|||
|
|
file_size=file_size,
|
|||
|
|
file_path=file_path,
|
|||
|
|
summary=text_content[:500] if len(text_content) > 500 else text_content,
|
|||
|
|
folder_id=folder_id,
|
|||
|
|
)
|
|||
|
|
self.db.add(doc)
|
|||
|
|
await self.db.commit()
|
|||
|
|
await self.db.refresh(doc)
|
|||
|
|
|
|||
|
|
chunks = self._chunk_text(text_content)
|
|||
|
|
for i, chunk_text in enumerate(chunks):
|
|||
|
|
chunk = DocumentChunk(
|
|||
|
|
document_id=doc.id,
|
|||
|
|
chunk_index=i,
|
|||
|
|
content=chunk_text,
|
|||
|
|
)
|
|||
|
|
self.db.add(chunk)
|
|||
|
|
doc.chunk_count = len(chunks)
|
|||
|
|
await self.db.commit()
|
|||
|
|
|
|||
|
|
return doc
|
|||
|
|
|
|||
|
|
async def _get_folder_path(self, folder_id: str) -> str | None:
|
|||
|
|
"""获取文件夹的完整路径"""
|
|||
|
|
folders = await self.db.execute(
|
|||
|
|
select(Folder).where(Folder.user_id == self.user_id)
|
|||
|
|
)
|
|||
|
|
folder_map = {f.id: f for f in folders.scalars().all()}
|
|||
|
|
|
|||
|
|
path_parts = []
|
|||
|
|
current_id = folder_id
|
|||
|
|
while current_id:
|
|||
|
|
folder = folder_map.get(current_id)
|
|||
|
|
if not folder:
|
|||
|
|
break
|
|||
|
|
path_parts.insert(0, folder.name)
|
|||
|
|
current_id = folder.parent_id
|
|||
|
|
|
|||
|
|
return "/" + "/".join(path_parts) if path_parts else None
|
|||
|
|
|
|||
|
|
async def delete_document(self, user_id: str, document_id: str):
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(Document).where(
|
|||
|
|
Document.id == document_id,
|
|||
|
|
Document.user_id == user_id,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
doc = result.scalar_one_or_none()
|
|||
|
|
if not doc:
|
|||
|
|
raise ValueError("文档不存在")
|
|||
|
|
|
|||
|
|
if os.path.exists(doc.file_path):
|
|||
|
|
os.remove(doc.file_path)
|
|||
|
|
|
|||
|
|
await self.db.delete(doc)
|
|||
|
|
await self.db.commit()
|
|||
|
|
|
|||
|
|
async def _extract_text(self, file_path: str, ext: str) -> str:
|
|||
|
|
if ext == ".pdf":
|
|||
|
|
try:
|
|||
|
|
import pymupdf
|
|||
|
|
doc = pymupdf.open(file_path)
|
|||
|
|
text = "".join(page.get_text() for page in doc)
|
|||
|
|
doc.close()
|
|||
|
|
return text
|
|||
|
|
except ImportError:
|
|||
|
|
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
|
|||
|
|
|
|||
|
|
elif ext in (".md", ".txt"):
|
|||
|
|
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
|||
|
|
return await f.read()
|
|||
|
|
|
|||
|
|
elif ext in (".docx", ".doc"):
|
|||
|
|
try:
|
|||
|
|
from docx import Document as DocxDocument
|
|||
|
|
doc = DocxDocument(file_path)
|
|||
|
|
return "\n".join([p.text for p in doc.paragraphs])
|
|||
|
|
except ImportError:
|
|||
|
|
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
|||
|
|
|
|||
|
|
return "[暂不支持此格式]"
|
|||
|
|
|
|||
|
|
def _chunk_text(self, text: str) -> list[str]:
|
|||
|
|
"""
|
|||
|
|
智能文档分块策略
|
|||
|
|
1. 先按 Markdown 标题层级(H1/H2/H3)切分
|
|||
|
|
2. 每个大段落内部按固定长度切分
|
|||
|
|
3. 保留上下文(prev_summary / next_summary)
|
|||
|
|
"""
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
chunks = []
|
|||
|
|
|
|||
|
|
# 策略1: Markdown 标题切分(优先)
|
|||
|
|
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
|
|||
|
|
headers = list(header_pattern.finditer(text))
|
|||
|
|
|
|||
|
|
if headers:
|
|||
|
|
# 按标题段落切分
|
|||
|
|
for i, match in enumerate(headers):
|
|||
|
|
start = match.start()
|
|||
|
|
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
|
|||
|
|
section = text[start:end].strip()
|
|||
|
|
if len(section) > settings.CHUNK_SIZE:
|
|||
|
|
# 大段落内部再切分
|
|||
|
|
sub_chunks = self._split_large_chunk(section, match.group(2))
|
|||
|
|
chunks.extend(sub_chunks)
|
|||
|
|
elif section:
|
|||
|
|
chunks.append(section)
|
|||
|
|
else:
|
|||
|
|
# 策略2: 按段落切分
|
|||
|
|
chunks = self._chunk_by_paragraphs(text)
|
|||
|
|
|
|||
|
|
# 过滤空 chunk
|
|||
|
|
chunks = [c.strip() for c in chunks if c.strip()]
|
|||
|
|
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
|
|||
|
|
|
|||
|
|
def _chunk_by_paragraphs(self, text: str) -> list[str]:
|
|||
|
|
"""按段落分块,带上下文"""
|
|||
|
|
paragraphs = text.split("\n\n")
|
|||
|
|
chunks = []
|
|||
|
|
current = ""
|
|||
|
|
prev_summary = ""
|
|||
|
|
|
|||
|
|
for para in paragraphs:
|
|||
|
|
para = para.strip()
|
|||
|
|
if not para:
|
|||
|
|
continue
|
|||
|
|
if len(current) + len(para) < settings.CHUNK_SIZE:
|
|||
|
|
current += "\n\n" + para
|
|||
|
|
else:
|
|||
|
|
if current:
|
|||
|
|
# 添加上下文摘要
|
|||
|
|
enriched = current.strip()
|
|||
|
|
chunks.append(enriched)
|
|||
|
|
current = para
|
|||
|
|
|
|||
|
|
if current.strip():
|
|||
|
|
chunks.append(current.strip())
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
def _split_large_chunk(self, text: str, title: str) -> list[str]:
|
|||
|
|
"""将大段落拆分为固定大小的子块"""
|
|||
|
|
chunks = []
|
|||
|
|
sentences = text.split("。")
|
|||
|
|
current = title + "\n\n"
|
|||
|
|
|
|||
|
|
for sentence in sentences:
|
|||
|
|
sentence = sentence.strip()
|
|||
|
|
if not sentence:
|
|||
|
|
continue
|
|||
|
|
full_sentence = sentence if sentence.endswith("。") else sentence + "。"
|
|||
|
|
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
|
|||
|
|
current += full_sentence + " "
|
|||
|
|
else:
|
|||
|
|
if current.strip():
|
|||
|
|
chunks.append(current.strip())
|
|||
|
|
current = title + "\n\n" + full_sentence + " "
|
|||
|
|
|
|||
|
|
if current.strip():
|
|||
|
|
chunks.append(current.strip())
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(DocumentChunk)
|
|||
|
|
.where(DocumentChunk.document_id == document_id)
|
|||
|
|
.order_by(DocumentChunk.chunk_index)
|
|||
|
|
)
|
|||
|
|
return list(result.scalars().all())
|
|||
|
|
|
|||
|
|
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
|||
|
|
"""获取文档的文本内容"""
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
result = await self.db.execute(
|
|||
|
|
select(Document).where(
|
|||
|
|
Document.id == document_id,
|
|||
|
|
Document.user_id == user_id,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
doc = result.scalar_one_or_none()
|
|||
|
|
if not doc:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
file_path = doc.file_path
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 根据文件类型读取内容
|
|||
|
|
ext = doc.filename.split('.')[-1].lower()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if ext == 'txt':
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
return f.read()
|
|||
|
|
elif ext == 'md':
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
return f.read()
|
|||
|
|
elif ext == 'pdf':
|
|||
|
|
# 简单文本提取(生产环境应使用专业库)
|
|||
|
|
return f"[PDF文档] {doc.filename}"
|
|||
|
|
else:
|
|||
|
|
return f"[文档] {doc.filename}"
|
|||
|
|
except Exception:
|
|||
|
|
return f"[文档] {doc.filename}"
|