Add FastAPI backend with agent system

This commit is contained in:
2026-03-21 10:13:29 +08:00
parent ed6bab59fe
commit 6ffa07adde
82 changed files with 11138 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# Services - import specific classes directly when needed
# e.g.: from app.services.agent_service import AgentService

View File

@@ -0,0 +1,261 @@
"""
Jarvis Agent 服务层
负责 LangGraph Agent 的调用、流式输出、对话历史管理
"""
import json
import uuid
from datetime import datetime
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_core.messages import HumanMessage, AIMessage
from app.models.conversation import Conversation, Message
from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.services import memory_service
class AgentService:
"""对话 Agent 服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def chat(
self,
user_id: str,
message: str,
conversation_id: str | None = None,
) -> tuple[str, str, AsyncGenerator[str, None]]:
"""
处理对话请求(流式)
Returns:
(conversation_id, message_id, response_stream)
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
# 预创建助手消息(后续更新内容)
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content="",
model="jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 调用 LangGraph Agent
async def run_agent():
set_current_user(user_id)
try:
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
}
collected = ""
async for event in graph.astream_events(langgraph_state, version="v2"):
kind = event.get("event")
if kind == "on_chat_model_end":
content = event.get("data", {}).get("output", {})
if isinstance(content, dict):
content = content.get("content", "")
if content:
delta = content[len(collected):]
if delta:
collected += delta
yield delta
elif kind == "on_tool_end":
name = event.get("name", "")
yield f"\n[工具执行: {name}]\n"
except Exception as e:
yield f"\n执行出错: {str(e)}"
finally:
clear_current_user()
# 异步触发自动摘要和记忆提取(不阻塞响应)
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
)
except Exception:
pass
# 最终更新数据库中的消息内容
if collected:
try:
result2 = await self.db.execute(
select(Message).where(Message.id == assistant_msg.id)
)
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await self.db.commit()
except Exception:
pass
return conversation_id, assistant_msg.id, run_agent()
async def chat_simple(
self,
user_id: str,
message: str,
conversation_id: str | None = None,
file_ids: list[str] | None = None,
) -> tuple[str, str, str]:
"""
简单同步版对话(无流式)
Returns:
(conversation_id, message_id, response_content)
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
if content:
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
# 将文件上下文添加到消息
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
attachments=[{"file_ids": file_ids}] if file_ids else None,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 调用 LangGraph Agent
set_current_user(user_id)
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
}
try:
result_state = await graph.ainvoke(langgraph_state)
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
except Exception as e:
response_content = f"抱歉,发生错误: {str(e)}"
finally:
clear_current_user()
# 异步触发自动摘要
import asyncio
try:
asyncio.get_running_loop().create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
)
except Exception:
pass
# 保存助手消息
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content=response_content,
model="jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
return conversation_id, assistant_msg.id, response_content

View File

@@ -0,0 +1,29 @@
from datetime import datetime, timedelta
from passlib.context import CryptContext
from jose import jwt, JWTError
from app.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def decode_token(token: str) -> dict | None:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None

View File

@@ -0,0 +1,256 @@
"""
文档服务 - 上传、解析、分块、存储
支持多种文档格式 + LlamaIndex 智能分块
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import UploadFile
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import os
import aiofiles
import uuid
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
class DocumentService:
def __init__(self, db: AsyncSession, user_id: str = None):
self.db = db
self.user_id = user_id
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: {ext}")
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
file_id = str(uuid.uuid4())
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
content = await file.read()
file_size = len(content)
if file_size > settings.MAX_UPLOAD_SIZE:
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
text_content = await self._extract_text(file_path, ext)
doc = Document(
user_id=user_id,
title=file.filename.rsplit('.', 1)[0],
filename=file.filename,
file_type=ext[1:],
file_size=file_size,
file_path=file_path,
summary=text_content[:500] if len(text_content) > 500 else text_content,
folder_id=folder_id,
)
self.db.add(doc)
await self.db.commit()
await self.db.refresh(doc)
chunks = self._chunk_text(text_content)
for i, chunk_text in enumerate(chunks):
chunk = DocumentChunk(
document_id=doc.id,
chunk_index=i,
content=chunk_text,
)
self.db.add(chunk)
doc.chunk_count = len(chunks)
await self.db.commit()
return doc
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
folders = await self.db.execute(
select(Folder).where(Folder.user_id == self.user_id)
)
folder_map = {f.id: f for f in folders.scalars().all()}
path_parts = []
current_id = folder_id
while current_id:
folder = folder_map.get(current_id)
if not folder:
break
path_parts.insert(0, folder.name)
current_id = folder.parent_id
return "/" + "/".join(path_parts) if path_parts else None
async def delete_document(self, user_id: str, document_id: str):
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError("文档不存在")
if os.path.exists(doc.file_path):
os.remove(doc.file_path)
await self.db.delete(doc)
await self.db.commit()
async def _extract_text(self, file_path: str, ext: str) -> str:
if ext == ".pdf":
try:
import pymupdf
doc = pymupdf.open(file_path)
text = "".join(page.get_text() for page in doc)
doc.close()
return text
except ImportError:
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
elif ext in (".md", ".txt"):
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read()
elif ext in (".docx", ".doc"):
try:
from docx import Document as DocxDocument
doc = DocxDocument(file_path)
return "\n".join([p.text for p in doc.paragraphs])
except ImportError:
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
return "[暂不支持此格式]"
def _chunk_text(self, text: str) -> list[str]:
"""
智能文档分块策略
1. 先按 Markdown 标题层级H1/H2/H3切分
2. 每个大段落内部按固定长度切分
3. 保留上下文prev_summary / next_summary
"""
import re
chunks = []
# 策略1: Markdown 标题切分(优先)
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
headers = list(header_pattern.finditer(text))
if headers:
# 按标题段落切分
for i, match in enumerate(headers):
start = match.start()
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
section = text[start:end].strip()
if len(section) > settings.CHUNK_SIZE:
# 大段落内部再切分
sub_chunks = self._split_large_chunk(section, match.group(2))
chunks.extend(sub_chunks)
elif section:
chunks.append(section)
else:
# 策略2: 按段落切分
chunks = self._chunk_by_paragraphs(text)
# 过滤空 chunk
chunks = [c.strip() for c in chunks if c.strip()]
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
def _chunk_by_paragraphs(self, text: str) -> list[str]:
"""按段落分块,带上下文"""
paragraphs = text.split("\n\n")
chunks = []
current = ""
prev_summary = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current) + len(para) < settings.CHUNK_SIZE:
current += "\n\n" + para
else:
if current:
# 添加上下文摘要
enriched = current.strip()
chunks.append(enriched)
current = para
if current.strip():
chunks.append(current.strip())
return chunks
def _split_large_chunk(self, text: str, title: str) -> list[str]:
"""将大段落拆分为固定大小的子块"""
chunks = []
sentences = text.split("")
current = title + "\n\n"
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
full_sentence = sentence if sentence.endswith("") else sentence + ""
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
current += full_sentence + " "
else:
if current.strip():
chunks.append(current.strip())
current = title + "\n\n" + full_sentence + " "
if current.strip():
chunks.append(current.strip())
return chunks
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return list(result.scalars().all())
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
"""获取文档的文本内容"""
import os
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
return None
file_path = doc.file_path
if not os.path.exists(file_path):
return None
# 根据文件类型读取内容
ext = doc.filename.split('.')[-1].lower()
try:
if ext == 'txt':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'md':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'pdf':
# 简单文本提取(生产环境应使用专业库)
return f"[PDF文档] {doc.filename}"
else:
return f"[文档] {doc.filename}"
except Exception:
return f"[文档] {doc.filename}"

View File

@@ -0,0 +1,342 @@
"""
知识图谱服务 - 实体识别、关系抽取、图谱查询
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.document import Document, DocumentChunk
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage
import json
import logging
logger = logging.getLogger(__name__)
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
实体类型:
- person人物人名、角色
- concept概念抽象概念、理论、方法
- topic主题话题、领域
- task任务要做的事情
- event事件发生的事件
- document文档文件、资料
关系类型:
- related_to相关于
- part_of隶属于
- caused_by由...导致)
- depends_on取决于
- contains包含
- located_in位于
- works_on从事
要求:
1. 识别文本中所有有意义的实体不超过10个
2. 识别实体之间的关系(每个实体至少一条关系)
3. 每个实体要有 name、type、description1-2句话
4. 关系要有 source、target、relation_type
文本内容:
{text}
请只返回 JSON不要有其他内容
{{
"entities": [
{{"name": "实体名", "type": "类型", "description": "描述"}}
],
"relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
]
}}
"""
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
用户问题:{question}
已知实体:
{entities}
请推断这些实体之间的隐含关系,返回 JSON
{{
"inferred_relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
]
}}
关系类型related_to / part_of / caused_by / depends_on / contains / works_on / located_in
confidence: 0.0-1.0,表示推断置信度
"""
class GraphService:
def __init__(self, db: AsyncSession):
self.db = db
self.llm = get_llm()
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
"""
从文档构建/更新知识图谱
- 遍历所有 chunk
- LLM 实体识别
- LLM 关系抽取
- 去重合并
"""
query = (
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(Document.is_indexed == True)
)
if document_ids:
query = query.where(DocumentChunk.document_id.in_(document_ids))
result = await self.db.execute(query)
chunks = list(result.scalars().all())
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
for chunk in chunks:
try:
await self._process_chunk(chunk, user_id)
except Exception as e:
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
continue
logger.info(f"[GraphService] 图谱构建完成")
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
"""处理单个 chunk提取实体和关系"""
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
response = await self.llm.invoke([HumanMessage(content=prompt)])
try:
data = json.loads(response.content)
except json.JSONDecodeError:
return
entities = data.get("entities", [])
relations = data.get("relations", [])
if not entities:
return
# 先查找已存在的节点
existing_nodes = {}
for entity_data in entities:
name = entity_data["name"]
result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.where(KGNode.name == name)
)
node = result.scalar_one_or_none()
if node:
existing_nodes[name] = node
# 插入新节点
entity_map = {}
for entity_data in entities:
name = entity_data["name"]
if name in existing_nodes:
entity_map[name] = existing_nodes[name].id
else:
node = KGNode(
user_id=user_id,
name=name,
entity_type=entity_data["type"],
description=entity_data.get("description", ""),
source_document_id=chunk.document_id,
)
self.db.add(node)
await self.db.flush()
entity_map[name] = node.id
# 插入关系(去重)
for rel in relations:
src, tgt = rel["source"], rel["target"]
if src not in entity_map or tgt not in entity_map:
continue
# 检查关系是否已存在
result = await self.db.execute(
select(KGEdge).where(
KGEdge.source_id == entity_map[src],
KGEdge.target_id == entity_map[tgt],
KGEdge.relation_type == rel["relation_type"],
)
)
existing = result.scalar_one_or_none()
if not existing:
edge = KGEdge(
source_id=entity_map[src],
target_id=entity_map[tgt],
relation_type=rel["relation_type"],
)
self.db.add(edge)
await self.db.commit()
async def get_graph_summary(self, user_id: str) -> str:
"""获取用户图谱的整体摘要"""
# 统计
node_count = await self.db.execute(
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
)
edge_count = await self.db.execute(
select(func.count()).select_from(KGEdge)
.select_from(KGEdge)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
)
node_total = node_count.scalar() or 0
edge_total = edge_count.scalar() or 0
if node_total == 0:
return "知识图谱为空,请先上传文档并构建图谱。"
# 按类型统计节点
type_result = await self.db.execute(
select(KGNode.entity_type, func.count())
.where(KGNode.user_id == user_id)
.group_by(KGNode.entity_type)
)
type_stats = type_result.all()
# 关系类型统计
rel_result = await self.db.execute(
select(KGEdge.relation_type, func.count())
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
.group_by(KGEdge.relation_type)
)
rel_stats = rel_result.all()
# 列出最重要的节点(按 importance
top_nodes_result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.order_by(KGNode.importance.desc())
.limit(10)
)
top_nodes = list(top_nodes_result.scalars().all())
lines = [
f"## 知识图谱摘要",
f"",
f"**总节点数**: {node_total}",
f"**总关系数**: {edge_total}",
f"",
f"### 节点类型分布",
]
for etype, count in type_stats:
lines.append(f"- {etype}: {count}")
lines.append(f"\n### 关系类型分布")
for rtype, count in rel_stats:
lines.append(f"- {rtype}: {count}")
lines.append(f"\n### 核心实体 (Top 10)")
for node in top_nodes:
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
return "\n".join(lines)
async def get_entity_context(self, entity: str, user_id: str) -> str:
"""获取某个实体的详细上下文"""
# 查找节点
result = await self.db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.name.contains(entity),
).limit(5)
)
nodes = list(result.scalars().all())
if not nodes:
return f"未找到实体: {entity}"
lines = []
for node in nodes:
lines.append(f"### {node.name} [{node.entity_type}]")
lines.append(f"描述: {node.description or '无描述'}")
# 获取该节点的关系
edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.target_id)
.where(KGEdge.source_id == node.id)
.limit(10)
)
out_edges = list(edges_result.all())
in_edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGEdge.target_id == node.id)
.limit(10)
)
in_edges = list(in_edges_result.all())
if out_edges:
lines.append("**关联到**:")
for edge, target in out_edges:
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
if in_edges:
lines.append("**被关联于**:")
for edge, source in in_edges:
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
lines.append("")
return "\n".join(lines)
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
"""获取节点的邻居节点(用于图谱可视化)"""
visited = set()
current_level = {node_id}
all_nodes = []
all_edges = []
for _ in range(depth):
if not current_level:
break
next_level = set()
for nid in current_level:
if nid in visited:
continue
visited.add(nid)
# 获取节点
node_result = await self.db.execute(
select(KGNode).where(KGNode.id == nid)
)
node = node_result.scalar_one_or_none()
if node:
all_nodes.append(node)
# 获取出边
out_result = await self.db.execute(
select(KGEdge).where(KGEdge.source_id == nid)
)
for edge in out_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.target_id)
# 获取入边
in_result = await self.db.execute(
select(KGEdge).where(KGEdge.target_id == nid)
)
for edge in in_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.source_id)
current_level = next_level
return {"nodes": all_nodes, "edges": all_edges}

View File

@@ -0,0 +1,308 @@
"""
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
检索策略:
1. 语义检索 (dense) - ChromaDB 向量相似度
2. 关键词检索 (sparse) - SQL LIKE
3. 混合检索 - 语义 + 关键词 加权融合
4. Rerank - 二次排序优化结果
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import chromadb
from chromadb.config import Settings as ChromaSettings
from dataclasses import dataclass
@dataclass
class SearchResult:
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None = None
prev_chunk: str | None = None
next_chunk: str | None = None
class KnowledgeService:
"""向量知识库检索服务"""
def __init__(self, db: AsyncSession, user_id: str | None = None):
self.db = db
self.user_id = user_id
self._chroma_client = None
@property
def chroma_client(self):
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(allow_reset=True),
)
return self._chroma_client
def get_collection(self, user_id: str):
return self.chroma_client.get_or_create_collection(
name=f"user_{user_id}",
metadata={"user_id": user_id},
)
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
"""将文档 chunks 向量化存入 ChromaDB"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
return
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return
collection = self.get_collection(user_id)
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [
{
"document_id": doc.id,
"document_title": doc.title,
"chunk_index": chunk.chunk_index,
"file_type": doc.file_type,
"folder_path": folder_path or "",
}
for chunk in chunks
]
collection.add(ids=ids, documents=documents, metadatas=metadatas)
doc.is_indexed = True
await self.db.commit()
async def retrieve(
self,
query: str,
user_id: str,
folder_id: str | None = None,
top_k: int = 5,
use_rerank: bool = True,
) -> list[SearchResult]:
"""
混合检索 + Rerank支持按文件夹过滤
流程:
1. ChromaDB 向量检索 (扩大候选集)
2. 提取父 chunk完整上下文
3. Rerank 二次排序
4. 返回 top_k 结果
"""
collection = self.get_collection(user_id)
# 构建过滤条件
where = None
if folder_id:
folder_path = await self._get_folder_path(folder_id)
if folder_path:
where = {"folder_path": {"$starts_with": folder_path}}
try:
results = collection.query(
query_texts=[query],
n_results=top_k * 3,
where=where,
include=["documents", "metadatas", "distances"],
)
except Exception:
return []
if not results or not results.get("ids"):
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
search_results: list[SearchResult] = []
for i, chunk_id in enumerate(ids):
meta = metadatas[i] if i < len(metadatas) else {}
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
prev_chunk, next_chunk = await self._get_sibling_chunks(
chunk_id=chunk_id,
chunk_index=meta.get("chunk_index", 0),
document_id=meta.get("document_id", ""),
)
search_results.append(SearchResult(
chunk_id=chunk_id,
document_id=meta.get("document_id", ""),
document_title=meta.get("document_title", ""),
content=documents[i] if i < len(documents) else "",
score=score,
metadata_=str(meta),
prev_chunk=prev_chunk,
next_chunk=next_chunk,
))
if use_rerank:
search_results = self._rerank(query, search_results, top_k)
else:
search_results = search_results[:top_k]
return search_results
def _rerank(
self,
query: str,
results: list[SearchResult],
top_k: int,
) -> list[SearchResult]:
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
import re
query_words = set(re.findall(r"\w+", query.lower()))
scored = []
for r in results:
score = r.score * 0.7
content_words = set(re.findall(r"\w+", r.content.lower()))
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
score += keyword_overlap * 0.2
if r.document_title:
title_words = set(re.findall(r"\w+", r.document_title.lower()))
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
score += title_overlap * 0.1
scored.append((score, r))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
async def _get_sibling_chunks(
self,
chunk_id: str,
chunk_index: int,
document_id: str,
) -> tuple[str | None, str | None]:
"""获取前一个和后一个 chunk完整上下文"""
prev_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index - 1,
)
)
next_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index + 1,
)
)
prev_chunk = prev_result.scalar_one_or_none()
next_chunk = next_result.scalar_one_or_none()
return (
prev_chunk.content if prev_chunk else None,
next_chunk.content if next_chunk else None,
)
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
result = await self.db.execute(
select(Folder).where(Folder.id == folder_id)
)
folder = result.scalar_one_or_none()
if not folder:
return None
path_parts = [folder.name]
current_parent_id = folder.parent_id
while current_parent_id:
parent_result = await self.db.execute(
select(Folder).where(Folder.id == current_parent_id)
)
parent = parent_result.scalar_one_or_none()
if not parent:
break
path_parts.insert(0, parent.name)
current_parent_id = parent.parent_id
return "/" + "/".join(path_parts)
async def hybrid_search(
self,
query: str,
user_id: str,
top_k: int = 5,
) -> list[SearchResult]:
"""混合检索: 向量 + 关键词 + Rerank"""
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
keyword_results = await self._keyword_search(query, user_id, top_k)
seen = set()
merged: list[SearchResult] = []
for r in vector_results + keyword_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
merged.append(r)
return self._rerank(query, merged, top_k)
async def _keyword_search(
self,
query: str,
user_id: str,
top_k: int,
) -> list[SearchResult]:
"""SQL 关键词搜索"""
result = await self.db.execute(
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(
or_(
DocumentChunk.content.contains(query),
Document.title.contains(query),
)
)
.limit(top_k)
)
chunks = result.scalars().all()
results = []
for chunk in chunks:
doc_result = await self.db.execute(
select(Document).where(Document.id == chunk.document_id)
)
doc = doc_result.scalar_one_or_none()
results.append(SearchResult(
chunk_id=chunk.id,
document_id=chunk.document_id,
document_title=doc.title if doc else "",
content=chunk.content,
score=0.5,
metadata_=None,
))
return results
async def delete_from_vectorstore(self, user_id: str, document_id: str):
"""从向量库删除文档"""
collection = self.get_collection(user_id)
try:
collection.delete(where={"document_id": document_id})
except Exception:
pass

View File

@@ -0,0 +1,145 @@
"""
LLM 服务 - 支持多种 LLM 提供商
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
"""
from abc import ABC, abstractmethod
from typing import AsyncIterator
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from app.config import settings
import httpx
import os
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
class LLMService(ABC):
@abstractmethod
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
raise NotImplementedError
@abstractmethod
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
raise NotImplementedError
@abstractmethod
def get_model_name(self) -> str:
raise NotImplementedError
class OpenAICompatibleService(LLMService):
"""
OpenAI 兼容接口
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
):
self.api_key = api_key or settings.OPENAI_API_KEY
self.model = model or settings.OPENAI_MODEL
self.base_url = base_url or settings.OPENAI_BASE_URL
self._llm = ChatOpenAI(
api_key=self.api_key,
model=self.model,
base_url=self.base_url,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class ClaudeService(LLMService):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_tokens: int = 8192,
):
self.api_key = api_key or settings.ANTHROPIC_API_KEY
self.model = model or settings.CLAUDE_MODEL
self._llm = ChatAnthropic(
api_key=self.api_key,
model=self.model,
max_tokens=max_tokens,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class OllamaService(LLMService):
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
):
self.base_url = base_url or settings.OLLAMA_BASE_URL
self.model = model or settings.OLLAMA_MODEL
self._llm = ChatOllama(
base_url=self.base_url,
model=self.model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
# 单例缓存
_llm_instance: LLMService | None = None
def get_llm() -> LLMService:
"""根据配置获取 LLM 实例"""
global _llm_instance
if _llm_instance is None:
provider = settings.LLM_PROVIDER
if provider == "openai":
_llm_instance = OpenAICompatibleService()
elif provider == "deepseek":
_llm_instance = OpenAICompatibleService(
base_url="https://api.deepseek.com/v1",
model="deepseek-chat",
)
elif provider == "custom":
_llm_instance = OpenAICompatibleService()
elif provider == "claude":
_llm_instance = ClaudeService()
elif provider == "ollama":
_llm_instance = OllamaService()
else:
raise ValueError(f"Unknown LLM provider: {provider}")
return _llm_instance

View File

@@ -0,0 +1,304 @@
"""
Jarvis 记忆系统
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
"""
import json
import re
from datetime import datetime
from typing import Optional
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import MemorySummary, UserMemory
from app.models.conversation import Conversation, Message
from app.services.llm_service import get_llm
from app.agents.context import get_current_user
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
limit: int = 20,
) -> list[Message]:
"""加载指定对话的历史消息"""
result = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
.limit(limit)
)
return list(result.scalars().all())
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id))
.where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
)
return result.scalar() or 0
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
turn_count = await get_conversation_turn_count(db, conversation_id)
# 检查是否已有摘要覆盖到当前轮数
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(desc(MemorySummary.turn_count))
.limit(1)
)
latest_summary = result.scalar_one_or_none()
if latest_summary:
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
return turn_count >= SUMMARIZE_THRESHOLD
async def generate_summary(
db: AsyncSession,
conversation_id: str,
messages: list[Message],
) -> str:
"""调用 LLM 生成对话摘要"""
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
HumanMessage(content=history_text),
])
return response.content.strip()
async def save_summary(
db: AsyncSession,
user_id: str,
conversation_id: str,
summary_text: str,
turn_count: int,
) -> MemorySummary:
"""保存对话摘要"""
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
summary_text=summary_text,
turn_count=turn_count,
)
db.add(summary)
await db.commit()
await db.refresh(summary)
return summary
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[MemorySummary]:
"""获取某对话的所有历史摘要"""
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(MemorySummary.summary_at)
)
return list(result.scalars().all())
# ———— 长期记忆: 用户画像 ————
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
只提取事实性的、可能对未来对话有帮助的信息,如:
- 用户的身份/职业/背景
- 用户的偏好和习惯
- 用户的目标和计划
- 重要的事件和日期
- 用户的观点和态度
每条记忆格式: [类型] 内容
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
如果没有提取到任何记忆,回复""
"""
FACT_TYPES = {"fact", "preference", "goal", "habit"}
def _parse_fact_line(line: str) -> tuple[str, str] | None:
"""解析一行记忆: [fact] 内容 -> (type, content)"""
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
if m and m.group(1) in FACT_TYPES:
return m.group(1), m.group(2).strip()
return None
async def extract_user_memories(
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[UserMemory]:
"""从对话中提取用户记忆并保存"""
if len(messages) < 2:
return []
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages[-10:]
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content=EXTRACTION_PROMPT),
HumanMessage(content=history_text),
])
text = response.content.strip()
if text == "" or not text:
return []
memories = []
for line in text.split("\n"):
parsed = _parse_fact_line(line)
if not parsed:
continue
mem_type, content = parsed
# 检查是否已有完全相同的记忆
existing = await db.execute(
select(UserMemory).where(
UserMemory.user_id == user_id,
UserMemory.content == content,
)
)
if existing.scalar_one_or_none():
continue
mem = UserMemory(
user_id=user_id,
memory_type=mem_type,
content=content,
importance=5,
source_conversation_id=conversation_id,
)
db.add(mem)
memories.append(mem)
if memories:
await db.commit()
return memories
async def recall_user_memories(
db: AsyncSession,
user_id: str,
query: str,
top_k: int = 5,
) -> list[UserMemory]:
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
# 先尝试语义相似(通过 LLM 判断)
# 降级: 直接从数据库取最近的重要记忆
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
.limit(top_k)
)
memories = list(result.scalars().all())
# 重置召回标记
for m in memories:
m.is_recalled = False
await db.commit()
return memories
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
"""标记记忆已被召回使用"""
result = await db.execute(
select(UserMemory).where(UserMemory.id == memory_id)
)
mem = result.scalar_one_or_none()
if mem:
mem.is_recalled = True
mem.recall_count = (mem.recall_count or 0) + 1
mem.last_recalled_at = datetime.utcnow()
await db.commit()
# ———— 记忆组装: 供 Agent 使用的上下文 ————
async def build_memory_context(
db: AsyncSession,
user_id: str,
conversation_id: str,
current_query: str,
) -> str:
"""
构建完整的记忆上下文字符串,
供注入到 Agent system prompt 中使用。
"""
parts = []
# 1. 用户画像(长期记忆)
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if user_memories:
lines = []
for m in user_memories:
tag = f"[{m.memory_type}]"
lines.append(f" {tag} {m.content}")
await mark_memory_recalled(db, m.id)
parts.append("【用户记忆】\n" + "\n".join(lines))
# 2. 对话摘要(中期记忆)
summaries = await get_summaries(db, conversation_id)
if summaries:
# 只取最近2条
recent = summaries[-2:]
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
if not parts:
return ""
return "\n\n".join(parts)
async def try_auto_summarize(
db: AsyncSession,
user_id: str,
conversation_id: str,
) -> bool:
"""
检查是否需要摘要,如果需要则生成并保存。
返回是否执行了摘要。
"""
if not await should_summarize(db, conversation_id):
return False
messages = await load_conversation_history(db, conversation_id, limit=30)
if len(messages) < 3:
return False
try:
summary_text = await generate_summary(db, conversation_id, messages)
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
# 同时提取用户记忆
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception:
return False

View File

@@ -0,0 +1,291 @@
"""
定时任务服务 - APScheduler 调度器
"""
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy import select, and_
from app.database import async_session
from app.models.task import Task
from app.models.forum import ForumPost
from app.models.knowledge_graph import KGNode
from app.services.agent_service import AgentService
from app.services.graph_service import GraphService
from app.config import settings
import logging
logger = logging.getLogger(__name__)
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
# ===================== 定时任务函数 =====================
async def daily_task_analysis():
"""
每日凌晨任务分析
- 分析前一天完成的任务
- 生成每日报告
- 创建次日计划建议
"""
logger.info("[Scheduler] 开始执行每日任务分析...")
async with async_session() as db:
from datetime import datetime, timedelta
yesterday = datetime.utcnow().date() - timedelta(days=1)
# 统计昨日任务完成情况
result = await db.execute(
select(Task).where(Task.updated_at >= yesterday)
)
tasks = result.scalars().all()
completed = [t for t in tasks if t.status == "done"]
pending = [t for t in tasks if t.status != "done"]
report = f"""## 每日任务报告 - {yesterday.strftime('%Y-%m-%d')}
### 完成情况
- 总任务数: {len(tasks)}
- 已完成: {len(completed)}
- 未完成: {len(pending)}
### 已完成任务
{chr(10).join([f"- {t.title}" for t in completed]) or ""}
### 未完成任务
{chr(10).join([f"- {t.title} (优先级: {t.priority})" for t in pending]) or ""}
### 建议
根据未完成任务,建议明天优先处理:
{chr(10).join([f"{i+1}. {t.title}" for i, t in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5])]) or "无待处理任务"}
"""
# 发布到论坛
from app.models.forum import ForumPost
post = ForumPost(
title=f"每日报告 - {yesterday.strftime('%Y-%m-%d')}",
content=report,
category="discussion",
)
db.add(post)
# 创建明日计划建议任务
for i, task in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5]):
suggestion = Task(
title=f"继续: {task.title}",
description=f"昨日未完成任务,优先级: {task.priority}",
priority=task.priority,
status="todo",
)
db.add(suggestion)
await db.commit()
logger.info(f"[Scheduler] 每日任务分析完成,完成 {len(completed)} 个任务")
async def forum_scan_task():
"""
论坛扫描任务
- 扫描所有指令类帖子
- 识别可执行指令
- AI自动执行
"""
logger.info("[Scheduler] 开始扫描论坛指令...")
async with async_session() as db:
from sqlalchemy import select
result = await db.execute(
select(ForumPost).where(
ForumPost.category == "instruction",
ForumPost.is_executed == False,
).limit(5)
)
posts = result.scalars().all()
if not posts:
logger.info("[Scheduler] 暂无待执行指令")
return
agent_svc = AgentService(db)
executed_count = 0
for post in posts:
try:
# 让 Agent 分析并执行指令
conv_id, msg_id, response = await agent_svc.chat_simple(
user_id=post.user_id,
message=f"请执行以下论坛指令: {post.title}{post.content}",
conversation_id=None,
)
post.is_executed = True
post.executed_response = response
executed_count += 1
logger.info(f"[Scheduler] 执行指令: {post.title}")
except Exception as e:
logger.error(f"[Scheduler] 执行指令失败 {post.title}: {e}")
await db.commit()
logger.info(f"[Scheduler] 论坛扫描完成,执行了 {executed_count} 个指令")
async def graph_rebuild_task():
"""
知识图谱增量重建任务
- 扫描新增/更新的文档
- 更新图谱节点和边
"""
logger.info("[Scheduler] 开始重建知识图谱...")
async with async_session() as db:
try:
graph_svc = GraphService(db)
# 只处理最近7天有活动的文档
await graph_svc.build_graph(user_id="default", document_ids=None)
logger.info("[Scheduler] 知识图谱重建完成")
except Exception as e:
logger.error(f"[Scheduler] 知识图谱重建失败: {e}")
async def tag_generation_task():
"""
每日凌晨 00:00 增量标签生成任务
"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
from sqlalchemy import select
logger.info("[Scheduler] 开始执行每日标签生成...")
async with async_session() as db:
try:
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
result = await db.execute(
select(KGNode.user_id).distinct().where(
KGNode.entity_type.in_(["conversation", "document", "chunk"])
)
)
user_ids = result.scalars().all()
total_tagged = 0
for user_id in user_ids:
sync_tag_service = TagService(db, llm_client)
result = sync_tag_service.tag_incremental_content(user_id, days=1)
total_tagged += result["tagged"]
logger.info(f"[Scheduler] 每日标签生成完成,共标签化 {total_tagged} 个内容节点")
except Exception as e:
logger.error(f"[Scheduler] 每日标签生成失败: {e}")
async def daily_todo_generation():
"""
每天早上 08:00 为所有活跃用户生成待办
- 来自前一天未完成的看板任务
- 来自前一天对话记录分析
"""
from app.models.user import User
from app.services.todo_service import generate_daily_todos
from sqlalchemy import select
logger.info("[Scheduler] 开始执行每日待办生成...")
async with async_session() as db:
try:
result = await db.execute(select(User).where(User.is_active == True))
users = result.scalars().all()
for user in users:
try:
await generate_daily_todos(user.id, db)
logger.info(f"[Scheduler] 为用户 {user.id} 生成今日待办完成")
except Exception as e:
logger.error(f"[Scheduler] 用户 {user.id} 定时生成待办失败: {e}")
logger.info(f"[Scheduler] 每日待办生成完成,共处理 {len(users)} 个用户")
except Exception as e:
logger.error(f"[Scheduler] 每日待办生成失败: {e}")
# ===================== 调度器管理 =====================
def start_scheduler():
"""启动调度器,注册所有定时任务"""
if scheduler.running:
logger.warning("[Scheduler] 调度器已在运行")
return
# 每日凌晨 00:30 执行任务分析
scheduler.add_job(
daily_task_analysis,
CronTrigger(hour=0, minute=30, timezone="Asia/Shanghai"),
id="daily_task_analysis",
name="每日任务分析",
replace_existing=True,
)
# 每小时扫描论坛指令
scheduler.add_job(
forum_scan_task,
IntervalTrigger(hours=1),
id="forum_scan",
name="论坛指令扫描",
replace_existing=True,
)
# 每天凌晨 3:00 重建图谱
scheduler.add_job(
graph_rebuild_task,
CronTrigger(hour=3, minute=0, timezone="Asia/Shanghai"),
id="graph_rebuild",
name="知识图谱重建",
replace_existing=True,
)
# 每天凌晨 00:00 生成标签
scheduler.add_job(
tag_generation_task,
CronTrigger(hour=0, minute=0, timezone="Asia/Shanghai"),
id="tag_generation",
name="每日标签生成",
replace_existing=True,
)
# 每天早上 08:00 生成今日待办
scheduler.add_job(
daily_todo_generation,
CronTrigger(hour=8, minute=0, timezone="Asia/Shanghai"),
id="daily_todo_generation",
name="每日待办生成",
replace_existing=True,
)
scheduler.start()
logger.info("[Scheduler] 定时任务调度器已启动")
def stop_scheduler():
"""停止调度器"""
if scheduler.running:
scheduler.shutdown(wait=False)
logger.info("[Scheduler] 定时任务调度器已停止")
def get_scheduler_status() -> dict:
"""获取调度器状态"""
if not scheduler.running:
return {"status": "stopped", "jobs": []}
jobs = []
for job in scheduler.get_jobs():
jobs.append({
"id": job.id,
"name": job.name,
"next_run": str(job.next_run_time) if job.next_run_time else None,
})
return {"status": "running", "jobs": jobs}

View File

@@ -0,0 +1,140 @@
import logging
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.user import User
from app.services.auth_service import verify_password, get_password_hash
logger = logging.getLogger(__name__)
async def get_user_settings(user_id: str, db: AsyncSession) -> dict:
"""获取用户完整设置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return None
return {
"profile": user,
"llm_config": user.llm_config or {},
"scheduler_config": user.scheduler_config or {}
}
async def update_user_profile(
user_id: str,
db: AsyncSession,
full_name: Optional[str] = None,
password: Optional[str] = None,
current_password: Optional[str] = None
) -> User:
"""更新用户资料"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
if password:
if not current_password or not verify_password(current_password, user.hashed_password):
raise ValueError("当前密码错误")
user.hashed_password = get_password_hash(password)
if full_name:
user.full_name = full_name
await db.commit()
await db.refresh(user)
return user
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
"""更新 LLM 配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
current = user.llm_config or {}
# 合并配置 - 直接替换整个类型配置列表
for key, value in config.items():
if value is not None:
if isinstance(value, list):
# 列表直接替换
current[key] = value
elif isinstance(value, dict):
# 字典合并
if key in current and isinstance(current[key], dict):
current[key] = {**current[key], **value}
else:
current[key] = value
else:
current[key] = value
user.llm_config = current
await db.commit()
return current
async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession) -> dict:
"""更新定时任务配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
current = user.scheduler_config or {}
for key, value in config.items():
if value is not None:
current[key] = value
user.scheduler_config = current
await db.commit()
return current
async def test_llm_connection(
provider: str,
model: str,
base_url: str,
api_key: str
) -> dict:
"""测试 LLM 连接"""
try:
# 根据不同 provider 创建临时 LLM 实例并测试
if provider == "openai":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=30
)
elif provider == "claude":
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=30
)
elif provider == "ollama":
from langchain_ollama import ChatOllama
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=30
)
elif provider == "deepseek":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or "https://api.deepseek.com/v1",
timeout=30
)
else:
return {"success": False, "error": f"不支持的 provider: {provider}"}
# 简单测试调用
from langchain_core.messages import HumanMessage
response = await llm.ainvoke([HumanMessage(content="Hi")])
return {"success": True, "message": f"连接成功,模型响应: {response.content[:50]}..."}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -0,0 +1,278 @@
import psutil
import time
from datetime import datetime, timedelta
from sqlalchemy import select, func, and_
from sqlalchemy.orm import Session
from app.models.conversation import Conversation, Message
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.task import Task, TaskStatus
from app.models.forum import ForumPost, ForumReply
from app.models.document import Document
class StatsService:
def __init__(self, db: Session):
self.db = db
def get_system_health(self) -> dict:
"""获取系统健康指标"""
uptime_seconds = int(time.time() - psutil.boot_time())
cpu_percent = psutil.cpu_percent(interval=0.1)
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"uptime_seconds": uptime_seconds,
"cpu_percent": cpu_percent,
"memory_used_mb": round(mem.used / (1024 * 1024), 1),
"memory_total_mb": round(mem.total / (1024 * 1024), 1),
"memory_percent": mem.percent,
"disk_used_gb": round(disk.used / (1024 * 1024 * 1024), 1),
"disk_total_gb": round(disk.total / (1024 * 1024 * 1024), 1),
"disk_percent": disk.percent,
"active_users_24h": 0, # 需要 User 表的 updated_at
}
def _get_daily_stats(self, model, date_column, user_id=None, days=30) -> list:
"""通用每日统计查询"""
cutoff = datetime.utcnow() - timedelta(days=days)
query = self.db.query(
func.date(date_column).label('date'),
func.count().label('count')
).filter(date_column >= cutoff)
if user_id and hasattr(model, 'user_id'):
query = query.filter(model.user_id == user_id)
query = query.group_by(func.date(date_column)).order_by(func.date(date_column))
results = query.all()
return [{"date": str(r.date), "count": r.count} for r in results]
def get_conversation_stats(self, user_id: str = None, days=30) -> dict:
"""获取对话统计数据"""
cutoff = datetime.utcnow() - timedelta(days=days)
daily_conversations = self._get_daily_stats(
Conversation, Conversation.created_at, user_id, days
)
daily_messages = self._get_daily_stats(
Message, Message.created_at, user_id, days
)
# Daily tokens
input_query = self.db.query(
func.date(Message.created_at).label('date'),
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
).filter(
Message.created_at >= cutoff,
Message.role == 'user'
)
if user_id:
input_query = input_query.join(Conversation).filter(Conversation.user_id == user_id)
input_results = input_query.group_by(func.date(Message.created_at)).all()
output_query = self.db.query(
func.date(Message.created_at).label('date'),
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
).filter(
Message.created_at >= cutoff,
Message.role == 'assistant'
)
if user_id:
output_query = output_query.join(Conversation).filter(Conversation.user_id == user_id)
output_results = output_query.group_by(func.date(Message.created_at)).all()
daily_input_tokens = [{"date": str(r.date), "input_tokens": r.tokens} for r in input_results]
daily_output_tokens = [{"date": str(r.date), "output_tokens": r.tokens} for r in output_results]
return {
"daily_conversations": daily_conversations,
"daily_messages": daily_messages,
"daily_input_tokens": daily_input_tokens,
"daily_output_tokens": daily_output_tokens,
"totals": {
"conversations": sum(c["count"] for c in daily_conversations),
"messages": sum(m["count"] for m in daily_messages),
"input_tokens": sum(t["input_tokens"] for t in daily_input_tokens),
"output_tokens": sum(t["output_tokens"] for t in daily_output_tokens),
}
}
def get_knowledge_stats(self, user_id: str = None, days=30) -> dict:
"""获取知识库统计数据"""
cutoff = datetime.utcnow() - timedelta(days=days)
# New tags
tag_query = self.db.query(
func.date(KGNode.created_at).label('date'),
func.count().label('count')
).filter(
KGNode.created_at >= cutoff,
KGNode.entity_type == 'tag'
)
if user_id:
tag_query = tag_query.filter(KGNode.user_id == user_id)
tag_results = tag_query.group_by(func.date(KGNode.created_at)).all()
daily_new_tags = [{"date": str(r.date), "count": r.count} for r in tag_results]
daily_documents = self._get_daily_stats(
Document, Document.created_at, user_id, days
)
daily_tag_relations = self._get_daily_stats(
KGEdge, KGEdge.created_at, user_id, days
)
return {
"daily_new_tags": daily_new_tags,
"daily_documents": daily_documents,
"daily_knowledge_queries": [],
"daily_tag_relations": daily_tag_relations,
"totals": {
"new_tags": sum(t["count"] for t in daily_new_tags),
"documents": sum(d["count"] for d in daily_documents),
"tag_relations": sum(r["count"] for r in daily_tag_relations),
}
}
def get_kanban_stats(self, user_id: str = None, days=30) -> dict:
"""获取看板统计数据"""
daily_new_tasks = self._get_daily_stats(
Task, Task.created_at, user_id, days
)
# Completed tasks
completed_query = self.db.query(
func.date(Task.completed_at).label('date'),
func.count().label('count')
).filter(
Task.completed_at >= datetime.utcnow() - timedelta(days=days),
Task.status == TaskStatus.DONE
)
if user_id:
completed_query = completed_query.filter(Task.user_id == user_id)
completed_results = completed_query.group_by(func.date(Task.completed_at)).all()
daily_completed_tasks = [{"date": str(r.date), "count": r.count} for r in completed_results]
# Current pending
pending_query = self.db.query(func.count(Task.id)).filter(Task.status == TaskStatus.TODO)
if user_id:
pending_query = pending_query.filter(Task.user_id == user_id)
current_pending_tasks = pending_query.scalar() or 0
# Completion rate
daily_new_dict = {d["date"]: d["count"] for d in daily_new_tasks}
daily_completed_dict = {d["date"]: d["count"] for d in daily_completed_tasks}
all_dates = set(daily_new_dict.keys()) | set(daily_completed_dict.keys())
daily_completion_rate = []
for date in sorted(all_dates):
new = daily_new_dict.get(date, 0)
completed = daily_completed_dict.get(date, 0)
rate = (completed / new * 100) if new > 0 else 0
daily_completion_rate.append({"date": date, "rate": round(rate, 1)})
return {
"daily_new_tasks": daily_new_tasks,
"daily_completed_tasks": daily_completed_tasks,
"daily_completion_rate": daily_completion_rate,
"current_pending_tasks": current_pending_tasks,
"totals": {
"new_tasks": sum(t["count"] for t in daily_new_tasks),
"completed_tasks": sum(c["count"] for c in daily_completed_tasks),
}
}
def get_community_stats(self, user_id: str = None, days=30) -> dict:
"""获取社区统计数据"""
daily_posts = self._get_daily_stats(
ForumPost, ForumPost.created_at, user_id, days
)
daily_replies = self._get_daily_stats(
ForumReply, ForumReply.created_at, user_id, days
)
# AI executions
ai_query = self.db.query(
func.date(ForumPost.updated_at).label('date'),
func.count().label('count')
).filter(
ForumPost.updated_at >= datetime.utcnow() - timedelta(days=days),
ForumPost.is_executed == True
)
if user_id:
ai_query = ai_query.filter(ForumPost.user_id == user_id)
ai_results = ai_query.group_by(func.date(ForumPost.updated_at)).all()
daily_ai_executions = [{"date": str(r.date), "count": r.count} for r in ai_results]
return {
"daily_posts": daily_posts,
"daily_replies": daily_replies,
"daily_ai_executions": daily_ai_executions,
"daily_agent_calls": [],
"totals": {
"posts": sum(p["count"] for p in daily_posts),
"replies": sum(r["count"] for r in daily_replies),
"ai_executions": sum(a["count"] for a in daily_ai_executions),
}
}
def get_personal_insights(self, user_id: str) -> dict:
"""获取个人洞察"""
# Hourly activity
hourly_query = self.db.query(
func.extract('hour', Conversation.created_at).label('hour'),
func.count().label('count')
).filter(Conversation.user_id == user_id).group_by(
func.extract('hour', Conversation.created_at)
)
hourly_results = hourly_query.all()
hourly_activity = [{"hour": int(r.hour), "count": r.count} for r in hourly_results]
# Top tags
tag_query = self.db.query(
KGNode.properties_["tag_path"].astext.label('tag_path'),
func.count(KGEdge.id).label('usage_count')
).join(
KGEdge, KGEdge.target_id == KGNode.id
).filter(
KGNode.user_id == user_id,
KGNode.entity_type == 'tag',
KGEdge.relation_type == 'has_tag'
).group_by(
KGNode.properties_["tag_path"].astext
).order_by(func.count(KGEdge.id).desc()).limit(5)
top_tags = [{"tag_path": r.tag_path, "usage_count": r.usage_count} for r in tag_query.all()]
# Token trend
now = datetime.utcnow()
this_month_start = datetime(now.year, now.month, 1)
last_month_end = this_month_start - timedelta(days=1)
last_month_start = datetime(last_month_end.year, last_month_end.month, 1)
this_month_tokens = self.db.query(
func.coalesce(func.sum(Message.tokens_used), 0)
).join(Conversation).filter(
Conversation.user_id == user_id,
Message.created_at >= this_month_start,
Message.role == 'assistant'
).scalar() or 0
last_month_tokens = self.db.query(
func.coalesce(func.sum(Message.tokens_used), 0)
).join(Conversation).filter(
Conversation.user_id == user_id,
Message.created_at >= last_month_start,
Message.created_at < this_month_start,
Message.role == 'assistant'
).scalar() or 0
token_trend_percent = 0
if last_month_tokens > 0:
token_trend_percent = round((this_month_tokens - last_month_tokens) / last_month_tokens * 100, 1)
return {
"hourly_activity": hourly_activity,
"top_tags": top_tags,
"token_trend_percent": token_trend_percent,
"this_month_tokens": this_month_tokens,
"last_month_tokens": last_month_tokens,
}

View File

@@ -0,0 +1,239 @@
import json
from sqlalchemy.orm import Session
from app.models.knowledge_graph import KGNode, KGEdge
TAG_EXTRACTION_PROMPT = """你是一个知识分类专家。从给定内容中提取标签。
要求:
1. 标签采用层级路径格式,如 "编程语言/Python""后端/框架/FastAPI"
2. 层级深度 1-4 层,避免过深
3. 每个内容提取 3-8 个标签
4. 标签应覆盖:主题、技术栈、领域、任务类型等维度
输出格式JSON数组
[
{"path": "编程语言/Python", "description": "Python编程语言相关"},
{"path": "后端/框架/FastAPI", "description": "FastAPI框架相关"}
]
内容:
{content}
"""
TAG_RELATION_PROMPT = """分析以下标签之间的关系,输出 JSON 数组:
关系类型:
- parent_of: 父子关系(上级包含下级)
- related_to: 语义相关(但不是父子)
- synonym_of: 同义词
标签列表:
{tag_paths}
输出格式:
[
{"source": "标签1", "target": "标签2", "relation": "related_to", "weight": 0.8},
{"source": "标签1", "target": "标签3", "relation": "parent_of", "weight": 1.0}
]
"""
class TagService:
def __init__(self, db: Session, llm_client):
self.db = db
self.llm_client = llm_client
def extract_tags_from_content(self, content: str, user_id: str) -> list[dict]:
"""从内容中提取标签"""
response = self.llm_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个知识分类专家。"},
{"role": "user", "content": TAG_EXTRACTION_PROMPT.format(content=content)}
],
response_format={"type": "json_object"}
)
result = json.loads(response.choices[0].message.content)
return result.get("tags", [])
def parse_tag_path(self, path: str) -> tuple[str, int, str | None]:
"""解析标签路径,返回 (short_name, level, parent_path)"""
parts = path.strip("/").split("/")
short_name = parts[-1]
level = len(parts)
parent_path = "/".join(parts[:-1]) if level > 1 else None
return short_name, level, parent_path
def get_or_create_tag_node(self, tag_info: dict, user_id: str) -> KGNode:
"""获取或创建标签节点"""
path = tag_info["path"]
existing = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.properties_["tag_path"].astext == path
).first()
if existing:
return existing
short_name, level, parent_path = self.parse_tag_path(path)
node = KGNode(
user_id=user_id,
name=short_name,
entity_type="tag",
description=tag_info.get("description"),
properties_={
"tag_path": path,
"short_name": short_name,
"level": level,
"parent_path": parent_path,
"description": tag_info.get("description"),
"color": tag_info.get("color"),
},
importance=0.5
)
self.db.add(node)
self.db.flush()
return node
def ensure_parent_tags(self, path: str, user_id: str) -> list[KGNode]:
"""确保父路径标签存在"""
parts = path.strip("/").split("/")
nodes = []
for i in range(1, len(parts)):
parent_path = "/".join(parts[:i])
tag_info = {"path": parent_path, "description": None}
node = self.get_or_create_tag_node(tag_info, user_id)
nodes.append(node)
return nodes
def create_tag_relations(self, tag_paths: list[str], user_id: str) -> list[KGEdge]:
"""分析并创建标签之间的关系边"""
path_to_node = {}
for path in tag_paths:
node = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.properties_["tag_path"].astext == path,
KGNode.entity_type == "tag"
).first()
if node:
path_to_node[path] = node
response = self.llm_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个知识图谱专家。"},
{"role": "user", "content": TAG_RELATION_PROMPT.format(tag_paths=json.dumps(tag_paths))}
],
response_format={"type": "json_object"}
)
result = json.loads(response.choices[0].message.content)
relations = result.get("relations", [])
edges = []
for rel in relations:
source_node = path_to_node.get(rel["source"])
target_node = path_to_node.get(rel["target"])
if source_node and target_node:
existing = self.db.query(KGEdge).filter(
KGEdge.source_id == source_node.id,
KGEdge.target_id == target_node.id
).first()
if not existing:
edge = KGEdge(
source_id=source_node.id,
target_id=target_node.id,
relation_type=rel["relation"],
weight=rel.get("weight", 0.5)
)
self.db.add(edge)
edges.append(edge)
self.db.flush()
return edges
def tag_content(self, content: str, user_id: str, content_node: KGNode) -> list[KGNode]:
"""为内容节点打标签"""
tag_infos = self.extract_tags_from_content(content, user_id)
tag_paths = [t["path"] for t in tag_infos]
tag_nodes = []
for tag_info in tag_infos:
node = self.get_or_create_tag_node(tag_info, user_id)
tag_nodes.append(node)
self.ensure_parent_tags(tag_info["path"], user_id)
# 创建 has_tag 边
for tag_node in tag_nodes:
existing_edge = self.db.query(KGEdge).filter(
KGEdge.source_id == content_node.id,
KGEdge.target_id == tag_node.id,
KGEdge.relation_type == "has_tag"
).first()
if not existing_edge:
edge = KGEdge(
source_id=content_node.id,
target_id=tag_node.id,
relation_type="has_tag",
weight=1.0
)
self.db.add(edge)
tag_node_ids = [n.id for n in tag_nodes]
current_tag_ids = content_node.properties_.get("tag_node_ids", []) if content_node.properties_ else []
content_node.properties_["tag_node_ids"] = list(set(current_tag_ids + tag_node_ids))
if len(tag_paths) >= 2:
self.create_tag_relations(tag_paths, user_id)
self.db.commit()
return tag_nodes
def tag_incremental_content(self, user_id: str, days: int = 1) -> dict:
"""
增量打标签 - 只对最近新增/更新的内容节点打标签
"""
from datetime import datetime, timedelta
cutoff_date = datetime.utcnow() - timedelta(days=days)
content_nodes = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.entity_type.in_(["conversation", "document", "chunk"]),
KGNode.updated_at >= cutoff_date
).all()
untagged = [
n for n in content_nodes
if not n.properties_.get("tag_node_ids")
]
tagged_count = 0
for node in untagged:
content = node.description or ""
try:
self.tag_content(content, user_id, node)
tagged_count += 1
except Exception as e:
pass
return {"total": len(untagged), "tagged": tagged_count}
def get_related_content(self, tag_node_ids: list[str], user_id: str, limit: int = 10) -> list[tuple[KGNode, float]]:
"""通过标签找相关内容"""
edges = self.db.query(KGEdge).filter(
KGEdge.target_id.in_(tag_node_ids),
KGEdge.relation_type == "has_tag"
).all()
content_weights: dict[str, float] = {}
for edge in edges:
content_weights[edge.source_id] = content_weights.get(edge.source_id, 0) + edge.weight
content_ids = list(content_weights.keys())
content_nodes = self.db.query(KGNode).filter(
KGNode.id.in_(content_ids),
KGNode.entity_type.in_(["conversation", "document", "chunk"])
).all()
return [(node, content_weights[node.id]) for node in content_nodes]

View File

@@ -0,0 +1,165 @@
import json
import logging
from datetime import date, datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.todo import DailyTodo, TodoSource
from app.models.task import Task, TaskStatus
from app.models.conversation import Conversation, Message
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
logger = logging.getLogger(__name__)
async def generate_daily_todos(user_id: str, db: AsyncSession) -> list[DailyTodo]:
"""
为用户生成今日待办:
1. 来自前一天未完成的看板任务最多20条
2. 来自前一天对话记录分析最多3条
"""
today = date.today()
yesterday = (today - timedelta(days=1)).isoformat()
todos: list[DailyTodo] = []
# 1. 从看板任务导入
kanban_todos = await _import_kanban_tasks(user_id, yesterday, db)
todos.extend(kanban_todos)
# 2. 从对话记录分析
chat_todos = await _analyze_chat_history(user_id, yesterday, db)
todos.extend(chat_todos)
return todos
async def _import_kanban_tasks(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
"""导入前一天创建的、未完成的看板任务"""
q = select(Task).where(
Task.user_id == user_id,
Task.status != TaskStatus.DONE,
).order_by(Task.created_at.desc()).limit(20)
tasks = (await db.execute(q)).scalars().all()
todos = []
for task in tasks:
todo = DailyTodo(
user_id=user_id,
title=task.title,
source=TodoSource.AI_KANBAN,
source_detail=f"看板:{task.title}",
source_ref_id=task.id,
todo_date=date.today().isoformat(),
)
db.add(todo)
todos.append(todo)
if todos:
await db.commit()
for todo in todos:
await db.refresh(todo)
return todos
async def _analyze_chat_history(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
"""分析前一天对话,提取待办事项"""
try:
# 查询前一天创建的对话
conv_q = select(Conversation).where(
Conversation.user_id == user_id,
).order_by(Conversation.created_at.desc()).limit(10)
convs = (await db.execute(conv_q)).scalars().all()
# 过滤出昨天的对话
yesterday_convs = []
for conv in convs:
created = conv.created_at
if hasattr(created, 'date'):
created_date = created.date() if hasattr(created, 'date') else created
else:
created_date = datetime.fromisoformat(str(created)).date()
if str(created_date) == date_str or (created + timedelta(hours=8)).strftime('%Y-%m-%d') == date_str:
yesterday_convs.append(conv)
if not yesterday_convs:
return []
# 收集消息内容限制2000字
messages_content = []
for conv in yesterday_convs:
msg_q = select(Message).where(
Message.conversation_id == conv.id
).order_by(Message.created_at.asc()).limit(50)
msgs = (await db.execute(msg_q)).scalars().all()
for msg in msgs:
if msg.content:
messages_content.append(f"[{msg.role}]: {msg.content[:500]}")
if not messages_content:
return []
full_text = "\n".join(messages_content)[:2000]
# 调用 LLM 分析
prompt = f"""你是一个任务规划助手。请分析以下对话记录,提取其中用户想要完成但尚未明确完成的事项。
要求:
- 最多提取 3 条
- 每条格式:{{"title": "事项描述50字以内", "reason": "来源说明60字以内"}}
- 只提取用户明确表达过需求但还未完成的事项
- 如果没有可提取的内容,返回空数组 []
对话记录:
{full_text}
返回 JSON 数组:"""
llm = get_llm()
response = await llm.invoke([
SystemMessage(content="你是一个任务规划助手。"),
HumanMessage(content=prompt),
])
content = response.content if hasattr(response, 'content') else str(response)
# 尝试解析 JSON
try:
# 提取 JSON 数组
start = content.find('[')
end = content.rfind(']') + 1
if start != -1 and end > start:
items = json.loads(content[start:end])
else:
items = []
except (json.JSONDecodeError, ValueError):
logger.warning(f"LLM 返回格式异常,跳过对话分析: {content[:200]}")
items = []
if not items:
return []
todos = []
for item in items[:3]:
todo = DailyTodo(
user_id=user_id,
title=item.get("title", "")[:500],
source=TodoSource.AI_CHAT,
source_detail=f"对话:{item.get('reason', '')[:60]}",
todo_date=date.today().isoformat(),
)
db.add(todo)
todos.append(todo)
if todos:
await db.commit()
for todo in todos:
await db.refresh(todo)
return todos
except Exception as e:
logger.error(f"对话分析失败: {e}")
return []