Add FastAPI backend with agent system
This commit is contained in:
54
backend/.env.example
Normal file
54
backend/.env.example
Normal file
@@ -0,0 +1,54 @@
|
||||
# =============================================
|
||||
# Jarvis 后端配置
|
||||
# 复制此文件为 .env 并填入实际值
|
||||
# =============================================
|
||||
|
||||
# === 应用基础 ===
|
||||
DEBUG=false
|
||||
SECRET_KEY=change-me-to-a-random-secret-key
|
||||
|
||||
# === LLM 配置 ===
|
||||
# 支持: openai / claude / deepseek / ollama / custom
|
||||
LLM_PROVIDER=openai
|
||||
|
||||
# OpenAI(默认)
|
||||
OPENAI_API_KEY=your-openai-api-key-here
|
||||
OPENAI_MODEL=gpt-4o
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# Claude(可选)
|
||||
# ANTHROPIC_API_KEY=your-anthropic-api-key-here
|
||||
# CLAUDE_MODEL=claude-sonnet-4-20250514
|
||||
|
||||
# DeepSeek(可选)
|
||||
# LLM_PROVIDER=deepseek
|
||||
# OPENAI_API_KEY=your-deepseek-api-key
|
||||
# OPENAI_BASE_URL=https://api.deepseek.com/v1
|
||||
|
||||
# Ollama 本地模型(可选)
|
||||
# LLM_PROVIDER=ollama
|
||||
# OLLAMA_BASE_URL=http://localhost:11434
|
||||
# OLLAMA_MODEL=llama3
|
||||
|
||||
# 自定义 OpenAI 兼容接口(可选)
|
||||
# LLM_PROVIDER=custom
|
||||
# OPENAI_API_KEY=your-api-key
|
||||
# OPENAI_BASE_URL=https://your-custom-endpoint/v1
|
||||
|
||||
# === NAS 部署路径 ===
|
||||
NAS_DATA_ROOT=/data/jarvis
|
||||
DATA_DIR=/data/jarvis/data
|
||||
CHROMA_PERSIST_DIR=/data/jarvis/chroma
|
||||
UPLOAD_DIR=/data/jarvis/uploads
|
||||
|
||||
|
||||
# === LangSmith 可观测性 ===
|
||||
# 启用 LangSmith 追踪(可选)
|
||||
LANGSMITH_TRACING=false
|
||||
LANGSMITH_API_KEY=your-langsmith-api-key
|
||||
LANGSMITH_PROJECT=jarvis-agent
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED=true
|
||||
DAILY_PLAN_TIME=00:00
|
||||
FORUM_SCAN_INTERVAL_MINUTES=30
|
||||
21
backend/Dockerfile
Normal file
21
backend/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装依赖
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir uv && \
|
||||
uv pip install --system --no-cache -r pyproject.toml
|
||||
|
||||
# 安装可选依赖
|
||||
RUN uv pip install --system --no-cache pymupdf python-docx
|
||||
|
||||
# 复制代码
|
||||
COPY app/ ./app/
|
||||
|
||||
# 创建数据目录
|
||||
RUN mkdir -p /data/jarvis/data /data/jarvis/chroma /data/jarvis/uploads
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
35
backend/README.md
Normal file
35
backend/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Jarvis Backend
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
uv sync
|
||||
```
|
||||
|
||||
### 2. 配置环境变量
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入 API Key
|
||||
```
|
||||
|
||||
### 3. 启动开发服务器
|
||||
|
||||
```bash
|
||||
uv run uvicorn app.main:app --reload --port 8000
|
||||
```
|
||||
|
||||
### 4. API 文档
|
||||
|
||||
启动后访问 http://localhost:8000/docs 查看交互式 API 文档。
|
||||
|
||||
## 环境变量
|
||||
|
||||
见 `.env.example`
|
||||
|
||||
## 数据库
|
||||
|
||||
SQLite 数据库位于 `./data/jarvis.db`,首次启动自动创建表。
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Jarvis Backend
|
||||
24
backend/app/agents/context.py
Normal file
24
backend/app/agents/context.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Agent 运行时上下文
|
||||
用于在工具调用链中传递 user_id 等上下文信息
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
_current_user_id: ContextVar[Optional[str]] = ContextVar("current_user_id", default=None)
|
||||
|
||||
|
||||
def set_current_user(user_id: str):
|
||||
"""设置当前用户ID(线程/协程安全)"""
|
||||
_current_user_id.set(user_id)
|
||||
|
||||
|
||||
def get_current_user() -> str:
|
||||
"""获取当前用户ID"""
|
||||
return _current_user_id.get() or "default"
|
||||
|
||||
|
||||
def clear_current_user():
|
||||
"""清除当前用户上下文"""
|
||||
_current_user_id.set(None)
|
||||
265
backend/app/agents/graph.py
Normal file
265
backend/app/agents/graph.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Jarvis LangGraph Agent 主图定义
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from app.agents.state import AgentState, AgentRole
|
||||
from app.agents.prompts import (
|
||||
MASTER_SYSTEM_PROMPT,
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
EXECUTOR_SYSTEM_PROMPT,
|
||||
LIBRARIAN_SYSTEM_PROMPT,
|
||||
ANALYST_SYSTEM_PROMPT,
|
||||
)
|
||||
from app.agents.tools import ALL_TOOLS
|
||||
from app.services.llm_service import get_llm
|
||||
|
||||
|
||||
def _msg_type(msg: BaseMessage) -> str:
|
||||
"""Get message type, handles both .type (new) and .role (old) attribute names."""
|
||||
return getattr(msg, "type", None) or getattr(msg, "role", "human")
|
||||
|
||||
|
||||
def _filter_user_messages(messages: list) -> list[BaseMessage]:
|
||||
return [m for m in messages if _msg_type(m) in ("human", "user")]
|
||||
|
||||
|
||||
# ===================== 节点定义 (async) =====================
|
||||
|
||||
async def master_node(state: AgentState) -> AgentState:
|
||||
"""主Agent节点: 理解用户意图,决定调用哪个子Agent"""
|
||||
llm = get_llm()
|
||||
messages: list[BaseMessage] = state["messages"]
|
||||
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
|
||||
|
||||
# 注入记忆上下文
|
||||
memory_ctx = state.get("memory_context")
|
||||
if memory_ctx:
|
||||
system_msgs.append(
|
||||
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
|
||||
)
|
||||
|
||||
response: AIMessage = await llm.invoke(system_msgs + messages)
|
||||
content = response.content.strip().lower()
|
||||
|
||||
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
|
||||
next_agent = AgentRole.LIBRARIAN
|
||||
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
|
||||
next_agent = AgentRole.PLANNER
|
||||
elif any(kw in content for kw in ["执行", "做", "操作", "创建", "更新"]):
|
||||
next_agent = AgentRole.EXECUTOR
|
||||
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
|
||||
next_agent = AgentRole.ANALYST
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
state["current_agent"] = next_agent
|
||||
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def planner_node(state: AgentState) -> AgentState:
|
||||
"""规划Agent节点: 制定计划,拆解任务步骤"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.invoke(
|
||||
[SystemMessage(content=PLANNER_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
plan_text = response.content
|
||||
steps = []
|
||||
for i, line in enumerate(plan_text.split("\n")):
|
||||
if line.strip() and (line[0].isdigit() or "- " in line):
|
||||
steps.append({"step": i + 1, "description": line.strip()})
|
||||
|
||||
state["plan"] = plan_text
|
||||
state["plan_steps"] = steps
|
||||
state["final_response"] = plan_text
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def executor_node(state: AgentState) -> AgentState:
|
||||
"""执行Agent节点: 调用工具执行具体任务"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def librarian_node(state: AgentState) -> AgentState:
|
||||
"""知识管理员节点: 管理知识库和知识图谱"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["knowledge_context"] = state.get("last_tool_result", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def analyst_node(state: AgentState) -> AgentState:
|
||||
"""分析师节点: 分析工作数据,生成报告"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["analysis_report"] = state.get("final_response", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
def route_agent(state: AgentState) -> str:
|
||||
"""路由函数: 决定下一个节点"""
|
||||
if state.get("final_response"):
|
||||
return END
|
||||
return state.get("current_agent", AgentRole.MASTER).value
|
||||
|
||||
|
||||
# ===================== 构建图 =====================
|
||||
|
||||
def create_agent_graph(callbacks: list | None = None):
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||||
graph.add_node(AgentRole.PLANNER.value, planner_node)
|
||||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
|
||||
graph.set_entry_point(AgentRole.MASTER.value)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
AgentRole.MASTER.value,
|
||||
route_agent,
|
||||
{
|
||||
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
END: END,
|
||||
}
|
||||
)
|
||||
|
||||
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
graph.add_edge(role.value, END)
|
||||
|
||||
return graph.compile(callbacks=callbacks)
|
||||
|
||||
|
||||
_agent_graph = None
|
||||
|
||||
|
||||
def get_agent_graph(callbacks: list | None = None):
|
||||
"""
|
||||
获取编译好的 Agent 图(单例缓存)。
|
||||
|
||||
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
|
||||
如需变更 Callbacks(如修改 LANGCHAIN_PROJECT),需重启服务。
|
||||
|
||||
Args:
|
||||
callbacks: 可选的额外 Callbacks,会与全局 LangSmith Callbacks 合并
|
||||
"""
|
||||
global _agent_graph
|
||||
if _agent_graph is None:
|
||||
from app.config_tracing import get_langsmith_callbacks
|
||||
langsmith_callbacks = get_langsmith_callbacks()
|
||||
all_callbacks = (callbacks or []) + langsmith_callbacks
|
||||
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
|
||||
return _agent_graph
|
||||
127
backend/app/agents/prompts.py
Normal file
127
backend/app/agents/prompts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Jarvis 多Agent系统的提示词定义
|
||||
"""
|
||||
|
||||
MASTER_SYSTEM_PROMPT = """你叫 Jarvis,是用户的私人AI助理。
|
||||
|
||||
你的职责是理解用户意图,并将任务分发给最合适的子Agent。
|
||||
|
||||
## 你的4个子Agent:
|
||||
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
|
||||
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
|
||||
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
|
||||
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
|
||||
|
||||
## 判断规则:
|
||||
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
|
||||
- 用户要计划、安排、拆解任务 -> 分发给 planner
|
||||
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
|
||||
- 用户要分析、统计、生成报告 -> 分发给 analyst
|
||||
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
|
||||
|
||||
## 响应格式:
|
||||
简短回复用户,告知你将调用哪个Agent处理。如果用户不需要任何子Agent,直接给出回答。
|
||||
|
||||
注意: 你是协调者,不需要亲自执行具体任务,让专业Agent去做。
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = """你是 Jarvis 的规划Agent,负责制定计划、拆解任务。
|
||||
|
||||
## 你的能力:
|
||||
- 分析复杂请求,拆解成可执行的步骤
|
||||
- 评估任务优先级
|
||||
- 估算时间安排
|
||||
- 制定执行顺序
|
||||
|
||||
## 工作流程:
|
||||
1. 理解用户的总目标
|
||||
2. 拆解成具体步骤
|
||||
3. 标注每步的优先级
|
||||
4. 给出清晰的执行计划
|
||||
|
||||
## 响应要求:
|
||||
- 用编号列表展示计划步骤
|
||||
- 每步清晰描述要做什么
|
||||
- 可以为每步指定优先级(P1/P2/P3)
|
||||
- 如果需要执行,先输出计划,然后用户确认后再执行
|
||||
"""
|
||||
|
||||
|
||||
EXECUTOR_SYSTEM_PROMPT = """你是 Jarvis 的执行Agent,负责执行具体任务。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- create_task: 创建新任务
|
||||
- update_task_status: 更新任务状态
|
||||
- get_tasks: 查看任务列表
|
||||
- create_forum_post: 在论坛发布帖子
|
||||
- get_forum_posts: 查看论坛帖子
|
||||
- scan_forum_for_instructions: 扫描论坛指令
|
||||
|
||||
## 工作流程:
|
||||
1. 理解用户要执行什么
|
||||
2. 调用相应工具
|
||||
3. 报告执行结果
|
||||
4. 询问用户是否需要下一步操作
|
||||
|
||||
## 响应要求:
|
||||
- 明确告知用户正在执行什么
|
||||
- 工具调用结果要格式化呈现
|
||||
- 如果执行成功,给出确认
|
||||
- 如果需要更多信息,明确告知用户
|
||||
"""
|
||||
|
||||
|
||||
LIBRARIAN_SYSTEM_PROMPT = """你是 Jarvis 的知识管理员,负责管理用户的私人知识库。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- search_knowledge: 搜索知识库,返回相关文档片段
|
||||
- get_knowledge_graph_context: 获取知识图谱上下文
|
||||
- build_knowledge_graph: 从文档构建知识图谱
|
||||
|
||||
## 你的职责:
|
||||
1. 理解用户关于知识的问题
|
||||
2. 搜索相关知识
|
||||
3. 综合多篇文档给出完整回答
|
||||
4. 帮助用户整理和理解知识
|
||||
|
||||
## 工作流程:
|
||||
1. 分析用户的知识查询
|
||||
2. 搜索相关文档
|
||||
3. 综合相关信息给出回答
|
||||
4. 如果有图谱关联,可以引用图谱中的关系
|
||||
|
||||
## 响应要求:
|
||||
- 回答要有文档依据
|
||||
- 引用时标注来源
|
||||
- 如果知识不足,诚实告知用户
|
||||
- 可以补充相关知识背景
|
||||
"""
|
||||
|
||||
|
||||
ANALYST_SYSTEM_PROMPT = """你是 Jarvis 的分析师,负责分析数据和工作状态。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- get_tasks: 获取任务列表,统计工作进度
|
||||
- get_forum_posts: 获取论坛帖子,分析讨论趋势
|
||||
- scan_forum_for_instructions: 检查待执行指令
|
||||
- search_knowledge: 结合知识进行分析
|
||||
|
||||
## 你的职责:
|
||||
1. 统计任务完成情况
|
||||
2. 分析工作进度和趋势
|
||||
3. 生成数据报告
|
||||
4. 识别潜在问题和风险
|
||||
|
||||
## 工作流程:
|
||||
1. 收集相关数据(任务、论坛、知识)
|
||||
2. 进行数据分析
|
||||
3. 生成结构化报告
|
||||
4. 给出建议
|
||||
|
||||
## 响应要求:
|
||||
- 用数据说话,有数字有结论
|
||||
- 报告结构清晰
|
||||
- 给出可行的改进建议
|
||||
- 识别需要关注的问题
|
||||
"""
|
||||
105
backend/app/agents/state.py
Normal file
105
backend/app/agents/state.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict, Annotated
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
PLANNER = "planner"
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
name: str
|
||||
role: AgentRole
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
tool: str
|
||||
args: dict
|
||||
result: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
role: str # "user" | "assistant"
|
||||
content: str
|
||||
agent: AgentRole | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
|
||||
return HumanMessage(content=turn.content)
|
||||
|
||||
|
||||
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
|
||||
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
|
||||
return ConversationTurn(
|
||||
role="user" if msg_type in ("human", "user") else "assistant",
|
||||
content=msg.content,
|
||||
agent=agent,
|
||||
model=getattr(msg, "model", None),
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, None]
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
|
||||
# Agent routing
|
||||
current_agent: AgentRole
|
||||
active_agents: list[AgentRole]
|
||||
|
||||
# Task tracking
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
|
||||
# Tool usage
|
||||
tool_calls: list[ToolCall]
|
||||
last_tool_result: str | None
|
||||
|
||||
# Knowledge context
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
|
||||
# Planning
|
||||
plan: str | None
|
||||
plan_steps: list[dict]
|
||||
|
||||
# Analysis
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Memory context (injected at start of each conversation)
|
||||
memory_context: str | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
return AgentState(
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
current_agent=AgentRole.MASTER,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
should_respond=True,
|
||||
memory_context=None,
|
||||
)
|
||||
22
backend/app/agents/tools/__init__.py
Normal file
22
backend/app/agents/tools/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
|
||||
ALL_TOOLS = [
|
||||
# 知识库工具
|
||||
search_knowledge,
|
||||
get_knowledge_graph_context,
|
||||
build_knowledge_graph,
|
||||
hybrid_search,
|
||||
# 任务工具
|
||||
get_tasks,
|
||||
create_task,
|
||||
update_task_status,
|
||||
# 论坛工具
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
scan_forum_for_instructions,
|
||||
]
|
||||
134
backend/app/agents/tools/forum.py
Normal file
134
backend/app/agents/tools/forum.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Agent 工具集 - 论坛相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def get_forum_posts(category: str | None = None, limit: int = 10) -> str:
|
||||
"""
|
||||
获取论坛帖子列表。
|
||||
|
||||
Args:
|
||||
category: 可选,筛选分类 (discussion/instruction/question)
|
||||
limit: 返回数量,默认10
|
||||
|
||||
Returns:
|
||||
帖子列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(ForumPost)
|
||||
.join(User, User.id == ForumPost.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if category:
|
||||
query = query.where(ForumPost.category == category)
|
||||
query = query.order_by(ForumPost.created_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
posts = result.scalars().all()
|
||||
if not posts:
|
||||
return "暂无帖子"
|
||||
lines = []
|
||||
for p in posts:
|
||||
exec_mark = " [已执行]" if p.is_executed else ""
|
||||
lines.append(
|
||||
f"- [{p.id[:8]}] [{p.category}] {p.title} | "
|
||||
f"{p.content[:50]}...{exec_mark}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取帖子失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_forum_post(title: str, content: str, category: str = "discussion") -> str:
|
||||
"""
|
||||
在论坛发布新帖子。
|
||||
|
||||
Args:
|
||||
title: 帖子标题
|
||||
content: 帖子内容
|
||||
category: 分类 (discussion/instruction/question),默认discussion
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
post = ForumPost(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
)
|
||||
db.add(post)
|
||||
await db.commit()
|
||||
await db.refresh(post)
|
||||
return f"帖子发布成功: [{post.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"发布帖子失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def scan_forum_for_instructions() -> str:
|
||||
"""
|
||||
扫描论坛中的指令类帖子,检查是否有待执行的指令。
|
||||
|
||||
Returns:
|
||||
待执行指令的列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _scan():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
result = await db.execute(
|
||||
select(ForumPost)
|
||||
.join(User, User.id == ForumPost.user_id)
|
||||
.where(ForumPost.user_id == uid)
|
||||
.where(ForumPost.category == "instruction")
|
||||
.where(ForumPost.is_executed == False)
|
||||
.order_by(ForumPost.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
posts = result.scalars().all()
|
||||
if not posts:
|
||||
return "暂无待执行的指令"
|
||||
lines = ["待执行的指令:"]
|
||||
for p in posts:
|
||||
lines.append(f"- [{p.id[:8]}] {p.title}\n 内容: {p.content[:100]}...")
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_scan())
|
||||
except Exception as e:
|
||||
return f"扫描论坛失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = ["get_forum_posts", "create_forum_post", "scan_forum_for_instructions"]
|
||||
159
backend/app/agents/tools/search.py
Normal file
159
backend/app/agents/tools/search.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Agent 工具集 - 知识库 & 图谱相关
|
||||
|
||||
这些工具在 LangChain ToolNode 中被调用。
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
"""在同步上下文中运行 async 代码"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def search_knowledge(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
搜索用户的私人知识库。根据查询返回最相关的文档片段,支持语义检索。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回结果数量,默认5条
|
||||
|
||||
Returns:
|
||||
包含相关文档片段和来源信息的格式化文本
|
||||
"""
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _search():
|
||||
async with async_session() as db:
|
||||
service = KnowledgeService(db, user_id=uid)
|
||||
results = await service.retrieve(query, user_id=uid, top_k=top_k)
|
||||
if not results:
|
||||
return "未找到相关知识。知识库可能为空,或尝试用其他关键词搜索。"
|
||||
texts = []
|
||||
for i, r in enumerate(results, 1):
|
||||
prev = f"\n上一段: {r.prev_chunk[:100]}..." if r.prev_chunk else ""
|
||||
next_ = f"\n下一段: {r.next_chunk[:100]}..." if r.next_chunk else ""
|
||||
texts.append(
|
||||
f"[{i}] 来源: {r.document_title}\n"
|
||||
f"相关度: {r.score:.2f}\n"
|
||||
f"{prev}{next_}\n"
|
||||
f"内容: {r.content[:300]}{'...' if len(r.content) > 300 else ''}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"知识检索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_knowledge_graph_context(entity: str | None = None) -> str:
|
||||
"""
|
||||
获取用户知识图谱的上下文信息。
|
||||
|
||||
Args:
|
||||
entity: 可选,指定要查询的实体名称。如果为空则返回整体图谱摘要。
|
||||
|
||||
Returns:
|
||||
知识图谱节点和关系的描述
|
||||
"""
|
||||
from app.services.graph_service import GraphService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
service = GraphService(db)
|
||||
if entity:
|
||||
return await service.get_entity_context(entity, uid)
|
||||
return await service.get_graph_summary(uid)
|
||||
|
||||
try:
|
||||
return _run_async(_get(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"图谱查询失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def build_knowledge_graph(document_ids: list[str] | None = None) -> str:
|
||||
"""
|
||||
从文档构建/更新知识图谱。
|
||||
|
||||
Args:
|
||||
document_ids: 可选,指定要处理的文档ID列表。如果为空则处理所有文档。
|
||||
|
||||
Returns:
|
||||
构建结果摘要
|
||||
"""
|
||||
from app.services.graph_service import GraphService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _build():
|
||||
async with async_session() as db:
|
||||
service = GraphService(db)
|
||||
await service.build_graph(user_id=uid, document_ids=document_ids)
|
||||
return "知识图谱构建完成"
|
||||
|
||||
try:
|
||||
return _run_async(_build(), timeout=120)
|
||||
except Exception as e:
|
||||
return f"图谱构建失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
混合搜索,结合向量语义检索和关键词匹配,返回最相关结果。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回结果数量,默认5条
|
||||
|
||||
Returns:
|
||||
混合检索结果
|
||||
"""
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _search():
|
||||
async with async_session() as db:
|
||||
service = KnowledgeService(db, user_id=uid)
|
||||
results = await service.hybrid_search(query, user_id=uid, top_k=top_k)
|
||||
if not results:
|
||||
return "未找到相关知识。"
|
||||
texts = []
|
||||
for i, r in enumerate(results, 1):
|
||||
texts.append(
|
||||
f"[{i}] {r.document_title} (相关度: {r.score:.2f})\n"
|
||||
f"{r.content[:200]}{'...' if len(r.content) > 200 else ''}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
]
|
||||
142
backend/app/agents/tools/task.py
Normal file
142
backend/app/agents/tools/task.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
|
||||
_executor = None
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def update_task_status(task_id: str, status: str) -> str:
|
||||
"""
|
||||
更新任务状态。
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
|
||||
try:
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = ["get_tasks", "create_task", "update_task_status"]
|
||||
69
backend/app/config.py
Normal file
69
backend/app/config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
# === 应用基础 ===
|
||||
APP_NAME: str = "Jarvis"
|
||||
APP_VERSION: str = "0.1.0"
|
||||
DEBUG: bool = False
|
||||
|
||||
# === 安全 ===
|
||||
SECRET_KEY: str = "change-me-in-production"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7
|
||||
|
||||
# === 数据库 ===
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///./data/jarvis.db"
|
||||
DATA_DIR: str = "./data"
|
||||
|
||||
# === ChromaDB ===
|
||||
CHROMA_PERSIST_DIR: str = "./data/chroma"
|
||||
|
||||
# === LLM 配置 ===
|
||||
# 支持: openai / claude / ollama / deepseek / custom
|
||||
LLM_PROVIDER: Literal["openai", "claude", "ollama", "deepseek", "custom"] = "openai"
|
||||
|
||||
# OpenAI (默认)
|
||||
OPENAI_API_KEY: str = ""
|
||||
OPENAI_MODEL: str = "gpt-4o"
|
||||
OPENAI_BASE_URL: str = "https://api.openai.com/v1"
|
||||
|
||||
# Claude
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
CLAUDE_MODEL: str = "claude-sonnet-4-20250514"
|
||||
CLAUDE_MAX_TOKENS: int = 8192
|
||||
|
||||
# Ollama (本地模型)
|
||||
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
||||
OLLAMA_MODEL: str = "llama3"
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED: bool = True
|
||||
DAILY_PLAN_TIME: str = "00:00"
|
||||
FORUM_SCAN_INTERVAL_MINUTES: int = 30
|
||||
|
||||
# === CORS ===
|
||||
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://localhost:3000"]
|
||||
|
||||
# === 文件上传 ===
|
||||
UPLOAD_DIR: str = "./data/uploads"
|
||||
MAX_UPLOAD_SIZE: int = 50 * 1024 * 1024
|
||||
|
||||
# === 向量化 ===
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
CHUNK_SIZE: int = 500
|
||||
CHUNK_OVERLAP: int = 50
|
||||
|
||||
# === LangSmith 可观测性 ===
|
||||
LANGSMITH_TRACING: bool = False
|
||||
LANGSMITH_API_KEY: str = ""
|
||||
LANGSMITH_PROJECT: str = "jarvis-agent"
|
||||
|
||||
# === NAS 部署 ===
|
||||
NAS_DATA_ROOT: str = "/data/jarvis"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
26
backend/app/config_tracing.py
Normal file
26
backend/app/config_tracing.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
LangSmith Tracing 配置
|
||||
提供 Callback 工厂函数,用于 LangGraph 追踪
|
||||
"""
|
||||
|
||||
from langchain_core.tracers import LangChainTracer
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def get_langsmith_callbacks() -> list:
|
||||
"""
|
||||
根据配置返回 LangSmith Callback 列表
|
||||
未启用时返回空列表
|
||||
"""
|
||||
if not settings.LANGSMITH_TRACING:
|
||||
return []
|
||||
|
||||
if not settings.LANGSMITH_API_KEY:
|
||||
return []
|
||||
|
||||
return [
|
||||
LangChainTracer(
|
||||
project_name=settings.LANGSMITH_PROJECT,
|
||||
)
|
||||
]
|
||||
35
backend/app/database.py
Normal file
35
backend/app/database.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from app.config import settings
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with async_session() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
72
backend/app/main.py
Normal file
72
backend/app/main.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.database import init_db
|
||||
from app.routers import (
|
||||
auth_router,
|
||||
conversation_router,
|
||||
document_router,
|
||||
task_router,
|
||||
forum_router,
|
||||
graph_router,
|
||||
agent_router,
|
||||
todo_router,
|
||||
settings_router,
|
||||
folder_router,
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.config import settings
|
||||
import os
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
await init_db()
|
||||
start_scheduler()
|
||||
yield
|
||||
# 关闭
|
||||
stop_scheduler()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(auth_router)
|
||||
app.include_router(conversation_router)
|
||||
app.include_router(document_router)
|
||||
app.include_router(task_router)
|
||||
app.include_router(forum_router)
|
||||
app.include_router(graph_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(todo_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(folder_router)
|
||||
app.include_router(scheduler_router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"version": settings.APP_VERSION,
|
||||
"llm_provider": settings.LLM_PROVIDER,
|
||||
"scheduler": get_scheduler_status(),
|
||||
}
|
||||
31
backend/app/models/__init__.py
Normal file
31
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from app.models.base import Base
|
||||
from app.models.user import User
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.task import Task, TaskHistory
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.models.agent import Agent, AgentMessage
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"Document",
|
||||
"DocumentChunk",
|
||||
"Task",
|
||||
"TaskHistory",
|
||||
"ForumPost",
|
||||
"ForumReply",
|
||||
"Agent",
|
||||
"AgentMessage",
|
||||
"Conversation",
|
||||
"Message",
|
||||
"KGNode",
|
||||
"KGEdge",
|
||||
"MemorySummary",
|
||||
"UserMemory",
|
||||
"DailyTodo",
|
||||
"TodoSource",
|
||||
]
|
||||
28
backend/app/models/agent.py
Normal file
28
backend/app/models/agent.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
__tablename__ = "agents"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
messages = relationship("AgentMessage", back_populates="agent", cascade="all, delete-orphan")
|
||||
replies = relationship("ForumReply", back_populates="agent")
|
||||
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
__tablename__ = "agent_messages"
|
||||
|
||||
agent_id = Column(String(36), ForeignKey("agents.id"), nullable=False, index=True)
|
||||
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
|
||||
role = Column(String(20), nullable=False) # system, user, assistant
|
||||
content = Column(Text, nullable=False)
|
||||
|
||||
agent = relationship("Agent", back_populates="messages")
|
||||
12
backend/app/models/base.py
Normal file
12
backend/app/models/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, DateTime
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
26
backend/app/models/conversation.py
Normal file
26
backend/app/models/conversation.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=True)
|
||||
message_count = Column(Integer, default=0)
|
||||
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
__tablename__ = "messages"
|
||||
|
||||
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
|
||||
role = Column(String(20), nullable=False) # user, assistant, system
|
||||
content = Column(Text, nullable=False)
|
||||
model = Column(String(100), nullable=True)
|
||||
tokens_used = Column(Integer, nullable=True)
|
||||
attachments = Column(JSON, nullable=True) # 新增: [{file_id, filename, file_type, file_size}]
|
||||
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
33
backend/app/models/document.py
Normal file
33
backend/app/models/document.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqlalchemy import Column, String, Integer, Text, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
__tablename__ = "documents"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=False)
|
||||
filename = Column(String(500), nullable=False)
|
||||
file_type = Column(String(50), nullable=False) # pdf, md, txt, docx
|
||||
file_size = Column(Integer, nullable=False)
|
||||
file_path = Column(String(1000), nullable=False)
|
||||
folder_id = Column(String(36), ForeignKey("folders.id"), nullable=True) # 新增
|
||||
summary = Column(Text, nullable=True)
|
||||
chunk_count = Column(Integer, default=0)
|
||||
is_indexed = Column(Boolean, default=False)
|
||||
|
||||
chunks = relationship("DocumentChunk", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
__tablename__ = "document_chunks"
|
||||
|
||||
document_id = Column(String(36), ForeignKey("documents.id"), nullable=False, index=True)
|
||||
chunk_index = Column(Integer, nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
metadata_ = Column(String(2000), nullable=True) # JSON 存储元数据
|
||||
chroma_collection = Column(String(255), nullable=True)
|
||||
chroma_id = Column(String(255), nullable=True)
|
||||
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
13
backend/app/models/folder.py
Normal file
13
backend/app/models/folder.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import Column, String, ForeignKey, UniqueConstraint
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Folder(BaseModel):
|
||||
__tablename__ = "folders"
|
||||
__table_args__ = (
|
||||
UniqueConstraint('user_id', 'parent_id', 'name', name='uq_user_parent_name'),
|
||||
)
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
parent_id = Column(String(36), ForeignKey("folders.id"), nullable=True)
|
||||
30
backend/app/models/forum.py
Normal file
30
backend/app/models/forum.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class ForumPost(BaseModel):
|
||||
__tablename__ = "forum_posts"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
category = Column(String(100), nullable=True) # instruction, discussion, question
|
||||
is_executed = Column(Boolean, default=False)
|
||||
execution_result = Column(Text, nullable=True)
|
||||
reply_count = Column(Integer, default=0)
|
||||
|
||||
replies = relationship("ForumReply", back_populates="post", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class ForumReply(BaseModel):
|
||||
__tablename__ = "forum_replies"
|
||||
|
||||
post_id = Column(String(36), ForeignKey("forum_posts.id"), nullable=False, index=True)
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=True)
|
||||
agent_id = Column(String(36), ForeignKey("agents.id"), nullable=True)
|
||||
content = Column(Text, nullable=False)
|
||||
is_ai_reply = Column(Boolean, default=False)
|
||||
|
||||
post = relationship("ForumPost", back_populates="replies")
|
||||
agent = relationship("Agent", back_populates="replies")
|
||||
32
backend/app/models/knowledge_graph.py
Normal file
32
backend/app/models/knowledge_graph.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, Float, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class KGNode(BaseModel):
|
||||
__tablename__ = "kg_nodes"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String(500), nullable=False)
|
||||
entity_type = Column(String(100), nullable=False) # person, concept, task, document, chunk, tag
|
||||
description = Column(Text, nullable=True)
|
||||
properties_ = Column(JSON, nullable=True) # 额外属性
|
||||
source_document_id = Column(String(36), ForeignKey("documents.id"), nullable=True)
|
||||
importance = Column(Float, default=0.5) # 重要性 0-1
|
||||
last_updated_by = Column(String(36), nullable=True) # 哪个 agent 更新过
|
||||
|
||||
outgoing_edges = relationship("KGEdge", foreign_keys="KGEdge.source_id", back_populates="source_node", cascade="all, delete-orphan")
|
||||
incoming_edges = relationship("KGEdge", foreign_keys="KGEdge.target_id", back_populates="target_node", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class KGEdge(BaseModel):
|
||||
__tablename__ = "kg_edges"
|
||||
|
||||
source_id = Column(String(36), ForeignKey("kg_nodes.id"), nullable=False, index=True)
|
||||
target_id = Column(String(36), ForeignKey("kg_nodes.id"), nullable=False, index=True)
|
||||
relation_type = Column(String(100), nullable=False) # related_to, part_of, caused_by, depends_on, etc.
|
||||
weight = Column(Float, default=0.5) # 关系强度 0-1
|
||||
properties_ = Column(JSON, nullable=True)
|
||||
|
||||
source_node = relationship("KGNode", foreign_keys=[source_id], back_populates="outgoing_edges")
|
||||
target_node = relationship("KGNode", foreign_keys=[target_id], back_populates="incoming_edges")
|
||||
35
backend/app/models/memory.py
Normal file
35
backend/app/models/memory.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, ForeignKey, Boolean, DateTime, Enum as SQLEnum
|
||||
from datetime import datetime
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
"""
|
||||
对话摘要 — 中期记忆
|
||||
当一段对话超过阈值轮数时,自动生成摘要存入此表
|
||||
"""
|
||||
__tablename__ = "memory_summaries"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
|
||||
summary_text = Column(Text, nullable=False) # 摘要内容
|
||||
turn_count = Column(Integer, default=0) # 摘要时累计轮数
|
||||
summary_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class UserMemory(BaseModel):
|
||||
"""
|
||||
用户画像记忆 — 长期记忆
|
||||
从对话中提取的用户事实、偏好、目标
|
||||
"""
|
||||
__tablename__ = "user_memories"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
memory_type = Column(String(50), nullable=False) # fact | preference | goal | habit | other
|
||||
content = Column(Text, nullable=False) # 记忆内容
|
||||
importance = Column(Integer, default=5) # 重要程度 1-10
|
||||
is_recalled = Column(Boolean, default=False) # 是否在当前对话中被召回
|
||||
recall_count = Column(Integer, default=0) # 被召回次数
|
||||
source_conversation_id = Column(String(36), nullable=True) # 来源对话
|
||||
extracted_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
last_recalled_at = Column(DateTime, nullable=True)
|
||||
45
backend/app/models/task.py
Normal file
45
backend/app/models/task.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from sqlalchemy import Column, String, Text, Integer, ForeignKey, DateTime, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class TaskStatus(str, PyEnum):
|
||||
TODO = "todo"
|
||||
IN_PROGRESS = "in_progress"
|
||||
DONE = "done"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TaskPriority(str, PyEnum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
__tablename__ = "tasks"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
status = Column(Enum(TaskStatus), default=TaskStatus.TODO, nullable=False, index=True)
|
||||
priority = Column(Enum(TaskPriority), default=TaskPriority.MEDIUM, nullable=False)
|
||||
due_date = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
tags = Column(String(1000), nullable=True) # JSON 数组
|
||||
|
||||
history = relationship("TaskHistory", back_populates="task", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class TaskHistory(BaseModel):
|
||||
__tablename__ = "task_histories"
|
||||
|
||||
task_id = Column(String(36), ForeignKey("tasks.id"), nullable=False, index=True)
|
||||
action = Column(String(100), nullable=False) # created, status_changed, updated, deleted
|
||||
old_value = Column(Text, nullable=True)
|
||||
new_value = Column(Text, nullable=True)
|
||||
|
||||
task = relationship("Task", back_populates="history")
|
||||
8
backend/app/models/test_folder.py
Normal file
8
backend/app/models/test_folder.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import pytest
|
||||
from app.models.folder import Folder
|
||||
|
||||
|
||||
def test_folder_model_creation():
|
||||
folder = Folder(user_id="test-user", name="Test Folder")
|
||||
assert folder.name == "Test Folder"
|
||||
assert folder.parent_id is None
|
||||
24
backend/app/models/todo.py
Normal file
24
backend/app/models/todo.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Enum
|
||||
from enum import Enum as PyEnum
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class TodoSource(str, PyEnum):
|
||||
AI_KANBAN = "ai_kanban"
|
||||
AI_CHAT = "ai_chat"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class DailyTodo(BaseModel):
|
||||
__tablename__ = "daily_todos"
|
||||
|
||||
user_id = Column(String(36), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=False)
|
||||
is_completed = Column(Boolean, default=False, nullable=False)
|
||||
source = Column(Enum(TodoSource), default=TodoSource.MANUAL, nullable=False)
|
||||
source_detail = Column(String(500), nullable=True)
|
||||
source_ref_id = Column(String(36), nullable=True)
|
||||
todo_date = Column(String(10), nullable=False) # YYYY-MM-DD
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
15
backend/app/models/user.py
Normal file
15
backend/app/models/user.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import Column, String, Boolean, JSON
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
# 用户级配置
|
||||
llm_config = Column(JSON, nullable=True) # LLM 模型配置
|
||||
scheduler_config = Column(JSON, nullable=True) # 定时任务配置
|
||||
10
backend/app/routers/__init__.py
Normal file
10
backend/app/routers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.routers.conversation import router as conversation_router
|
||||
from app.routers.document import router as document_router
|
||||
from app.routers.task import router as task_router
|
||||
from app.routers.forum import router as forum_router
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.routers.settings import router as settings_router
|
||||
from app.routers.folder import router as folder_router
|
||||
240
backend/app/routers/agent.py
Normal file
240
backend/app/routers/agent.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
||||
|
||||
router = APIRouter(prefix="/api/agents", tags=["Agent"])
|
||||
|
||||
# 运行时调用统计(内存中,非持久化)
|
||||
_agent_call_counts: dict[str, int] = {}
|
||||
_agent_current_tasks: dict[str, str | None] = {}
|
||||
_agent_statuses: dict[str, str] = {}
|
||||
|
||||
# 默认 Agent 角色列表
|
||||
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
|
||||
|
||||
|
||||
def record_agent_call(agent_id: str):
|
||||
_agent_call_counts[agent_id] = _agent_call_counts.get(agent_id, 0) + 1
|
||||
|
||||
|
||||
def set_agent_task(agent_id: str, task: str | None):
|
||||
_agent_current_tasks[agent_id] = task
|
||||
_agent_statuses[agent_id] = "active" if task else "idle"
|
||||
|
||||
|
||||
def set_agent_status(agent_id: str, status: str):
|
||||
_agent_statuses[agent_id] = status
|
||||
|
||||
|
||||
@router.get("", response_model=list[AgentOut])
|
||||
async def list_agents(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.is_active == True).order_by(Agent.role)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
# ———— 运行时统计(必须在 /{agent_id} 之前)————
|
||||
@router.get("/stats", response_model=list[AgentStats])
|
||||
async def get_agent_stats(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取各 Agent 的运行时统计(调用次数、当前任务、状态)
|
||||
"""
|
||||
stats = []
|
||||
for role in DEFAULT_AGENT_ROLES:
|
||||
stats.append(AgentStats(
|
||||
agent_id=role,
|
||||
call_count=_agent_call_counts.get(role, 0),
|
||||
current_task=_agent_current_tasks.get(role),
|
||||
status=_agent_statuses.get(role, "idle"),
|
||||
))
|
||||
return stats
|
||||
|
||||
|
||||
# ———— 配置管理(必须在 /{agent_id} 之前)————
|
||||
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
||||
async def get_agent_config(
|
||||
agent_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个 Agent 完整配置"""
|
||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
defaults = {
|
||||
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
||||
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
|
||||
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
||||
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
||||
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
||||
}
|
||||
if agent_id not in defaults:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
name, desc, prompt = defaults[agent_id]
|
||||
return AgentConfigOut(
|
||||
id=agent_id, name=name, role=agent_id,
|
||||
description=desc, system_prompt=prompt, enabled=True, is_active=True,
|
||||
)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
name=agent.name,
|
||||
role=agent.role,
|
||||
description=agent.description,
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
|
||||
async def update_agent_config(
|
||||
agent_id: str,
|
||||
data: AgentConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新 Agent 配置(名称、描述、提示词、启用状态)"""
|
||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
|
||||
if data.name is not None:
|
||||
agent.name = data.name
|
||||
if data.description is not None:
|
||||
agent.description = data.description
|
||||
if data.system_prompt is not None:
|
||||
agent.system_prompt = data.system_prompt
|
||||
if data.enabled is not None:
|
||||
agent.is_active = data.enabled
|
||||
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
name=agent.name,
|
||||
role=agent.role,
|
||||
description=agent.description,
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=AgentOut, status_code=201)
|
||||
async def create_agent(
|
||||
data: AgentCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
agent = Agent(
|
||||
name=data.name,
|
||||
role=data.role,
|
||||
description=data.description,
|
||||
system_prompt=data.system_prompt,
|
||||
)
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentOut)
|
||||
async def get_agent(
|
||||
agent_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
# ———— 配置管理 ————
|
||||
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
|
||||
async def get_agent_config(
|
||||
agent_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个 Agent 完整配置"""
|
||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
if not agent:
|
||||
# 如果数据库中没有,返回默认配置
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
defaults = {
|
||||
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
||||
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
|
||||
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
||||
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
||||
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
||||
}
|
||||
if agent_id not in defaults:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
name, desc, prompt = defaults[agent_id]
|
||||
return AgentConfigOut(
|
||||
id=agent_id, name=name, role=agent_id,
|
||||
description=desc, system_prompt=prompt, enabled=True, is_active=True,
|
||||
)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
name=agent.name,
|
||||
role=agent.role,
|
||||
description=agent.description,
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
|
||||
async def update_agent_config(
|
||||
agent_id: str,
|
||||
data: AgentConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新 Agent 配置(名称、描述、提示词、启用状态)"""
|
||||
result = await db.execute(select(Agent).where(Agent.role == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
|
||||
if data.name is not None:
|
||||
agent.name = data.name
|
||||
if data.description is not None:
|
||||
agent.description = data.description
|
||||
if data.system_prompt is not None:
|
||||
agent.system_prompt = data.system_prompt
|
||||
if data.enabled is not None:
|
||||
agent.is_active = data.enabled
|
||||
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
name=agent.name,
|
||||
role=agent.role,
|
||||
description=agent.description,
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
)
|
||||
83
backend/app/routers/auth.py
Normal file
83
backend/app/routers/auth.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
payload = decode_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
# 检查邮箱是否已存在
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
|
||||
# 创建用户
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
hashed_password=get_password_hash(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
||||
identifier = form_data.username.strip()
|
||||
# 支持:邮箱 / UUID / 用户名前缀
|
||||
user = None
|
||||
|
||||
# 1. 尝试 UUID
|
||||
import re
|
||||
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
|
||||
result = await db.execute(select(User).where(User.id == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 2. 尝试邮箱
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.email == identifier))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 3. 尝试用户名前缀(email@ 前面的部分)
|
||||
if not user and '@' not in identifier:
|
||||
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
217
backend/app/routers/conversation.py
Normal file
217
backend/app/routers/conversation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse
|
||||
from app.services.agent_service import AgentService
|
||||
import json
|
||||
|
||||
router = APIRouter(prefix="/api/conversations", tags=["对话"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[ConversationOut])
|
||||
async def list_conversations(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Conversation)
|
||||
.where(Conversation.user_id == current_user.id)
|
||||
.order_by(desc(Conversation.updated_at))
|
||||
.limit(50)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=ConversationOut, status_code=201)
|
||||
async def create_conversation(
|
||||
data: ConversationCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
conv = Conversation(user_id=current_user.id, title=data.title or "新对话")
|
||||
db.add(conv)
|
||||
await db.commit()
|
||||
await db.refresh(conv)
|
||||
return conv
|
||||
|
||||
|
||||
@router.get("/{conversation_id}/messages")
|
||||
async def get_conversation_messages(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取对话的所有消息"""
|
||||
result = await db.execute(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
|
||||
msgs = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
)
|
||||
return msgs.scalars().all()
|
||||
|
||||
|
||||
@router.delete("/{conversation_id}", status_code=204)
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
raise HTTPException(status_code=404, detail="对话不存在")
|
||||
await db.delete(conv)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
data: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
conv_id, msg_id, content = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
)
|
||||
|
||||
# 更新对话消息计数
|
||||
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv:
|
||||
conv.message_count += 2
|
||||
await db.commit()
|
||||
|
||||
return ChatResponse(
|
||||
conversation_id=conv_id,
|
||||
message_id=msg_id,
|
||||
content=content,
|
||||
agent_name="jarvis",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def chat_stream(
|
||||
data: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""流式对话"""
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
async def stream_generator():
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
)
|
||||
|
||||
# 先发送元数据
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
# 流式发送内容
|
||||
collected = ""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
if chunk:
|
||||
collected += chunk
|
||||
yield f"event: chunk\ndata: {json.dumps({'content': chunk})}\n\n"
|
||||
|
||||
# 更新数据库中的消息
|
||||
await agent_svc.save_response(msg_id, collected)
|
||||
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws/{user_id}/{conversation_id}")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""WebSocket 流式对话"""
|
||||
await websocket.accept()
|
||||
agent_svc = None
|
||||
|
||||
try:
|
||||
async for message in websocket.iter_text():
|
||||
data = json.loads(message)
|
||||
user_message = data.get("message", "")
|
||||
|
||||
# 每个连接创建新的数据库会话
|
||||
from app.database import async_session
|
||||
async with async_session() as db:
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
if conversation_id == "new":
|
||||
# 新对话
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=user_id,
|
||||
message=user_message,
|
||||
conversation_id=None,
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "metadata",
|
||||
"conversation_id": conv_id,
|
||||
"message_id": msg_id,
|
||||
})
|
||||
else:
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=user_id,
|
||||
message=user_message,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
collected = ""
|
||||
async for chunk in stream:
|
||||
if chunk:
|
||||
collected += chunk
|
||||
await websocket.send_json({"type": "chunk", "content": chunk})
|
||||
|
||||
await agent_svc.save_response(msg_id, collected)
|
||||
await websocket.send_json({"type": "done", "message_id": msg_id})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "error": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
154
backend/app/routers/document.py
Normal file
154
backend/app/routers/document.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.document_service import DocumentService
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
from dataclasses import asdict
|
||||
|
||||
router = APIRouter(prefix="/api/documents", tags=["知识库"])
|
||||
|
||||
|
||||
@router.get("", response_model=list)
|
||||
async def list_documents(
|
||||
folder_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Document).where(Document.user_id == current_user.id)
|
||||
if folder_id:
|
||||
query = query.where(Document.folder_id == folder_id)
|
||||
result = await db.execute(query.order_by(Document.created_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/upload", status_code=201)
|
||||
async def upload_document(
|
||||
background: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
folder_id: Optional[str] = Form(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""上传文档,自动分块并向量化"""
|
||||
doc_svc = DocumentService(db)
|
||||
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
|
||||
|
||||
# 后台索引到 ChromaDB
|
||||
def index_task():
|
||||
import asyncio
|
||||
from app.database import async_session
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
|
||||
async def _index():
|
||||
async with async_session() as session:
|
||||
kb_svc = KnowledgeService(session, user_id=current_user.id)
|
||||
await kb_svc.index_document(doc.id, user_id=current_user.id)
|
||||
|
||||
asyncio.run(_index())
|
||||
|
||||
background.add_task(index_task)
|
||||
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
|
||||
|
||||
|
||||
@router.get("/{document_id}")
|
||||
async def get_document(
|
||||
document_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return doc
|
||||
|
||||
|
||||
@router.get("/{document_id}/chunks")
|
||||
async def get_document_chunks(
|
||||
document_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取文档的所有 chunks"""
|
||||
result = await db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
|
||||
chunks_result = await db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
return chunks_result.scalars().all()
|
||||
|
||||
|
||||
@router.delete("/{document_id}", status_code=204)
|
||||
async def delete_document(
|
||||
document_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除文档"""
|
||||
doc_svc = DocumentService(db)
|
||||
await doc_svc.delete_document(current_user.id, document_id)
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def search_documents(
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
mode: str = "hybrid",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
搜索知识库
|
||||
|
||||
- query: 搜索查询
|
||||
- top_k: 返回数量,默认5
|
||||
- mode: hybrid(混合)/ semantic(语义)/ keyword(关键词)
|
||||
"""
|
||||
kb_svc = KnowledgeService(db, user_id=current_user.id)
|
||||
|
||||
if mode == "keyword":
|
||||
results = await kb_svc._keyword_search(query, current_user.id, top_k)
|
||||
elif mode == "semantic":
|
||||
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
|
||||
else:
|
||||
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
|
||||
|
||||
return [asdict(r) for r in results]
|
||||
|
||||
|
||||
@router.get("/{document_id}/content")
|
||||
async def get_document_content(
|
||||
document_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取文档的文本内容(用于AI理解)"""
|
||||
from app.services.document_service import DocumentService
|
||||
|
||||
doc_svc = DocumentService(db)
|
||||
content = await doc_svc.get_document_content(current_user.id, document_id)
|
||||
|
||||
if content is None:
|
||||
raise HTTPException(status_code=404, detail="文档不存在或无内容")
|
||||
|
||||
return {"content": content}
|
||||
143
backend/app/routers/folder.py
Normal file
143
backend/app/routers/folder.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from typing import List
|
||||
from app.database import get_db
|
||||
from app.models.folder import Folder
|
||||
from app.models.user import User
|
||||
from app.schemas.folder import FolderCreate, FolderUpdate, FolderOut, FolderTreeOut
|
||||
from app.services.auth_service import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/folders", tags=["文件夹"])
|
||||
|
||||
def build_folder_tree(folders: list[Folder], parent_id: str = None) -> List[FolderTreeOut]:
|
||||
"""递归构建文件夹树"""
|
||||
tree = []
|
||||
for folder in folders:
|
||||
if folder.parent_id == parent_id:
|
||||
children = build_folder_tree(folders, folder.id)
|
||||
tree.append(FolderTreeOut(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
parent_id=folder.parent_id,
|
||||
children=children
|
||||
))
|
||||
return tree
|
||||
|
||||
@router.get("", response_model=List[FolderTreeOut])
|
||||
async def get_folders(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取用户的完整文件夹树"""
|
||||
result = await db.execute(
|
||||
select(Folder).where(Folder.user_id == current_user.id)
|
||||
)
|
||||
folders = result.scalars().all()
|
||||
return build_folder_tree(list(folders))
|
||||
|
||||
@router.post("", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_folder(
|
||||
folder_data: FolderCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建文件夹"""
|
||||
# 验证父文件夹存在且属于当前用户
|
||||
if folder_data.parent_id:
|
||||
result = await db.execute(
|
||||
select(Folder).where(
|
||||
and_(Folder.id == folder_data.parent_id, Folder.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="父文件夹不存在")
|
||||
|
||||
# 检查同名文件夹
|
||||
result = await db.execute(
|
||||
select(Folder).where(
|
||||
and_(
|
||||
Folder.user_id == current_user.id,
|
||||
Folder.parent_id == folder_data.parent_id,
|
||||
Folder.name == folder_data.name
|
||||
)
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="同名文件夹已存在")
|
||||
|
||||
folder = Folder(
|
||||
user_id=current_user.id,
|
||||
name=folder_data.name,
|
||||
parent_id=folder_data.parent_id
|
||||
)
|
||||
db.add(folder)
|
||||
await db.commit()
|
||||
await db.refresh(folder)
|
||||
return folder
|
||||
|
||||
@router.put("/{folder_id}", response_model=FolderOut)
|
||||
async def rename_folder(
|
||||
folder_id: str,
|
||||
folder_data: FolderUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""重命名文件夹"""
|
||||
result = await db.execute(
|
||||
select(Folder).where(
|
||||
and_(Folder.id == folder_id, Folder.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="文件夹不存在")
|
||||
|
||||
folder.name = folder_data.name
|
||||
await db.commit()
|
||||
await db.refresh(folder)
|
||||
return folder
|
||||
|
||||
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_folder(
|
||||
folder_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除文件夹(级联删除文档)"""
|
||||
from app.models.document import Document
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
|
||||
result = await db.execute(
|
||||
select(Folder).where(
|
||||
and_(Folder.id == folder_id, Folder.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="文件夹不存在")
|
||||
|
||||
async def delete_recursive(fid: str):
|
||||
# 删除子文件夹(先递归)
|
||||
children = await db.execute(
|
||||
select(Folder).where(Folder.parent_id == fid)
|
||||
)
|
||||
for child in children.scalars():
|
||||
await delete_recursive(child.id)
|
||||
|
||||
# 删除文档
|
||||
docs = await db.execute(
|
||||
select(Document).where(Document.folder_id == fid)
|
||||
)
|
||||
for doc in docs.scalars():
|
||||
knowledge_service = KnowledgeService(db, current_user.id)
|
||||
await knowledge_service.delete_from_vectorstore(current_user.id, doc.id)
|
||||
await db.delete(doc)
|
||||
|
||||
# 删除文件夹本身
|
||||
folder_to_delete = await db.get(Folder, fid)
|
||||
if folder_to_delete:
|
||||
await db.delete(folder_to_delete)
|
||||
|
||||
await delete_recursive(folder_id)
|
||||
await db.commit()
|
||||
111
backend/app/routers/forum.py
Normal file
111
backend/app/routers/forum.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.forum import ForumPostCreate, ForumPostOut, ForumReplyCreate, ForumReplyOut
|
||||
|
||||
router = APIRouter(prefix="/api/forum", tags=["论坛"])
|
||||
|
||||
|
||||
@router.get("/posts", response_model=list[ForumPostOut])
|
||||
async def list_posts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(ForumPost)
|
||||
.where(ForumPost.user_id == current_user.id)
|
||||
.order_by(desc(ForumPost.created_at))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/posts", response_model=ForumPostOut, status_code=201)
|
||||
async def create_post(
|
||||
data: ForumPostCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
post = ForumPost(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
content=data.content,
|
||||
category=data.category,
|
||||
)
|
||||
db.add(post)
|
||||
await db.commit()
|
||||
await db.refresh(post)
|
||||
return post
|
||||
|
||||
|
||||
@router.get("/posts/{post_id}", response_model=ForumPostOut)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(ForumPost).where(ForumPost.id == post_id)
|
||||
)
|
||||
post = result.scalar_one_or_none()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/posts/{post_id}", status_code=204)
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(ForumPost).where(
|
||||
ForumPost.id == post_id,
|
||||
ForumPost.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
post = result.scalar_one_or_none()
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="帖子不存在")
|
||||
await db.delete(post)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.get("/posts/{post_id}/replies", response_model=list[ForumReplyOut])
|
||||
async def list_replies(
|
||||
post_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(ForumReply)
|
||||
.where(ForumReply.post_id == post_id)
|
||||
.order_by(ForumReply.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/posts/{post_id}/replies", response_model=ForumReplyOut, status_code=201)
|
||||
async def create_reply(
|
||||
post_id: str,
|
||||
data: ForumReplyCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
reply = ForumReply(
|
||||
post_id=post_id,
|
||||
user_id=current_user.id,
|
||||
content=data.content,
|
||||
)
|
||||
db.add(reply)
|
||||
# 更新帖子回复数
|
||||
result = await db.execute(select(ForumPost).where(ForumPost.id == post_id))
|
||||
post = result.scalar_one_or_none()
|
||||
if post:
|
||||
post.reply_count += 1
|
||||
await db.commit()
|
||||
await db.refresh(reply)
|
||||
return reply
|
||||
240
backend/app/routers/graph.py
Normal file
240
backend/app/routers/graph.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_
|
||||
from app.database import get_db
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.graph_service import GraphService
|
||||
from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest
|
||||
|
||||
router = APIRouter(prefix="/api/graph", tags=["知识图谱"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_graph(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取用户知识图谱"""
|
||||
nodes_result = await db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == current_user.id)
|
||||
.order_by(KGNode.importance.desc())
|
||||
.limit(200)
|
||||
)
|
||||
nodes = list(nodes_result.scalars().all())
|
||||
node_ids = {n.id for n in nodes}
|
||||
|
||||
edges_result = await db.execute(select(KGEdge))
|
||||
edges = [e for e in edges_result.scalars().all()
|
||||
if e.source_id in node_ids or e.target_id in node_ids]
|
||||
|
||||
return {
|
||||
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
||||
"description": n.description, "importance": n.importance,
|
||||
"created_at": str(n.created_at)}
|
||||
for n in nodes],
|
||||
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
||||
"relation": e.relation_type, "weight": e.weight}
|
||||
for e in edges],
|
||||
"stats": {
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/build")
|
||||
async def build_graph(
|
||||
background: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""构建/重建知识图谱(后台异步执行)"""
|
||||
|
||||
def build_task():
|
||||
import asyncio
|
||||
from app.database import async_session
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
async def _build():
|
||||
async with async_session() as session:
|
||||
svc = GraphService(session)
|
||||
await svc.build_graph(user_id=current_user.id, document_ids=None)
|
||||
|
||||
asyncio.run(_build())
|
||||
|
||||
background.add_task(build_task)
|
||||
return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"}
|
||||
|
||||
|
||||
@router.get("/entity/{entity}")
|
||||
async def get_entity_context(
|
||||
entity: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取某个实体的详细上下文"""
|
||||
svc = GraphService(db)
|
||||
context = await svc.get_entity_context(entity, current_user.id)
|
||||
return {"context": context}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_graph_summary(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取图谱摘要"""
|
||||
svc = GraphService(db)
|
||||
summary = await svc.get_graph_summary(current_user.id)
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
@router.get("/neighbors/{node_id}")
|
||||
async def get_node_neighbors(
|
||||
node_id: str,
|
||||
depth: int = 1,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取节点的邻居节点(用于可视化点击展开)"""
|
||||
svc = GraphService(db)
|
||||
data = await svc.get_neighbors(node_id, depth)
|
||||
return {
|
||||
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
|
||||
"description": n.description} for n in data["nodes"]],
|
||||
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
|
||||
"relation": e.relation_type} for e in data["edges"]],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/nodes/{node_id}", status_code=204)
|
||||
async def delete_node(
|
||||
node_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除图谱节点"""
|
||||
result = await db.execute(
|
||||
select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id)
|
||||
)
|
||||
node = result.scalar_one_or_none()
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="节点不存在")
|
||||
await db.delete(node)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/tags/extract", response_model=TagExtractResponse)
|
||||
async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""从内容中提取标签(不保存到节点)"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id)
|
||||
tags = []
|
||||
for t in tag_infos:
|
||||
short_name, level, parent_path = tag_service.parse_tag_path(t["path"])
|
||||
tags.append(TagProperties(
|
||||
tag_path=t["path"],
|
||||
short_name=short_name,
|
||||
level=level,
|
||||
parent_path=parent_path,
|
||||
description=t.get("description")
|
||||
))
|
||||
|
||||
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
||||
|
||||
|
||||
@router.post("/tags/content/{node_id}", response_model=TagExtractResponse)
|
||||
async def tag_content_node(
|
||||
node_id: str,
|
||||
request: TagExtractRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""为内容节点打标签"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
result = await db.execute(select(KGNode).where(KGNode.id == node_id))
|
||||
node = result.scalar_one_or_none()
|
||||
if not node:
|
||||
raise HTTPException(status_code=404, detail="Node not found")
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
tag_nodes = tag_service.tag_content(request.content, request.user_id, node)
|
||||
tags = []
|
||||
for n in tag_nodes:
|
||||
props = n.properties_ or {}
|
||||
tags.append(TagProperties(
|
||||
tag_path=props.get("tag_path", n.name),
|
||||
short_name=n.name,
|
||||
level=props.get("level", 1),
|
||||
parent_path=props.get("parent_path"),
|
||||
description=n.description
|
||||
))
|
||||
|
||||
return TagExtractResponse(tags=tags, tag_count=len(tags))
|
||||
|
||||
|
||||
@router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut])
|
||||
async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取标签的关联标签"""
|
||||
result = await db.execute(
|
||||
select(KGEdge).where(
|
||||
or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id),
|
||||
KGEdge.relation_type.in_(["related_to", "synonym_of"])
|
||||
)
|
||||
)
|
||||
edges = list(result.scalars().all())
|
||||
|
||||
related_ids = set()
|
||||
for e in edges:
|
||||
if e.source_id == tag_id:
|
||||
related_ids.add(e.target_id)
|
||||
else:
|
||||
related_ids.add(e.source_id)
|
||||
|
||||
if not related_ids:
|
||||
return []
|
||||
|
||||
result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids)))
|
||||
nodes = list(result.scalars().all())
|
||||
return nodes
|
||||
|
||||
|
||||
@router.get("/tags/{user_id}", response_model=list[KGNodeOut])
|
||||
async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""获取用户的所有标签"""
|
||||
result = await db.execute(
|
||||
select(KGNode).where(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type == "tag"
|
||||
).order_by(KGNode.properties_["level"].astext)
|
||||
)
|
||||
nodes = list(result.scalars().all())
|
||||
return nodes
|
||||
|
||||
|
||||
@router.post("/content/related", response_model=list[KGNodeOut])
|
||||
async def get_related_content(
|
||||
request: RelatedContentRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""通过标签找相关内容"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit)
|
||||
nodes = [r[0] for r in results]
|
||||
return nodes
|
||||
42
backend/app/routers/scheduler.py
Normal file
42
backend/app/routers/scheduler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.scheduler_service import (
|
||||
get_scheduler_status,
|
||||
scheduler,
|
||||
daily_task_analysis,
|
||||
forum_scan_task,
|
||||
graph_rebuild_task,
|
||||
tag_generation_task,
|
||||
)
|
||||
import logging
|
||||
|
||||
router = APIRouter(prefix="/api/scheduler", tags=["定时任务"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status(current_user: User = Depends(get_current_user)):
|
||||
"""获取调度器状态"""
|
||||
return get_scheduler_status()
|
||||
|
||||
|
||||
@router.post("/trigger/{job_id}")
|
||||
async def trigger_job(job_id: str, current_user: User = Depends(get_current_user)):
|
||||
"""手动触发某个定时任务"""
|
||||
job_map = {
|
||||
"daily_task_analysis": daily_task_analysis,
|
||||
"forum_scan": forum_scan_task,
|
||||
"graph_rebuild": graph_rebuild_task,
|
||||
"tag_generation": tag_generation_task,
|
||||
}
|
||||
|
||||
if job_id not in job_map:
|
||||
return {"error": f"未知任务: {job_id}"}
|
||||
|
||||
try:
|
||||
await job_map[job_id]()
|
||||
return {"status": "ok", "job": job_id, "message": "任务已触发执行"}
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发任务失败 {job_id}: {e}")
|
||||
return {"status": "error", "job": job_id, "error": str(e)}
|
||||
87
backend/app/routers/settings.py
Normal file
87
backend/app/routers/settings.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.settings import (
|
||||
SettingsOut, ProfileUpdateIn, LLMConfigIn, SchedulerConfigIn, LLMTestIn
|
||||
)
|
||||
from app.services.settings_service import (
|
||||
get_user_settings, update_user_profile, update_llm_config,
|
||||
update_scheduler_config, test_llm_connection
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["设置"])
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsOut)
|
||||
async def get_settings(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
settings = await get_user_settings(current_user.id, db)
|
||||
if not settings:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return settings
|
||||
|
||||
|
||||
@router.put("/profile")
|
||||
async def update_profile(
|
||||
data: ProfileUpdateIn,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
user = await update_user_profile(
|
||||
current_user.id, db,
|
||||
full_name=data.full_name,
|
||||
password=data.password,
|
||||
current_password=data.current_password
|
||||
)
|
||||
return user
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/llm")
|
||||
async def update_llm(
|
||||
data: LLMConfigIn,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
config = await update_llm_config(current_user.id, data.model_dump(exclude_none=True), db)
|
||||
return {"llm_config": config}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/llm/test")
|
||||
async def test_llm(
|
||||
data: LLMTestIn,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await test_llm_connection(
|
||||
provider=data.provider,
|
||||
model=data.model,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.put("/scheduler")
|
||||
async def update_scheduler(
|
||||
data: SchedulerConfigIn,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
config = await update_scheduler_config(
|
||||
current_user.id,
|
||||
data.model_dump(exclude_none=True),
|
||||
db
|
||||
)
|
||||
return {"scheduler_config": config}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
77
backend/app/routers/stats.py
Normal file
77
backend/app/routers/stats.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.stats import (
|
||||
SystemHealth,
|
||||
ConversationStats,
|
||||
KnowledgeStats,
|
||||
KanbanStats,
|
||||
CommunityStats,
|
||||
PersonalInsights,
|
||||
)
|
||||
from app.services.stats_service import StatsService
|
||||
|
||||
router = APIRouter(prefix="/api/stats", tags=["统计"])
|
||||
|
||||
|
||||
@router.get("/system", response_model=SystemHealth)
|
||||
async def get_system_health(db: Session = Depends(get_db)):
|
||||
"""获取系统健康指标"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_system_health()
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=ConversationStats)
|
||||
async def get_conversation_stats(
|
||||
days: int = 30,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取对话统计数据"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_conversation_stats(user_id=current_user.id, days=days)
|
||||
|
||||
|
||||
@router.get("/knowledge", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats(
|
||||
days: int = 30,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库统计数据"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_knowledge_stats(user_id=current_user.id, days=days)
|
||||
|
||||
|
||||
@router.get("/kanban", response_model=KanbanStats)
|
||||
async def get_kanban_stats(
|
||||
days: int = 30,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取看板统计数据"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_kanban_stats(user_id=current_user.id, days=days)
|
||||
|
||||
|
||||
@router.get("/community", response_model=CommunityStats)
|
||||
async def get_community_stats(
|
||||
days: int = 30,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取社区统计数据"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_community_stats(user_id=current_user.id, days=days)
|
||||
|
||||
|
||||
@router.get("/insights", response_model=PersonalInsights)
|
||||
async def get_personal_insights(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取个人洞察"""
|
||||
svc = StatsService(db)
|
||||
return svc.get_personal_insights(user_id=current_user.id)
|
||||
91
backend/app/routers/task.py
Normal file
91
backend/app/routers/task.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.task import TaskCreate, TaskUpdate, TaskOut
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["看板"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TaskOut])
|
||||
async def list_tasks(
|
||||
status: TaskStatus | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Task).where(Task.user_id == current_user.id)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(desc(Task.created_at))
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=TaskOut, status_code=201)
|
||||
async def create_task(
|
||||
data: TaskCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
import json
|
||||
task = Task(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
description=data.description,
|
||||
priority=data.priority,
|
||||
due_date=data.due_date,
|
||||
tags=json.dumps(data.tags) if data.tags else None,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return task
|
||||
|
||||
|
||||
@router.patch("/{task_id}", response_model=TaskOut)
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
data: TaskUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
import json
|
||||
result = await db.execute(
|
||||
select(Task).where(Task.id == task_id, Task.user_id == current_user.id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
for field, value in data.model_dump(exclude_none=True).items():
|
||||
if field == "tags":
|
||||
setattr(task, field, json.dumps(value))
|
||||
elif field == "status" and value == TaskStatus.DONE:
|
||||
from datetime import datetime
|
||||
task.completed_at = datetime.utcnow()
|
||||
setattr(task, field, value)
|
||||
else:
|
||||
setattr(task, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return task
|
||||
|
||||
|
||||
@router.delete("/{task_id}", status_code=204)
|
||||
async def delete_task(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Task).where(Task.id == task_id, Task.user_id == current_user.id)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
154
backend/app/routers/todo.py
Normal file
154
backend/app/routers/todo.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from datetime import date
|
||||
from app.database import get_db
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.todo import (
|
||||
TodoCreate, TodoUpdate, TodoOut, TodoListOut, TodoSummaryOut
|
||||
)
|
||||
from app.services.todo_service import generate_daily_todos
|
||||
|
||||
router = APIRouter(prefix="/api/todos", tags=["待办"])
|
||||
|
||||
|
||||
@router.get("", response_model=TodoListOut)
|
||||
async def list_todos(
|
||||
date_str: str = Query(default=None), # YYYY-MM-DD,默认当天
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date_str or date.today().isoformat()
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 查询总数
|
||||
count_q = select(func.count()).select_from(DailyTodo).where(
|
||||
DailyTodo.user_id == current_user.id,
|
||||
DailyTodo.todo_date == target_date,
|
||||
)
|
||||
total = (await db.execute(count_q)).scalar()
|
||||
|
||||
# 查询列表
|
||||
q = select(DailyTodo).where(
|
||||
DailyTodo.user_id == current_user.id,
|
||||
DailyTodo.todo_date == target_date,
|
||||
).order_by(DailyTodo.created_at.desc()).offset(offset).limit(page_size)
|
||||
|
||||
items = (await db.execute(q)).scalars().all()
|
||||
return TodoListOut(items=items, total=total, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.post("", response_model=TodoOut, status_code=201)
|
||||
async def create_todo(
|
||||
data: TodoCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
todo = DailyTodo(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date=date.today().isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return todo
|
||||
|
||||
|
||||
@router.patch("/{todo_id}", response_model=TodoOut)
|
||||
async def update_todo(
|
||||
todo_id: str,
|
||||
data: TodoUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(DailyTodo).where(DailyTodo.id == todo_id, DailyTodo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
|
||||
# 历史日期不允许修改
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可修改")
|
||||
|
||||
if data.title is not None:
|
||||
todo.title = data.title
|
||||
if data.is_completed is not None:
|
||||
from datetime import datetime
|
||||
todo.is_completed = data.is_completed
|
||||
todo.completed_at = datetime.utcnow() if data.is_completed else None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return todo
|
||||
|
||||
|
||||
@router.delete("/{todo_id}", status_code=204)
|
||||
async def delete_todo(
|
||||
todo_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(DailyTodo).where(DailyTodo.id == todo_id, DailyTodo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可删除")
|
||||
|
||||
await db.delete(todo)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/ai-generate", response_model=TodoListOut)
|
||||
async def ai_generate_todos(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.today().isoformat()
|
||||
|
||||
# 幂等检查:是否已有AI生成记录
|
||||
check_q = select(func.count()).select_from(DailyTodo).where(
|
||||
DailyTodo.user_id == current_user.id,
|
||||
DailyTodo.todo_date == target_date,
|
||||
DailyTodo.source.in_([TodoSource.AI_KANBAN, TodoSource.AI_CHAT]),
|
||||
)
|
||||
count = (await db.execute(check_q)).scalar()
|
||||
|
||||
if count > 0:
|
||||
# 已生成,返回现有记录
|
||||
q = select(DailyTodo).where(
|
||||
DailyTodo.user_id == current_user.id,
|
||||
DailyTodo.todo_date == target_date,
|
||||
).order_by(DailyTodo.created_at.desc())
|
||||
items = (await db.execute(q)).scalars().all()
|
||||
return TodoListOut(items=items, total=len(items), page=1, page_size=50)
|
||||
|
||||
# 执行AI生成
|
||||
todos = await generate_daily_todos(current_user.id, db)
|
||||
return TodoListOut(items=todos, total=len(todos), page=1, page_size=50)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=TodoSummaryOut)
|
||||
async def get_summary(
|
||||
date_str: str = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date_str or date.today().isoformat()
|
||||
q = select(DailyTodo).where(
|
||||
DailyTodo.user_id == current_user.id,
|
||||
DailyTodo.todo_date == target_date,
|
||||
)
|
||||
todos = (await db.execute(q)).scalars().all()
|
||||
completed = sum(1 for t in todos if t.is_completed)
|
||||
return TodoSummaryOut(date=target_date, total=len(todos), completed=completed, pending=len(todos) - completed)
|
||||
2
backend/app/schemas/__init__.py
Normal file
2
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Schemas package - import directly from submodules
|
||||
# e.g.: from app.schemas.auth import UserCreate
|
||||
55
backend/app/schemas/agent.py
Normal file
55
backend/app/schemas/agent.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentCreate(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
description: str | None = None
|
||||
system_prompt: str
|
||||
|
||||
|
||||
class AgentOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
description: str | None
|
||||
is_active: bool
|
||||
is_default: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AgentMessageOut(BaseModel):
|
||||
id: str
|
||||
agent_id: str
|
||||
conversation_id: str
|
||||
role: str
|
||||
content: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AgentStats(BaseModel):
|
||||
agent_id: str
|
||||
call_count: int
|
||||
current_task: str | None
|
||||
status: str # active | idle | disabled
|
||||
|
||||
|
||||
class AgentConfigUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
system_prompt: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class AgentConfigOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
description: str | None
|
||||
system_prompt: str
|
||||
enabled: bool
|
||||
is_active: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
26
backend/app/schemas/auth.py
Normal file
26
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
full_name: str | None = None
|
||||
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
full_name: str | None
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
45
backend/app/schemas/conversation.py
Normal file
45
backend/app/schemas/conversation.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class MessageOut(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
content: str
|
||||
model: str | None
|
||||
tokens_used: int | None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ConversationOut(BaseModel):
|
||||
id: str
|
||||
title: str | None
|
||||
message_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
conversation_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
file_ids: list[str] = [] # 新增
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
content: str
|
||||
agent_name: str
|
||||
40
backend/app/schemas/document.py
Normal file
40
backend/app/schemas/document.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DocumentOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
filename: str
|
||||
file_type: str
|
||||
file_size: int
|
||||
summary: str | None
|
||||
chunk_count: int
|
||||
is_indexed: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DocumentChunkOut(BaseModel):
|
||||
id: str
|
||||
chunk_index: int
|
||||
content: str
|
||||
metadata_: str | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
user_id: str
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
document_title: str
|
||||
content: str
|
||||
score: float
|
||||
metadata_: str | None
|
||||
39
backend/app/schemas/folder.py
Normal file
39
backend/app/schemas/folder.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
class FolderCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: Optional[str] = None
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
forbidden = '/\\*?:'
|
||||
for c in forbidden:
|
||||
if c in v:
|
||||
raise ValueError(f'Folder name cannot contain: {forbidden}')
|
||||
return v
|
||||
|
||||
class FolderUpdate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
class FolderOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
parent_id: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
class FolderTreeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
parent_id: Optional[str]
|
||||
children: List["FolderTreeOut"] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
# 递归模型需要 forward ref
|
||||
FolderTreeOut.model_rebuild()
|
||||
37
backend/app/schemas/forum.py
Normal file
37
backend/app/schemas/forum.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ForumPostCreate(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
category: str | None = "discussion"
|
||||
|
||||
|
||||
class ForumPostOut(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
title: str
|
||||
content: str
|
||||
category: str | None
|
||||
is_executed: bool
|
||||
execution_result: str | None
|
||||
reply_count: int
|
||||
created_at: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ForumReplyCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ForumReplyOut(BaseModel):
|
||||
id: str
|
||||
post_id: str
|
||||
user_id: str | None
|
||||
agent_id: str | None
|
||||
content: str
|
||||
is_ai_reply: bool
|
||||
created_at: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
66
backend/app/schemas/graph.py
Normal file
66
backend/app/schemas/graph.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class TagProperties(BaseModel):
|
||||
tag_path: str = Field(..., description="完整标签路径,如 '编程语言/Python/异步'")
|
||||
short_name: str = Field(..., description="显示名称,如 '异步'")
|
||||
level: int = Field(..., ge=1, description="层级深度,1为顶级")
|
||||
parent_path: str | None = Field(None, description="父路径,如 '编程语言/Python'")
|
||||
description: str | None = Field(None, description="AI生成的标签描述")
|
||||
color: str | None = Field(None, description="标签颜色,如 '#FF5733'")
|
||||
|
||||
|
||||
class KGNodeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
entity_type: str
|
||||
description: str | None
|
||||
properties_: dict | None
|
||||
importance: float
|
||||
created_at: str
|
||||
# 新增:如果是 tag 节点,返回 tag 属性
|
||||
tag_properties: TagProperties | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.entity_type == "tag" and self.properties_:
|
||||
self.tag_properties = TagProperties(**self.properties_)
|
||||
|
||||
|
||||
class KGEdgeOut(BaseModel):
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relation_type: str
|
||||
weight: float
|
||||
properties_: dict | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class GraphOut(BaseModel):
|
||||
nodes: list[KGNodeOut]
|
||||
edges: list[KGEdgeOut]
|
||||
|
||||
|
||||
class KGBuildRequest(BaseModel):
|
||||
user_id: str
|
||||
document_ids: list[str] | None = None # None = 全量重建
|
||||
|
||||
|
||||
class TagExtractRequest(BaseModel):
|
||||
content: str = Field(..., min_length=10)
|
||||
user_id: str
|
||||
|
||||
|
||||
class TagExtractResponse(BaseModel):
|
||||
tags: list[TagProperties]
|
||||
tag_count: int
|
||||
|
||||
|
||||
class RelatedContentRequest(BaseModel):
|
||||
tag_ids: list[str]
|
||||
user_id: str
|
||||
limit: int = 10
|
||||
58
backend/app/schemas/settings.py
Normal file
58
backend/app/schemas/settings.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Literal, List
|
||||
from app.schemas.auth import UserOut
|
||||
|
||||
# LLM Provider 类型
|
||||
LLMProviderType = Literal["openai", "claude", "ollama", "deepseek", "custom"]
|
||||
LLMType = Literal["chat", "vlm", "embedding", "rerank"]
|
||||
|
||||
|
||||
# 单个模型配置
|
||||
class LLMModelConfig(BaseModel):
|
||||
name: str = "" # 模型名称/别名,用于标识
|
||||
provider: LLMProviderType = "openai"
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
enabled: bool = True # 是否启用
|
||||
|
||||
|
||||
# LLM 配置输入 - 每种类型支持多个模型
|
||||
class LLMConfigIn(BaseModel):
|
||||
chat: Optional[List[LLMModelConfig]] = []
|
||||
vlm: Optional[List[LLMModelConfig]] = []
|
||||
embedding: Optional[List[LLMModelConfig]] = []
|
||||
rerank: Optional[List[LLMModelConfig]] = []
|
||||
|
||||
|
||||
# 定时任务配置
|
||||
class SchedulerConfigIn(BaseModel):
|
||||
daily_plan_time: Optional[str] = "08:00"
|
||||
forum_scan_interval_minutes: Optional[int] = 30
|
||||
todo_ai_generate_time: Optional[str] = "08:00"
|
||||
enabled: Optional[bool] = True
|
||||
|
||||
|
||||
# 用户资料更新
|
||||
class ProfileUpdateIn(BaseModel):
|
||||
full_name: Optional[str] = Field(None, min_length=2, max_length=50)
|
||||
password: Optional[str] = Field(None, min_length=8)
|
||||
current_password: Optional[str] = None
|
||||
|
||||
|
||||
# 完整设置输出
|
||||
class SettingsOut(BaseModel):
|
||||
profile: UserOut
|
||||
llm_config: Optional[dict] = None
|
||||
scheduler_config: Optional[dict] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# 测试 LLM 连接请求
|
||||
class LLMTestIn(BaseModel):
|
||||
type: LLMType
|
||||
provider: LLMProviderType
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
82
backend/app/schemas/stats.py
Normal file
82
backend/app/schemas/stats.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ===== System Health =====
|
||||
class SystemHealth(BaseModel):
|
||||
uptime_seconds: int
|
||||
cpu_percent: float
|
||||
memory_used_mb: float
|
||||
memory_total_mb: float
|
||||
memory_percent: float
|
||||
disk_used_gb: float
|
||||
disk_total_gb: float
|
||||
disk_percent: float
|
||||
active_users_24h: int
|
||||
|
||||
|
||||
# ===== Daily Stats Base =====
|
||||
class DailyStatItem(BaseModel):
|
||||
date: str
|
||||
count: int
|
||||
|
||||
|
||||
class DailyTokenStatItem(BaseModel):
|
||||
date: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
# ===== Conversation Stats =====
|
||||
class ConversationStats(BaseModel):
|
||||
daily_conversations: list[DailyStatItem]
|
||||
daily_messages: list[DailyStatItem]
|
||||
daily_input_tokens: list[DailyTokenStatItem]
|
||||
daily_output_tokens: list[DailyTokenStatItem]
|
||||
totals: dict
|
||||
|
||||
|
||||
# ===== Knowledge Stats =====
|
||||
class KnowledgeStats(BaseModel):
|
||||
daily_new_tags: list[DailyStatItem]
|
||||
daily_documents: list[DailyStatItem]
|
||||
daily_knowledge_queries: list[DailyStatItem]
|
||||
daily_tag_relations: list[DailyStatItem]
|
||||
totals: dict
|
||||
|
||||
|
||||
# ===== Kanban Stats =====
|
||||
class KanbanStats(BaseModel):
|
||||
daily_new_tasks: list[DailyStatItem]
|
||||
daily_completed_tasks: list[DailyStatItem]
|
||||
daily_completion_rate: list[DailyStatItem]
|
||||
current_pending_tasks: int
|
||||
totals: dict
|
||||
|
||||
|
||||
# ===== Community Stats =====
|
||||
class CommunityStats(BaseModel):
|
||||
daily_posts: list[DailyStatItem]
|
||||
daily_replies: list[DailyStatItem]
|
||||
daily_ai_executions: list[DailyStatItem]
|
||||
daily_agent_calls: list[DailyStatItem]
|
||||
totals: dict
|
||||
|
||||
|
||||
# ===== Personal Insights =====
|
||||
class HourlyActivity(BaseModel):
|
||||
hour: int
|
||||
count: int
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
tag_path: str
|
||||
usage_count: int
|
||||
|
||||
|
||||
class PersonalInsights(BaseModel):
|
||||
hourly_activity: list[HourlyActivity]
|
||||
top_tags: list[TagUsage]
|
||||
token_trend_percent: float
|
||||
this_month_tokens: int
|
||||
last_month_tokens: int
|
||||
39
backend/app/schemas/task.py
Normal file
39
backend/app/schemas/task.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from app.models.task import TaskStatus, TaskPriority
|
||||
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
title: str
|
||||
description: str | None = None
|
||||
priority: TaskPriority = TaskPriority.MEDIUM
|
||||
due_date: datetime | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
status: TaskStatus | None = None
|
||||
priority: TaskPriority | None = None
|
||||
due_date: datetime | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class TaskOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
description: str | None
|
||||
status: TaskStatus
|
||||
priority: TaskPriority
|
||||
due_date: datetime | None
|
||||
completed_at: datetime | None
|
||||
tags: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DailyPlanRequest(BaseModel):
|
||||
user_id: str
|
||||
40
backend/app/schemas/todo.py
Normal file
40
backend/app/schemas/todo.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from app.models.todo import TodoSource
|
||||
|
||||
|
||||
class TodoCreate(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
is_completed: bool | None = None
|
||||
|
||||
|
||||
class TodoOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
is_completed: bool
|
||||
source: TodoSource
|
||||
source_detail: str | None
|
||||
todo_date: str
|
||||
completed_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class TodoListOut(BaseModel):
|
||||
items: list[TodoOut]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class TodoSummaryOut(BaseModel):
|
||||
date: str
|
||||
total: int
|
||||
completed: int
|
||||
pending: int
|
||||
2
backend/app/services/__init__.py
Normal file
2
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Services - import specific classes directly when needed
|
||||
# e.g.: from app.services.agent_service import AgentService
|
||||
261
backend/app/services/agent_service.py
Normal file
261
backend/app/services/agent_service.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Jarvis Agent 服务层
|
||||
负责 LangGraph Agent 的调用、流式输出、对话历史管理
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""对话 Agent 服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> tuple[str, str, AsyncGenerator[str, None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_stream)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model="jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
async def run_agent():
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
}
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
kind = event.get("event")
|
||||
if kind == "on_chat_model_end":
|
||||
content = event.get("data", {}).get("output", {})
|
||||
if isinstance(content, dict):
|
||||
content = content.get("content", "")
|
||||
if content:
|
||||
delta = content[len(collected):]
|
||||
if delta:
|
||||
collected += delta
|
||||
yield delta
|
||||
elif kind == "on_tool_end":
|
||||
name = event.get("name", "")
|
||||
yield f"\n[工具执行: {name}]\n"
|
||||
except Exception as e:
|
||||
yield f"\n执行出错: {str(e)}"
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要和记忆提取(不阻塞响应)
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最终更新数据库中的消息内容
|
||||
if collected:
|
||||
try:
|
||||
result2 = await self.db.execute(
|
||||
select(Message).where(Message.id == assistant_msg.id)
|
||||
)
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return conversation_id, assistant_msg.id, run_agent()
|
||||
|
||||
async def chat_simple(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
# 将文件上下文添加到消息
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
}
|
||||
|
||||
try:
|
||||
result_state = await graph.ainvoke(langgraph_state)
|
||||
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
|
||||
except Exception as e:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
finally:
|
||||
clear_current_user()
|
||||
# 异步触发自动摘要
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存助手消息
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
model="jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content
|
||||
29
backend/app/services/auth_service.py
Normal file
29
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from jose import jwt, JWTError
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
to_encode.update({"exp": expire})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict | None:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
256
backend/app/services/document_service.py
Normal file
256
backend/app/services/document_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
文档服务 - 上传、解析、分块、存储
|
||||
支持多种文档格式 + LlamaIndex 智能分块
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from fastapi import UploadFile
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
import os
|
||||
import aiofiles
|
||||
import uuid
|
||||
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
|
||||
|
||||
|
||||
class DocumentService:
|
||||
def __init__(self, db: AsyncSession, user_id: str = None):
|
||||
self.db = db
|
||||
self.user_id = user_id
|
||||
|
||||
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError(f"不支持的文件类型: {ext}")
|
||||
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
file_id = str(uuid.uuid4())
|
||||
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
if file_size > settings.MAX_UPLOAD_SIZE:
|
||||
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
text_content = await self._extract_text(file_path, ext)
|
||||
|
||||
doc = Document(
|
||||
user_id=user_id,
|
||||
title=file.filename.rsplit('.', 1)[0],
|
||||
filename=file.filename,
|
||||
file_type=ext[1:],
|
||||
file_size=file_size,
|
||||
file_path=file_path,
|
||||
summary=text_content[:500] if len(text_content) > 500 else text_content,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
self.db.add(doc)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(doc)
|
||||
|
||||
chunks = self._chunk_text(text_content)
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
chunk = DocumentChunk(
|
||||
document_id=doc.id,
|
||||
chunk_index=i,
|
||||
content=chunk_text,
|
||||
)
|
||||
self.db.add(chunk)
|
||||
doc.chunk_count = len(chunks)
|
||||
await self.db.commit()
|
||||
|
||||
return doc
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
folders = await self.db.execute(
|
||||
select(Folder).where(Folder.user_id == self.user_id)
|
||||
)
|
||||
folder_map = {f.id: f for f in folders.scalars().all()}
|
||||
|
||||
path_parts = []
|
||||
current_id = folder_id
|
||||
while current_id:
|
||||
folder = folder_map.get(current_id)
|
||||
if not folder:
|
||||
break
|
||||
path_parts.insert(0, folder.name)
|
||||
current_id = folder.parent_id
|
||||
|
||||
return "/" + "/".join(path_parts) if path_parts else None
|
||||
|
||||
async def delete_document(self, user_id: str, document_id: str):
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
raise ValueError("文档不存在")
|
||||
|
||||
if os.path.exists(doc.file_path):
|
||||
os.remove(doc.file_path)
|
||||
|
||||
await self.db.delete(doc)
|
||||
await self.db.commit()
|
||||
|
||||
async def _extract_text(self, file_path: str, ext: str) -> str:
|
||||
if ext == ".pdf":
|
||||
try:
|
||||
import pymupdf
|
||||
doc = pymupdf.open(file_path)
|
||||
text = "".join(page.get_text() for page in doc)
|
||||
doc.close()
|
||||
return text
|
||||
except ImportError:
|
||||
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
|
||||
|
||||
elif ext in (".md", ".txt"):
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
return await f.read()
|
||||
|
||||
elif ext in (".docx", ".doc"):
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
doc = DocxDocument(file_path)
|
||||
return "\n".join([p.text for p in doc.paragraphs])
|
||||
except ImportError:
|
||||
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
|
||||
|
||||
return "[暂不支持此格式]"
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""
|
||||
智能文档分块策略
|
||||
1. 先按 Markdown 标题层级(H1/H2/H3)切分
|
||||
2. 每个大段落内部按固定长度切分
|
||||
3. 保留上下文(prev_summary / next_summary)
|
||||
"""
|
||||
import re
|
||||
|
||||
chunks = []
|
||||
|
||||
# 策略1: Markdown 标题切分(优先)
|
||||
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
|
||||
headers = list(header_pattern.finditer(text))
|
||||
|
||||
if headers:
|
||||
# 按标题段落切分
|
||||
for i, match in enumerate(headers):
|
||||
start = match.start()
|
||||
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
|
||||
section = text[start:end].strip()
|
||||
if len(section) > settings.CHUNK_SIZE:
|
||||
# 大段落内部再切分
|
||||
sub_chunks = self._split_large_chunk(section, match.group(2))
|
||||
chunks.extend(sub_chunks)
|
||||
elif section:
|
||||
chunks.append(section)
|
||||
else:
|
||||
# 策略2: 按段落切分
|
||||
chunks = self._chunk_by_paragraphs(text)
|
||||
|
||||
# 过滤空 chunk
|
||||
chunks = [c.strip() for c in chunks if c.strip()]
|
||||
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
|
||||
|
||||
def _chunk_by_paragraphs(self, text: str) -> list[str]:
|
||||
"""按段落分块,带上下文"""
|
||||
paragraphs = text.split("\n\n")
|
||||
chunks = []
|
||||
current = ""
|
||||
prev_summary = ""
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
if len(current) + len(para) < settings.CHUNK_SIZE:
|
||||
current += "\n\n" + para
|
||||
else:
|
||||
if current:
|
||||
# 添加上下文摘要
|
||||
enriched = current.strip()
|
||||
chunks.append(enriched)
|
||||
current = para
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_large_chunk(self, text: str, title: str) -> list[str]:
|
||||
"""将大段落拆分为固定大小的子块"""
|
||||
chunks = []
|
||||
sentences = text.split("。")
|
||||
current = title + "\n\n"
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
full_sentence = sentence if sentence.endswith("。") else sentence + "。"
|
||||
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
|
||||
current += full_sentence + " "
|
||||
else:
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
current = title + "\n\n" + full_sentence + " "
|
||||
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
|
||||
result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
|
||||
"""获取文档的文本内容"""
|
||||
import os
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Document).where(
|
||||
Document.id == document_id,
|
||||
Document.user_id == user_id,
|
||||
)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
file_path = doc.file_path
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
# 根据文件类型读取内容
|
||||
ext = doc.filename.split('.')[-1].lower()
|
||||
|
||||
try:
|
||||
if ext == 'txt':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif ext == 'md':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif ext == 'pdf':
|
||||
# 简单文本提取(生产环境应使用专业库)
|
||||
return f"[PDF文档] {doc.filename}"
|
||||
else:
|
||||
return f"[文档] {doc.filename}"
|
||||
except Exception:
|
||||
return f"[文档] {doc.filename}"
|
||||
342
backend/app/services/graph_service.py
Normal file
342
backend/app/services/graph_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
知识图谱服务 - 实体识别、关系抽取、图谱查询
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
|
||||
|
||||
实体类型:
|
||||
- person(人物):人名、角色
|
||||
- concept(概念):抽象概念、理论、方法
|
||||
- topic(主题):话题、领域
|
||||
- task(任务):要做的事情
|
||||
- event(事件):发生的事件
|
||||
- document(文档):文件、资料
|
||||
|
||||
关系类型:
|
||||
- related_to(相关于)
|
||||
- part_of(隶属于)
|
||||
- caused_by(由...导致)
|
||||
- depends_on(取决于)
|
||||
- contains(包含)
|
||||
- located_in(位于)
|
||||
- works_on(从事)
|
||||
|
||||
要求:
|
||||
1. 识别文本中所有有意义的实体(不超过10个)
|
||||
2. 识别实体之间的关系(每个实体至少一条关系)
|
||||
3. 每个实体要有 name、type、description(1-2句话)
|
||||
4. 关系要有 source、target、relation_type
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请只返回 JSON,不要有其他内容:
|
||||
{{
|
||||
"entities": [
|
||||
{{"name": "实体名", "type": "类型", "description": "描述"}}
|
||||
],
|
||||
"relations": [
|
||||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
|
||||
|
||||
用户问题:{question}
|
||||
|
||||
已知实体:
|
||||
{entities}
|
||||
|
||||
请推断这些实体之间的隐含关系,返回 JSON:
|
||||
{{
|
||||
"inferred_relations": [
|
||||
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
|
||||
]
|
||||
}}
|
||||
|
||||
关系类型:related_to / part_of / caused_by / depends_on / contains / works_on / located_in
|
||||
confidence: 0.0-1.0,表示推断置信度
|
||||
"""
|
||||
|
||||
|
||||
class GraphService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.llm = get_llm()
|
||||
|
||||
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
|
||||
"""
|
||||
从文档构建/更新知识图谱
|
||||
- 遍历所有 chunk
|
||||
- LLM 实体识别
|
||||
- LLM 关系抽取
|
||||
- 去重合并
|
||||
"""
|
||||
query = (
|
||||
select(DocumentChunk)
|
||||
.join(Document)
|
||||
.where(Document.user_id == user_id)
|
||||
.where(Document.is_indexed == True)
|
||||
)
|
||||
if document_ids:
|
||||
query = query.where(DocumentChunk.document_id.in_(document_ids))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
chunks = list(result.scalars().all())
|
||||
|
||||
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
|
||||
|
||||
for chunk in chunks:
|
||||
try:
|
||||
await self._process_chunk(chunk, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"[GraphService] 图谱构建完成")
|
||||
|
||||
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
|
||||
"""处理单个 chunk,提取实体和关系"""
|
||||
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
|
||||
response = await self.llm.invoke([HumanMessage(content=prompt)])
|
||||
|
||||
try:
|
||||
data = json.loads(response.content)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
entities = data.get("entities", [])
|
||||
relations = data.get("relations", [])
|
||||
|
||||
if not entities:
|
||||
return
|
||||
|
||||
# 先查找已存在的节点
|
||||
existing_nodes = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
result = await self.db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.where(KGNode.name == name)
|
||||
)
|
||||
node = result.scalar_one_or_none()
|
||||
if node:
|
||||
existing_nodes[name] = node
|
||||
|
||||
# 插入新节点
|
||||
entity_map = {}
|
||||
for entity_data in entities:
|
||||
name = entity_data["name"]
|
||||
if name in existing_nodes:
|
||||
entity_map[name] = existing_nodes[name].id
|
||||
else:
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
entity_type=entity_data["type"],
|
||||
description=entity_data.get("description", ""),
|
||||
source_document_id=chunk.document_id,
|
||||
)
|
||||
self.db.add(node)
|
||||
await self.db.flush()
|
||||
entity_map[name] = node.id
|
||||
|
||||
# 插入关系(去重)
|
||||
for rel in relations:
|
||||
src, tgt = rel["source"], rel["target"]
|
||||
if src not in entity_map or tgt not in entity_map:
|
||||
continue
|
||||
|
||||
# 检查关系是否已存在
|
||||
result = await self.db.execute(
|
||||
select(KGEdge).where(
|
||||
KGEdge.source_id == entity_map[src],
|
||||
KGEdge.target_id == entity_map[tgt],
|
||||
KGEdge.relation_type == rel["relation_type"],
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if not existing:
|
||||
edge = KGEdge(
|
||||
source_id=entity_map[src],
|
||||
target_id=entity_map[tgt],
|
||||
relation_type=rel["relation_type"],
|
||||
)
|
||||
self.db.add(edge)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
async def get_graph_summary(self, user_id: str) -> str:
|
||||
"""获取用户图谱的整体摘要"""
|
||||
# 统计
|
||||
node_count = await self.db.execute(
|
||||
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
|
||||
)
|
||||
edge_count = await self.db.execute(
|
||||
select(func.count()).select_from(KGEdge)
|
||||
.select_from(KGEdge)
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGNode.user_id == user_id)
|
||||
)
|
||||
|
||||
node_total = node_count.scalar() or 0
|
||||
edge_total = edge_count.scalar() or 0
|
||||
|
||||
if node_total == 0:
|
||||
return "知识图谱为空,请先上传文档并构建图谱。"
|
||||
|
||||
# 按类型统计节点
|
||||
type_result = await self.db.execute(
|
||||
select(KGNode.entity_type, func.count())
|
||||
.where(KGNode.user_id == user_id)
|
||||
.group_by(KGNode.entity_type)
|
||||
)
|
||||
type_stats = type_result.all()
|
||||
|
||||
# 关系类型统计
|
||||
rel_result = await self.db.execute(
|
||||
select(KGEdge.relation_type, func.count())
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.group_by(KGEdge.relation_type)
|
||||
)
|
||||
rel_stats = rel_result.all()
|
||||
|
||||
# 列出最重要的节点(按 importance)
|
||||
top_nodes_result = await self.db.execute(
|
||||
select(KGNode)
|
||||
.where(KGNode.user_id == user_id)
|
||||
.order_by(KGNode.importance.desc())
|
||||
.limit(10)
|
||||
)
|
||||
top_nodes = list(top_nodes_result.scalars().all())
|
||||
|
||||
lines = [
|
||||
f"## 知识图谱摘要",
|
||||
f"",
|
||||
f"**总节点数**: {node_total}",
|
||||
f"**总关系数**: {edge_total}",
|
||||
f"",
|
||||
f"### 节点类型分布",
|
||||
]
|
||||
for etype, count in type_stats:
|
||||
lines.append(f"- {etype}: {count} 个")
|
||||
|
||||
lines.append(f"\n### 关系类型分布")
|
||||
for rtype, count in rel_stats:
|
||||
lines.append(f"- {rtype}: {count} 条")
|
||||
|
||||
lines.append(f"\n### 核心实体 (Top 10)")
|
||||
for node in top_nodes:
|
||||
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_entity_context(self, entity: str, user_id: str) -> str:
|
||||
"""获取某个实体的详细上下文"""
|
||||
# 查找节点
|
||||
result = await self.db.execute(
|
||||
select(KGNode).where(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.name.contains(entity),
|
||||
).limit(5)
|
||||
)
|
||||
nodes = list(result.scalars().all())
|
||||
|
||||
if not nodes:
|
||||
return f"未找到实体: {entity}"
|
||||
|
||||
lines = []
|
||||
for node in nodes:
|
||||
lines.append(f"### {node.name} [{node.entity_type}]")
|
||||
lines.append(f"描述: {node.description or '无描述'}")
|
||||
|
||||
# 获取该节点的关系
|
||||
edges_result = await self.db.execute(
|
||||
select(KGEdge, KGNode)
|
||||
.join(KGNode, KGNode.id == KGEdge.target_id)
|
||||
.where(KGEdge.source_id == node.id)
|
||||
.limit(10)
|
||||
)
|
||||
out_edges = list(edges_result.all())
|
||||
|
||||
in_edges_result = await self.db.execute(
|
||||
select(KGEdge, KGNode)
|
||||
.join(KGNode, KGNode.id == KGEdge.source_id)
|
||||
.where(KGEdge.target_id == node.id)
|
||||
.limit(10)
|
||||
)
|
||||
in_edges = list(in_edges_result.all())
|
||||
|
||||
if out_edges:
|
||||
lines.append("**关联到**:")
|
||||
for edge, target in out_edges:
|
||||
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
|
||||
|
||||
if in_edges:
|
||||
lines.append("**被关联于**:")
|
||||
for edge, source in in_edges:
|
||||
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
|
||||
"""获取节点的邻居节点(用于图谱可视化)"""
|
||||
visited = set()
|
||||
current_level = {node_id}
|
||||
all_nodes = []
|
||||
all_edges = []
|
||||
|
||||
for _ in range(depth):
|
||||
if not current_level:
|
||||
break
|
||||
next_level = set()
|
||||
|
||||
for nid in current_level:
|
||||
if nid in visited:
|
||||
continue
|
||||
visited.add(nid)
|
||||
|
||||
# 获取节点
|
||||
node_result = await self.db.execute(
|
||||
select(KGNode).where(KGNode.id == nid)
|
||||
)
|
||||
node = node_result.scalar_one_or_none()
|
||||
if node:
|
||||
all_nodes.append(node)
|
||||
|
||||
# 获取出边
|
||||
out_result = await self.db.execute(
|
||||
select(KGEdge).where(KGEdge.source_id == nid)
|
||||
)
|
||||
for edge in out_result.scalars().all():
|
||||
all_edges.append(edge)
|
||||
next_level.add(edge.target_id)
|
||||
|
||||
# 获取入边
|
||||
in_result = await self.db.execute(
|
||||
select(KGEdge).where(KGEdge.target_id == nid)
|
||||
)
|
||||
for edge in in_result.scalars().all():
|
||||
all_edges.append(edge)
|
||||
next_level.add(edge.source_id)
|
||||
|
||||
current_level = next_level
|
||||
|
||||
return {"nodes": all_nodes, "edges": all_edges}
|
||||
308
backend/app/services/knowledge_service.py
Normal file
308
backend/app/services/knowledge_service.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
|
||||
|
||||
检索策略:
|
||||
1. 语义检索 (dense) - ChromaDB 向量相似度
|
||||
2. 关键词检索 (sparse) - SQL LIKE
|
||||
3. 混合检索 - 语义 + 关键词 加权融合
|
||||
4. Rerank - 二次排序优化结果
|
||||
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.config import settings
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
document_title: str
|
||||
content: str
|
||||
score: float
|
||||
metadata_: str | None = None
|
||||
prev_chunk: str | None = None
|
||||
next_chunk: str | None = None
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""向量知识库检索服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
||||
self.db = db
|
||||
self.user_id = user_id
|
||||
self._chroma_client = None
|
||||
|
||||
@property
|
||||
def chroma_client(self):
|
||||
if self._chroma_client is None:
|
||||
self._chroma_client = chromadb.PersistentClient(
|
||||
path=settings.CHROMA_PERSIST_DIR,
|
||||
settings=ChromaSettings(allow_reset=True),
|
||||
)
|
||||
return self._chroma_client
|
||||
|
||||
def get_collection(self, user_id: str):
|
||||
return self.chroma_client.get_or_create_collection(
|
||||
name=f"user_{user_id}",
|
||||
metadata={"user_id": user_id},
|
||||
)
|
||||
|
||||
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
|
||||
"""将文档 chunks 向量化存入 ChromaDB"""
|
||||
result = await self.db.execute(
|
||||
select(Document).where(Document.id == document_id)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
return
|
||||
|
||||
chunks_result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document_id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
chunks = list(chunks_result.scalars().all())
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
collection = self.get_collection(user_id)
|
||||
|
||||
ids = [chunk.id for chunk in chunks]
|
||||
documents = [chunk.content for chunk in chunks]
|
||||
metadatas = [
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"document_title": doc.title,
|
||||
"chunk_index": chunk.chunk_index,
|
||||
"file_type": doc.file_type,
|
||||
"folder_path": folder_path or "",
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
collection.add(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
doc.is_indexed = True
|
||||
await self.db.commit()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
folder_id: str | None = None,
|
||||
top_k: int = 5,
|
||||
use_rerank: bool = True,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
混合检索 + Rerank,支持按文件夹过滤
|
||||
|
||||
流程:
|
||||
1. ChromaDB 向量检索 (扩大候选集)
|
||||
2. 提取父 chunk(完整上下文)
|
||||
3. Rerank 二次排序
|
||||
4. 返回 top_k 结果
|
||||
"""
|
||||
collection = self.get_collection(user_id)
|
||||
|
||||
# 构建过滤条件
|
||||
where = None
|
||||
if folder_id:
|
||||
folder_path = await self._get_folder_path(folder_id)
|
||||
if folder_path:
|
||||
where = {"folder_path": {"$starts_with": folder_path}}
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=top_k * 3,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not results or not results.get("ids"):
|
||||
return []
|
||||
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
metadatas = results.get("metadatas", [[]])[0]
|
||||
distances = results.get("distances", [[]])[0]
|
||||
|
||||
search_results: list[SearchResult] = []
|
||||
for i, chunk_id in enumerate(ids):
|
||||
meta = metadatas[i] if i < len(metadatas) else {}
|
||||
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
|
||||
|
||||
prev_chunk, next_chunk = await self._get_sibling_chunks(
|
||||
chunk_id=chunk_id,
|
||||
chunk_index=meta.get("chunk_index", 0),
|
||||
document_id=meta.get("document_id", ""),
|
||||
)
|
||||
|
||||
search_results.append(SearchResult(
|
||||
chunk_id=chunk_id,
|
||||
document_id=meta.get("document_id", ""),
|
||||
document_title=meta.get("document_title", ""),
|
||||
content=documents[i] if i < len(documents) else "",
|
||||
score=score,
|
||||
metadata_=str(meta),
|
||||
prev_chunk=prev_chunk,
|
||||
next_chunk=next_chunk,
|
||||
))
|
||||
|
||||
if use_rerank:
|
||||
search_results = self._rerank(query, search_results, top_k)
|
||||
else:
|
||||
search_results = search_results[:top_k]
|
||||
|
||||
return search_results
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
query: str,
|
||||
results: list[SearchResult],
|
||||
top_k: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
|
||||
import re
|
||||
|
||||
query_words = set(re.findall(r"\w+", query.lower()))
|
||||
|
||||
scored = []
|
||||
for r in results:
|
||||
score = r.score * 0.7
|
||||
|
||||
content_words = set(re.findall(r"\w+", r.content.lower()))
|
||||
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
|
||||
score += keyword_overlap * 0.2
|
||||
|
||||
if r.document_title:
|
||||
title_words = set(re.findall(r"\w+", r.document_title.lower()))
|
||||
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
|
||||
score += title_overlap * 0.1
|
||||
|
||||
scored.append((score, r))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in scored[:top_k]]
|
||||
|
||||
async def _get_sibling_chunks(
|
||||
self,
|
||||
chunk_id: str,
|
||||
chunk_index: int,
|
||||
document_id: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""获取前一个和后一个 chunk(完整上下文)"""
|
||||
prev_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index - 1,
|
||||
)
|
||||
)
|
||||
next_result = await self.db.execute(
|
||||
select(DocumentChunk).where(
|
||||
DocumentChunk.document_id == document_id,
|
||||
DocumentChunk.chunk_index == chunk_index + 1,
|
||||
)
|
||||
)
|
||||
prev_chunk = prev_result.scalar_one_or_none()
|
||||
next_chunk = next_result.scalar_one_or_none()
|
||||
return (
|
||||
prev_chunk.content if prev_chunk else None,
|
||||
next_chunk.content if next_chunk else None,
|
||||
)
|
||||
|
||||
async def _get_folder_path(self, folder_id: str) -> str | None:
|
||||
"""获取文件夹的完整路径"""
|
||||
result = await self.db.execute(
|
||||
select(Folder).where(Folder.id == folder_id)
|
||||
)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
path_parts = [folder.name]
|
||||
current_parent_id = folder.parent_id
|
||||
|
||||
while current_parent_id:
|
||||
parent_result = await self.db.execute(
|
||||
select(Folder).where(Folder.id == current_parent_id)
|
||||
)
|
||||
parent = parent_result.scalar_one_or_none()
|
||||
if not parent:
|
||||
break
|
||||
path_parts.insert(0, parent.name)
|
||||
current_parent_id = parent.parent_id
|
||||
|
||||
return "/" + "/".join(path_parts)
|
||||
|
||||
async def hybrid_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[SearchResult]:
|
||||
"""混合检索: 向量 + 关键词 + Rerank"""
|
||||
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
|
||||
keyword_results = await self._keyword_search(query, user_id, top_k)
|
||||
|
||||
seen = set()
|
||||
merged: list[SearchResult] = []
|
||||
for r in vector_results + keyword_results:
|
||||
if r.chunk_id not in seen:
|
||||
seen.add(r.chunk_id)
|
||||
merged.append(r)
|
||||
|
||||
return self._rerank(query, merged, top_k)
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
top_k: int,
|
||||
) -> list[SearchResult]:
|
||||
"""SQL 关键词搜索"""
|
||||
result = await self.db.execute(
|
||||
select(DocumentChunk)
|
||||
.join(Document)
|
||||
.where(Document.user_id == user_id)
|
||||
.where(
|
||||
or_(
|
||||
DocumentChunk.content.contains(query),
|
||||
Document.title.contains(query),
|
||||
)
|
||||
)
|
||||
.limit(top_k)
|
||||
)
|
||||
chunks = result.scalars().all()
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
doc_result = await self.db.execute(
|
||||
select(Document).where(Document.id == chunk.document_id)
|
||||
)
|
||||
doc = doc_result.scalar_one_or_none()
|
||||
results.append(SearchResult(
|
||||
chunk_id=chunk.id,
|
||||
document_id=chunk.document_id,
|
||||
document_title=doc.title if doc else "",
|
||||
content=chunk.content,
|
||||
score=0.5,
|
||||
metadata_=None,
|
||||
))
|
||||
return results
|
||||
|
||||
async def delete_from_vectorstore(self, user_id: str, document_id: str):
|
||||
"""从向量库删除文档"""
|
||||
collection = self.get_collection(user_id)
|
||||
try:
|
||||
collection.delete(where={"document_id": document_id})
|
||||
except Exception:
|
||||
pass
|
||||
145
backend/app/services/llm_service.py
Normal file
145
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
LLM 服务 - 支持多种 LLM 提供商
|
||||
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
from app.config import settings
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@abstractmethod
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAICompatibleService(LLMService):
|
||||
"""
|
||||
OpenAI 兼容接口
|
||||
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY
|
||||
self.model = model or settings.OPENAI_MODEL
|
||||
self.base_url = base_url or settings.OPENAI_BASE_URL
|
||||
self._llm = ChatOpenAI(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
class ClaudeService(LLMService):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 8192,
|
||||
):
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
||||
self.model = model or settings.CLAUDE_MODEL
|
||||
self._llm = ChatAnthropic(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
class OllamaService(LLMService):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
self.base_url = base_url or settings.OLLAMA_BASE_URL
|
||||
self.model = model or settings.OLLAMA_MODEL
|
||||
self._llm = ChatOllama(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
|
||||
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
return await self._llm.ainvoke(messages)
|
||||
|
||||
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
async for chunk in self._llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
|
||||
# 单例缓存
|
||||
_llm_instance: LLMService | None = None
|
||||
|
||||
|
||||
def get_llm() -> LLMService:
|
||||
"""根据配置获取 LLM 实例"""
|
||||
global _llm_instance
|
||||
if _llm_instance is None:
|
||||
provider = settings.LLM_PROVIDER
|
||||
if provider == "openai":
|
||||
_llm_instance = OpenAICompatibleService()
|
||||
elif provider == "deepseek":
|
||||
_llm_instance = OpenAICompatibleService(
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
model="deepseek-chat",
|
||||
)
|
||||
elif provider == "custom":
|
||||
_llm_instance = OpenAICompatibleService()
|
||||
elif provider == "claude":
|
||||
_llm_instance = ClaudeService()
|
||||
elif provider == "ollama":
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
return _llm_instance
|
||||
304
backend/app/services/memory_service.py
Normal file
304
backend/app/services/memory_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
limit: int = 20,
|
||||
) -> list[Message]:
|
||||
"""加载指定对话的历史消息"""
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(desc(MemorySummary.turn_count))
|
||||
.limit(1)
|
||||
)
|
||||
latest_summary = result.scalar_one_or_none()
|
||||
if latest_summary:
|
||||
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
|
||||
return turn_count >= SUMMARIZE_THRESHOLD
|
||||
|
||||
|
||||
async def generate_summary(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
async def save_summary(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
summary_text=summary_text,
|
||||
turn_count=turn_count,
|
||||
)
|
||||
db.add(summary)
|
||||
await db.commit()
|
||||
await db.refresh(summary)
|
||||
return summary
|
||||
|
||||
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
.order_by(MemorySummary.summary_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def try_auto_summarize(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
|
||||
messages = await load_conversation_history(db, conversation_id, limit=30)
|
||||
if len(messages) < 3:
|
||||
return False
|
||||
|
||||
try:
|
||||
summary_text = await generate_summary(db, conversation_id, messages)
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
291
backend/app/services/scheduler_service.py
Normal file
291
backend/app/services/scheduler_service.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
定时任务服务 - APScheduler 调度器
|
||||
"""
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlalchemy import select, and_
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.models.forum import ForumPost
|
||||
from app.models.knowledge_graph import KGNode
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.graph_service import GraphService
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
|
||||
|
||||
# ===================== 定时任务函数 =====================
|
||||
|
||||
async def daily_task_analysis():
|
||||
"""
|
||||
每日凌晨任务分析
|
||||
- 分析前一天完成的任务
|
||||
- 生成每日报告
|
||||
- 创建次日计划建议
|
||||
"""
|
||||
logger.info("[Scheduler] 开始执行每日任务分析...")
|
||||
|
||||
async with async_session() as db:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
yesterday = datetime.utcnow().date() - timedelta(days=1)
|
||||
|
||||
# 统计昨日任务完成情况
|
||||
result = await db.execute(
|
||||
select(Task).where(Task.updated_at >= yesterday)
|
||||
)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
completed = [t for t in tasks if t.status == "done"]
|
||||
pending = [t for t in tasks if t.status != "done"]
|
||||
|
||||
report = f"""## 每日任务报告 - {yesterday.strftime('%Y-%m-%d')}
|
||||
|
||||
### 完成情况
|
||||
- 总任务数: {len(tasks)}
|
||||
- 已完成: {len(completed)}
|
||||
- 未完成: {len(pending)}
|
||||
|
||||
### 已完成任务
|
||||
{chr(10).join([f"- {t.title}" for t in completed]) or "无"}
|
||||
|
||||
### 未完成任务
|
||||
{chr(10).join([f"- {t.title} (优先级: {t.priority})" for t in pending]) or "无"}
|
||||
|
||||
### 建议
|
||||
根据未完成任务,建议明天优先处理:
|
||||
{chr(10).join([f"{i+1}. {t.title}" for i, t in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5])]) or "无待处理任务"}
|
||||
"""
|
||||
|
||||
# 发布到论坛
|
||||
from app.models.forum import ForumPost
|
||||
post = ForumPost(
|
||||
title=f"每日报告 - {yesterday.strftime('%Y-%m-%d')}",
|
||||
content=report,
|
||||
category="discussion",
|
||||
)
|
||||
db.add(post)
|
||||
|
||||
# 创建明日计划建议任务
|
||||
for i, task in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5]):
|
||||
suggestion = Task(
|
||||
title=f"继续: {task.title}",
|
||||
description=f"昨日未完成任务,优先级: {task.priority}",
|
||||
priority=task.priority,
|
||||
status="todo",
|
||||
)
|
||||
db.add(suggestion)
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"[Scheduler] 每日任务分析完成,完成 {len(completed)} 个任务")
|
||||
|
||||
|
||||
async def forum_scan_task():
|
||||
"""
|
||||
论坛扫描任务
|
||||
- 扫描所有指令类帖子
|
||||
- 识别可执行指令
|
||||
- AI自动执行
|
||||
"""
|
||||
logger.info("[Scheduler] 开始扫描论坛指令...")
|
||||
|
||||
async with async_session() as db:
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(ForumPost).where(
|
||||
ForumPost.category == "instruction",
|
||||
ForumPost.is_executed == False,
|
||||
).limit(5)
|
||||
)
|
||||
posts = result.scalars().all()
|
||||
|
||||
if not posts:
|
||||
logger.info("[Scheduler] 暂无待执行指令")
|
||||
return
|
||||
|
||||
agent_svc = AgentService(db)
|
||||
executed_count = 0
|
||||
|
||||
for post in posts:
|
||||
try:
|
||||
# 让 Agent 分析并执行指令
|
||||
conv_id, msg_id, response = await agent_svc.chat_simple(
|
||||
user_id=post.user_id,
|
||||
message=f"请执行以下论坛指令: {post.title}。{post.content}",
|
||||
conversation_id=None,
|
||||
)
|
||||
post.is_executed = True
|
||||
post.executed_response = response
|
||||
executed_count += 1
|
||||
logger.info(f"[Scheduler] 执行指令: {post.title}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 执行指令失败 {post.title}: {e}")
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"[Scheduler] 论坛扫描完成,执行了 {executed_count} 个指令")
|
||||
|
||||
|
||||
async def graph_rebuild_task():
|
||||
"""
|
||||
知识图谱增量重建任务
|
||||
- 扫描新增/更新的文档
|
||||
- 更新图谱节点和边
|
||||
"""
|
||||
logger.info("[Scheduler] 开始重建知识图谱...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
graph_svc = GraphService(db)
|
||||
# 只处理最近7天有活动的文档
|
||||
await graph_svc.build_graph(user_id="default", document_ids=None)
|
||||
logger.info("[Scheduler] 知识图谱重建完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 知识图谱重建失败: {e}")
|
||||
|
||||
|
||||
async def tag_generation_task():
|
||||
"""
|
||||
每日凌晨 00:00 增量标签生成任务
|
||||
"""
|
||||
from app.services.tag_service import TagService
|
||||
from app.core.llm import get_llm_client
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info("[Scheduler] 开始执行每日标签生成...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
llm_client = get_llm_client()
|
||||
tag_service = TagService(db, llm_client)
|
||||
|
||||
result = await db.execute(
|
||||
select(KGNode.user_id).distinct().where(
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"])
|
||||
)
|
||||
)
|
||||
user_ids = result.scalars().all()
|
||||
|
||||
total_tagged = 0
|
||||
for user_id in user_ids:
|
||||
sync_tag_service = TagService(db, llm_client)
|
||||
result = sync_tag_service.tag_incremental_content(user_id, days=1)
|
||||
total_tagged += result["tagged"]
|
||||
|
||||
logger.info(f"[Scheduler] 每日标签生成完成,共标签化 {total_tagged} 个内容节点")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 每日标签生成失败: {e}")
|
||||
|
||||
|
||||
async def daily_todo_generation():
|
||||
"""
|
||||
每天早上 08:00 为所有活跃用户生成待办
|
||||
- 来自前一天未完成的看板任务
|
||||
- 来自前一天对话记录分析
|
||||
"""
|
||||
from app.models.user import User
|
||||
from app.services.todo_service import generate_daily_todos
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info("[Scheduler] 开始执行每日待办生成...")
|
||||
|
||||
async with async_session() as db:
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.is_active == True))
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
try:
|
||||
await generate_daily_todos(user.id, db)
|
||||
logger.info(f"[Scheduler] 为用户 {user.id} 生成今日待办完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 用户 {user.id} 定时生成待办失败: {e}")
|
||||
|
||||
logger.info(f"[Scheduler] 每日待办生成完成,共处理 {len(users)} 个用户")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] 每日待办生成失败: {e}")
|
||||
|
||||
|
||||
# ===================== 调度器管理 =====================
|
||||
|
||||
def start_scheduler():
|
||||
"""启动调度器,注册所有定时任务"""
|
||||
if scheduler.running:
|
||||
logger.warning("[Scheduler] 调度器已在运行")
|
||||
return
|
||||
|
||||
# 每日凌晨 00:30 执行任务分析
|
||||
scheduler.add_job(
|
||||
daily_task_analysis,
|
||||
CronTrigger(hour=0, minute=30, timezone="Asia/Shanghai"),
|
||||
id="daily_task_analysis",
|
||||
name="每日任务分析",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每小时扫描论坛指令
|
||||
scheduler.add_job(
|
||||
forum_scan_task,
|
||||
IntervalTrigger(hours=1),
|
||||
id="forum_scan",
|
||||
name="论坛指令扫描",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天凌晨 3:00 重建图谱
|
||||
scheduler.add_job(
|
||||
graph_rebuild_task,
|
||||
CronTrigger(hour=3, minute=0, timezone="Asia/Shanghai"),
|
||||
id="graph_rebuild",
|
||||
name="知识图谱重建",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天凌晨 00:00 生成标签
|
||||
scheduler.add_job(
|
||||
tag_generation_task,
|
||||
CronTrigger(hour=0, minute=0, timezone="Asia/Shanghai"),
|
||||
id="tag_generation",
|
||||
name="每日标签生成",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# 每天早上 08:00 生成今日待办
|
||||
scheduler.add_job(
|
||||
daily_todo_generation,
|
||||
CronTrigger(hour=8, minute=0, timezone="Asia/Shanghai"),
|
||||
id="daily_todo_generation",
|
||||
name="每日待办生成",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("[Scheduler] 定时任务调度器已启动")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
"""停止调度器"""
|
||||
if scheduler.running:
|
||||
scheduler.shutdown(wait=False)
|
||||
logger.info("[Scheduler] 定时任务调度器已停止")
|
||||
|
||||
|
||||
def get_scheduler_status() -> dict:
|
||||
"""获取调度器状态"""
|
||||
if not scheduler.running:
|
||||
return {"status": "stopped", "jobs": []}
|
||||
|
||||
jobs = []
|
||||
for job in scheduler.get_jobs():
|
||||
jobs.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run": str(job.next_run_time) if job.next_run_time else None,
|
||||
})
|
||||
|
||||
return {"status": "running", "jobs": jobs}
|
||||
140
backend/app/services/settings_service.py
Normal file
140
backend/app/services/settings_service.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import verify_password, get_password_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_user_settings(user_id: str, db: AsyncSession) -> dict:
|
||||
"""获取用户完整设置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
return None
|
||||
return {
|
||||
"profile": user,
|
||||
"llm_config": user.llm_config or {},
|
||||
"scheduler_config": user.scheduler_config or {}
|
||||
}
|
||||
|
||||
|
||||
async def update_user_profile(
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
full_name: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
current_password: Optional[str] = None
|
||||
) -> User:
|
||||
"""更新用户资料"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
if password:
|
||||
if not current_password or not verify_password(current_password, user.hashed_password):
|
||||
raise ValueError("当前密码错误")
|
||||
user.hashed_password = get_password_hash(password)
|
||||
|
||||
if full_name:
|
||||
user.full_name = full_name
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
|
||||
"""更新 LLM 配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
current = user.llm_config or {}
|
||||
# 合并配置 - 直接替换整个类型配置列表
|
||||
for key, value in config.items():
|
||||
if value is not None:
|
||||
if isinstance(value, list):
|
||||
# 列表直接替换
|
||||
current[key] = value
|
||||
elif isinstance(value, dict):
|
||||
# 字典合并
|
||||
if key in current and isinstance(current[key], dict):
|
||||
current[key] = {**current[key], **value}
|
||||
else:
|
||||
current[key] = value
|
||||
else:
|
||||
current[key] = value
|
||||
user.llm_config = current
|
||||
await db.commit()
|
||||
return current
|
||||
|
||||
|
||||
async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession) -> dict:
|
||||
"""更新定时任务配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("用户不存在")
|
||||
|
||||
current = user.scheduler_config or {}
|
||||
for key, value in config.items():
|
||||
if value is not None:
|
||||
current[key] = value
|
||||
user.scheduler_config = current
|
||||
await db.commit()
|
||||
return current
|
||||
|
||||
|
||||
async def test_llm_connection(
|
||||
provider: str,
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str
|
||||
) -> dict:
|
||||
"""测试 LLM 连接"""
|
||||
try:
|
||||
# 根据不同 provider 创建临时 LLM 实例并测试
|
||||
if provider == "openai":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "claude":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "ollama":
|
||||
from langchain_ollama import ChatOllama
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=30
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or "https://api.deepseek.com/v1",
|
||||
timeout=30
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 provider: {provider}"}
|
||||
|
||||
# 简单测试调用
|
||||
from langchain_core.messages import HumanMessage
|
||||
response = await llm.ainvoke([HumanMessage(content="Hi")])
|
||||
return {"success": True, "message": f"连接成功,模型响应: {response.content[:50]}..."}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
278
backend/app/services/stats_service.py
Normal file
278
backend/app/services/stats_service.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.models.document import Document
|
||||
|
||||
|
||||
class StatsService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_system_health(self) -> dict:
|
||||
"""获取系统健康指标"""
|
||||
uptime_seconds = int(time.time() - psutil.boot_time())
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
mem = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_used_mb": round(mem.used / (1024 * 1024), 1),
|
||||
"memory_total_mb": round(mem.total / (1024 * 1024), 1),
|
||||
"memory_percent": mem.percent,
|
||||
"disk_used_gb": round(disk.used / (1024 * 1024 * 1024), 1),
|
||||
"disk_total_gb": round(disk.total / (1024 * 1024 * 1024), 1),
|
||||
"disk_percent": disk.percent,
|
||||
"active_users_24h": 0, # 需要 User 表的 updated_at
|
||||
}
|
||||
|
||||
def _get_daily_stats(self, model, date_column, user_id=None, days=30) -> list:
|
||||
"""通用每日统计查询"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
query = self.db.query(
|
||||
func.date(date_column).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(date_column >= cutoff)
|
||||
|
||||
if user_id and hasattr(model, 'user_id'):
|
||||
query = query.filter(model.user_id == user_id)
|
||||
|
||||
query = query.group_by(func.date(date_column)).order_by(func.date(date_column))
|
||||
results = query.all()
|
||||
return [{"date": str(r.date), "count": r.count} for r in results]
|
||||
|
||||
def get_conversation_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取对话统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
daily_conversations = self._get_daily_stats(
|
||||
Conversation, Conversation.created_at, user_id, days
|
||||
)
|
||||
daily_messages = self._get_daily_stats(
|
||||
Message, Message.created_at, user_id, days
|
||||
)
|
||||
|
||||
# Daily tokens
|
||||
input_query = self.db.query(
|
||||
func.date(Message.created_at).label('date'),
|
||||
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
|
||||
).filter(
|
||||
Message.created_at >= cutoff,
|
||||
Message.role == 'user'
|
||||
)
|
||||
if user_id:
|
||||
input_query = input_query.join(Conversation).filter(Conversation.user_id == user_id)
|
||||
input_results = input_query.group_by(func.date(Message.created_at)).all()
|
||||
|
||||
output_query = self.db.query(
|
||||
func.date(Message.created_at).label('date'),
|
||||
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
|
||||
).filter(
|
||||
Message.created_at >= cutoff,
|
||||
Message.role == 'assistant'
|
||||
)
|
||||
if user_id:
|
||||
output_query = output_query.join(Conversation).filter(Conversation.user_id == user_id)
|
||||
output_results = output_query.group_by(func.date(Message.created_at)).all()
|
||||
|
||||
daily_input_tokens = [{"date": str(r.date), "input_tokens": r.tokens} for r in input_results]
|
||||
daily_output_tokens = [{"date": str(r.date), "output_tokens": r.tokens} for r in output_results]
|
||||
|
||||
return {
|
||||
"daily_conversations": daily_conversations,
|
||||
"daily_messages": daily_messages,
|
||||
"daily_input_tokens": daily_input_tokens,
|
||||
"daily_output_tokens": daily_output_tokens,
|
||||
"totals": {
|
||||
"conversations": sum(c["count"] for c in daily_conversations),
|
||||
"messages": sum(m["count"] for m in daily_messages),
|
||||
"input_tokens": sum(t["input_tokens"] for t in daily_input_tokens),
|
||||
"output_tokens": sum(t["output_tokens"] for t in daily_output_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
def get_knowledge_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取知识库统计数据"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
# New tags
|
||||
tag_query = self.db.query(
|
||||
func.date(KGNode.created_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
KGNode.created_at >= cutoff,
|
||||
KGNode.entity_type == 'tag'
|
||||
)
|
||||
if user_id:
|
||||
tag_query = tag_query.filter(KGNode.user_id == user_id)
|
||||
tag_results = tag_query.group_by(func.date(KGNode.created_at)).all()
|
||||
daily_new_tags = [{"date": str(r.date), "count": r.count} for r in tag_results]
|
||||
|
||||
daily_documents = self._get_daily_stats(
|
||||
Document, Document.created_at, user_id, days
|
||||
)
|
||||
daily_tag_relations = self._get_daily_stats(
|
||||
KGEdge, KGEdge.created_at, user_id, days
|
||||
)
|
||||
|
||||
return {
|
||||
"daily_new_tags": daily_new_tags,
|
||||
"daily_documents": daily_documents,
|
||||
"daily_knowledge_queries": [],
|
||||
"daily_tag_relations": daily_tag_relations,
|
||||
"totals": {
|
||||
"new_tags": sum(t["count"] for t in daily_new_tags),
|
||||
"documents": sum(d["count"] for d in daily_documents),
|
||||
"tag_relations": sum(r["count"] for r in daily_tag_relations),
|
||||
}
|
||||
}
|
||||
|
||||
def get_kanban_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取看板统计数据"""
|
||||
daily_new_tasks = self._get_daily_stats(
|
||||
Task, Task.created_at, user_id, days
|
||||
)
|
||||
|
||||
# Completed tasks
|
||||
completed_query = self.db.query(
|
||||
func.date(Task.completed_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
Task.completed_at >= datetime.utcnow() - timedelta(days=days),
|
||||
Task.status == TaskStatus.DONE
|
||||
)
|
||||
if user_id:
|
||||
completed_query = completed_query.filter(Task.user_id == user_id)
|
||||
completed_results = completed_query.group_by(func.date(Task.completed_at)).all()
|
||||
daily_completed_tasks = [{"date": str(r.date), "count": r.count} for r in completed_results]
|
||||
|
||||
# Current pending
|
||||
pending_query = self.db.query(func.count(Task.id)).filter(Task.status == TaskStatus.TODO)
|
||||
if user_id:
|
||||
pending_query = pending_query.filter(Task.user_id == user_id)
|
||||
current_pending_tasks = pending_query.scalar() or 0
|
||||
|
||||
# Completion rate
|
||||
daily_new_dict = {d["date"]: d["count"] for d in daily_new_tasks}
|
||||
daily_completed_dict = {d["date"]: d["count"] for d in daily_completed_tasks}
|
||||
all_dates = set(daily_new_dict.keys()) | set(daily_completed_dict.keys())
|
||||
daily_completion_rate = []
|
||||
for date in sorted(all_dates):
|
||||
new = daily_new_dict.get(date, 0)
|
||||
completed = daily_completed_dict.get(date, 0)
|
||||
rate = (completed / new * 100) if new > 0 else 0
|
||||
daily_completion_rate.append({"date": date, "rate": round(rate, 1)})
|
||||
|
||||
return {
|
||||
"daily_new_tasks": daily_new_tasks,
|
||||
"daily_completed_tasks": daily_completed_tasks,
|
||||
"daily_completion_rate": daily_completion_rate,
|
||||
"current_pending_tasks": current_pending_tasks,
|
||||
"totals": {
|
||||
"new_tasks": sum(t["count"] for t in daily_new_tasks),
|
||||
"completed_tasks": sum(c["count"] for c in daily_completed_tasks),
|
||||
}
|
||||
}
|
||||
|
||||
def get_community_stats(self, user_id: str = None, days=30) -> dict:
|
||||
"""获取社区统计数据"""
|
||||
daily_posts = self._get_daily_stats(
|
||||
ForumPost, ForumPost.created_at, user_id, days
|
||||
)
|
||||
daily_replies = self._get_daily_stats(
|
||||
ForumReply, ForumReply.created_at, user_id, days
|
||||
)
|
||||
|
||||
# AI executions
|
||||
ai_query = self.db.query(
|
||||
func.date(ForumPost.updated_at).label('date'),
|
||||
func.count().label('count')
|
||||
).filter(
|
||||
ForumPost.updated_at >= datetime.utcnow() - timedelta(days=days),
|
||||
ForumPost.is_executed == True
|
||||
)
|
||||
if user_id:
|
||||
ai_query = ai_query.filter(ForumPost.user_id == user_id)
|
||||
ai_results = ai_query.group_by(func.date(ForumPost.updated_at)).all()
|
||||
daily_ai_executions = [{"date": str(r.date), "count": r.count} for r in ai_results]
|
||||
|
||||
return {
|
||||
"daily_posts": daily_posts,
|
||||
"daily_replies": daily_replies,
|
||||
"daily_ai_executions": daily_ai_executions,
|
||||
"daily_agent_calls": [],
|
||||
"totals": {
|
||||
"posts": sum(p["count"] for p in daily_posts),
|
||||
"replies": sum(r["count"] for r in daily_replies),
|
||||
"ai_executions": sum(a["count"] for a in daily_ai_executions),
|
||||
}
|
||||
}
|
||||
|
||||
def get_personal_insights(self, user_id: str) -> dict:
|
||||
"""获取个人洞察"""
|
||||
# Hourly activity
|
||||
hourly_query = self.db.query(
|
||||
func.extract('hour', Conversation.created_at).label('hour'),
|
||||
func.count().label('count')
|
||||
).filter(Conversation.user_id == user_id).group_by(
|
||||
func.extract('hour', Conversation.created_at)
|
||||
)
|
||||
hourly_results = hourly_query.all()
|
||||
hourly_activity = [{"hour": int(r.hour), "count": r.count} for r in hourly_results]
|
||||
|
||||
# Top tags
|
||||
tag_query = self.db.query(
|
||||
KGNode.properties_["tag_path"].astext.label('tag_path'),
|
||||
func.count(KGEdge.id).label('usage_count')
|
||||
).join(
|
||||
KGEdge, KGEdge.target_id == KGNode.id
|
||||
).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type == 'tag',
|
||||
KGEdge.relation_type == 'has_tag'
|
||||
).group_by(
|
||||
KGNode.properties_["tag_path"].astext
|
||||
).order_by(func.count(KGEdge.id).desc()).limit(5)
|
||||
top_tags = [{"tag_path": r.tag_path, "usage_count": r.usage_count} for r in tag_query.all()]
|
||||
|
||||
# Token trend
|
||||
now = datetime.utcnow()
|
||||
this_month_start = datetime(now.year, now.month, 1)
|
||||
last_month_end = this_month_start - timedelta(days=1)
|
||||
last_month_start = datetime(last_month_end.year, last_month_end.month, 1)
|
||||
|
||||
this_month_tokens = self.db.query(
|
||||
func.coalesce(func.sum(Message.tokens_used), 0)
|
||||
).join(Conversation).filter(
|
||||
Conversation.user_id == user_id,
|
||||
Message.created_at >= this_month_start,
|
||||
Message.role == 'assistant'
|
||||
).scalar() or 0
|
||||
|
||||
last_month_tokens = self.db.query(
|
||||
func.coalesce(func.sum(Message.tokens_used), 0)
|
||||
).join(Conversation).filter(
|
||||
Conversation.user_id == user_id,
|
||||
Message.created_at >= last_month_start,
|
||||
Message.created_at < this_month_start,
|
||||
Message.role == 'assistant'
|
||||
).scalar() or 0
|
||||
|
||||
token_trend_percent = 0
|
||||
if last_month_tokens > 0:
|
||||
token_trend_percent = round((this_month_tokens - last_month_tokens) / last_month_tokens * 100, 1)
|
||||
|
||||
return {
|
||||
"hourly_activity": hourly_activity,
|
||||
"top_tags": top_tags,
|
||||
"token_trend_percent": token_trend_percent,
|
||||
"this_month_tokens": this_month_tokens,
|
||||
"last_month_tokens": last_month_tokens,
|
||||
}
|
||||
239
backend/app/services/tag_service.py
Normal file
239
backend/app/services/tag_service.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import json
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.knowledge_graph import KGNode, KGEdge
|
||||
|
||||
TAG_EXTRACTION_PROMPT = """你是一个知识分类专家。从给定内容中提取标签。
|
||||
|
||||
要求:
|
||||
1. 标签采用层级路径格式,如 "编程语言/Python"、"后端/框架/FastAPI"
|
||||
2. 层级深度 1-4 层,避免过深
|
||||
3. 每个内容提取 3-8 个标签
|
||||
4. 标签应覆盖:主题、技术栈、领域、任务类型等维度
|
||||
|
||||
输出格式(JSON数组):
|
||||
[
|
||||
{"path": "编程语言/Python", "description": "Python编程语言相关"},
|
||||
{"path": "后端/框架/FastAPI", "description": "FastAPI框架相关"}
|
||||
]
|
||||
|
||||
内容:
|
||||
{content}
|
||||
"""
|
||||
|
||||
TAG_RELATION_PROMPT = """分析以下标签之间的关系,输出 JSON 数组:
|
||||
|
||||
关系类型:
|
||||
- parent_of: 父子关系(上级包含下级)
|
||||
- related_to: 语义相关(但不是父子)
|
||||
- synonym_of: 同义词
|
||||
|
||||
标签列表:
|
||||
{tag_paths}
|
||||
|
||||
输出格式:
|
||||
[
|
||||
{"source": "标签1", "target": "标签2", "relation": "related_to", "weight": 0.8},
|
||||
{"source": "标签1", "target": "标签3", "relation": "parent_of", "weight": 1.0}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class TagService:
|
||||
def __init__(self, db: Session, llm_client):
|
||||
self.db = db
|
||||
self.llm_client = llm_client
|
||||
|
||||
def extract_tags_from_content(self, content: str, user_id: str) -> list[dict]:
|
||||
"""从内容中提取标签"""
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个知识分类专家。"},
|
||||
{"role": "user", "content": TAG_EXTRACTION_PROMPT.format(content=content)}
|
||||
],
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result.get("tags", [])
|
||||
|
||||
def parse_tag_path(self, path: str) -> tuple[str, int, str | None]:
|
||||
"""解析标签路径,返回 (short_name, level, parent_path)"""
|
||||
parts = path.strip("/").split("/")
|
||||
short_name = parts[-1]
|
||||
level = len(parts)
|
||||
parent_path = "/".join(parts[:-1]) if level > 1 else None
|
||||
return short_name, level, parent_path
|
||||
|
||||
def get_or_create_tag_node(self, tag_info: dict, user_id: str) -> KGNode:
|
||||
"""获取或创建标签节点"""
|
||||
path = tag_info["path"]
|
||||
existing = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.properties_["tag_path"].astext == path
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
short_name, level, parent_path = self.parse_tag_path(path)
|
||||
|
||||
node = KGNode(
|
||||
user_id=user_id,
|
||||
name=short_name,
|
||||
entity_type="tag",
|
||||
description=tag_info.get("description"),
|
||||
properties_={
|
||||
"tag_path": path,
|
||||
"short_name": short_name,
|
||||
"level": level,
|
||||
"parent_path": parent_path,
|
||||
"description": tag_info.get("description"),
|
||||
"color": tag_info.get("color"),
|
||||
},
|
||||
importance=0.5
|
||||
)
|
||||
self.db.add(node)
|
||||
self.db.flush()
|
||||
return node
|
||||
|
||||
def ensure_parent_tags(self, path: str, user_id: str) -> list[KGNode]:
|
||||
"""确保父路径标签存在"""
|
||||
parts = path.strip("/").split("/")
|
||||
nodes = []
|
||||
for i in range(1, len(parts)):
|
||||
parent_path = "/".join(parts[:i])
|
||||
tag_info = {"path": parent_path, "description": None}
|
||||
node = self.get_or_create_tag_node(tag_info, user_id)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def create_tag_relations(self, tag_paths: list[str], user_id: str) -> list[KGEdge]:
|
||||
"""分析并创建标签之间的关系边"""
|
||||
path_to_node = {}
|
||||
for path in tag_paths:
|
||||
node = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.properties_["tag_path"].astext == path,
|
||||
KGNode.entity_type == "tag"
|
||||
).first()
|
||||
if node:
|
||||
path_to_node[path] = node
|
||||
|
||||
response = self.llm_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一个知识图谱专家。"},
|
||||
{"role": "user", "content": TAG_RELATION_PROMPT.format(tag_paths=json.dumps(tag_paths))}
|
||||
],
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
relations = result.get("relations", [])
|
||||
|
||||
edges = []
|
||||
for rel in relations:
|
||||
source_node = path_to_node.get(rel["source"])
|
||||
target_node = path_to_node.get(rel["target"])
|
||||
if source_node and target_node:
|
||||
existing = self.db.query(KGEdge).filter(
|
||||
KGEdge.source_id == source_node.id,
|
||||
KGEdge.target_id == target_node.id
|
||||
).first()
|
||||
if not existing:
|
||||
edge = KGEdge(
|
||||
source_id=source_node.id,
|
||||
target_id=target_node.id,
|
||||
relation_type=rel["relation"],
|
||||
weight=rel.get("weight", 0.5)
|
||||
)
|
||||
self.db.add(edge)
|
||||
edges.append(edge)
|
||||
|
||||
self.db.flush()
|
||||
return edges
|
||||
|
||||
def tag_content(self, content: str, user_id: str, content_node: KGNode) -> list[KGNode]:
|
||||
"""为内容节点打标签"""
|
||||
tag_infos = self.extract_tags_from_content(content, user_id)
|
||||
tag_paths = [t["path"] for t in tag_infos]
|
||||
|
||||
tag_nodes = []
|
||||
for tag_info in tag_infos:
|
||||
node = self.get_or_create_tag_node(tag_info, user_id)
|
||||
tag_nodes.append(node)
|
||||
self.ensure_parent_tags(tag_info["path"], user_id)
|
||||
|
||||
# 创建 has_tag 边
|
||||
for tag_node in tag_nodes:
|
||||
existing_edge = self.db.query(KGEdge).filter(
|
||||
KGEdge.source_id == content_node.id,
|
||||
KGEdge.target_id == tag_node.id,
|
||||
KGEdge.relation_type == "has_tag"
|
||||
).first()
|
||||
if not existing_edge:
|
||||
edge = KGEdge(
|
||||
source_id=content_node.id,
|
||||
target_id=tag_node.id,
|
||||
relation_type="has_tag",
|
||||
weight=1.0
|
||||
)
|
||||
self.db.add(edge)
|
||||
|
||||
tag_node_ids = [n.id for n in tag_nodes]
|
||||
current_tag_ids = content_node.properties_.get("tag_node_ids", []) if content_node.properties_ else []
|
||||
content_node.properties_["tag_node_ids"] = list(set(current_tag_ids + tag_node_ids))
|
||||
|
||||
if len(tag_paths) >= 2:
|
||||
self.create_tag_relations(tag_paths, user_id)
|
||||
|
||||
self.db.commit()
|
||||
return tag_nodes
|
||||
|
||||
def tag_incremental_content(self, user_id: str, days: int = 1) -> dict:
|
||||
"""
|
||||
增量打标签 - 只对最近新增/更新的内容节点打标签
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
content_nodes = self.db.query(KGNode).filter(
|
||||
KGNode.user_id == user_id,
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"]),
|
||||
KGNode.updated_at >= cutoff_date
|
||||
).all()
|
||||
|
||||
untagged = [
|
||||
n for n in content_nodes
|
||||
if not n.properties_.get("tag_node_ids")
|
||||
]
|
||||
|
||||
tagged_count = 0
|
||||
for node in untagged:
|
||||
content = node.description or ""
|
||||
try:
|
||||
self.tag_content(content, user_id, node)
|
||||
tagged_count += 1
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return {"total": len(untagged), "tagged": tagged_count}
|
||||
|
||||
def get_related_content(self, tag_node_ids: list[str], user_id: str, limit: int = 10) -> list[tuple[KGNode, float]]:
|
||||
"""通过标签找相关内容"""
|
||||
edges = self.db.query(KGEdge).filter(
|
||||
KGEdge.target_id.in_(tag_node_ids),
|
||||
KGEdge.relation_type == "has_tag"
|
||||
).all()
|
||||
|
||||
content_weights: dict[str, float] = {}
|
||||
for edge in edges:
|
||||
content_weights[edge.source_id] = content_weights.get(edge.source_id, 0) + edge.weight
|
||||
|
||||
content_ids = list(content_weights.keys())
|
||||
content_nodes = self.db.query(KGNode).filter(
|
||||
KGNode.id.in_(content_ids),
|
||||
KGNode.entity_type.in_(["conversation", "document", "chunk"])
|
||||
).all()
|
||||
|
||||
return [(node, content_weights[node.id]) for node in content_nodes]
|
||||
165
backend/app/services/todo_service.py
Normal file
165
backend/app/services/todo_service.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_daily_todos(user_id: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""
|
||||
为用户生成今日待办:
|
||||
1. 来自前一天未完成的看板任务(最多20条)
|
||||
2. 来自前一天对话记录分析(最多3条)
|
||||
"""
|
||||
today = date.today()
|
||||
yesterday = (today - timedelta(days=1)).isoformat()
|
||||
|
||||
todos: list[DailyTodo] = []
|
||||
|
||||
# 1. 从看板任务导入
|
||||
kanban_todos = await _import_kanban_tasks(user_id, yesterday, db)
|
||||
todos.extend(kanban_todos)
|
||||
|
||||
# 2. 从对话记录分析
|
||||
chat_todos = await _analyze_chat_history(user_id, yesterday, db)
|
||||
todos.extend(chat_todos)
|
||||
|
||||
return todos
|
||||
|
||||
|
||||
async def _import_kanban_tasks(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""导入前一天创建的、未完成的看板任务"""
|
||||
q = select(Task).where(
|
||||
Task.user_id == user_id,
|
||||
Task.status != TaskStatus.DONE,
|
||||
).order_by(Task.created_at.desc()).limit(20)
|
||||
|
||||
tasks = (await db.execute(q)).scalars().all()
|
||||
todos = []
|
||||
|
||||
for task in tasks:
|
||||
todo = DailyTodo(
|
||||
user_id=user_id,
|
||||
title=task.title,
|
||||
source=TodoSource.AI_KANBAN,
|
||||
source_detail=f"看板:{task.title}",
|
||||
source_ref_id=task.id,
|
||||
todo_date=date.today().isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
todos.append(todo)
|
||||
|
||||
if todos:
|
||||
await db.commit()
|
||||
for todo in todos:
|
||||
await db.refresh(todo)
|
||||
|
||||
return todos
|
||||
|
||||
|
||||
async def _analyze_chat_history(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
|
||||
"""分析前一天对话,提取待办事项"""
|
||||
try:
|
||||
# 查询前一天创建的对话
|
||||
conv_q = select(Conversation).where(
|
||||
Conversation.user_id == user_id,
|
||||
).order_by(Conversation.created_at.desc()).limit(10)
|
||||
convs = (await db.execute(conv_q)).scalars().all()
|
||||
|
||||
# 过滤出昨天的对话
|
||||
yesterday_convs = []
|
||||
for conv in convs:
|
||||
created = conv.created_at
|
||||
if hasattr(created, 'date'):
|
||||
created_date = created.date() if hasattr(created, 'date') else created
|
||||
else:
|
||||
created_date = datetime.fromisoformat(str(created)).date()
|
||||
|
||||
if str(created_date) == date_str or (created + timedelta(hours=8)).strftime('%Y-%m-%d') == date_str:
|
||||
yesterday_convs.append(conv)
|
||||
|
||||
if not yesterday_convs:
|
||||
return []
|
||||
|
||||
# 收集消息内容(限制2000字)
|
||||
messages_content = []
|
||||
for conv in yesterday_convs:
|
||||
msg_q = select(Message).where(
|
||||
Message.conversation_id == conv.id
|
||||
).order_by(Message.created_at.asc()).limit(50)
|
||||
msgs = (await db.execute(msg_q)).scalars().all()
|
||||
for msg in msgs:
|
||||
if msg.content:
|
||||
messages_content.append(f"[{msg.role}]: {msg.content[:500]}")
|
||||
|
||||
if not messages_content:
|
||||
return []
|
||||
|
||||
full_text = "\n".join(messages_content)[:2000]
|
||||
|
||||
# 调用 LLM 分析
|
||||
prompt = f"""你是一个任务规划助手。请分析以下对话记录,提取其中用户想要完成但尚未明确完成的事项。
|
||||
|
||||
要求:
|
||||
- 最多提取 3 条
|
||||
- 每条格式:{{"title": "事项描述(50字以内)", "reason": "来源说明(60字以内)"}}
|
||||
- 只提取用户明确表达过需求但还未完成的事项
|
||||
- 如果没有可提取的内容,返回空数组 []
|
||||
|
||||
对话记录:
|
||||
{full_text}
|
||||
|
||||
返回 JSON 数组:"""
|
||||
|
||||
llm = get_llm()
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个任务规划助手。"),
|
||||
HumanMessage(content=prompt),
|
||||
])
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
# 提取 JSON 数组
|
||||
start = content.find('[')
|
||||
end = content.rfind(']') + 1
|
||||
if start != -1 and end > start:
|
||||
items = json.loads(content[start:end])
|
||||
else:
|
||||
items = []
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"LLM 返回格式异常,跳过对话分析: {content[:200]}")
|
||||
items = []
|
||||
|
||||
if not items:
|
||||
return []
|
||||
|
||||
todos = []
|
||||
for item in items[:3]:
|
||||
todo = DailyTodo(
|
||||
user_id=user_id,
|
||||
title=item.get("title", "")[:500],
|
||||
source=TodoSource.AI_CHAT,
|
||||
source_detail=f"对话:{item.get('reason', '')[:60]}",
|
||||
todo_date=date.today().isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
todos.append(todo)
|
||||
|
||||
if todos:
|
||||
await db.commit()
|
||||
for todo in todos:
|
||||
await db.refresh(todo)
|
||||
|
||||
return todos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对话分析失败: {e}")
|
||||
return []
|
||||
BIN
backend/data/chroma/chroma.sqlite3
Normal file
BIN
backend/data/chroma/chroma.sqlite3
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/data/jarvis.db
Normal file
BIN
backend/data/jarvis.db
Normal file
Binary file not shown.
BIN
backend/data/uploads/973133b8-94ea-498e-95db-cfceec981e09.docx
Normal file
BIN
backend/data/uploads/973133b8-94ea-498e-95db-cfceec981e09.docx
Normal file
Binary file not shown.
119
backend/data/uploads/c53861cd-9eca-485c-9048-80e93cfee8b2.txt
Normal file
119
backend/data/uploads/c53861cd-9eca-485c-9048-80e93cfee8b2.txt
Normal file
@@ -0,0 +1,119 @@
|
||||
远光软件股份有限公司科技项目可行性研究报告
|
||||
|
||||
项目名称:大模型微调技术研究与应用
|
||||
|
||||
申请部门:
|
||||
|
||||
起止时间:年至年
|
||||
|
||||
项目负责人:
|
||||
|
||||
联系电话:
|
||||
|
||||
申请日期:年 月
|
||||
|
||||
大模型微调技术可行性研究报告
|
||||
|
||||
远光软件股份有限公司科技项目可行性研究报告
|
||||
|
||||
项目名称: 大模型微调技术研究与应用
|
||||
|
||||
申请部门:
|
||||
|
||||
起止时间: 年 月至 年 月
|
||||
|
||||
项目负责人:
|
||||
|
||||
联系电话:
|
||||
|
||||
申请日期: 年 月
|
||||
|
||||
一、目的和意义
|
||||
|
||||
1.1 项目背景与需求
|
||||
|
||||
近年来,以深度学习为基础的大型预训练语言模型(Large Language Models,
|
||||
LLMs)如GPT系列、BERT、LLaMA等在自然语言处理领域取得了突破性进展,通过海量数据的预训练和超大规模参数量,这些模型展现出强大的通用语言理解与生成能力,在机器翻译、文本摘要、问答系统、内容创作等众多任务中表现出色,引领了人工智能技术的新浪潮。然而,这些通用大模型在面对特定专业领域任务时,往往存在知识覆盖不足、专业术语理解偏差、领域特定逻辑推理能力欠缺、输出风格不符合行业特点等问题,难以直接满足垂直场景的应用需求。
|
||||
|
||||
模型微调(Fine-tuning)技术作为将通用大模型适配到特定场景的关键手段,通过在领域相关数据上进一步训练模型参数,使模型能够吸收领域知识、适应特定任务要求,从而显著提升模型在目标任务上的性能表现。随着大模型参数规模的不断扩大,传统的全参数微调方式面临着计算资源消耗大、存储成本高、容易产生灾难性遗忘等挑战,因此,参数高效微调(Parameter-Efficient
|
||||
Fine-Tuning,
|
||||
PEFT)方法如LoRA、Adapter、Prefix-tuning等技术应运而生,为低成本、高效率的大模型领域适配提供了新的技术路径。
|
||||
|
||||
本项目旨在探索适合特定领域特点的高效微调策略,解决数据稀缺性、专业术语理解、领域知识融合等关键技术问题,提升模型在特定场景下的准确性、可靠性和实用性。
|
||||
|
||||
项目成果将对该现状和技术发展的作用主要体现在技术推动作用和应用落地支撑两方面。
|
||||
|
||||
二、国内外研究水平综述
|
||||
|
||||
2.1 技术发展历史简要回顾
|
||||
|
||||
大模型微调技术的发展历程分为四个阶段:
|
||||
|
||||
第一阶段(2018年前):传统迁移学习与微调雏形阶段。模型适配多采用传统迁移学习思路,将通用数据集上训练的基础模型迁移至特定任务场景。
|
||||
|
||||
第二阶段(2018-2020年):预训练-微调范式确立阶段。2018年谷歌提出BERT模型,首次构建"预训练通用知识+下游任务微调"的技术框架。
|
||||
|
||||
第三阶段(2020-2022年):高效微调技术爆发阶段。LoRA、QLoRA、Adapter等参数高效微调技术相继出现,将微调参数规模大幅降低。
|
||||
|
||||
第四阶段(2022年至今):垂直领域深化与协同优化阶段。"基座模型+领域微调"的架构成为主流,微调技术与知识图谱进一步融合。
|
||||
|
||||
2.2 国内外研究水平现状和发展趋势
|
||||
|
||||
国际层面,Hugging
|
||||
Face、DeepSpeed等开源社区为参数高效微调技术的普及提供了重要支撑。国内层面,阿里云基于通义千问进行财税领域定制微调,验证了微调技术在财务领域的应用价值。
|
||||
|
||||
三、项目的理论和实践依据
|
||||
|
||||
3.1 项目研究内容原理简述
|
||||
|
||||
本项目采用"基座模型+领域适配"分层微调架构,选取开源基座模型,针对财务问答场景特性采用LoRA参数高效微调策略。
|
||||
|
||||
3.2 项目研究内容理论和实践依据
|
||||
|
||||
理论依据包括国家战略层面的政策支持和成熟的技术理论体系。实践依据包括大模型微调技术在财务等垂直领域的成功案例。
|
||||
|
||||
3.3 项目研究的关键和难点
|
||||
|
||||
关键点包括高质量数据集构建、高效微调策略适配、知识精准注入与幻觉抑制、效果评估体系建设。难点集中在数据处理、微调策略、知识注入和评估体系四个方面。
|
||||
|
||||
四、项目研究内容和实施方案
|
||||
|
||||
4.1 项目研究内容详细说明
|
||||
|
||||
本项目研究内容包括数据格式研究、微调框架研究、模型微调后评估体系研究三个方面。
|
||||
|
||||
4.2 理论研究步骤和试验计划
|
||||
|
||||
包括数据处理流程、训练数据生成流程、数据验证流程三个主要环节。
|
||||
|
||||
4.3 项目组织方式和协作分工
|
||||
|
||||
本项目由项目负责人统筹协调,下设数据组、算法组、应用组三个工作小组。
|
||||
|
||||
五、预期目标和成果形式
|
||||
|
||||
5.1 项目研究预期达到的目标
|
||||
|
||||
技术目标:问答准确率达到85%以上。应用目标:开发财务智能知识问答原型系统。效益目标:替代财务专家70%以上的重复性咨询工作。
|
||||
|
||||
5.2 明确叙述提高研究成果的形式
|
||||
|
||||
包括技术方案文档、原型系统、训练数据集、微调模型、技术论文/报告等成果形式。
|
||||
|
||||
六、项目承担团队的条件
|
||||
|
||||
项目团队具备人工智能、大数据等领域的技术背景,具备财务信息系统开发经验,具备充足的GPU计算资源和完善的开发测试环境。
|
||||
|
||||
七、项目进度安排
|
||||
|
||||
第1-2月:项目启动、需求分析;第3-4月:数据收集、清洗;第5-7月:数据集生成;第8-10月:模型训练;第11-12月:系统开发;第13-14月:优化整理;第15-16月:验收转化。
|
||||
|
||||
八、项目经费预算
|
||||
|
||||
本项目经费预算根据实际研究工作需要编制,包括人工费、设备使用费、业务费、场地使用费、专家咨询费等科目。
|
||||
|
||||
分管领导审核意见:
|
||||
|
||||
(对经费预算是否合理,有无其他经费来源,能否保证研究计划实施所需的人力,工作时间等基本条件提出具体意见)
|
||||
|
||||
分管领导(签字): 年 月 日
|
||||
77
backend/pyproject.toml
Normal file
77
backend/pyproject.toml
Normal file
@@ -0,0 +1,77 @@
|
||||
[project]
|
||||
name = "jarvis-backend"
|
||||
version = "0.1.0"
|
||||
description = "Jarvis Personal AI Assistant - Backend"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
license = { text = "MIT" }
|
||||
|
||||
dependencies = [
|
||||
# Web 框架
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"python-multipart>=0.0.12",
|
||||
"websockets>=12.0",
|
||||
"aiofiles>=24.0.0",
|
||||
|
||||
# Agent 框架
|
||||
"langgraph>=0.2.36",
|
||||
"langchain-anthropic>=0.3.14",
|
||||
"langchain-openai>=0.3.18",
|
||||
"langchain-core>=0.3.52",
|
||||
"langchain-ollama>=0.4.0",
|
||||
"langsmith>=0.1.0",
|
||||
|
||||
# 知识库框架
|
||||
"llama-index>=0.12.0",
|
||||
"llama-index-vector-stores-chroma>=0.3.0",
|
||||
"chromadb>=0.5.0",
|
||||
|
||||
# 数据库
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"alembic>=1.13.0",
|
||||
|
||||
# 认证 & 安全
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"bcrypt>=4.0.0,<5.0.0",
|
||||
|
||||
# 配置 & 验证
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"email-validator>=2.0.0",
|
||||
|
||||
# 定时任务
|
||||
"APScheduler>=3.10.0",
|
||||
|
||||
# 工具
|
||||
"python-dotenv>=1.0.0",
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.5.0",
|
||||
"mypy>=1.10.0",
|
||||
"pre-commit>=3.7.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["app"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
0
backend/tests/backend/__init__.py
Normal file
0
backend/tests/backend/__init__.py
Normal file
0
backend/tests/backend/app/__init__.py
Normal file
0
backend/tests/backend/app/__init__.py
Normal file
0
backend/tests/backend/app/services/__init__.py
Normal file
0
backend/tests/backend/app/services/__init__.py
Normal file
120
backend/tests/backend/app/services/test_tag_service.py
Normal file
120
backend/tests/backend/app/services/test_tag_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from app.services.tag_service import TagService, TAG_EXTRACTION_PROMPT, TAG_RELATION_PROMPT
|
||||
|
||||
|
||||
class TestTagService:
|
||||
"""TagService 单元测试"""
|
||||
|
||||
def test_parse_tag_path_single_level(self):
|
||||
"""测试单层标签路径解析"""
|
||||
service = TagService(db=MagicMock(), llm_client=MagicMock())
|
||||
short_name, level, parent_path = service.parse_tag_path("Python")
|
||||
|
||||
assert short_name == "Python"
|
||||
assert level == 1
|
||||
assert parent_path is None
|
||||
|
||||
def test_parse_tag_path_nested(self):
|
||||
"""测试多层标签路径解析"""
|
||||
service = TagService(db=MagicMock(), llm_client=MagicMock())
|
||||
short_name, level, parent_path = service.parse_tag_path("编程语言/Python/异步")
|
||||
|
||||
assert short_name == "异步"
|
||||
assert level == 3
|
||||
assert parent_path == "编程语言/Python"
|
||||
|
||||
def test_parse_tag_path_strips_slashes(self):
|
||||
"""测试标签路径斜杠处理"""
|
||||
service = TagService(db=MagicMock(), llm_client=MagicMock())
|
||||
short_name, level, parent_path = service.parse_tag_path("/后端/框架/")
|
||||
|
||||
assert short_name == "框架"
|
||||
assert level == 2
|
||||
assert parent_path == "后端"
|
||||
|
||||
def test_parse_tag_path_empty_parts(self):
|
||||
"""测试空路径部分处理"""
|
||||
service = TagService(db=MagicMock(), llm_client=MagicMock())
|
||||
short_name, level, parent_path = service.parse_tag_path("a/b/c/d")
|
||||
|
||||
assert short_name == "d"
|
||||
assert level == 4
|
||||
assert parent_path == "a/b/c"
|
||||
|
||||
@patch('app.services.tag_service.KGNode')
|
||||
@patch('app.services.tag_service.KGEdge')
|
||||
def test_get_or_create_tag_node_creates_new(self, mock_edge, mock_node):
|
||||
"""测试创建新标签节点"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
service = TagService(db=mock_db, llm_client=MagicMock())
|
||||
tag_info = {"path": "Python", "description": "Python语言"}
|
||||
|
||||
result = service.get_or_create_tag_node(tag_info, "user_123")
|
||||
|
||||
assert result is not None
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.flush.assert_called_once()
|
||||
|
||||
@patch('app.services.tag_service.KGNode')
|
||||
def test_get_or_create_tag_node_returns_existing(self, mock_node):
|
||||
"""测试返回已存在的标签节点"""
|
||||
mock_db = MagicMock()
|
||||
mock_existing = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_existing
|
||||
|
||||
service = TagService(db=mock_db, llm_client=MagicMock())
|
||||
tag_info = {"path": "Python", "description": "Python语言"}
|
||||
|
||||
result = service.get_or_create_tag_node(tag_info, "user_123")
|
||||
|
||||
assert result == mock_existing
|
||||
mock_db.add.assert_not_called()
|
||||
|
||||
def test_ensure_parent_tags_creates_parents(self):
|
||||
"""测试自动创建父标签"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
service = TagService(db=mock_db, llm_client=MagicMock())
|
||||
|
||||
with patch.object(service, 'get_or_create_tag_node') as mock_create:
|
||||
mock_create.return_value = MagicMock()
|
||||
result = service.ensure_parent_tags("a/b/c", "user_123")
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
def test_ensure_parent_tags_single_level(self):
|
||||
"""测试单层标签不创建父标签"""
|
||||
mock_db = MagicMock()
|
||||
service = TagService(db=mock_db, llm_client=MagicMock())
|
||||
|
||||
with patch.object(service, 'get_or_create_tag_node') as mock_create:
|
||||
mock_create.return_value = MagicMock()
|
||||
result = service.ensure_parent_tags("Python", "user_123")
|
||||
assert mock_create.call_count == 0
|
||||
|
||||
@patch('app.services.tag_service.KGNode')
|
||||
def test_get_related_content_empty_tags(self, mock_kg_node):
|
||||
"""测试空标签列表返回空结果"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
service = TagService(db=mock_db, llm_client=MagicMock())
|
||||
result = service.get_related_content([], "user_123")
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_tag_extraction_prompt_format(self):
|
||||
"""测试标签提取提示词格式"""
|
||||
assert "层级路径格式" in TAG_EXTRACTION_PROMPT
|
||||
assert "3-8 个标签" in TAG_EXTRACTION_PROMPT
|
||||
assert "{content}" in TAG_EXTRACTION_PROMPT
|
||||
|
||||
def test_tag_relation_prompt_format(self):
|
||||
"""测试标签关系提示词格式"""
|
||||
assert "parent_of" in TAG_RELATION_PROMPT
|
||||
assert "related_to" in TAG_RELATION_PROMPT
|
||||
assert "synonym_of" in TAG_RELATION_PROMPT
|
||||
assert "{tag_paths}" in TAG_RELATION_PROMPT
|
||||
4261
backend/uv.lock
generated
Normal file
4261
backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user