Files
JARVIS/backend/app/services/document_service.py

257 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
文档服务 - 上传、解析、分块、存储
支持多种文档格式 + LlamaIndex 智能分块
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import UploadFile
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import os
import aiofiles
import uuid
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
class DocumentService:
def __init__(self, db: AsyncSession, user_id: str = None):
self.db = db
self.user_id = user_id
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: {ext}")
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
file_id = str(uuid.uuid4())
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
content = await file.read()
file_size = len(content)
if file_size > settings.MAX_UPLOAD_SIZE:
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
text_content = await self._extract_text(file_path, ext)
doc = Document(
user_id=user_id,
title=file.filename.rsplit('.', 1)[0],
filename=file.filename,
file_type=ext[1:],
file_size=file_size,
file_path=file_path,
summary=text_content[:500] if len(text_content) > 500 else text_content,
folder_id=folder_id,
)
self.db.add(doc)
await self.db.commit()
await self.db.refresh(doc)
chunks = self._chunk_text(text_content)
for i, chunk_text in enumerate(chunks):
chunk = DocumentChunk(
document_id=doc.id,
chunk_index=i,
content=chunk_text,
)
self.db.add(chunk)
doc.chunk_count = len(chunks)
await self.db.commit()
return doc
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
folders = await self.db.execute(
select(Folder).where(Folder.user_id == self.user_id)
)
folder_map = {f.id: f for f in folders.scalars().all()}
path_parts = []
current_id = folder_id
while current_id:
folder = folder_map.get(current_id)
if not folder:
break
path_parts.insert(0, folder.name)
current_id = folder.parent_id
return "/" + "/".join(path_parts) if path_parts else None
async def delete_document(self, user_id: str, document_id: str):
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError("文档不存在")
if os.path.exists(doc.file_path):
os.remove(doc.file_path)
await self.db.delete(doc)
await self.db.commit()
async def _extract_text(self, file_path: str, ext: str) -> str:
if ext == ".pdf":
try:
import pymupdf
doc = pymupdf.open(file_path)
text = "".join(page.get_text() for page in doc)
doc.close()
return text
except ImportError:
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
elif ext in (".md", ".txt"):
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read()
elif ext in (".docx", ".doc"):
try:
from docx import Document as DocxDocument
doc = DocxDocument(file_path)
return "\n".join([p.text for p in doc.paragraphs])
except ImportError:
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
return "[暂不支持此格式]"
def _chunk_text(self, text: str) -> list[str]:
"""
智能文档分块策略
1. 先按 Markdown 标题层级H1/H2/H3切分
2. 每个大段落内部按固定长度切分
3. 保留上下文prev_summary / next_summary
"""
import re
chunks = []
# 策略1: Markdown 标题切分(优先)
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
headers = list(header_pattern.finditer(text))
if headers:
# 按标题段落切分
for i, match in enumerate(headers):
start = match.start()
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
section = text[start:end].strip()
if len(section) > settings.CHUNK_SIZE:
# 大段落内部再切分
sub_chunks = self._split_large_chunk(section, match.group(2))
chunks.extend(sub_chunks)
elif section:
chunks.append(section)
else:
# 策略2: 按段落切分
chunks = self._chunk_by_paragraphs(text)
# 过滤空 chunk
chunks = [c.strip() for c in chunks if c.strip()]
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
def _chunk_by_paragraphs(self, text: str) -> list[str]:
"""按段落分块,带上下文"""
paragraphs = text.split("\n\n")
chunks = []
current = ""
prev_summary = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current) + len(para) < settings.CHUNK_SIZE:
current += "\n\n" + para
else:
if current:
# 添加上下文摘要
enriched = current.strip()
chunks.append(enriched)
current = para
if current.strip():
chunks.append(current.strip())
return chunks
def _split_large_chunk(self, text: str, title: str) -> list[str]:
"""将大段落拆分为固定大小的子块"""
chunks = []
sentences = text.split("")
current = title + "\n\n"
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
full_sentence = sentence if sentence.endswith("") else sentence + ""
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
current += full_sentence + " "
else:
if current.strip():
chunks.append(current.strip())
current = title + "\n\n" + full_sentence + " "
if current.strip():
chunks.append(current.strip())
return chunks
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return list(result.scalars().all())
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
"""获取文档的文本内容"""
import os
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
return None
file_path = doc.file_path
if not os.path.exists(file_path):
return None
# 根据文件类型读取内容
ext = doc.filename.split('.')[-1].lower()
try:
if ext == 'txt':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'md':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'pdf':
# 简单文本提取(生产环境应使用专业库)
return f"[PDF文档] {doc.filename}"
else:
return f"[文档] {doc.filename}"
except Exception:
return f"[文档] {doc.filename}"