Add FastAPI backend with agent system
This commit is contained in:
2
backend/app/services/__init__.py
Normal file
2
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Services - import specific classes directly when needed
|
||||
# e.g.: from app.services.agent_service import AgentService
|
||||
261
backend/app/services/agent_service.py
Normal file
261
backend/app/services/agent_service.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Jarvis Agent 服务层
|
||||
负责 LangGraph Agent 的调用、流式输出、对话历史管理
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""对话 Agent 服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> tuple[str, str, AsyncGenerator[str, None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_stream)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model="jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
async def run_agent():
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
}
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
kind = event.get("event")
|
||||
if kind == "on_chat_model_end":
|
||||
content = event.get("data", {}).get("output", {})
|
||||
if isinstance(content, dict):
|
||||
content = content.get("content", "")
|
||||
if content:
|
||||
delta = content[len(collected):]
|
||||
if delta:
|
||||
collected += delta
|
||||
yield delta
|
||||
elif kind == "on_tool_end":
|
||||
name = event.get("name", "")
|
||||
yield f"\n[工具执行: {name}]\n"
|
||||
except Exception as e:
|
||||
yield f"\n执行出错: {str(e)}"
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要和记忆提取(不阻塞响应)
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最终更新数据库中的消息内容
|
||||
if collected:
|
||||
try:
|
||||
result2 = await self.db.execute(
|
||||
select(Message).where(Message.id == assistant_msg.id)
|
||||
)
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return conversation_id, assistant_msg.id, run_agent()
|
||||
|
||||
async def chat_simple(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
# 将文件上下文添加到消息
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
}
|
||||
|
||||
try:
|
||||
result_state = await graph.ainvoke(langgraph_state)
|
||||
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
|
||||
except Exception as e:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存助手消息
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
model="jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content
|
||||
29
backend/app/services/auth_service.py
Normal file
29
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from jose import jwt, JWTError
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict | None:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
256
backend/app/services/document_service.py
Normal file
256
backend/app/services/document_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
文档服务 - 上传、解析、分块、存储
|
||||
支持多种文档格式 + LlamaIndex 智能分块
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from fastapi import UploadFile
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
import os
|
||||
import aiofiles
|
||||
import uuid
|
||||
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
|
||||
|
||||
|
||||
class DocumentService:
|
||||
def __init__(self, db: AsyncSession, user_id: str = None):
|
||||
self.db = db
|
||||
self.user_id = user_id
|
||||
|
||||
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError(f"不支持的文件类型: {ext}")
|
||||
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
file_id = str(uuid.uuid4())
|
||||
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
if file_size > settings.MAX_UPLOAD_SIZE:
|
||||
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
text_content = await self._extract_text(file_path, ext)
|
||||
|
||||
doc = Document(
|
||||
user_id=user_id,
|
||||
title=file.filename.rsplit('.', 1)[0],
|
||||
filename=file.filename,
|
||||
file_type=ext[1:],
|
||||
file_size=file_size,
|
||||
file_path=file_path,
|
||||
summary=text_content[:500] if len(text_content) > 500 else text_content,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
self.db.add(doc)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(doc)
|
||||
|
||||
chunks = self._chunk_text(text_content)
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
chunk = DocumentChunk(
|
||||
document_id=doc.id,
|
||||
chunk_index=i,
|
||||
content=chunk_text,
|
||||
)
|
||||
self.db.add(chunk)
|
||||
doc.chunk_count = len(chunks)
|
||||
await self.db.commit()
|
||||
|
||||
return doc
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
folders = await self.db.execute(
|
||||
select(Folder).where(Folder.user_id == self.user_id)
|
||||
)
|
||||
folder_map = {f.id: f for f in folders.scalars().all()}
|
||||
|
||||
path_parts = []
|
||||
current_id = folder_id
|
||||
while current_id:
|
||||
folder = folder_map.get(current_id)
|
||||
if not folder:
|
||||
break
|
||||
path_parts.insert(0, folder.name)
|
||||
current_id = folder.parent_id
|
||||
|
||||
return "/" + "/".join(path_parts) if path_parts else None
|
||||
|
||||
async def delete_document(self, user_id: str, document_id: str):
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
raise ValueError("文档不存在")
|
||||
|
||||
if os.path.exists(doc.file_path):
|
||||
os.remove(doc.file_path)
|
||||
|
||||
await self.db.delete(doc)
|
||||
await self.db.commit()
|
||||
|
||||
async def _extract_text(self, file_path: str, ext: str) -> str:
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import pymupdf
|
||||
doc = pymupdf.open(file_path)
|
||||
text = "".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
return text
|
||||
except ImportError:
|
||||
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
|
||||
|
||||
elif ext in (".md", ".txt"):
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
return await f.read()
|
||||
|
||||
elif ext in (".docx", ".doc"):
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
doc = DocxDocument(file_path)
|
||||
return "\n".join([p.text for p in doc.paragraphs])
|
||||
except ImportError:
|
||||
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
||||
|
||||
return "[暂不支持此格式]"
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""
|
||||
智能文档分块策略
|
||||
1. 先按 Markdown 标题层级(H1/H2/H3)切分
|
||||
2. 每个大段落内部按固定长度切分
|
||||
3. 保留上下文(prev_summary / next_summary)
|
||||
"""
|
||||
import re
|
||||
|
||||
chunks = []
|
||||
|
||||
# 策略1: Markdown 标题切分(优先)
|
||||
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
|
||||
headers = list(header_pattern.finditer(text))
|
||||
|
||||
if headers:
|
||||
# 按标题段落切分
|
||||
for i, match in enumerate(headers):
|
||||
start = match.start()
|
||||
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
|
||||
section = text[start:end].strip()
|
||||
if len(section) > settings.CHUNK_SIZE:
|
||||
# 大段落内部再切分
|
||||
sub_chunks = self._split_large_chunk(section, match.group(2))
|
||||
chunks.extend(sub_chunks)
|
||||
elif section:
|
||||
chunks.append(section)
|
||||
else:
|
||||
# 策略2: 按段落切分
|
||||
chunks = self._chunk_by_paragraphs(text)
|
||||
|
||||
# 过滤空 chunk
|
||||
chunks = [c.strip() for c in chunks if c.strip()]
|
||||
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
|
||||
|
||||
def _chunk_by_paragraphs(self, text: str) -> list[str]:
|
||||
"""按段落分块,带上下文"""
|
||||
paragraphs = text.split("\n\n")
|
||||
chunks = []
|
||||
current = ""
|
||||
prev_summary = ""
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
if len(current) + len(para) < settings.CHUNK_SIZE:
|
||||
current += "\n\n" + para
|
||||
else:
|
||||
if current:
|
||||
# 添加上下文摘要
|
||||
enriched = current.strip()
|
||||
chunks.append(enriched)
|
||||
current = para
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_large_chunk(self, text: str, title: str) -> list[str]:
|
||||
"""将大段落拆分为固定大小的子块"""
|
||||
chunks = []
|
||||
sentences = text.split("。")
|
||||
current = title + "\n\n"
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
full_sentence = sentence if sentence.endswith("。") else sentence + "。"
|
||||
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
|
||||
current += full_sentence + " "
|
||||
else:
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
current = title + "\n\n" + full_sentence + " "
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
||||
result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
||||
"""获取文档的文本内容"""
|
||||
import os
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
file_path = doc.file_path
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
# 根据文件类型读取内容
|
||||
ext = doc.filename.split('.')[-1].lower()
|
||||
|
||||
try:
|
||||
if ext == 'txt':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif ext == 'md':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif ext == 'pdf':
|
||||
# 简单文本提取(生产环境应使用专业库)
|
||||
return f"[PDF文档] {doc.filename}"
|
||||
else:
|
||||
return f"[文档] {doc.filename}"
|
||||
except Exception:
|
||||
return f"[文档] {doc.filename}"
|
||||
342
backend/app/services/graph_service.py
Normal file
342
backend/app/services/graph_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
知识图谱服务 - 实体识别、关系抽取、图谱查询
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
|
||||
|
||||
实体类型:
|
||||
- person(人物):人名、角色
|
||||
- concept(概念):抽象概念、理论、方法
|
||||
- topic(主题):话题、领域
|
||||
- task(任务):要做的事情
|
||||
- event(事件):发生的事件
|
||||
- document(文档):文件、资料
|
||||
|
||||
关系类型:
|
||||
- related_to(相关于)
|
||||
- part_of(隶属于)
|
||||
- caused_by(由...导致)
|
||||
- depends_on(取决于)
|
||||
- contains(包含)
|
||||
- located_in(位于)
|
||||
- works_on(从事)
|
||||
|
||||
要求:
|
||||
1. 识别文本中所有有意义的实体(不超过10个)
|
||||
2. 识别实体之间的关系(每个实体至少一条关系)
|
||||
3. 每个实体要有 name、type、description(1-2句话)
|
||||
4. 关系要有 source、target、relation_type
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请只返回 JSON,不要有其他内容:
|
||||
{{
|
||||
"entities": [
|
||||
{{"name": "实体名", "type": "类型", "description": "描述"}}
|
||||
],
|
||||
"relations": [
|
||||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
|
||||
|
||||
用户问题:{question}
|
||||
|
||||
已知实体:
|
||||
{entities}
|
||||
|
||||
请推断这些实体之间的隐含关系,返回 JSON:
|
||||
{{
|
||||
"inferred_relations": [
|
||||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
|
||||
]
|
||||
}}
|
||||
|
||||
关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in
|
||||
confidence: 0.0-1.0,表示推断置信度
|
||||
"""
|
||||
|
||||
|
||||
class GraphService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.llm = get_llm()
|
||||
|
||||
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
||||
"""
|
||||
从文档构建/更新知识图谱
|
||||
- 遍历所有 chunk
|
||||
- LLM 实体识别
|
||||
- LLM 关系抽取
|
||||
- 去重合并
|
||||
"""
|
||||
query = (
|
||||
select(DocumentChunk)
|
||||
.join(Document)
|
||||
.where(Document.user_id == user_id)
|
||||
.where(Document.is_indexed == True)
|
||||
)
|
||||
if document_ids:
|
||||
query = query.where(DocumentChunk.document_id.in_(document_ids))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
chunks = list(result.scalars().all())
|
||||
|
||||
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
|
||||
|
||||
for chunk in chunks:
|
||||
try:
|
||||
await self._process_chunk(chunk, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"[GraphService] 图谱构建完成")
|
||||
|
||||
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
|
||||
"""处理单个 chunk,提取实体和关系"""
|
||||
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
|
||||
response = await self.llm.invoke([HumanMessage(content=prompt)])
|
||||
|
||||
try:
|
||||
data = json.loads(response.content)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
|
||||
if not entities:
|
||||
return
|
||||
|
||||
# 先查找已存在的节点
|
||||
existing_nodes = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
result = await self.db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.where(KGNode.name == name)
|
||||
)
|
||||
node = result.scalar_one_or_none()
|
||||
if node:
|
||||
existing_nodes[name] = node
|
||||
|
||||
# 插入新节点
|
||||
entity_map = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
if name in existing_nodes:
|
||||
entity_map[name] = existing_nodes[name].id
|
||||
else:
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
entity_type=entity_data["type"],
|
||||
description=entity_data.get("description", ""),
|
||||
source_document_id=chunk.document_id,
|
||||
)
|
||||
self.db.add(node)
|
||||
await self.db.flush()
|
||||
entity_map[name] = node.id
|
||||
|
||||
# 插入关系(去重)
|
||||
for rel in relations:
|
||||
src, tgt = rel["source"], rel["target"]
|
||||
if src not in entity_map or tgt not in entity_map:
|
||||
continue
|
||||
|
||||
# 检查关系是否已存在
|
||||
result = await self.db.execute(
|
||||
select(KGEdge).where(
|
||||
KGEdge.source_id == entity_map[src],
|
||||
KGEdge.target_id == entity_map[tgt],
|
||||
KGEdge.relation_type == rel["relation_type"],
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if not existing:
|
||||
edge = KGEdge(
|
||||
source_id=entity_map[src],
|
||||
target_id=entity_map[tgt],
|
||||
relation_type=rel["relation_type"],
|
||||
)
|
||||
self.db.add(edge)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
async def get_graph_summary(self, user_id: str) -> str:
|
||||
"""获取用户图谱的整体摘要"""
|
||||
# 统计
|
||||
node_count = await self.db.execute(
|
||||
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
|
||||
)
|
||||
edge_count = await self.db.execute(
|
||||
select(func.count()).select_from(KGEdge)
|
||||
.select_from(KGEdge)
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGNode.user_id == user_id)
|
||||
)
|
||||
|
||||
node_total = node_count.scalar() or 0
|
||||
edge_total = edge_count.scalar() or 0
|
||||
|
||||
if node_total == 0:
|
||||
return "知识图谱为空,请先上传文档并构建图谱。"
|
||||
|
||||
# 按类型统计节点
|
||||
type_result = await self.db.execute(
|
||||
select(KGNode.entity_type, func.count())
|
||||
.where(KGNode.user_id == user_id)
|
||||
.group_by(KGNode.entity_type)
|
||||
)
|
||||
type_stats = type_result.all()
|
||||
|
||||
# 关系类型统计
|
||||
rel_result = await self.db.execute(
|
||||
select(KGEdge.relation_type, func.count())
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.group_by(KGEdge.relation_type)
|
||||
)
|
||||
rel_stats = rel_result.all()
|
||||
|
||||
# 列出最重要的节点(按 importance)
|
||||
top_nodes_result = await self.db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.order_by(KGNode.importance.desc())
|
||||
.limit(10)
|
||||
)
|
||||
top_nodes = list(top_nodes_result.scalars().all())
|
||||
|
||||
lines = [
|
||||
f"## 知识图谱摘要",
|
||||
f"",
|
||||
f"**总节点数**: {node_total}",
|
||||
f"**总关系数**: {edge_total}",
|
||||
f"",
|
||||
f"### 节点类型分布",
|
||||
]
|
||||
for etype, count in type_stats:
|
||||
lines.append(f"- {etype}: {count} 个")
|
||||
|
||||
lines.append(f"\n### 关系类型分布")
|
||||
for rtype, count in rel_stats:
|
||||
lines.append(f"- {rtype}: {count} 条")
|
||||
|
||||
lines.append(f"\n### 核心实体 (Top 10)")
|
||||
for node in top_nodes:
|
||||
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_entity_context(self, entity: str, user_id: str) -> str:
|
||||
"""获取某个实体的详细上下文"""
|
||||
# 查找节点
|
||||
result = await self.db.execute(
|
||||
select(KGNode).where(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.name.contains(entity),
|
||||
).limit(5)
|
||||
)
|
||||
nodes = list(result.scalars().all())
|
||||
|
||||
if not nodes:
|
||||
return f"未找到实体: {entity}"
|
||||
|
||||
lines = []
|
||||
for node in nodes:
|
||||
lines.append(f"### {node.name} [{node.entity_type}]")
|
||||
lines.append(f"描述: {node.description or '无描述'}")
|
||||
|
||||
# 获取该节点的关系
|
||||
edges_result = await self.db.execute(
|
||||
select(KGEdge, KGNode)
|
||||
.join(KGNode, KGNode.id == KGEdge.target_id)
|
||||
.where(KGEdge.source_id == node.id)
|
||||
.limit(10)
|
||||
)
|
||||
out_edges = list(edges_result.all())
|
||||
|
||||
in_edges_result = await self.db.execute(
|
||||
select(KGEdge, KGNode)
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGEdge.target_id == node.id)
|
||||
.limit(10)
|
||||
)
|
||||
in_edges = list(in_edges_result.all())
|
||||
|
||||
if out_edges:
|
||||
lines.append("**关联到**:")
|
||||
for edge, target in out_edges:
|
||||
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
|
||||
|
||||
if in_edges:
|
||||
lines.append("**被关联于**:")
|
||||
for edge, source in in_edges:
|
||||
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
|
||||
"""获取节点的邻居节点(用于图谱可视化)"""
|
||||
visited = set()
|
||||
current_level = {node_id}
|
||||
all_nodes = []
|
||||
all_edges = []
|
||||
|
||||
for _ in range(depth):
|
||||
if not current_level:
|
||||
break
|
||||
next_level = set()
|
||||
|
||||
for nid in current_level:
|
||||
if nid in visited:
|
||||
continue
|
||||
visited.add(nid)
|
||||
|
||||
# 获取节点
|
||||
node_result = await self.db.execute(
|
||||
select(KGNode).where(KGNode.id == nid)
|
||||
)
|
||||
node = node_result.scalar_one_or_none()
|
||||
if node:
|
||||
all_nodes.append(node)
|
||||
|
||||
# 获取出边
|
||||
out_result = await self.db.execute(
|
||||
select(KGEdge).where(KGEdge.source_id == nid)
|
||||
)
|
||||
for edge in out_result.scalars().all():
|
||||
all_edges.append(edge)
|
||||
next_level.add(edge.target_id)
|
||||
|
||||
# 获取入边
|
||||
in_result = await self.db.execute(
|
||||
select(KGEdge).where(KGEdge.target_id == nid)
|
||||
)
|
||||
for edge in in_result.scalars().all():
|
||||
all_edges.append(edge)
|
||||
next_level.add(edge.source_id)
|
||||
|
||||
current_level = next_level
|
||||
|
||||
return {"nodes": all_nodes, "edges": all_edges}
|
||||
308
backend/app/services/knowledge_service.py
Normal file
308
backend/app/services/knowledge_service.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
|
||||
|
||||
检索策略:
|
||||
1. 语义检索 (dense) - ChromaDB 向量相似度
|
||||
2. 关键词检索 (sparse) - SQL LIKE
|
||||
3. 混合检索 - 语义 + 关键词 加权融合
|
||||
4. Rerank - 二次排序优化结果
|
||||
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
document_title: str
|
||||
content: str
|
||||
score: float
|
||||
metadata_: str | None = None
|
||||
prev_chunk: str | None = None
|
||||
next_chunk: str | None = None
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""向量知识库检索服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
||||
self.db = db
|
||||
self.user_id = user_id
|
||||
self._chroma_client = None
|
||||
|
||||
@property
|
||||
def chroma_client(self):
|
||||
if self._chroma_client is None:
|
||||
self._chroma_client = chromadb.PersistentClient(
|
||||
path=settings.CHROMA_PERSIST_DIR,
|
||||
settings=ChromaSettings(allow_reset=True),
|
||||
)
|
||||
return self._chroma_client
|
||||
|
||||
def get_collection(self, user_id: str):
|
||||
return self.chroma_client.get_or_create_collection(
|
||||
name=f"user_{user_id}",
|
||||
metadata={"user_id": user_id},
|
||||
)
|
||||
|
||||
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
|
||||
"""将文档 chunks 向量化存入 ChromaDB"""
|
||||
result = await self.db.execute(
|
||||
select(Document).where(Document.id == document_id)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
return
|
||||
|
||||
chunks_result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
chunks = list(chunks_result.scalars().all())
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
collection = self.get_collection(user_id)
|
||||
|
||||
ids = [chunk.id for chunk in chunks]
|
||||
documents = [chunk.content for chunk in chunks]
|
||||
metadatas = [
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"document_title": doc.title,
|
||||
"chunk_index": chunk.chunk_index,
|
||||
"file_type": doc.file_type,
|
||||
"folder_path": folder_path or "",
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
doc.is_indexed = True
|
||||
await self.db.commit()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
top_k: int = 5,
|
||||
use_rerank: bool = True,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
混合检索 + Rerank,支持按文件夹过滤
|
||||
|
||||
流程:
|
||||
1. ChromaDB 向量检索 (扩大候选集)
|
||||
2. 提取父 chunk(完整上下文)
|
||||
3. Rerank 二次排序
|
||||
4. 返回 top_k 结果
|
||||
"""
|
||||
collection = self.get_collection(user_id)
|
||||
|
||||
# 构建过滤条件
|
||||
where = None
|
||||
if folder_id:
|
||||
folder_path = await self._get_folder_path(folder_id)
|
||||
if folder_path:
|
||||
where = {"folder_path": {"$starts_with": folder_path}}
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=top_k * 3,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not results or not results.get("ids"):
|
||||
return []
|
||||
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
metadatas = results.get("metadatas", [[]])[0]
|
||||
distances = results.get("distances", [[]])[0]
|
||||
|
||||
search_results: list[SearchResult] = []
|
||||
for i, chunk_id in enumerate(ids):
|
||||
meta = metadatas[i] if i < len(metadatas) else {}
|
||||
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
||||
|
||||
prev_chunk, next_chunk = await self._get_sibling_chunks(
|
||||
chunk_id=chunk_id,
|
||||
chunk_index=meta.get("chunk_index", 0),
|
||||
document_id=meta.get("document_id", ""),
|
||||
)
|
||||
|
||||
search_results.append(SearchResult(
|
||||
chunk_id=chunk_id,
|
||||
document_id=meta.get("document_id", ""),
|
||||
document_title=meta.get("document_title", ""),
|
||||
content=documents[i] if i < len(documents) else "",
|
||||
score=score,
|
||||
metadata_=str(meta),
|
||||
prev_chunk=prev_chunk,
|
||||
next_chunk=next_chunk,
|
||||
))
|
||||
|
||||
if use_rerank:
|
||||
search_results = self._rerank(query, search_results, top_k)
|
||||
else:
|
||||
search_results = search_results[:top_k]
|
||||
|
||||
return search_results
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
query: str,
|
||||
results: list[SearchResult],
|
||||
top_k: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
|
||||
import re
|
||||
|
||||
query_words = set(re.findall(r"\w+", query.lower()))
|
||||
|
||||
scored = []
|
||||
for r in results:
|
||||
score = r.score * 0.7
|
||||
|
||||
content_words = set(re.findall(r"\w+", r.content.lower()))
|
||||
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
|
||||
score += keyword_overlap * 0.2
|
||||
|
||||
if r.document_title:
|
||||
title_words = set(re.findall(r"\w+", r.document_title.lower()))
|
||||
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
||||
score += title_overlap * 0.1
|
||||
|
||||
scored.append((score, r))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in scored[:top_k]]
|
||||
|
||||
async def _get_sibling_chunks(
|
||||
self,
|
||||
chunk_id: str,
|
||||
chunk_index: int,
|
||||
document_id: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""获取前一个和后一个 chunk(完整上下文)"""
|
||||
prev_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index - 1,
|
||||
)
|
||||
)
|
||||
next_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index + 1,
|
||||
)
|
||||
)
|
||||
prev_chunk = prev_result.scalar_one_or_none()
|
||||
next_chunk = next_result.scalar_one_or_none()
|
||||
return (
|
||||
prev_chunk.content if prev_chunk else None,
|
||||
next_chunk.content if next_chunk else None,
|
||||
)
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
result = await self.db.execute(
|
||||
select(Folder).where(Folder.id == folder_id)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
path_parts = [folder.name]
|
||||
current_parent_id = folder.parent_id
|
||||
|
||||
while current_parent_id:
|
||||
parent_result = await self.db.execute(
|
||||
select(Folder).where(Folder.id == current_parent_id)
|
||||
)
|
||||
parent = parent_result.scalar_one_or_none()
|
||||
if not parent:
|
||||
break
|
||||
path_parts.insert(0, parent.name)
|
||||
current_parent_id = parent.parent_id
|
||||
|
||||
return "/" + "/".join(path_parts)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[SearchResult]:
|
||||
"""混合检索: 向量 + 关键词 + Rerank"""
|
||||
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
|
||||
keyword_results = await self._keyword_search(query, user_id, top_k)
|
||||
|
||||
seen = set()
|
||||
merged: list[SearchResult] = []
|
||||
for r in vector_results + keyword_results:
|
||||
if r.chunk_id not in seen:
|
||||
seen.add(r.chunk_id)
|
||||
merged.append(r)
|
||||
|
||||
return self._rerank(query, merged, top_k)
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
top_k: int,
|
||||
) -> list[SearchResult]:
|
||||
"""SQL 关键词搜索"""
|
||||
result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.join(Document)
|
||||
.where(Document.user_id == user_id)
|
||||
.where(
|
||||
or_(
|
||||
DocumentChunk.content.contains(query),
|
||||
Document.title.contains(query),
|
||||
)
|
||||
)
|
||||
.limit(top_k)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
doc_result = await self.db.execute(
|
||||
select(Document).where(Document.id == chunk.document_id)
|
||||
)
|
||||
doc = doc_result.scalar_one_or_none()
|
||||
results.append(SearchResult(
|
||||
chunk_id=chunk.id,
|
||||
document_id=chunk.document_id,
|
||||
document_title=doc.title if doc else "",
|
||||
content=chunk.content,
|
||||
score=0.5,
|
||||
metadata_=None,
|
||||
))
|
||||
return results
|
||||
|
||||
async def delete_from_vectorstore(self, user_id: str, document_id: str):
|
||||
"""从向量库删除文档"""
|
||||
collection = self.get_collection(user_id)
|
||||
try:
|
||||
collection.delete(where={"document_id": document_id})
|
||||
except Exception:
|
||||
pass
|
||||
145
backend/app/services/llm_service.py
Normal file
145
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
LLM 服务 - 支持多种 LLM 提供商
|
||||
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
from app.config import settings
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@abstractmethod
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAICompatibleService(LLMService):
|
||||
"""
|
||||
OpenAI 兼容接口
|
||||
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY
|
||||
self.model = model or settings.OPENAI_MODEL
|
||||
self.base_url = base_url or settings.OPENAI_BASE_URL
|
||||
self._llm = ChatOpenAI(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
class ClaudeService(LLMService):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 8192,
|
||||
):
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
||||
self.model = model or settings.CLAUDE_MODEL
|
||||
self._llm = ChatAnthropic(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
class OllamaService(LLMService):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
self.base_url = base_url or settings.OLLAMA_BASE_URL
|
||||
self.model = model or settings.OLLAMA_MODEL
|
||||
self._llm = ChatOllama(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
# 单例缓存
|
||||
_llm_instance: LLMService | None = None
|
||||
|
||||
|
||||
def get_llm() -> LLMService:
|
||||
"""根据配置获取 LLM 实例"""
|
||||
global _llm_instance
|
||||
if _llm_instance is None:
|
||||
provider = settings.LLM_PROVIDER
|
||||
if provider == "openai":
|
||||
_llm_instance = OpenAICompatibleService()
|
||||
elif provider == "deepseek":
|
||||
_llm_instance = OpenAICompatibleService(
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
model="deepseek-chat",
|
||||
)
|
||||
elif provider == "custom":
|
||||
_llm_instance = OpenAICompatibleService()
|
||||
elif provider == "claude":
|
||||
_llm_instance = ClaudeService()
|
||||
elif provider == "ollama":
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
return _llm_instance
|
||||
304
backend/app/services/memory_service.py
Normal file
304
backend/app/services/memory_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
limit: int = 20,
|
||||
) -> list[Message]:
|
||||
"""加载指定对话的历史消息"""
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(desc(MemorySummary.turn_count))
|
||||
.limit(1)
|
||||
)
|
||||
latest_summary = result.scalar_one_or_none()
|
||||
if latest_summary:
|
||||
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
||||
return turn_count >= SUMMARIZE_THRESHOLD
|
||||
|
||||
|
||||
async def generate_summary(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
async def save_summary(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
summary_text=summary_text,
|
||||
turn_count=turn_count,
|
||||
)
|
||||
db.add(summary)
|
||||
await db.commit()
|
||||
await db.refresh(summary)
|
||||
return summary
|
||||
|
||||
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(MemorySummary.summary_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def try_auto_summarize(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
|
||||
messages = await load_conversation_history(db, conversation_id, limit=30)
|
||||
if len(messages) < 3:
|
||||
return False
|
||||
|
||||
try:
|
||||
summary_text = await generate_summary(db, conversation_id, messages)
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
291
backend/app/services/scheduler_service.py
Normal file
291
backend/app/services/scheduler_service.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
定时任务服务 - APScheduler 调度器
|
||||
"""
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy import select, and_
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.models.forum import ForumPost
|
||||
from app.models.knowledge_graph import KGNode
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.graph_service import GraphService
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
|
||||
|
||||
# ===================== 定时任务函数 =====================
|
||||
|
||||
async def daily_task_analysis():
|
||||
"""
|
||||
每日凌晨任务分析
|
||||
- 分析前一天完成的任务
|
||||
- 生成每日报告
|
||||
- 创建次日计划建议
|
||||
"""
|
||||
logger.info("[Scheduler] 开始执行每日任务分析...")
|
||||
|
||||
async with async_session() as db:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
yesterday = datetime.utcnow().date() - timedelta(days=1)
|
||||
|
||||
# 统计昨日任务完成情况
|
||||
result = await db.execute(
|
||||
select(Task).where(Task.updated_at >= yesterday)
|
||||
)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
completed = [t for t in tasks if t.status == "done"]
|
||||
pending = [t for t in tasks if t.status != "done"]
|
||||
|
||||
report = f"""## 每日任务报告 - {yesterday.strftime('%Y-%m-%d')}
|
||||
|
||||
### 完成情况
|
||||
- 总任务数: {len(tasks)}
|
||||
- 已完成: {len(completed)}
|
||||
- 未完成: {len(pending)}
|
||||
|
||||
### 已完成任务
|
||||
{chr(10).join([f"- {t.title}" for t in completed]) or "无"}
|
||||
|
||||
### 未完成任务
|
||||
{chr(10).join([f"- {t.title} (优先级: {t.priority})" for t in pending]) or "无"}
|
||||
|
||||
### 建议
|
||||
根据未完成任务,建议明天优先处理:
|
||||
{chr(10).join([f"{i+1}. {t.title}" for i, t in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5])]) or "无待处理任务"}
|
||||
"""
|
||||
|
||||
# 发布到论坛
|
||||
from app.models.forum import ForumPost
|
||||
post = ForumPost(
|
||||
title=f"每日报告 - {yesterday.strftime('%Y-%m-%d')}",
|
||||
content=report,
|
||||
category="discussion",
|
||||
)
|
||||
db.add(post)
|
||||
|
||||
# 创建明日计划建议任务
|
||||
for i, task in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5]):
|
||||
suggestion = Task(
|
||||
title=f"继续: {task.title}",
|
||||
description=f"昨日未完成任务,优先级: {task.priority}",
|
||||
priority=task.priority,
|
||||
status="todo",
|
||||
)
|
||||
db.add(suggestion)
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"[Scheduler] 每日任务分析完成,完成 {len(completed)} 个任务")
|
||||
|
||||
|
||||
async def forum_scan_task():
|
||||
"""
|
||||
论坛扫描任务
|
||||
- 扫描所有指令类帖子
|
||||
- 识别可执行指令
|
||||
- AI自动执行
|
||||
"""
|
||||
logger.info("[Scheduler] 开始扫描论坛指令...")
|
||||
|
||||
async with async_session() as db:
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(ForumPost).where(
|
||||
ForumPost.category == "instruction",
|
||||
ForumPost.is_executed == False,
|
||||
).limit(5)
|
||||
)
|
||||
posts = result.scalars().all()
|
||||
|
||||
if not posts:
|
||||
logger.info("[Scheduler] 暂无待执行指令")
|
||||
return
|
||||
|
||||
agent_svc = AgentService(db)
|
||||
executed_count = 0
|
||||
|
||||
for post in posts:
|
||||
try:
|
||||
# 让 Agent 分析并执行指令
|
||||
conv_id, msg_id, response = await agent_svc.chat_simple(
|
||||
user_id=post.user_id,
|
||||
message=f"请执行以下论坛指令: {post.title}。{post.content}",
|
||||
conversation_id=None,
|
||||
)
|
||||
post.is_executed = True
|
||||
post.executed_response = response
|
||||
executed_count += 1
|
||||
logger.info(f"[Scheduler] 执行指令: {post.title}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 执行指令失败 {post.title}: {e}")
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"[Scheduler] 论坛扫描完成,执行了 {executed_count} 个指令")
|
||||
|
||||
|
||||
async def graph_rebuild_task():
|
||||
"""
|
||||
知识图谱增量重建任务
|
||||
- 扫描新增/更新的文档
|
||||
- 更新图谱节点和边
|
||||
"""
|
||||
logger.info("[Scheduler] 开始重建知识图谱...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
graph_svc = GraphService(db)
|
||||
# 只处理最近7天有活动的文档
|
||||
await graph_svc.build_graph(user_id="default", document_ids=None)
|
||||
logger.info("[Scheduler] 知识图谱重建完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 知识图谱重建失败: {e}")
|
||||
|
||||
|
||||
async def tag_generation_task():
|
||||
"""
|
||||
每日凌晨 00:00 增量标签生成任务
|
||||
"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info("[Scheduler] 开始执行每日标签生成...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
result = await db.execute(
|
||||
select(KGNode.user_id).distinct().where(
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"])
|
||||
)
|
||||
)
|
||||
user_ids = result.scalars().all()
|
||||
|
||||
total_tagged = 0
|
||||
for user_id in user_ids:
|
||||
sync_tag_service = TagService(db, llm_client)
|
||||
result = sync_tag_service.tag_incremental_content(user_id, days=1)
|
||||
total_tagged += result["tagged"]
|
||||
|
||||
logger.info(f"[Scheduler] 每日标签生成完成,共标签化 {total_tagged} 个内容节点")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 每日标签生成失败: {e}")
|
||||
|
||||
|
||||
async def daily_todo_generation():
|
||||
"""
|
||||
每天早上 08:00 为所有活跃用户生成待办
|
||||
- 来自前一天未完成的看板任务
|
||||
- 来自前一天对话记录分析
|
||||
"""
|
||||
from app.models.user import User
|
||||
from app.services.todo_service import generate_daily_todos
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info("[Scheduler] 开始执行每日待办生成...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.is_active == True))
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
try:
|
||||
await generate_daily_todos(user.id, db)
|
||||
logger.info(f"[Scheduler] 为用户 {user.id} 生成今日待办完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 用户 {user.id} 定时生成待办失败: {e}")
|
||||
|
||||
logger.info(f"[Scheduler] 每日待办生成完成,共处理 {len(users)} 个用户")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 每日待办生成失败: {e}")
|
||||
|
||||
|
||||
# ===================== 调度器管理 =====================
|
||||
|
||||
def start_scheduler():
|
||||
"""启动调度器,注册所有定时任务"""
|
||||
if scheduler.running:
|
||||
logger.warning("[Scheduler] 调度器已在运行")
|
||||
return
|
||||
|
||||
# 每日凌晨 00:30 执行任务分析
|
||||
scheduler.add_job(
|
||||
daily_task_analysis,
|
||||
CronTrigger(hour=0, minute=30, timezone="Asia/Shanghai"),
|
||||
id="daily_task_analysis",
|
||||
name="每日任务分析",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每小时扫描论坛指令
|
||||
scheduler.add_job(
|
||||
forum_scan_task,
|
||||
IntervalTrigger(hours=1),
|
||||
id="forum_scan",
|
||||
name="论坛指令扫描",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天凌晨 3:00 重建图谱
|
||||
scheduler.add_job(
|
||||
graph_rebuild_task,
|
||||
CronTrigger(hour=3, minute=0, timezone="Asia/Shanghai"),
|
||||
id="graph_rebuild",
|
||||
name="知识图谱重建",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天凌晨 00:00 生成标签
|
||||
scheduler.add_job(
|
||||
tag_generation_task,
|
||||
CronTrigger(hour=0, minute=0, timezone="Asia/Shanghai"),
|
||||
id="tag_generation",
|
||||
name="每日标签生成",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天早上 08:00 生成今日待办
|
||||
scheduler.add_job(
|
||||
daily_todo_generation,
|
||||
CronTrigger(hour=8, minute=0, timezone="Asia/Shanghai"),
|
||||
id="daily_todo_generation",
|
||||
name="每日待办生成",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("[Scheduler] 定时任务调度器已启动")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
"""停止调度器"""
|
||||
if scheduler.running:
|
||||
scheduler.shutdown(wait=False)
|
||||
logger.info("[Scheduler] 定时任务调度器已停止")
|
||||
|
||||
|
||||
def get_scheduler_status() -> dict:
|
||||
"""获取调度器状态"""
|
||||
if not scheduler.running:
|
||||
return {"status": "stopped", "jobs": []}
|
||||
|
||||
jobs = []
|
||||
for job in scheduler.get_jobs():
|
||||
jobs.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run": str(job.next_run_time) if job.next_run_time else None,
|
||||
})
|
||||
|
||||
return {"status": "running", "jobs": jobs}
|
||||
140
backend/app/services/settings_service.py
Normal file
140
backend/app/services/settings_service.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import verify_password, get_password_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_user_settings(user_id: str, db: AsyncSession) -> dict:
|
||||
"""获取用户完整设置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
return None
|
||||
return {
|
||||
"profile": user,
|
||||
"llm_config": user.llm_config or {},
|
||||
"scheduler_config": user.scheduler_config or {}
|
||||
}
|
||||
|
||||
|
||||
async def update_user_profile(
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
full_name: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
current_password: Optional[str] = None
|
||||
) -> User:
|
||||
"""更新用户资料"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
if password:
|
||||
if not current_password or not verify_password(current_password, user.hashed_password):
|
||||
raise ValueError("当前密码错误")
|
||||
user.hashed_password = get_password_hash(password)
|
||||
|
||||
if full_name:
|
||||
user.full_name = full_name
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
|
||||
"""更新 LLM 配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
current = user.llm_config or {}
|
||||
# 合并配置 - 直接替换整个类型配置列表
|
||||
for key, value in config.items():
|
||||
if value is not None:
|
||||
if isinstance(value, list):
|
||||
# 列表直接替换
|
||||
current[key] = value
|
||||
elif isinstance(value, dict):
|
||||
# 字典合并
|
||||
if key in current and isinstance(current[key], dict):
|
||||
current[key] = {**current[key], **value}
|
||||
else:
|
||||
current[key] = value
|
||||
else:
|
||||
current[key] = value
|
||||
user.llm_config = current
|
||||
await db.commit()
|
||||
return current
|
||||
|
||||
|
||||
async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession) -> dict:
|
||||
"""更新定时任务配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
current = user.scheduler_config or {}
|
||||
for key, value in config.items():
|
||||
if value is not None:
|
||||
current[key] = value
|
||||
user.scheduler_config = current
|
||||
await db.commit()
|
||||
return current
|
||||
|
||||
|
||||
async def test_llm_connection(
|
||||
provider: str,
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str
|
||||
) -> dict:
|
||||
"""测试 LLM 连接"""
|
||||
try:
|
||||
# 根据不同 provider 创建临时 LLM 实例并测试
|
||||
if provider == "openai":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "claude":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "ollama":
|
||||
from langchain_ollama import ChatOllama
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or "https://api.deepseek.com/v1",
|
||||
timeout=30
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 provider: {provider}"}
|
||||
|
||||
# 简单测试调用
|
||||
from langchain_core.messages import HumanMessage
|
||||
response = await llm.ainvoke([HumanMessage(content="Hi")])
|
||||
return {"success": True, "message": f"连接成功,模型响应: {response.content[:50]}..."}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
278
backend/app/services/stats_service.py
Normal file
278
backend/app/services/stats_service.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.models.document import Document
|
||||
|
||||
|
||||
class StatsService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_system_health(self) -> dict:
|
||||
"""获取系统健康指标"""
|
||||
uptime_seconds = int(time.time() - psutil.boot_time())
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
mem = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_used_mb": round(mem.used / (1024 * 1024), 1),
|
||||
"memory_total_mb": round(mem.total / (1024 * 1024), 1),
|
||||
"memory_percent": mem.percent,
|
||||
"disk_used_gb": round(disk.used / (1024 * 1024 * 1024), 1),
|
||||
"disk_total_gb": round(disk.total / (1024 * 1024 * 1024), 1),
|
||||
"disk_percent": disk.percent,
|
||||
"active_users_24h": 0, # 需要 User 表的 updated_at
|
||||
}
|
||||
|
||||
def _get_daily_stats(self, model, date_column, user_id=None, days=30) -> list:
|
||||
"""通用每日统计查询"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
query = self.db.query(
|
||||
func.date(date_column).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(date_column >= cutoff)
|
||||
|
||||
if user_id and hasattr(model, 'user_id'):
|
||||
query = query.filter(model.user_id == user_id)
|
||||
|
||||
query = query.group_by(func.date(date_column)).order_by(func.date(date_column))
|
||||
results = query.all()
|
||||
return [{"date": str(r.date), "count": r.count} for r in results]
|
||||
|
||||
def get_conversation_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取对话统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
daily_conversations = self._get_daily_stats(
|
||||
Conversation, Conversation.created_at, user_id, days
|
||||
)
|
||||
daily_messages = self._get_daily_stats(
|
||||
Message, Message.created_at, user_id, days
|
||||
)
|
||||
|
||||
# Daily tokens
|
||||
input_query = self.db.query(
|
||||
func.date(Message.created_at).label('date'),
|
||||
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
|
||||
).filter(
|
||||
Message.created_at >= cutoff,
|
||||
Message.role == 'user'
|
||||
)
|
||||
if user_id:
|
||||
input_query = input_query.join(Conversation).filter(Conversation.user_id == user_id)
|
||||
input_results = input_query.group_by(func.date(Message.created_at)).all()
|
||||
|
||||
output_query = self.db.query(
|
||||
func.date(Message.created_at).label('date'),
|
||||
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
|
||||
).filter(
|
||||
Message.created_at >= cutoff,
|
||||
Message.role == 'assistant'
|
||||
)
|
||||
if user_id:
|
||||
output_query = output_query.join(Conversation).filter(Conversation.user_id == user_id)
|
||||
output_results = output_query.group_by(func.date(Message.created_at)).all()
|
||||
|
||||
daily_input_tokens = [{"date": str(r.date), "input_tokens": r.tokens} for r in input_results]
|
||||
daily_output_tokens = [{"date": str(r.date), "output_tokens": r.tokens} for r in output_results]
|
||||
|
||||
return {
|
||||
"daily_conversations": daily_conversations,
|
||||
"daily_messages": daily_messages,
|
||||
"daily_input_tokens": daily_input_tokens,
|
||||
"daily_output_tokens": daily_output_tokens,
|
||||
"totals": {
|
||||
"conversations": sum(c["count"] for c in daily_conversations),
|
||||
"messages": sum(m["count"] for m in daily_messages),
|
||||
"input_tokens": sum(t["input_tokens"] for t in daily_input_tokens),
|
||||
"output_tokens": sum(t["output_tokens"] for t in daily_output_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
def get_knowledge_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取知识库统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# New tags
|
||||
tag_query = self.db.query(
|
||||
func.date(KGNode.created_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
KGNode.created_at >= cutoff,
|
||||
KGNode.entity_type == 'tag'
|
||||
)
|
||||
if user_id:
|
||||
tag_query = tag_query.filter(KGNode.user_id == user_id)
|
||||
tag_results = tag_query.group_by(func.date(KGNode.created_at)).all()
|
||||
daily_new_tags = [{"date": str(r.date), "count": r.count} for r in tag_results]
|
||||
|
||||
daily_documents = self._get_daily_stats(
|
||||
Document, Document.created_at, user_id, days
|
||||
)
|
||||
daily_tag_relations = self._get_daily_stats(
|
||||
KGEdge, KGEdge.created_at, user_id, days
|
||||
)
|
||||
|
||||
return {
|
||||
"daily_new_tags": daily_new_tags,
|
||||
"daily_documents": daily_documents,
|
||||
"daily_knowledge_queries": [],
|
||||
"daily_tag_relations": daily_tag_relations,
|
||||
"totals": {
|
||||
"new_tags": sum(t["count"] for t in daily_new_tags),
|
||||
"documents": sum(d["count"] for d in daily_documents),
|
||||
"tag_relations": sum(r["count"] for r in daily_tag_relations),
|
||||
}
|
||||
}
|
||||
|
||||
def get_kanban_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取看板统计数据"""
|
||||
daily_new_tasks = self._get_daily_stats(
|
||||
Task, Task.created_at, user_id, days
|
||||
)
|
||||
|
||||
# Completed tasks
|
||||
completed_query = self.db.query(
|
||||
func.date(Task.completed_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
Task.completed_at >= datetime.utcnow() - timedelta(days=days),
|
||||
Task.status == TaskStatus.DONE
|
||||
)
|
||||
if user_id:
|
||||
completed_query = completed_query.filter(Task.user_id == user_id)
|
||||
completed_results = completed_query.group_by(func.date(Task.completed_at)).all()
|
||||
daily_completed_tasks = [{"date": str(r.date), "count": r.count} for r in completed_results]
|
||||
|
||||
# Current pending
|
||||
pending_query = self.db.query(func.count(Task.id)).filter(Task.status == TaskStatus.TODO)
|
||||
if user_id:
|
||||
pending_query = pending_query.filter(Task.user_id == user_id)
|
||||
current_pending_tasks = pending_query.scalar() or 0
|
||||
|
||||
# Completion rate
|
||||
daily_new_dict = {d["date"]: d["count"] for d in daily_new_tasks}
|
||||
daily_completed_dict = {d["date"]: d["count"] for d in daily_completed_tasks}
|
||||
all_dates = set(daily_new_dict.keys()) | set(daily_completed_dict.keys())
|
||||
daily_completion_rate = []
|
||||
for date in sorted(all_dates):
|
||||
new = daily_new_dict.get(date, 0)
|
||||
completed = daily_completed_dict.get(date, 0)
|
||||
rate = (completed / new * 100) if new > 0 else 0
|
||||
daily_completion_rate.append({"date": date, "rate": round(rate, 1)})
|
||||
|
||||
return {
|
||||
"daily_new_tasks": daily_new_tasks,
|
||||
"daily_completed_tasks": daily_completed_tasks,
|
||||
"daily_completion_rate": daily_completion_rate,
|
||||
"current_pending_tasks": current_pending_tasks,
|
||||
"totals": {
|
||||
"new_tasks": sum(t["count"] for t in daily_new_tasks),
|
||||
"completed_tasks": sum(c["count"] for c in daily_completed_tasks),
|
||||
}
|
||||
}
|
||||
|
||||
def get_community_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取社区统计数据"""
|
||||
daily_posts = self._get_daily_stats(
|
||||
ForumPost, ForumPost.created_at, user_id, days
|
||||
)
|
||||
daily_replies = self._get_daily_stats(
|
||||
ForumReply, ForumReply.created_at, user_id, days
|
||||
)
|
||||
|
||||
# AI executions
|
||||
ai_query = self.db.query(
|
||||
func.date(ForumPost.updated_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
ForumPost.updated_at >= datetime.utcnow() - timedelta(days=days),
|
||||
ForumPost.is_executed == True
|
||||
)
|
||||
if user_id:
|
||||
ai_query = ai_query.filter(ForumPost.user_id == user_id)
|
||||
ai_results = ai_query.group_by(func.date(ForumPost.updated_at)).all()
|
||||
daily_ai_executions = [{"date": str(r.date), "count": r.count} for r in ai_results]
|
||||
|
||||
return {
|
||||
"daily_posts": daily_posts,
|
||||
"daily_replies": daily_replies,
|
||||
"daily_ai_executions": daily_ai_executions,
|
||||
"daily_agent_calls": [],
|
||||
"totals": {
|
||||
"posts": sum(p["count"] for p in daily_posts),
|
||||
"replies": sum(r["count"] for r in daily_replies),
|
||||
"ai_executions": sum(a["count"] for a in daily_ai_executions),
|
||||
}
|
||||
}
|
||||
|
||||
def get_personal_insights(self, user_id: str) -> dict:
|
||||
"""获取个人洞察"""
|
||||
# Hourly activity
|
||||
hourly_query = self.db.query(
|
||||
func.extract('hour', Conversation.created_at).label('hour'),
|
||||
func.count().label('count')
|
||||
).filter(Conversation.user_id == user_id).group_by(
|
||||
func.extract('hour', Conversation.created_at)
|
||||
)
|
||||
hourly_results = hourly_query.all()
|
||||
hourly_activity = [{"hour": int(r.hour), "count": r.count} for r in hourly_results]
|
||||
|
||||
# Top tags
|
||||
tag_query = self.db.query(
|
||||
KGNode.properties_["tag_path"].astext.label('tag_path'),
|
||||
func.count(KGEdge.id).label('usage_count')
|
||||
).join(
|
||||
KGEdge, KGEdge.target_id == KGNode.id
|
||||
).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type == 'tag',
|
||||
KGEdge.relation_type == 'has_tag'
|
||||
).group_by(
|
||||
KGNode.properties_["tag_path"].astext
|
||||
).order_by(func.count(KGEdge.id).desc()).limit(5)
|
||||
top_tags = [{"tag_path": r.tag_path, "usage_count": r.usage_count} for r in tag_query.all()]
|
||||
|
||||
# Token trend
|
||||
now = datetime.utcnow()
|
||||
this_month_start = datetime(now.year, now.month, 1)
|
||||
last_month_end = this_month_start - timedelta(days=1)
|
||||
last_month_start = datetime(last_month_end.year, last_month_end.month, 1)
|
||||
|
||||
this_month_tokens = self.db.query(
|
||||
func.coalesce(func.sum(Message.tokens_used), 0)
|
||||
).join(Conversation).filter(
|
||||
Conversation.user_id == user_id,
|
||||
Message.created_at >= this_month_start,
|
||||
Message.role == 'assistant'
|
||||
).scalar() or 0
|
||||
|
||||
last_month_tokens = self.db.query(
|
||||
func.coalesce(func.sum(Message.tokens_used), 0)
|
||||
).join(Conversation).filter(
|
||||
Conversation.user_id == user_id,
|
||||
Message.created_at >= last_month_start,
|
||||
Message.created_at < this_month_start,
|
||||
Message.role == 'assistant'
|
||||
).scalar() or 0
|
||||
|
||||
token_trend_percent = 0
|
||||
if last_month_tokens > 0:
|
||||
token_trend_percent = round((this_month_tokens - last_month_tokens) / last_month_tokens * 100, 1)
|
||||
|
||||
return {
|
||||
"hourly_activity": hourly_activity,
|
||||
"top_tags": top_tags,
|
||||
"token_trend_percent": token_trend_percent,
|
||||
"this_month_tokens": this_month_tokens,
|
||||
"last_month_tokens": last_month_tokens,
|
||||
}
|
||||
239
backend/app/services/tag_service.py
Normal file
239
backend/app/services/tag_service.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import json
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
|
||||
TAG_EXTRACTION_PROMPT = """你是一个知识分类专家。从给定内容中提取标签。
|
||||
|
||||
要求:
|
||||
1. 标签采用层级路径格式,如 "编程语言/Python"、"后端/框架/FastAPI"
|
||||
2. 层级深度 1-4 层,避免过深
|
||||
3. 每个内容提取 3-8 个标签
|
||||
4. 标签应覆盖:主题、技术栈、领域、任务类型等维度
|
||||
|
||||
输出格式(JSON数组):
|
||||
[
|
||||
{"path": "编程语言/Python", "description": "Python编程语言相关"},
|
||||
{"path": "后端/框架/FastAPI", "description": "FastAPI框架相关"}
|
||||
]
|
||||
|
||||
内容:
|
||||
{content}
|
||||
"""
|
||||
|
||||
TAG_RELATION_PROMPT = """分析以下标签之间的关系,输出 JSON 数组:
|
||||
|
||||
关系类型:
|
||||
- parent_of: 父子关系(上级包含下级)
|
||||
- related_to: 语义相关(但不是父子)
|
||||
- synonym_of: 同义词
|
||||
|
||||
标签列表:
|
||||
{tag_paths}
|
||||
|
||||
输出格式:
|
||||
[
|
||||
{"source": "标签1", "target": "标签2", "relation": "related_to", "weight": 0.8},
|
||||
{"source": "标签1", "target": "标签3", "relation": "parent_of", "weight": 1.0}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class TagService:
|
||||
def __init__(self, db: Session, llm_client):
|
||||
self.db = db
|
||||
self.llm_client = llm_client
|
||||
|
||||
def extract_tags_from_content(self, content: str, user_id: str) -> list[dict]:
|
||||
"""从内容中提取标签"""
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个知识分类专家。"},
|
||||
{"role": "user", "content": TAG_EXTRACTION_PROMPT.format(content=content)}
|
||||
],
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result.get("tags", [])
|
||||
|
||||
def parse_tag_path(self, path: str) -> tuple[str, int, str | None]:
|
||||
"""解析标签路径,返回 (short_name, level, parent_path)"""
|
||||
parts = path.strip("/").split("/")
|
||||
short_name = parts[-1]
|
||||
level = len(parts)
|
||||
parent_path = "/".join(parts[:-1]) if level > 1 else None
|
||||
return short_name, level, parent_path
|
||||
|
||||
def get_or_create_tag_node(self, tag_info: dict, user_id: str) -> KGNode:
|
||||
"""获取或创建标签节点"""
|
||||
path = tag_info["path"]
|
||||
existing = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.properties_["tag_path"].astext == path
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
short_name, level, parent_path = self.parse_tag_path(path)
|
||||
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=short_name,
|
||||
entity_type="tag",
|
||||
description=tag_info.get("description"),
|
||||
properties_={
|
||||
"tag_path": path,
|
||||
"short_name": short_name,
|
||||
"level": level,
|
||||
"parent_path": parent_path,
|
||||
"description": tag_info.get("description"),
|
||||
"color": tag_info.get("color"),
|
||||
},
|
||||
importance=0.5
|
||||
)
|
||||
self.db.add(node)
|
||||
self.db.flush()
|
||||
return node
|
||||
|
||||
def ensure_parent_tags(self, path: str, user_id: str) -> list[KGNode]:
|
||||
"""确保父路径标签存在"""
|
||||
parts = path.strip("/").split("/")
|
||||
nodes = []
|
||||
for i in range(1, len(parts)):
|
||||
parent_path = "/".join(parts[:i])
|
||||
tag_info = {"path": parent_path, "description": None}
|
||||
node = self.get_or_create_tag_node(tag_info, user_id)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def create_tag_relations(self, tag_paths: list[str], user_id: str) -> list[KGEdge]:
|
||||
"""分析并创建标签之间的关系边"""
|
||||
path_to_node = {}
|
||||
for path in tag_paths:
|
||||
node = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.properties_["tag_path"].astext == path,
|
||||
KGNode.entity_type == "tag"
|
||||
).first()
|
||||
if node:
|
||||
path_to_node[path] = node
|
||||
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个知识图谱专家。"},
|
||||
{"role": "user", "content": TAG_RELATION_PROMPT.format(tag_paths=json.dumps(tag_paths))}
|
||||
],
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
relations = result.get("relations", [])
|
||||
|
||||
edges = []
|
||||
for rel in relations:
|
||||
source_node = path_to_node.get(rel["source"])
|
||||
target_node = path_to_node.get(rel["target"])
|
||||
if source_node and target_node:
|
||||
existing = self.db.query(KGEdge).filter(
|
||||
KGEdge.source_id == source_node.id,
|
||||
KGEdge.target_id == target_node.id
|
||||
).first()
|
||||
if not existing:
|
||||
edge = KGEdge(
|
||||
source_id=source_node.id,
|
||||
target_id=target_node.id,
|
||||
relation_type=rel["relation"],
|
||||
weight=rel.get("weight", 0.5)
|
||||
)
|
||||
self.db.add(edge)
|
||||
edges.append(edge)
|
||||
|
||||
self.db.flush()
|
||||
return edges
|
||||
|
||||
def tag_content(self, content: str, user_id: str, content_node: KGNode) -> list[KGNode]:
|
||||
"""为内容节点打标签"""
|
||||
tag_infos = self.extract_tags_from_content(content, user_id)
|
||||
tag_paths = [t["path"] for t in tag_infos]
|
||||
|
||||
tag_nodes = []
|
||||
for tag_info in tag_infos:
|
||||
node = self.get_or_create_tag_node(tag_info, user_id)
|
||||
tag_nodes.append(node)
|
||||
self.ensure_parent_tags(tag_info["path"], user_id)
|
||||
|
||||
# 创建 has_tag 边
|
||||
for tag_node in tag_nodes:
|
||||
existing_edge = self.db.query(KGEdge).filter(
|
||||
KGEdge.source_id == content_node.id,
|
||||
KGEdge.target_id == tag_node.id,
|
||||
KGEdge.relation_type == "has_tag"
|
||||
).first()
|
||||
if not existing_edge:
|
||||
edge = KGEdge(
|
||||
source_id=content_node.id,
|
||||
target_id=tag_node.id,
|
||||
relation_type="has_tag",
|
||||
weight=1.0
|
||||
)
|
||||
self.db.add(edge)
|
||||
|
||||
tag_node_ids = [n.id for n in tag_nodes]
|
||||
current_tag_ids = content_node.properties_.get("tag_node_ids", []) if content_node.properties_ else []
|
||||
content_node.properties_["tag_node_ids"] = list(set(current_tag_ids + tag_node_ids))
|
||||
|
||||
if len(tag_paths) >= 2:
|
||||
self.create_tag_relations(tag_paths, user_id)
|
||||
|
||||
self.db.commit()
|
||||
return tag_nodes
|
||||
|
||||
def tag_incremental_content(self, user_id: str, days: int = 1) -> dict:
|
||||
"""
|
||||
增量打标签 - 只对最近新增/更新的内容节点打标签
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
content_nodes = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"]),
|
||||
KGNode.updated_at >= cutoff_date
|
||||
).all()
|
||||
|
||||
untagged = [
|
||||
n for n in content_nodes
|
||||
if not n.properties_.get("tag_node_ids")
|
||||
]
|
||||
|
||||
tagged_count = 0
|
||||
for node in untagged:
|
||||
content = node.description or ""
|
||||
try:
|
||||
self.tag_content(content, user_id, node)
|
||||
tagged_count += 1
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return {"total": len(untagged), "tagged": tagged_count}
|
||||
|
||||
def get_related_content(self, tag_node_ids: list[str], user_id: str, limit: int = 10) -> list[tuple[KGNode, float]]:
|
||||
"""通过标签找相关内容"""
|
||||
edges = self.db.query(KGEdge).filter(
|
||||
KGEdge.target_id.in_(tag_node_ids),
|
||||
KGEdge.relation_type == "has_tag"
|
||||
).all()
|
||||
|
||||
content_weights: dict[str, float] = {}
|
||||
for edge in edges:
|
||||
content_weights[edge.source_id] = content_weights.get(edge.source_id, 0) + edge.weight
|
||||
|
||||
content_ids = list(content_weights.keys())
|
||||
content_nodes = self.db.query(KGNode).filter(
|
||||
KGNode.id.in_(content_ids),
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"])
|
||||
).all()
|
||||
|
||||
return [(node, content_weights[node.id]) for node in content_nodes]
|
||||
165
backend/app/services/todo_service.py
Normal file
165
backend/app/services/todo_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_daily_todos(user_id: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""
|
||||
为用户生成今日待办:
|
||||
1. 来自前一天未完成的看板任务(最多20条)
|
||||
2. 来自前一天对话记录分析(最多3条)
|
||||
"""
|
||||
today = date.today()
|
||||
yesterday = (today - timedelta(days=1)).isoformat()
|
||||
|
||||
todos: list[DailyTodo] = []
|
||||
|
||||
# 1. 从看板任务导入
|
||||
kanban_todos = await _import_kanban_tasks(user_id, yesterday, db)
|
||||
todos.extend(kanban_todos)
|
||||
|
||||
# 2. 从对话记录分析
|
||||
chat_todos = await _analyze_chat_history(user_id, yesterday, db)
|
||||
todos.extend(chat_todos)
|
||||
|
||||
return todos
|
||||
|
||||
|
||||
async def _import_kanban_tasks(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""导入前一天创建的、未完成的看板任务"""
|
||||
q = select(Task).where(
|
||||
Task.user_id == user_id,
|
||||
Task.status != TaskStatus.DONE,
|
||||
).order_by(Task.created_at.desc()).limit(20)
|
||||
|
||||
tasks = (await db.execute(q)).scalars().all()
|
||||
todos = []
|
||||
|
||||
for task in tasks:
|
||||
todo = DailyTodo(
|
||||
user_id=user_id,
|
||||
title=task.title,
|
||||
source=TodoSource.AI_KANBAN,
|
||||
source_detail=f"看板:{task.title}",
|
||||
source_ref_id=task.id,
|
||||
todo_date=date.today().isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
todos.append(todo)
|
||||
|
||||
if todos:
|
||||
await db.commit()
|
||||
for todo in todos:
|
||||
await db.refresh(todo)
|
||||
|
||||
return todos
|
||||
|
||||
|
||||
async def _analyze_chat_history(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""分析前一天对话,提取待办事项"""
|
||||
try:
|
||||
# 查询前一天创建的对话
|
||||
conv_q = select(Conversation).where(
|
||||
Conversation.user_id == user_id,
|
||||
).order_by(Conversation.created_at.desc()).limit(10)
|
||||
convs = (await db.execute(conv_q)).scalars().all()
|
||||
|
||||
# 过滤出昨天的对话
|
||||
yesterday_convs = []
|
||||
for conv in convs:
|
||||
created = conv.created_at
|
||||
if hasattr(created, 'date'):
|
||||
created_date = created.date() if hasattr(created, 'date') else created
|
||||
else:
|
||||
created_date = datetime.fromisoformat(str(created)).date()
|
||||
|
||||
if str(created_date) == date_str or (created + timedelta(hours=8)).strftime('%Y-%m-%d') == date_str:
|
||||
yesterday_convs.append(conv)
|
||||
|
||||
if not yesterday_convs:
|
||||
return []
|
||||
|
||||
# 收集消息内容(限制2000字)
|
||||
messages_content = []
|
||||
for conv in yesterday_convs:
|
||||
msg_q = select(Message).where(
|
||||
Message.conversation_id == conv.id
|
||||
).order_by(Message.created_at.asc()).limit(50)
|
||||
msgs = (await db.execute(msg_q)).scalars().all()
|
||||
for msg in msgs:
|
||||
if msg.content:
|
||||
messages_content.append(f"[{msg.role}]: {msg.content[:500]}")
|
||||
|
||||
if not messages_content:
|
||||
return []
|
||||
|
||||
full_text = "\n".join(messages_content)[:2000]
|
||||
|
||||
# 调用 LLM 分析
|
||||
prompt = f"""你是一个任务规划助手。请分析以下对话记录,提取其中用户想要完成但尚未明确完成的事项。
|
||||
|
||||
要求:
|
||||
- 最多提取 3 条
|
||||
- 每条格式:{{"title": "事项描述(50字以内)", "reason": "来源说明(60字以内)"}}
|
||||
- 只提取用户明确表达过需求但还未完成的事项
|
||||
- 如果没有可提取的内容,返回空数组 []
|
||||
|
||||
对话记录:
|
||||
{full_text}
|
||||
|
||||
返回 JSON 数组:"""
|
||||
|
||||
llm = get_llm()
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个任务规划助手。"),
|
||||
HumanMessage(content=prompt),
|
||||
])
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
# 提取 JSON 数组
|
||||
start = content.find('[')
|
||||
end = content.rfind(']') + 1
|
||||
if start != -1 and end > start:
|
||||
items = json.loads(content[start:end])
|
||||
else:
|
||||
items = []
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"LLM 返回格式异常,跳过对话分析: {content[:200]}")
|
||||
items = []
|
||||
|
||||
if not items:
|
||||
return []
|
||||
|
||||
todos = []
|
||||
for item in items[:3]:
|
||||
todo = DailyTodo(
|
||||
user_id=user_id,
|
||||
title=item.get("title", "")[:500],
|
||||
source=TodoSource.AI_CHAT,
|
||||
source_detail=f"对话:{item.get('reason', '')[:60]}",
|
||||
todo_date=date.today().isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
todos.append(todo)
|
||||
|
||||
if todos:
|
||||
await db.commit()
|
||||
for todo in todos:
|
||||
await db.refresh(todo)
|
||||
|
||||
return todos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对话分析失败: {e}")
|
||||
return []
|
||||
Reference in New Issue
Block a user