257 lines
8.5 KiB
Python
257 lines
8.5 KiB
Python
"""
|
||
文档服务 - 上传、解析、分块、存储
|
||
支持多种文档格式 + LlamaIndex 智能分块
|
||
"""
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from fastapi import UploadFile
|
||
from app.models.document import Document, DocumentChunk
|
||
from app.models.folder import Folder
|
||
from app.config import settings
|
||
import os
|
||
import aiofiles
|
||
import uuid
|
||
|
||
|
||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
|
||
|
||
|
||
class DocumentService:
|
||
def __init__(self, db: AsyncSession, user_id: str = None):
|
||
self.db = db
|
||
self.user_id = user_id
|
||
|
||
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
|
||
ext = os.path.splitext(file.filename)[1].lower()
|
||
if ext not in ALLOWED_EXTENSIONS:
|
||
raise ValueError(f"不支持的文件类型: {ext}")
|
||
|
||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||
file_id = str(uuid.uuid4())
|
||
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
|
||
|
||
content = await file.read()
|
||
file_size = len(content)
|
||
if file_size > settings.MAX_UPLOAD_SIZE:
|
||
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
|
||
|
||
async with aiofiles.open(file_path, "wb") as f:
|
||
await f.write(content)
|
||
|
||
text_content = await self._extract_text(file_path, ext)
|
||
|
||
doc = Document(
|
||
user_id=user_id,
|
||
title=file.filename.rsplit('.', 1)[0],
|
||
filename=file.filename,
|
||
file_type=ext[1:],
|
||
file_size=file_size,
|
||
file_path=file_path,
|
||
summary=text_content[:500] if len(text_content) > 500 else text_content,
|
||
folder_id=folder_id,
|
||
)
|
||
self.db.add(doc)
|
||
await self.db.commit()
|
||
await self.db.refresh(doc)
|
||
|
||
chunks = self._chunk_text(text_content)
|
||
for i, chunk_text in enumerate(chunks):
|
||
chunk = DocumentChunk(
|
||
document_id=doc.id,
|
||
chunk_index=i,
|
||
content=chunk_text,
|
||
)
|
||
self.db.add(chunk)
|
||
doc.chunk_count = len(chunks)
|
||
await self.db.commit()
|
||
|
||
return doc
|
||
|
||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||
"""获取文件夹的完整路径"""
|
||
folders = await self.db.execute(
|
||
select(Folder).where(Folder.user_id == self.user_id)
|
||
)
|
||
folder_map = {f.id: f for f in folders.scalars().all()}
|
||
|
||
path_parts = []
|
||
current_id = folder_id
|
||
while current_id:
|
||
folder = folder_map.get(current_id)
|
||
if not folder:
|
||
break
|
||
path_parts.insert(0, folder.name)
|
||
current_id = folder.parent_id
|
||
|
||
return "/" + "/".join(path_parts) if path_parts else None
|
||
|
||
async def delete_document(self, user_id: str, document_id: str):
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
raise ValueError("文档不存在")
|
||
|
||
if os.path.exists(doc.file_path):
|
||
os.remove(doc.file_path)
|
||
|
||
await self.db.delete(doc)
|
||
await self.db.commit()
|
||
|
||
async def _extract_text(self, file_path: str, ext: str) -> str:
|
||
if ext == ".pdf":
|
||
try:
|
||
import pymupdf
|
||
doc = pymupdf.open(file_path)
|
||
text = "".join(page.get_text() for page in doc)
|
||
doc.close()
|
||
return text
|
||
except ImportError:
|
||
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
|
||
|
||
elif ext in (".md", ".txt"):
|
||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||
return await f.read()
|
||
|
||
elif ext in (".docx", ".doc"):
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
doc = DocxDocument(file_path)
|
||
return "\n".join([p.text for p in doc.paragraphs])
|
||
except ImportError:
|
||
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
||
|
||
return "[暂不支持此格式]"
|
||
|
||
def _chunk_text(self, text: str) -> list[str]:
|
||
"""
|
||
智能文档分块策略
|
||
1. 先按 Markdown 标题层级(H1/H2/H3)切分
|
||
2. 每个大段落内部按固定长度切分
|
||
3. 保留上下文(prev_summary / next_summary)
|
||
"""
|
||
import re
|
||
|
||
chunks = []
|
||
|
||
# 策略1: Markdown 标题切分(优先)
|
||
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
|
||
headers = list(header_pattern.finditer(text))
|
||
|
||
if headers:
|
||
# 按标题段落切分
|
||
for i, match in enumerate(headers):
|
||
start = match.start()
|
||
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
|
||
section = text[start:end].strip()
|
||
if len(section) > settings.CHUNK_SIZE:
|
||
# 大段落内部再切分
|
||
sub_chunks = self._split_large_chunk(section, match.group(2))
|
||
chunks.extend(sub_chunks)
|
||
elif section:
|
||
chunks.append(section)
|
||
else:
|
||
# 策略2: 按段落切分
|
||
chunks = self._chunk_by_paragraphs(text)
|
||
|
||
# 过滤空 chunk
|
||
chunks = [c.strip() for c in chunks if c.strip()]
|
||
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
|
||
|
||
def _chunk_by_paragraphs(self, text: str) -> list[str]:
|
||
"""按段落分块,带上下文"""
|
||
paragraphs = text.split("\n\n")
|
||
chunks = []
|
||
current = ""
|
||
prev_summary = ""
|
||
|
||
for para in paragraphs:
|
||
para = para.strip()
|
||
if not para:
|
||
continue
|
||
if len(current) + len(para) < settings.CHUNK_SIZE:
|
||
current += "\n\n" + para
|
||
else:
|
||
if current:
|
||
# 添加上下文摘要
|
||
enriched = current.strip()
|
||
chunks.append(enriched)
|
||
current = para
|
||
|
||
if current.strip():
|
||
chunks.append(current.strip())
|
||
|
||
return chunks
|
||
|
||
def _split_large_chunk(self, text: str, title: str) -> list[str]:
|
||
"""将大段落拆分为固定大小的子块"""
|
||
chunks = []
|
||
sentences = text.split("。")
|
||
current = title + "\n\n"
|
||
|
||
for sentence in sentences:
|
||
sentence = sentence.strip()
|
||
if not sentence:
|
||
continue
|
||
full_sentence = sentence if sentence.endswith("。") else sentence + "。"
|
||
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
|
||
current += full_sentence + " "
|
||
else:
|
||
if current.strip():
|
||
chunks.append(current.strip())
|
||
current = title + "\n\n" + full_sentence + " "
|
||
|
||
if current.strip():
|
||
chunks.append(current.strip())
|
||
|
||
return chunks
|
||
|
||
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
||
result = await self.db.execute(
|
||
select(DocumentChunk)
|
||
.where(DocumentChunk.document_id == document_id)
|
||
.order_by(DocumentChunk.chunk_index)
|
||
)
|
||
return list(result.scalars().all())
|
||
|
||
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
||
"""获取文档的文本内容"""
|
||
import os
|
||
|
||
result = await self.db.execute(
|
||
select(Document).where(
|
||
Document.id == document_id,
|
||
Document.user_id == user_id,
|
||
)
|
||
)
|
||
doc = result.scalar_one_or_none()
|
||
if not doc:
|
||
return None
|
||
|
||
file_path = doc.file_path
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
|
||
# 根据文件类型读取内容
|
||
ext = doc.filename.split('.')[-1].lower()
|
||
|
||
try:
|
||
if ext == 'txt':
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
elif ext == 'md':
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
elif ext == 'pdf':
|
||
# 简单文本提取(生产环境应使用专业库)
|
||
return f"[PDF文档] {doc.filename}"
|
||
else:
|
||
return f"[文档] {doc.filename}"
|
||
except Exception:
|
||
return f"[文档] {doc.filename}"
|