Add MinerU document ingestion support
Normalize uploaded documents into structured markdown, add clearer parser errors for missing dependencies, and cover the ingestion flow with backend tests. This also replaces deprecated UTC timestamp helpers in the touched backend paths so the knowledge pipeline stays warning-free. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.config import settings
|
||||
@@ -33,3 +34,62 @@ async def get_db() -> AsyncSession:
|
||||
async def init_db():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await ensure_log_columns(conn)
|
||||
await ensure_message_columns(conn)
|
||||
await ensure_document_columns(conn)
|
||||
|
||||
|
||||
async def ensure_log_columns(conn):
|
||||
result = await conn.execute(text("PRAGMA table_info(logs)"))
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
"request_id": "ALTER TABLE logs ADD COLUMN request_id VARCHAR(64)",
|
||||
"route": "ALTER TABLE logs ADD COLUMN route VARCHAR(255)",
|
||||
"method": "ALTER TABLE logs ADD COLUMN method VARCHAR(16)",
|
||||
"status_code": "ALTER TABLE logs ADD COLUMN status_code INTEGER",
|
||||
"error_type": "ALTER TABLE logs ADD COLUMN error_type VARCHAR(100)",
|
||||
"operation": "ALTER TABLE logs ADD COLUMN operation VARCHAR(100)",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_message_columns(conn):
|
||||
result = await conn.execute(text("PRAGMA table_info(messages)"))
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
"attachments": "ALTER TABLE messages ADD COLUMN attachments JSON",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_document_columns(conn):
|
||||
result = await conn.execute(text("PRAGMA table_info(documents)"))
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
"ingestion_status": "ALTER TABLE documents ADD COLUMN ingestion_status VARCHAR(50) DEFAULT 'uploaded' NOT NULL",
|
||||
"ingestion_error": "ALTER TABLE documents ADD COLUMN ingestion_error TEXT",
|
||||
"indexed_at": "ALTER TABLE documents ADD COLUMN indexed_at DATETIME",
|
||||
"parser_version": "ALTER TABLE documents ADD COLUMN parser_version VARCHAR(50)",
|
||||
"index_version": "ALTER TABLE documents ADD COLUMN index_version VARCHAR(50)",
|
||||
"normalized_content": "ALTER TABLE documents ADD COLUMN normalized_content TEXT",
|
||||
"normalized_format": "ALTER TABLE documents ADD COLUMN normalized_format VARCHAR(50)",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from sqlalchemy import Column, String, DateTime
|
||||
from app.database import Base
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
created_at = Column(DateTime, default=utc_now, nullable=False)
|
||||
updated_at = Column(DateTime, default=utc_now, onupdate=utc_now, nullable=False)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Integer, Text, ForeignKey, Boolean
|
||||
from sqlalchemy import Column, String, Integer, Text, ForeignKey, Boolean, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -16,6 +16,13 @@ class Document(BaseModel):
|
||||
summary = Column(Text, nullable=True)
|
||||
chunk_count = Column(Integer, default=0)
|
||||
is_indexed = Column(Boolean, default=False)
|
||||
ingestion_status = Column(String(50), default="uploaded", nullable=False)
|
||||
ingestion_error = Column(Text, nullable=True)
|
||||
indexed_at = Column(DateTime, nullable=True)
|
||||
parser_version = Column(String(50), nullable=True)
|
||||
index_version = Column(String(50), nullable=True)
|
||||
normalized_content = Column(Text, nullable=True)
|
||||
normalized_format = Column(String(50), nullable=True)
|
||||
|
||||
chunks = relationship("DocumentChunk", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, ForeignKey, Boolean, DateTime, Enum as SQLEnum
|
||||
from datetime import datetime
|
||||
from app.models.base import BaseModel
|
||||
from app.models.base import BaseModel, utc_now
|
||||
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
@@ -14,7 +13,7 @@ class MemorySummary(BaseModel):
|
||||
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
|
||||
summary_text = Column(Text, nullable=False) # 摘要内容
|
||||
turn_count = Column(Integer, default=0) # 摘要时累计轮数
|
||||
summary_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
summary_at = Column(DateTime, default=utc_now, nullable=False)
|
||||
|
||||
|
||||
class UserMemory(BaseModel):
|
||||
@@ -31,5 +30,5 @@ class UserMemory(BaseModel):
|
||||
is_recalled = Column(Boolean, default=False) # 是否在当前对话中被召回
|
||||
recall_count = Column(Integer, default=0) # 被召回次数
|
||||
source_conversation_id = Column(String(36), nullable=True) # 来源对话
|
||||
extracted_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
extracted_at = Column(DateTime, default=utc_now, nullable=False)
|
||||
last_recalled_at = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -8,12 +8,13 @@ from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.document_service import DocumentService
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
from app.schemas.document import DocumentChunkOut, DocumentChunkUpdate, DocumentOut
|
||||
from dataclasses import asdict
|
||||
|
||||
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
||||
|
||||
|
||||
@router.get("", response_model=list)
|
||||
@router.get("", response_model=list[DocumentOut])
|
||||
async def list_documents(
|
||||
folder_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -36,7 +37,10 @@ async def upload_document(
|
||||
):
|
||||
"""上传文档,自动分块并向量化"""
|
||||
doc_svc = DocumentService(db)
|
||||
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
||||
try:
|
||||
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
||||
except ValueError as error:
|
||||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||||
|
||||
# 后台索引到 ChromaDB
|
||||
def index_task():
|
||||
@@ -73,7 +77,7 @@ async def get_document(
|
||||
return doc
|
||||
|
||||
|
||||
@router.get("/{document_id}/chunks")
|
||||
@router.get("/{document_id}/chunks", response_model=list[DocumentChunkOut])
|
||||
async def get_document_chunks(
|
||||
document_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -98,6 +102,33 @@ async def get_document_chunks(
|
||||
return chunks_result.scalars().all()
|
||||
|
||||
|
||||
@router.put("/{document_id}/chunks/{chunk_id}", response_model=DocumentChunkOut)
|
||||
async def update_document_chunk(
|
||||
document_id: str,
|
||||
chunk_id: str,
|
||||
payload: DocumentChunkUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
doc_svc = DocumentService(db)
|
||||
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
||||
|
||||
try:
|
||||
chunk = await doc_svc.update_document_chunk(current_user.id, document_id, chunk_id, payload.content)
|
||||
except ValueError as error:
|
||||
raise HTTPException(status_code=404, detail=str(error)) from error
|
||||
|
||||
reindexed = await kb_svc.reindex_document_chunks(document_id, current_user.id)
|
||||
if not reindexed:
|
||||
raise HTTPException(status_code=500, detail="切片更新后重新索引失败")
|
||||
|
||||
refreshed_chunk_result = await db.execute(
|
||||
select(DocumentChunk).where(DocumentChunk.id == chunk.id)
|
||||
)
|
||||
refreshed_chunk = refreshed_chunk_result.scalar_one()
|
||||
return refreshed_chunk
|
||||
|
||||
|
||||
@router.delete("/{document_id}", status_code=204)
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
@@ -129,7 +160,7 @@ async def search_documents(
|
||||
if mode == "keyword":
|
||||
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
||||
elif mode == "semantic":
|
||||
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
|
||||
results = await kb_svc.retrieve(query, current_user.id, top_k=top_k, use_rerank=True)
|
||||
else:
|
||||
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
||||
|
||||
|
||||
@@ -64,8 +64,8 @@ async def update_task(
|
||||
if field == "tags":
|
||||
setattr(task, field, json.dumps(value))
|
||||
elif field == "status" and value == TaskStatus.DONE:
|
||||
from datetime import datetime
|
||||
task.completed_at = datetime.utcnow()
|
||||
from datetime import UTC, datetime
|
||||
task.completed_at = datetime.now(UTC)
|
||||
setattr(task, field, value)
|
||||
else:
|
||||
setattr(task, field, value)
|
||||
|
||||
@@ -81,9 +81,9 @@ async def update_todo(
|
||||
if data.title is not None:
|
||||
todo.title = data.title
|
||||
if data.is_completed is not None:
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
todo.is_completed = data.is_completed
|
||||
todo.completed_at = datetime.utcnow() if data.is_completed else None
|
||||
todo.completed_at = datetime.now(UTC) if data.is_completed else None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
|
||||
@@ -11,6 +11,13 @@ class DocumentOut(BaseModel):
|
||||
summary: str | None
|
||||
chunk_count: int
|
||||
is_indexed: bool
|
||||
ingestion_status: str
|
||||
ingestion_error: str | None
|
||||
indexed_at: datetime | None
|
||||
parser_version: str | None
|
||||
index_version: str | None
|
||||
normalized_format: str | None
|
||||
folder_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -25,6 +32,10 @@ class DocumentChunkOut(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DocumentChunkUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from jose import jwt, JWTError
|
||||
from app.config import settings
|
||||
@@ -16,7 +16,7 @@ def get_password_hash(password: str) -> str:
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
expire = datetime.now(UTC) + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
@@ -9,12 +9,35 @@ from fastapi import UploadFile
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
from app.services.brain_service import BrainService
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import aiofiles
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
|
||||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc", ".csv", ".xlsx"}
|
||||
PARSER_VERSION = "v2"
|
||||
INDEX_VERSION = "v2"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedNode:
|
||||
node_type: str
|
||||
text: str
|
||||
metadata: dict = field(default_factory=dict)
|
||||
section_path: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDocument:
|
||||
summary: str
|
||||
nodes: list[ParsedNode]
|
||||
structured_markdown: str = ""
|
||||
|
||||
|
||||
class DocumentService:
|
||||
@@ -39,7 +62,8 @@ class DocumentService:
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
text_content = await self._extract_text(file_path, ext)
|
||||
parsed = await self._parse_document(file_path, ext)
|
||||
parsed.structured_markdown = self._render_structured_markdown(parsed)
|
||||
|
||||
doc = Document(
|
||||
user_id=user_id,
|
||||
@@ -48,26 +72,85 @@ class DocumentService:
|
||||
file_type=ext[1:],
|
||||
file_size=file_size,
|
||||
file_path=file_path,
|
||||
summary=text_content[:500] if len(text_content) > 500 else text_content,
|
||||
summary=parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary,
|
||||
folder_id=folder_id,
|
||||
ingestion_status="uploaded",
|
||||
ingestion_error=None,
|
||||
parser_version=PARSER_VERSION,
|
||||
index_version=INDEX_VERSION,
|
||||
normalized_content=parsed.structured_markdown,
|
||||
normalized_format="structured_markdown",
|
||||
)
|
||||
self.db.add(doc)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(doc)
|
||||
await self.db.flush()
|
||||
|
||||
chunks = self._chunk_text(text_content)
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
chunks = self._build_chunks(parsed)
|
||||
for i, chunk_data in enumerate(chunks):
|
||||
chunk = DocumentChunk(
|
||||
document_id=doc.id,
|
||||
chunk_index=i,
|
||||
content=chunk_text,
|
||||
content=chunk_data["content"],
|
||||
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
|
||||
)
|
||||
self.db.add(chunk)
|
||||
doc.chunk_count = len(chunks)
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="document",
|
||||
source_id=doc.id,
|
||||
event_type="document_uploaded",
|
||||
title=doc.filename,
|
||||
content_summary=doc.summary,
|
||||
raw_excerpt=(doc.normalized_content or "")[:1000] or None,
|
||||
metadata_={
|
||||
"document_id": doc.id,
|
||||
"file_type": doc.file_type,
|
||||
"ingestion_status": doc.ingestion_status,
|
||||
},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(doc)
|
||||
|
||||
return doc
|
||||
|
||||
async def rebuild_document(self, document: Document) -> Document:
|
||||
ext = os.path.splitext(document.filename)[1].lower()
|
||||
parsed = await self._parse_document(document.file_path, ext)
|
||||
parsed.structured_markdown = self._render_structured_markdown(parsed)
|
||||
|
||||
chunk_result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document.id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
existing_chunks = list(chunk_result.scalars().all())
|
||||
for chunk in existing_chunks:
|
||||
await self.db.delete(chunk)
|
||||
await self.db.flush()
|
||||
|
||||
chunks = self._build_chunks(parsed)
|
||||
for i, chunk_data in enumerate(chunks):
|
||||
self.db.add(DocumentChunk(
|
||||
document_id=document.id,
|
||||
chunk_index=i,
|
||||
content=chunk_data["content"],
|
||||
metadata_=json.dumps(chunk_data["metadata"], ensure_ascii=False),
|
||||
))
|
||||
|
||||
document.summary = parsed.summary[:500] if len(parsed.summary) > 500 else parsed.summary
|
||||
document.chunk_count = len(chunks)
|
||||
document.ingestion_status = "indexing"
|
||||
document.ingestion_error = None
|
||||
document.parser_version = PARSER_VERSION
|
||||
document.index_version = INDEX_VERSION
|
||||
document.normalized_content = parsed.structured_markdown
|
||||
document.normalized_format = "structured_markdown"
|
||||
await self.db.commit()
|
||||
await self.db.refresh(document)
|
||||
return document
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
folders = await self.db.execute(
|
||||
@@ -104,112 +187,313 @@ class DocumentService:
|
||||
await self.db.commit()
|
||||
|
||||
async def _extract_text(self, file_path: str, ext: str) -> str:
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import pymupdf
|
||||
doc = pymupdf.open(file_path)
|
||||
text = "".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
return text
|
||||
except ImportError:
|
||||
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
|
||||
|
||||
elif ext in (".md", ".txt"):
|
||||
if ext in (".md", ".txt"):
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
return await f.read()
|
||||
|
||||
elif ext in (".docx", ".doc"):
|
||||
if ext in (".docx", ".doc"):
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
doc = DocxDocument(file_path)
|
||||
return "\n".join([p.text for p in doc.paragraphs])
|
||||
parts = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
row_values = [cell.text.strip() for cell in row.cells]
|
||||
if any(row_values):
|
||||
parts.append(" | ".join(row_values))
|
||||
return "\n".join(parts)
|
||||
except ImportError:
|
||||
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
||||
|
||||
return "[暂不支持此格式]"
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""
|
||||
智能文档分块策略
|
||||
1. 先按 Markdown 标题层级(H1/H2/H3)切分
|
||||
2. 每个大段落内部按固定长度切分
|
||||
3. 保留上下文(prev_summary / next_summary)
|
||||
"""
|
||||
import re
|
||||
async def _parse_document(self, file_path: str, ext: str) -> ParsedDocument:
|
||||
if ext == ".csv":
|
||||
return await self._parse_csv(file_path)
|
||||
if ext == ".xlsx":
|
||||
return await self._parse_xlsx(file_path)
|
||||
if ext == ".md":
|
||||
content = await self._extract_text(file_path, ext)
|
||||
return self._parse_markdown(content)
|
||||
if ext == ".txt":
|
||||
content = await self._extract_text(file_path, ext)
|
||||
return self._parse_text(content)
|
||||
if ext == ".docx":
|
||||
return await self._parse_docx(file_path)
|
||||
if ext == ".doc":
|
||||
content = await self._extract_text(file_path, ext)
|
||||
return self._parse_text(content)
|
||||
if ext == ".pdf":
|
||||
return await self._parse_pdf(file_path)
|
||||
content = await self._extract_text(file_path, ext)
|
||||
return self._parse_text(content)
|
||||
|
||||
chunks = []
|
||||
async def _parse_csv(self, file_path: str) -> ParsedDocument:
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8-sig") as f:
|
||||
content = await f.read()
|
||||
reader = list(csv.reader(io.StringIO(content)))
|
||||
headers = reader[0] if reader else []
|
||||
rows = reader[1:] if len(reader) > 1 else []
|
||||
nodes = [
|
||||
ParsedNode(
|
||||
node_type="table_schema",
|
||||
text=f"CSV columns: {', '.join(headers)} | rows: {len(rows)}",
|
||||
metadata={"headers": headers, "row_count": len(rows), "table_name": "csv"},
|
||||
section_path=["csv"],
|
||||
)
|
||||
]
|
||||
for start in range(0, len(rows), 50):
|
||||
batch = rows[start:start + 50]
|
||||
serialized_rows = []
|
||||
for row in batch:
|
||||
serialized = ", ".join(
|
||||
f"{header}={value}" for header, value in zip(headers, row)
|
||||
)
|
||||
serialized_rows.append(serialized)
|
||||
nodes.append(
|
||||
ParsedNode(
|
||||
node_type="table_rows",
|
||||
text="\n".join(serialized_rows),
|
||||
metadata={
|
||||
"headers": headers,
|
||||
"row_start": start + 1,
|
||||
"row_end": start + len(batch),
|
||||
"table_name": "csv",
|
||||
},
|
||||
section_path=["csv"],
|
||||
)
|
||||
)
|
||||
summary = f"CSV with columns {', '.join(headers)}" if headers else "CSV document"
|
||||
return ParsedDocument(summary=summary, nodes=nodes)
|
||||
|
||||
# 策略1: Markdown 标题切分(优先)
|
||||
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
|
||||
headers = list(header_pattern.finditer(text))
|
||||
async def _parse_xlsx(self, file_path: str) -> ParsedDocument:
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ModuleNotFoundError as error:
|
||||
raise ValueError("XLSX 解析依赖缺失: openpyxl") from error
|
||||
|
||||
if headers:
|
||||
# 按标题段落切分
|
||||
for i, match in enumerate(headers):
|
||||
start = match.start()
|
||||
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
|
||||
section = text[start:end].strip()
|
||||
if len(section) > settings.CHUNK_SIZE:
|
||||
# 大段落内部再切分
|
||||
sub_chunks = self._split_large_chunk(section, match.group(2))
|
||||
chunks.extend(sub_chunks)
|
||||
elif section:
|
||||
chunks.append(section)
|
||||
else:
|
||||
# 策略2: 按段落切分
|
||||
chunks = self._chunk_by_paragraphs(text)
|
||||
|
||||
# 过滤空 chunk
|
||||
chunks = [c.strip() for c in chunks if c.strip()]
|
||||
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
|
||||
|
||||
def _chunk_by_paragraphs(self, text: str) -> list[str]:
|
||||
"""按段落分块,带上下文"""
|
||||
paragraphs = text.split("\n\n")
|
||||
chunks = []
|
||||
current = ""
|
||||
prev_summary = ""
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
workbook = load_workbook(file_path, data_only=True)
|
||||
nodes: list[ParsedNode] = []
|
||||
summaries: list[str] = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = list(sheet.iter_rows(values_only=True))
|
||||
if not rows:
|
||||
continue
|
||||
if len(current) + len(para) < settings.CHUNK_SIZE:
|
||||
current += "\n\n" + para
|
||||
headers = [str(cell).strip() if cell is not None else "" for cell in rows[0]]
|
||||
data_rows = rows[1:]
|
||||
summaries.append(sheet.title)
|
||||
nodes.append(
|
||||
ParsedNode(
|
||||
node_type="table_schema",
|
||||
text=f"Sheet {sheet.title} columns: {', '.join(headers)} | rows: {len(data_rows)}",
|
||||
metadata={"headers": headers, "row_count": len(data_rows), "sheet_name": sheet.title},
|
||||
section_path=[sheet.title],
|
||||
)
|
||||
)
|
||||
for start in range(0, len(data_rows), 50):
|
||||
batch = data_rows[start:start + 50]
|
||||
serialized_rows = []
|
||||
for row in batch:
|
||||
normalized = ["" if value is None else str(value) for value in row]
|
||||
serialized_rows.append(", ".join(f"{header}={value}" for header, value in zip(headers, normalized)))
|
||||
nodes.append(
|
||||
ParsedNode(
|
||||
node_type="table_rows",
|
||||
text="\n".join(serialized_rows),
|
||||
metadata={
|
||||
"headers": headers,
|
||||
"row_start": start + 1,
|
||||
"row_end": start + len(batch),
|
||||
"sheet_name": sheet.title,
|
||||
},
|
||||
section_path=[sheet.title],
|
||||
)
|
||||
)
|
||||
summary = f"Workbook sheets: {', '.join(summaries)}" if summaries else "Workbook"
|
||||
return ParsedDocument(summary=summary, nodes=nodes)
|
||||
|
||||
async def _parse_docx(self, file_path: str) -> ParsedDocument:
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ModuleNotFoundError as error:
|
||||
raise ValueError("DOCX 解析依赖缺失: python-docx") from error
|
||||
|
||||
doc = DocxDocument(file_path)
|
||||
nodes: list[ParsedNode] = []
|
||||
section_path: list[str] = []
|
||||
summary_parts: list[str] = []
|
||||
for paragraph in doc.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
style_name = getattr(paragraph.style, "name", "") or ""
|
||||
if style_name.startswith("Heading"):
|
||||
level_match = re.search(r"(\d+)", style_name)
|
||||
level = int(level_match.group(1)) if level_match else 1
|
||||
section_path = section_path[: level - 1] + [text]
|
||||
nodes.append(ParsedNode("heading", text, {"level": level}, list(section_path)))
|
||||
else:
|
||||
if current:
|
||||
# 添加上下文摘要
|
||||
enriched = current.strip()
|
||||
chunks.append(enriched)
|
||||
current = para
|
||||
if not section_path:
|
||||
section_path = [doc.core_properties.title or "Document"]
|
||||
summary_parts.append(text)
|
||||
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
|
||||
for table in doc.tables:
|
||||
rows = [[cell.text.strip() for cell in row.cells] for row in table.rows]
|
||||
if not rows:
|
||||
continue
|
||||
headers = rows[0]
|
||||
nodes.append(
|
||||
ParsedNode(
|
||||
"table_schema",
|
||||
f"DOCX table columns: {', '.join(headers)} | rows: {max(len(rows) - 1, 0)}",
|
||||
{"headers": headers, "row_count": max(len(rows) - 1, 0), "table_name": "docx_table"},
|
||||
list(section_path),
|
||||
)
|
||||
)
|
||||
for start in range(1, len(rows), 50):
|
||||
batch = rows[start:start + 50]
|
||||
serialized_rows = [", ".join(f"{header}={value}" for header, value in zip(headers, row)) for row in batch]
|
||||
nodes.append(
|
||||
ParsedNode(
|
||||
"table_rows",
|
||||
"\n".join(serialized_rows),
|
||||
{
|
||||
"headers": headers,
|
||||
"row_start": start,
|
||||
"row_end": start + len(batch) - 1,
|
||||
"table_name": "docx_table",
|
||||
},
|
||||
list(section_path),
|
||||
)
|
||||
)
|
||||
summary = " ".join(summary_parts[:3]) if summary_parts else doc.core_properties.title or "Document"
|
||||
return ParsedDocument(summary=summary, nodes=nodes)
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
async def _parse_pdf_with_mineru(self, file_path: str) -> str:
|
||||
try:
|
||||
import mineru
|
||||
except ModuleNotFoundError as error:
|
||||
raise ValueError("PDF 解析依赖缺失: mineru") from error
|
||||
|
||||
if hasattr(mineru, "to_markdown"):
|
||||
return mineru.to_markdown(file_path)
|
||||
if hasattr(mineru, "parse_to_markdown"):
|
||||
return mineru.parse_to_markdown(file_path)
|
||||
|
||||
raise ValueError("PDF 解析失败: MinerU 不支持当前接口")
|
||||
|
||||
async def _parse_pdf(self, file_path: str) -> ParsedDocument:
|
||||
markdown = await self._parse_pdf_with_mineru(file_path)
|
||||
return self._parse_markdown(markdown)
|
||||
|
||||
def _parse_markdown(self, content: str) -> ParsedDocument:
|
||||
nodes: list[ParsedNode] = []
|
||||
section_path: list[str] = []
|
||||
summary_parts: list[str] = []
|
||||
buffer: list[str] = []
|
||||
|
||||
def flush_buffer():
|
||||
if not buffer:
|
||||
return
|
||||
text = "\n".join(buffer).strip()
|
||||
buffer.clear()
|
||||
if not text:
|
||||
return
|
||||
nodes.append(ParsedNode("paragraph", text, {}, list(section_path)))
|
||||
summary_parts.append(text)
|
||||
|
||||
for line in content.splitlines():
|
||||
heading_match = re.match(r"^(#{1,6})\s+(.+)$", line.strip())
|
||||
if heading_match:
|
||||
flush_buffer()
|
||||
level = len(heading_match.group(1))
|
||||
title = heading_match.group(2).strip()
|
||||
section_path = section_path[: level - 1] + [title]
|
||||
nodes.append(ParsedNode("heading", title, {"level": level}, list(section_path)))
|
||||
continue
|
||||
if line.strip():
|
||||
buffer.append(line.strip())
|
||||
else:
|
||||
flush_buffer()
|
||||
flush_buffer()
|
||||
summary = " ".join(summary_parts[:3]) if summary_parts else content[:200]
|
||||
return ParsedDocument(summary=summary, nodes=nodes)
|
||||
|
||||
def _parse_text(self, content: str) -> ParsedDocument:
|
||||
paragraphs = [part.strip() for part in content.split("\n\n") if part.strip()]
|
||||
nodes = [ParsedNode("text", paragraph, {}, []) for paragraph in paragraphs]
|
||||
summary = " ".join(paragraphs[:3]) if paragraphs else content[:200]
|
||||
return ParsedDocument(summary=summary, nodes=nodes)
|
||||
|
||||
def _build_chunks(self, parsed: ParsedDocument) -> list[dict]:
|
||||
chunks: list[dict] = []
|
||||
for source_order, node in enumerate(parsed.nodes):
|
||||
section_path = node.section_path or []
|
||||
metadata = {
|
||||
"content_type": node.node_type,
|
||||
"section_path": section_path,
|
||||
"section_title": section_path[-1] if section_path else None,
|
||||
"chunk_level": len(section_path),
|
||||
"parent_key": "/".join(section_path[:-1]) or None,
|
||||
"block_key": "/".join(section_path) or None,
|
||||
"parser_version": PARSER_VERSION,
|
||||
"index_version": INDEX_VERSION,
|
||||
"source_order": source_order,
|
||||
**node.metadata,
|
||||
}
|
||||
chunks.append({"content": node.text, "metadata": metadata})
|
||||
if not chunks:
|
||||
chunks.append({
|
||||
"content": parsed.summary,
|
||||
"metadata": {
|
||||
"content_type": "text",
|
||||
"section_path": [],
|
||||
"section_title": None,
|
||||
"chunk_level": 0,
|
||||
"parent_key": None,
|
||||
"block_key": None,
|
||||
"parser_version": PARSER_VERSION,
|
||||
"index_version": INDEX_VERSION,
|
||||
"source_order": 0,
|
||||
},
|
||||
})
|
||||
return chunks
|
||||
|
||||
def _split_large_chunk(self, text: str, title: str) -> list[str]:
|
||||
"""将大段落拆分为固定大小的子块"""
|
||||
chunks = []
|
||||
sentences = text.split("。")
|
||||
current = title + "\n\n"
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
def _render_structured_markdown(self, parsed: ParsedDocument) -> str:
|
||||
blocks: list[str] = []
|
||||
for node in parsed.nodes:
|
||||
if node.node_type == "heading":
|
||||
level = max(1, min(int(node.metadata.get("level", 1)), 6))
|
||||
blocks.append(f"{'#' * level} {node.text}")
|
||||
continue
|
||||
full_sentence = sentence if sentence.endswith("。") else sentence + "。"
|
||||
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
|
||||
current += full_sentence + " "
|
||||
else:
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
current = title + "\n\n" + full_sentence + " "
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
if node.node_type == "table_schema":
|
||||
headers = node.metadata.get("headers") or []
|
||||
if headers:
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
divider_row = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
blocks.append("\n".join([header_row, divider_row]))
|
||||
else:
|
||||
blocks.append(node.text)
|
||||
continue
|
||||
if node.node_type == "table_rows":
|
||||
headers = node.metadata.get("headers") or []
|
||||
if headers:
|
||||
rows = []
|
||||
for line in node.text.splitlines():
|
||||
values_by_header = {}
|
||||
for part in line.split(", "):
|
||||
if "=" not in part:
|
||||
continue
|
||||
key, value = part.split("=", 1)
|
||||
values_by_header[key] = value
|
||||
rows.append("| " + " | ".join(values_by_header.get(header, "") for header in headers) + " |")
|
||||
if rows:
|
||||
blocks.append("\n".join(rows))
|
||||
continue
|
||||
blocks.append(node.text)
|
||||
continue
|
||||
blocks.append(node.text)
|
||||
return "\n\n".join(block for block in blocks if block).strip() or parsed.summary
|
||||
|
||||
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
||||
result = await self.db.execute(
|
||||
@@ -219,6 +503,34 @@ class DocumentService:
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_document_chunk(self, user_id: str, document_id: str, chunk_id: str, content: str) -> DocumentChunk:
|
||||
document_result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
document = document_result.scalar_one_or_none()
|
||||
if not document:
|
||||
raise ValueError("文档不存在")
|
||||
|
||||
chunk_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.id == chunk_id,
|
||||
DocumentChunk.document_id == document_id,
|
||||
)
|
||||
)
|
||||
chunk = chunk_result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
raise ValueError("切片不存在")
|
||||
|
||||
chunk.content = content
|
||||
document.ingestion_status = "indexing"
|
||||
document.ingestion_error = None
|
||||
await self.db.commit()
|
||||
await self.db.refresh(chunk)
|
||||
return chunk
|
||||
|
||||
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
||||
"""获取文档的文本内容"""
|
||||
import os
|
||||
@@ -233,6 +545,9 @@ class DocumentService:
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
if doc.normalized_content:
|
||||
return doc.normalized_content
|
||||
|
||||
file_path = doc.file_path
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
@@ -247,9 +562,6 @@ class DocumentService:
|
||||
elif ext == 'md':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif ext == 'pdf':
|
||||
# 简单文本提取(生产环境应使用专业库)
|
||||
return f"[PDF文档] {doc.filename}"
|
||||
else:
|
||||
return f"[文档] {doc.filename}"
|
||||
except Exception:
|
||||
|
||||
@@ -14,9 +14,12 @@ from sqlalchemy import select, or_
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
from app.services.document_service import DocumentService
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -72,24 +75,50 @@ class KnowledgeService:
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self._index_chunks(doc, chunks, user_id, folder_path=folder_path)
|
||||
|
||||
async def _index_chunks(
|
||||
self,
|
||||
document: Document,
|
||||
chunks: list[DocumentChunk],
|
||||
user_id: str,
|
||||
folder_path: str | None = None,
|
||||
):
|
||||
folder_path = folder_path or (await self._get_folder_path(document.folder_id) if document.folder_id else "")
|
||||
collection = self.get_collection(user_id)
|
||||
|
||||
ids = [chunk.id for chunk in chunks]
|
||||
documents = [chunk.content for chunk in chunks]
|
||||
metadatas = [
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"document_title": doc.title,
|
||||
metadatas = []
|
||||
for chunk in chunks:
|
||||
chunk_metadata = self._parse_metadata(chunk.metadata_)
|
||||
meta = {
|
||||
"document_id": document.id,
|
||||
"document_title": document.title,
|
||||
"document_filename": document.filename,
|
||||
"chunk_index": chunk.chunk_index,
|
||||
"file_type": doc.file_type,
|
||||
"file_type": document.file_type,
|
||||
"folder_path": folder_path or "",
|
||||
"content_type": chunk_metadata.get("content_type", "text"),
|
||||
"section_title": chunk_metadata.get("section_title") or "",
|
||||
"section_path": " / ".join(chunk_metadata.get("section_path", [])),
|
||||
"page_number": chunk_metadata.get("page_number") or 0,
|
||||
"sheet_name": chunk_metadata.get("sheet_name") or "",
|
||||
"row_start": chunk_metadata.get("row_start") or 0,
|
||||
"row_end": chunk_metadata.get("row_end") or 0,
|
||||
"parser_version": chunk_metadata.get("parser_version") or document.parser_version or "",
|
||||
"index_version": chunk_metadata.get("index_version") or document.index_version or "",
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
chunk.chroma_collection = f"user_{user_id}"
|
||||
chunk.chroma_id = chunk.id
|
||||
metadatas.append(meta)
|
||||
|
||||
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
doc.is_indexed = True
|
||||
document.is_indexed = True
|
||||
document.ingestion_status = "ready"
|
||||
document.ingestion_error = None
|
||||
document.indexed_at = datetime.now(UTC)
|
||||
await self.db.commit()
|
||||
|
||||
async def retrieve(
|
||||
@@ -141,7 +170,7 @@ class KnowledgeService:
|
||||
meta = metadatas[i] if i < len(metadatas) else {}
|
||||
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
||||
|
||||
prev_chunk, next_chunk = await self._get_sibling_chunks(
|
||||
prev_chunk, next_chunk = await self._get_related_chunks(
|
||||
chunk_id=chunk_id,
|
||||
chunk_index=meta.get("chunk_index", 0),
|
||||
document_id=meta.get("document_id", ""),
|
||||
@@ -153,7 +182,7 @@ class KnowledgeService:
|
||||
document_title=meta.get("document_title", ""),
|
||||
content=documents[i] if i < len(documents) else "",
|
||||
score=score,
|
||||
metadata_=str(meta),
|
||||
metadata_=json.dumps(meta, ensure_ascii=False),
|
||||
prev_chunk=prev_chunk,
|
||||
next_chunk=next_chunk,
|
||||
))
|
||||
@@ -171,10 +200,11 @@ class KnowledgeService:
|
||||
results: list[SearchResult],
|
||||
top_k: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
|
||||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1 + 结构加权"""
|
||||
import re
|
||||
|
||||
query_words = set(re.findall(r"\w+", query.lower()))
|
||||
table_query = any(token in query.lower() for token in ["sheet", "excel", "csv", "表", "列", "金额", "统计", "日期"])
|
||||
|
||||
scored = []
|
||||
for r in results:
|
||||
@@ -189,36 +219,56 @@ class KnowledgeService:
|
||||
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
||||
score += title_overlap * 0.1
|
||||
|
||||
metadata = self._parse_metadata(r.metadata_)
|
||||
if table_query and metadata.get("content_type") == "table_schema":
|
||||
score += 0.25
|
||||
elif table_query and metadata.get("content_type") == "table_rows":
|
||||
score += 0.15
|
||||
|
||||
scored.append((score, r))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in scored[:top_k]]
|
||||
|
||||
async def _get_sibling_chunks(
|
||||
async def _get_related_chunks(
|
||||
self,
|
||||
chunk_id: str,
|
||||
chunk_index: int,
|
||||
document_id: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""获取前一个和后一个 chunk(完整上下文)"""
|
||||
prev_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index - 1,
|
||||
)
|
||||
"""获取结构相关的上下文 chunk"""
|
||||
current_result = await self.db.execute(
|
||||
select(DocumentChunk).where(DocumentChunk.id == chunk_id)
|
||||
)
|
||||
next_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index + 1,
|
||||
)
|
||||
)
|
||||
prev_chunk = prev_result.scalar_one_or_none()
|
||||
next_chunk = next_result.scalar_one_or_none()
|
||||
return (
|
||||
prev_chunk.content if prev_chunk else None,
|
||||
next_chunk.content if next_chunk else None,
|
||||
current_chunk = current_result.scalar_one_or_none()
|
||||
if not current_chunk:
|
||||
return None, None
|
||||
|
||||
current_metadata = self._parse_metadata(current_chunk.metadata_)
|
||||
section_path = current_metadata.get("section_path") or []
|
||||
sheet_name = current_metadata.get("sheet_name")
|
||||
|
||||
chunk_result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
chunks = list(chunk_result.scalars().all())
|
||||
|
||||
prev_chunk = None
|
||||
next_chunk = None
|
||||
for chunk in chunks:
|
||||
if chunk.id == chunk_id:
|
||||
continue
|
||||
metadata = self._parse_metadata(chunk.metadata_)
|
||||
same_sheet = bool(sheet_name) and metadata.get("sheet_name") == sheet_name
|
||||
same_section = bool(section_path) and metadata.get("section_path") == section_path
|
||||
if chunk.chunk_index < chunk_index and (same_sheet or same_section):
|
||||
prev_chunk = chunk.content
|
||||
if chunk.chunk_index > chunk_index and (same_sheet or same_section):
|
||||
next_chunk = chunk.content
|
||||
break
|
||||
return prev_chunk, next_chunk
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
@@ -244,6 +294,16 @@ class KnowledgeService:
|
||||
|
||||
return "/" + "/".join(path_parts)
|
||||
|
||||
def _parse_metadata(self, raw_metadata: str | dict | None) -> dict:
|
||||
if isinstance(raw_metadata, dict):
|
||||
return raw_metadata
|
||||
if not raw_metadata:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(raw_metadata)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -306,3 +366,43 @@ class KnowledgeService:
|
||||
collection.delete(where={"document_id": document_id})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def reindex_document(self, document_id: str, user_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if not document:
|
||||
return False
|
||||
|
||||
await self.delete_from_vectorstore(user_id, document_id)
|
||||
document = await DocumentService(self.db, user_id=user_id).rebuild_document(document)
|
||||
await self.index_document(document.id, user_id)
|
||||
return True
|
||||
|
||||
async def reindex_document_chunks(self, document_id: str, user_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if not document:
|
||||
return False
|
||||
|
||||
chunks_result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
chunks = list(chunks_result.scalars().all())
|
||||
if not chunks:
|
||||
return False
|
||||
|
||||
await self.delete_from_vectorstore(user_id, document_id)
|
||||
await self._index_chunks(document, chunks, user_id)
|
||||
return True
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
|
||||
@@ -235,7 +236,7 @@ async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.utcnow()
|
||||
mem.last_recalled_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -271,6 +272,14 @@ async def build_memory_context(
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
# 3. 知识大脑(长期项目记忆)
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
for memory in brain_memories:
|
||||
lines.append(f"- {memory.title}: {memory.content}")
|
||||
parts.append("【知识大脑】\n" + "\n".join(lines))
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
@@ -32,9 +32,9 @@ async def daily_task_analysis():
|
||||
logger.info("[Scheduler] 开始执行每日任务分析...")
|
||||
|
||||
async with async_session() as db:
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
yesterday = datetime.utcnow().date() - timedelta(days=1)
|
||||
yesterday = datetime.now(UTC).date() - timedelta(days=1)
|
||||
|
||||
# 统计昨日任务完成情况
|
||||
result = await db.execute(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
import psutil
|
||||
except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fallback
|
||||
psutil = None
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.conversation import Conversation, Message
|
||||
@@ -16,6 +20,19 @@ class StatsService:
|
||||
|
||||
def get_system_health(self) -> dict:
|
||||
"""获取系统健康指标"""
|
||||
if psutil is None:
|
||||
return {
|
||||
"uptime_seconds": 0,
|
||||
"cpu_percent": 0.0,
|
||||
"memory_used_mb": 0.0,
|
||||
"memory_total_mb": 0.0,
|
||||
"memory_percent": 0.0,
|
||||
"disk_used_gb": 0.0,
|
||||
"disk_total_gb": 0.0,
|
||||
"disk_percent": 0.0,
|
||||
"active_users_24h": 0,
|
||||
}
|
||||
|
||||
uptime_seconds = int(time.time() - psutil.boot_time())
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
mem = psutil.virtual_memory()
|
||||
@@ -35,7 +52,7 @@ class StatsService:
|
||||
|
||||
def _get_daily_stats(self, model, date_column, user_id=None, days=30) -> list:
|
||||
"""通用每日统计查询"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=days)
|
||||
query = self.db.query(
|
||||
func.date(date_column).label('date'),
|
||||
func.count().label('count')
|
||||
@@ -50,7 +67,7 @@ class StatsService:
|
||||
|
||||
def get_conversation_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取对话统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=days)
|
||||
|
||||
daily_conversations = self._get_daily_stats(
|
||||
Conversation, Conversation.created_at, user_id, days
|
||||
@@ -100,7 +117,7 @@ class StatsService:
|
||||
|
||||
def get_knowledge_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取知识库统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=days)
|
||||
|
||||
# New tags
|
||||
tag_query = self.db.query(
|
||||
@@ -145,7 +162,7 @@ class StatsService:
|
||||
func.date(Task.completed_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
Task.completed_at >= datetime.utcnow() - timedelta(days=days),
|
||||
Task.completed_at >= datetime.now(UTC) - timedelta(days=days),
|
||||
Task.status == TaskStatus.DONE
|
||||
)
|
||||
if user_id:
|
||||
@@ -195,7 +212,7 @@ class StatsService:
|
||||
func.date(ForumPost.updated_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
ForumPost.updated_at >= datetime.utcnow() - timedelta(days=days),
|
||||
ForumPost.updated_at >= datetime.now(UTC) - timedelta(days=days),
|
||||
ForumPost.is_executed == True
|
||||
)
|
||||
if user_id:
|
||||
@@ -243,7 +260,7 @@ class StatsService:
|
||||
top_tags = [{"tag_path": r.tag_path, "usage_count": r.usage_count} for r in tag_query.all()]
|
||||
|
||||
# Token trend
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(UTC)
|
||||
this_month_start = datetime(now.year, now.month, 1)
|
||||
last_month_end = this_month_start - timedelta(days=1)
|
||||
last_month_start = datetime(last_month_end.year, last_month_end.month, 1)
|
||||
|
||||
@@ -193,9 +193,9 @@ class TagService:
|
||||
"""
|
||||
增量打标签 - 只对最近新增/更新的内容节点打标签
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=days)
|
||||
|
||||
content_nodes = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
|
||||
Reference in New Issue
Block a user