Add FastAPI backend with agent system

This commit is contained in:
2026-03-21 10:13:29 +08:00
parent ed6bab59fe
commit 6ffa07adde
82 changed files with 11138 additions and 0 deletions

54
backend/.env.example Normal file
View File

@@ -0,0 +1,54 @@
# =============================================
# Jarvis 后端配置
# 复制此文件为 .env 并填入实际值
# =============================================
# === 应用基础 ===
DEBUG=false
SECRET_KEY=change-me-to-a-random-secret-key
# === LLM 配置 ===
# 支持: openai / claude / deepseek / ollama / custom
LLM_PROVIDER=openai
# OpenAI默认
OPENAI_API_KEY=your-openai-api-key-here
OPENAI_MODEL=gpt-4o
OPENAI_BASE_URL=https://api.openai.com/v1
# Claude可选
# ANTHROPIC_API_KEY=your-anthropic-api-key-here
# CLAUDE_MODEL=claude-sonnet-4-20250514
# DeepSeek可选
# LLM_PROVIDER=deepseek
# OPENAI_API_KEY=your-deepseek-api-key
# OPENAI_BASE_URL=https://api.deepseek.com/v1
# Ollama 本地模型(可选)
# LLM_PROVIDER=ollama
# OLLAMA_BASE_URL=http://localhost:11434
# OLLAMA_MODEL=llama3
# 自定义 OpenAI 兼容接口(可选)
# LLM_PROVIDER=custom
# OPENAI_API_KEY=your-api-key
# OPENAI_BASE_URL=https://your-custom-endpoint/v1
# === NAS 部署路径 ===
NAS_DATA_ROOT=/data/jarvis
DATA_DIR=/data/jarvis/data
CHROMA_PERSIST_DIR=/data/jarvis/chroma
UPLOAD_DIR=/data/jarvis/uploads
# === LangSmith 可观测性 ===
# 启用 LangSmith 追踪(可选)
LANGSMITH_TRACING=false
LANGSMITH_API_KEY=your-langsmith-api-key
LANGSMITH_PROJECT=jarvis-agent
# === 定时任务 ===
SCHEDULER_ENABLED=true
DAILY_PLAN_TIME=00:00
FORUM_SCAN_INTERVAL_MINUTES=30

21
backend/Dockerfile Normal file
View File

@@ -0,0 +1,21 @@
FROM python:3.12-slim
WORKDIR /app
# 安装依赖
COPY pyproject.toml .
RUN pip install --no-cache-dir uv && \
uv pip install --system --no-cache -r pyproject.toml
# 安装可选依赖
RUN uv pip install --system --no-cache pymupdf python-docx
# 复制代码
COPY app/ ./app/
# 创建数据目录
RUN mkdir -p /data/jarvis/data /data/jarvis/chroma /data/jarvis/uploads
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

35
backend/README.md Normal file
View File

@@ -0,0 +1,35 @@
# Jarvis Backend
## 快速开始
### 1. 安装依赖
```bash
cd backend
uv sync
```
### 2. 配置环境变量
```bash
cp .env.example .env
# 编辑 .env 填入 API Key
```
### 3. 启动开发服务器
```bash
uv run uvicorn app.main:app --reload --port 8000
```
### 4. API 文档
启动后访问 http://localhost:8000/docs 查看交互式 API 文档。
## 环境变量
`.env.example`
## 数据库
SQLite 数据库位于 `./data/jarvis.db`,首次启动自动创建表。

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Jarvis Backend

View File

@@ -0,0 +1,24 @@
"""
Agent 运行时上下文
用于在工具调用链中传递 user_id 等上下文信息
"""
from contextvars import ContextVar
from typing import Optional
_current_user_id: ContextVar[Optional[str]] = ContextVar("current_user_id", default=None)
def set_current_user(user_id: str):
"""设置当前用户ID线程/协程安全)"""
_current_user_id.set(user_id)
def get_current_user() -> str:
"""获取当前用户ID"""
return _current_user_id.get() or "default"
def clear_current_user():
"""清除当前用户上下文"""
_current_user_id.set(None)

265
backend/app/agents/graph.py Normal file
View File

@@ -0,0 +1,265 @@
"""
Jarvis LangGraph Agent 主图定义
"""
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from app.agents.state import AgentState, AgentRole
from app.agents.prompts import (
MASTER_SYSTEM_PROMPT,
PLANNER_SYSTEM_PROMPT,
EXECUTOR_SYSTEM_PROMPT,
LIBRARIAN_SYSTEM_PROMPT,
ANALYST_SYSTEM_PROMPT,
)
from app.agents.tools import ALL_TOOLS
from app.services.llm_service import get_llm
def _msg_type(msg: BaseMessage) -> str:
"""Get message type, handles both .type (new) and .role (old) attribute names."""
return getattr(msg, "type", None) or getattr(msg, "role", "human")
def _filter_user_messages(messages: list) -> list[BaseMessage]:
return [m for m in messages if _msg_type(m) in ("human", "user")]
# ===================== 节点定义 (async) =====================
async def master_node(state: AgentState) -> AgentState:
"""主Agent节点: 理解用户意图决定调用哪个子Agent"""
llm = get_llm()
messages: list[BaseMessage] = state["messages"]
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
# 注入记忆上下文
memory_ctx = state.get("memory_context")
if memory_ctx:
system_msgs.append(
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
)
response: AIMessage = await llm.invoke(system_msgs + messages)
content = response.content.strip().lower()
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
next_agent = AgentRole.LIBRARIAN
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
next_agent = AgentRole.PLANNER
elif any(kw in content for kw in ["执行", "", "操作", "创建", "更新"]):
next_agent = AgentRole.EXECUTOR
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
next_agent = AgentRole.ANALYST
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
state["current_agent"] = next_agent
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
state["should_respond"] = True
return state
async def planner_node(state: AgentState) -> AgentState:
"""规划Agent节点: 制定计划,拆解任务步骤"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
response = await llm.invoke(
[SystemMessage(content=PLANNER_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
)
plan_text = response.content
steps = []
for i, line in enumerate(plan_text.split("\n")):
if line.strip() and (line[0].isdigit() or "- " in line):
steps.append({"step": i + 1, "description": line.strip()})
state["plan"] = plan_text
state["plan_steps"] = steps
state["final_response"] = plan_text
state["should_respond"] = True
return state
async def executor_node(state: AgentState) -> AgentState:
"""执行Agent节点: 调用工具执行具体任务"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
response = await llm.bind_tools(ALL_TOOLS).invoke(
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
async def librarian_node(state: AgentState) -> AgentState:
"""知识管理员节点: 管理知识库和知识图谱"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
response = await llm.bind_tools(ALL_TOOLS).invoke(
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["knowledge_context"] = state.get("last_tool_result", "")
state["should_respond"] = True
return state
async def analyst_node(state: AgentState) -> AgentState:
"""分析师节点: 分析工作数据,生成报告"""
llm = get_llm()
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
response = await llm.bind_tools(ALL_TOOLS).invoke(
[SystemMessage(content=ANALYST_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in ALL_TOOLS:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await llm.invoke(
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
state["analysis_report"] = state.get("final_response", "")
state["should_respond"] = True
return state
def route_agent(state: AgentState) -> str:
"""路由函数: 决定下一个节点"""
if state.get("final_response"):
return END
return state.get("current_agent", AgentRole.MASTER).value
# ===================== 构建图 =====================
def create_agent_graph(callbacks: list | None = None):
graph = StateGraph(AgentState)
graph.add_node(AgentRole.MASTER.value, master_node)
graph.add_node(AgentRole.PLANNER.value, planner_node)
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
graph.add_node(AgentRole.ANALYST.value, analyst_node)
graph.set_entry_point(AgentRole.MASTER.value)
graph.add_conditional_edges(
AgentRole.MASTER.value,
route_agent,
{
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
END: END,
}
)
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
graph.add_edge(role.value, END)
return graph.compile(callbacks=callbacks)
_agent_graph = None
def get_agent_graph(callbacks: list | None = None):
"""
获取编译好的 Agent 图(单例缓存)。
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
如需变更 Callbacks如修改 LANGCHAIN_PROJECT需重启服务。
Args:
callbacks: 可选的额外 Callbacks会与全局 LangSmith Callbacks 合并
"""
global _agent_graph
if _agent_graph is None:
from app.config_tracing import get_langsmith_callbacks
langsmith_callbacks = get_langsmith_callbacks()
all_callbacks = (callbacks or []) + langsmith_callbacks
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
return _agent_graph

View File

@@ -0,0 +1,127 @@
"""
Jarvis 多Agent系统的提示词定义
"""
MASTER_SYSTEM_PROMPT = """你叫 Jarvis是用户的私人AI助理。
你的职责是理解用户意图并将任务分发给最合适的子Agent。
## 你的4个子Agent:
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
## 判断规则:
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
- 用户要计划、安排、拆解任务 -> 分发给 planner
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
- 用户要分析、统计、生成报告 -> 分发给 analyst
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
## 响应格式:
简短回复用户告知你将调用哪个Agent处理。如果用户不需要任何子Agent直接给出回答。
注意: 你是协调者不需要亲自执行具体任务让专业Agent去做。
"""
PLANNER_SYSTEM_PROMPT = """你是 Jarvis 的规划Agent负责制定计划、拆解任务。
## 你的能力:
- 分析复杂请求,拆解成可执行的步骤
- 评估任务优先级
- 估算时间安排
- 制定执行顺序
## 工作流程:
1. 理解用户的总目标
2. 拆解成具体步骤
3. 标注每步的优先级
4. 给出清晰的执行计划
## 响应要求:
- 用编号列表展示计划步骤
- 每步清晰描述要做什么
- 可以为每步指定优先级(P1/P2/P3)
- 如果需要执行,先输出计划,然后用户确认后再执行
"""
EXECUTOR_SYSTEM_PROMPT = """你是 Jarvis 的执行Agent负责执行具体任务。
## 你可以使用的工具:
- create_task: 创建新任务
- update_task_status: 更新任务状态
- get_tasks: 查看任务列表
- create_forum_post: 在论坛发布帖子
- get_forum_posts: 查看论坛帖子
- scan_forum_for_instructions: 扫描论坛指令
## 工作流程:
1. 理解用户要执行什么
2. 调用相应工具
3. 报告执行结果
4. 询问用户是否需要下一步操作
## 响应要求:
- 明确告知用户正在执行什么
- 工具调用结果要格式化呈现
- 如果执行成功,给出确认
- 如果需要更多信息,明确告知用户
"""
LIBRARIAN_SYSTEM_PROMPT = """你是 Jarvis 的知识管理员,负责管理用户的私人知识库。
## 你可以使用的工具:
- search_knowledge: 搜索知识库,返回相关文档片段
- get_knowledge_graph_context: 获取知识图谱上下文
- build_knowledge_graph: 从文档构建知识图谱
## 你的职责:
1. 理解用户关于知识的问题
2. 搜索相关知识
3. 综合多篇文档给出完整回答
4. 帮助用户整理和理解知识
## 工作流程:
1. 分析用户的知识查询
2. 搜索相关文档
3. 综合相关信息给出回答
4. 如果有图谱关联,可以引用图谱中的关系
## 响应要求:
- 回答要有文档依据
- 引用时标注来源
- 如果知识不足,诚实告知用户
- 可以补充相关知识背景
"""
ANALYST_SYSTEM_PROMPT = """你是 Jarvis 的分析师,负责分析数据和工作状态。
## 你可以使用的工具:
- get_tasks: 获取任务列表,统计工作进度
- get_forum_posts: 获取论坛帖子,分析讨论趋势
- scan_forum_for_instructions: 检查待执行指令
- search_knowledge: 结合知识进行分析
## 你的职责:
1. 统计任务完成情况
2. 分析工作进度和趋势
3. 生成数据报告
4. 识别潜在问题和风险
## 工作流程:
1. 收集相关数据(任务、论坛、知识)
2. 进行数据分析
3. 生成结构化报告
4. 给出建议
## 响应要求:
- 用数据说话,有数字有结论
- 报告结构清晰
- 给出可行的改进建议
- 识别需要关注的问题
"""

105
backend/app/agents/state.py Normal file
View File

@@ -0,0 +1,105 @@
from dataclasses import dataclass
from typing import TypedDict, Annotated
from enum import Enum
class AgentRole(str, Enum):
MASTER = "master"
PLANNER = "planner"
EXECUTOR = "executor"
LIBRARIAN = "librarian"
ANALYST = "analyst"
@dataclass
class AgentInfo:
name: str
role: AgentRole
description: str
@dataclass
class ToolCall:
tool: str
args: dict
result: str | None = None
@dataclass
class ConversationTurn:
role: str # "user" | "assistant"
content: str
agent: AgentRole | None = None
model: str | None = None
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
return HumanMessage(content=turn.content)
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
return ConversationTurn(
role="user" if msg_type in ("human", "user") else "assistant",
content=msg.content,
agent=agent,
model=getattr(msg, "model", None),
)
class AgentState(TypedDict):
messages: Annotated[list, None]
user_id: str
conversation_id: str
# Agent routing
current_agent: AgentRole
active_agents: list[AgentRole]
# Task tracking
pending_tasks: list[dict]
completed_tasks: list[dict]
# Tool usage
tool_calls: list[ToolCall]
last_tool_result: str | None
# Knowledge context
knowledge_context: str | None
graph_context: str | None
# Planning
plan: str | None
plan_steps: list[dict]
# Analysis
analysis_report: str | None
# Output control
final_response: str | None
should_respond: bool
# Memory context (injected at start of each conversation)
memory_context: str | None
def initial_state(user_id: str, conversation_id: str) -> AgentState:
return AgentState(
messages=[],
user_id=user_id,
conversation_id=conversation_id,
current_agent=AgentRole.MASTER,
active_agents=[AgentRole.MASTER],
pending_tasks=[],
completed_tasks=[],
tool_calls=[],
last_tool_result=None,
knowledge_context=None,
graph_context=None,
plan=None,
plan_steps=[],
analysis_report=None,
final_response=None,
should_respond=True,
memory_context=None,
)

View File

@@ -0,0 +1,22 @@
from app.agents.tools.search import (
search_knowledge, get_knowledge_graph_context,
build_knowledge_graph, hybrid_search,
)
from app.agents.tools.task import get_tasks, create_task, update_task_status
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
ALL_TOOLS = [
# 知识库工具
search_knowledge,
get_knowledge_graph_context,
build_knowledge_graph,
hybrid_search,
# 任务工具
get_tasks,
create_task,
update_task_status,
# 论坛工具
get_forum_posts,
create_forum_post,
scan_forum_for_instructions,
]

View File

@@ -0,0 +1,134 @@
"""Agent 工具集 - 论坛相关"""
from langchain_core.tools import tool
from app.database import async_session
from app.models.forum import ForumPost, ForumReply
from app.agents.context import get_current_user
from sqlalchemy import select
import asyncio
def _run_async(coro, timeout: int = 30):
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
return future.result(timeout=timeout)
except RuntimeError:
return asyncio.run(coro)
@tool
def get_forum_posts(category: str | None = None, limit: int = 10) -> str:
"""
获取论坛帖子列表。
Args:
category: 可选,筛选分类 (discussion/instruction/question)
limit: 返回数量默认10
Returns:
帖子列表
"""
uid = get_current_user()
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(ForumPost)
.join(User, User.id == ForumPost.user_id)
.where(User.id == uid)
)
if category:
query = query.where(ForumPost.category == category)
query = query.order_by(ForumPost.created_at.desc()).limit(limit)
result = await db.execute(query)
posts = result.scalars().all()
if not posts:
return "暂无帖子"
lines = []
for p in posts:
exec_mark = " [已执行]" if p.is_executed else ""
lines.append(
f"- [{p.id[:8]}] [{p.category}] {p.title} | "
f"{p.content[:50]}...{exec_mark}"
)
return "\n".join(lines)
try:
return _run_async(_get())
except Exception as e:
return f"获取帖子失败: {str(e)}"
@tool
def create_forum_post(title: str, content: str, category: str = "discussion") -> str:
"""
在论坛发布新帖子。
Args:
title: 帖子标题
content: 帖子内容
category: 分类 (discussion/instruction/question)默认discussion
Returns:
创建结果
"""
uid = get_current_user()
async def _create():
async with async_session() as db:
post = ForumPost(
user_id=uid,
title=title,
content=content,
category=category,
)
db.add(post)
await db.commit()
await db.refresh(post)
return f"帖子发布成功: [{post.id[:8]}] {title}"
try:
return _run_async(_create())
except Exception as e:
return f"发布帖子失败: {str(e)}"
@tool
def scan_forum_for_instructions() -> str:
"""
扫描论坛中的指令类帖子,检查是否有待执行的指令。
Returns:
待执行指令的列表
"""
uid = get_current_user()
async def _scan():
async with async_session() as db:
from app.models.user import User
result = await db.execute(
select(ForumPost)
.join(User, User.id == ForumPost.user_id)
.where(ForumPost.user_id == uid)
.where(ForumPost.category == "instruction")
.where(ForumPost.is_executed == False)
.order_by(ForumPost.created_at.desc())
.limit(10)
)
posts = result.scalars().all()
if not posts:
return "暂无待执行的指令"
lines = ["待执行的指令:"]
for p in posts:
lines.append(f"- [{p.id[:8]}] {p.title}\n 内容: {p.content[:100]}...")
return "\n".join(lines)
try:
return _run_async(_scan())
except Exception as e:
return f"扫描论坛失败: {str(e)}"
__all__ = ["get_forum_posts", "create_forum_post", "scan_forum_for_instructions"]

View File

@@ -0,0 +1,159 @@
"""
Agent 工具集 - 知识库 & 图谱相关
这些工具在 LangChain ToolNode 中被调用。
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
"""
from langchain_core.tools import tool
from concurrent.futures import ThreadPoolExecutor
from app.database import async_session
from app.agents.context import get_current_user
import asyncio
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
"""在同步上下文中运行 async 代码"""
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
return future.result(timeout=timeout)
except RuntimeError:
return asyncio.run(coro)
@tool
def search_knowledge(query: str, top_k: int = 5) -> str:
"""
搜索用户的私人知识库。根据查询返回最相关的文档片段,支持语义检索。
Args:
query: 搜索查询
top_k: 返回结果数量默认5条
Returns:
包含相关文档片段和来源信息的格式化文本
"""
from app.services.knowledge_service import KnowledgeService
uid = get_current_user()
async def _search():
async with async_session() as db:
service = KnowledgeService(db, user_id=uid)
results = await service.retrieve(query, user_id=uid, top_k=top_k)
if not results:
return "未找到相关知识。知识库可能为空,或尝试用其他关键词搜索。"
texts = []
for i, r in enumerate(results, 1):
prev = f"\n上一段: {r.prev_chunk[:100]}..." if r.prev_chunk else ""
next_ = f"\n下一段: {r.next_chunk[:100]}..." if r.next_chunk else ""
texts.append(
f"[{i}] 来源: {r.document_title}\n"
f"相关度: {r.score:.2f}\n"
f"{prev}{next_}\n"
f"内容: {r.content[:300]}{'...' if len(r.content) > 300 else ''}"
)
return "\n\n---\n\n".join(texts)
try:
return _run_async(_search(), timeout=30)
except Exception as e:
return f"知识检索失败: {str(e)}"
@tool
def get_knowledge_graph_context(entity: str | None = None) -> str:
"""
获取用户知识图谱的上下文信息。
Args:
entity: 可选,指定要查询的实体名称。如果为空则返回整体图谱摘要。
Returns:
知识图谱节点和关系的描述
"""
from app.services.graph_service import GraphService
uid = get_current_user()
async def _get():
async with async_session() as db:
service = GraphService(db)
if entity:
return await service.get_entity_context(entity, uid)
return await service.get_graph_summary(uid)
try:
return _run_async(_get(), timeout=30)
except Exception as e:
return f"图谱查询失败: {str(e)}"
@tool
def build_knowledge_graph(document_ids: list[str] | None = None) -> str:
"""
从文档构建/更新知识图谱。
Args:
document_ids: 可选指定要处理的文档ID列表。如果为空则处理所有文档。
Returns:
构建结果摘要
"""
from app.services.graph_service import GraphService
uid = get_current_user()
async def _build():
async with async_session() as db:
service = GraphService(db)
await service.build_graph(user_id=uid, document_ids=document_ids)
return "知识图谱构建完成"
try:
return _run_async(_build(), timeout=120)
except Exception as e:
return f"图谱构建失败: {str(e)}"
@tool
def hybrid_search(query: str, top_k: int = 5) -> str:
"""
混合搜索,结合向量语义检索和关键词匹配,返回最相关结果。
Args:
query: 搜索查询
top_k: 返回结果数量默认5条
Returns:
混合检索结果
"""
from app.services.knowledge_service import KnowledgeService
uid = get_current_user()
async def _search():
async with async_session() as db:
service = KnowledgeService(db, user_id=uid)
results = await service.hybrid_search(query, user_id=uid, top_k=top_k)
if not results:
return "未找到相关知识。"
texts = []
for i, r in enumerate(results, 1):
texts.append(
f"[{i}] {r.document_title} (相关度: {r.score:.2f})\n"
f"{r.content[:200]}{'...' if len(r.content) > 200 else ''}"
)
return "\n\n---\n\n".join(texts)
try:
return _run_async(_search(), timeout=30)
except Exception as e:
return f"混合搜索失败: {str(e)}"
__all__ = [
"search_knowledge",
"get_knowledge_graph_context",
"build_knowledge_graph",
"hybrid_search",
]

View File

@@ -0,0 +1,142 @@
"""Agent 工具集 - 任务相关"""
from langchain_core.tools import tool
from app.database import async_session
from app.models.task import Task
from app.agents.context import get_current_user
from sqlalchemy import select
import asyncio
_executor = None
def _run_async(coro, timeout: int = 30):
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
return future.result(timeout=timeout)
except RuntimeError:
return asyncio.run(coro)
@tool
def get_tasks(status: str | None = None, limit: int = 20) -> str:
"""
获取用户当前的任务列表。
Args:
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
limit: 返回数量默认20
Returns:
任务列表
"""
uid = get_current_user()
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if status:
query = query.where(Task.status == status)
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
if not tasks:
return "暂无任务"
lines = []
for t in tasks:
lines.append(
f"- [{t.id[:8]}] {t.title} | "
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or ''}"
)
return "\n".join(lines)
try:
return _run_async(_get())
except Exception as e:
return f"获取任务失败: {str(e)}"
@tool
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
"""
创建新任务。
Args:
title: 任务标题(必填)
description: 任务描述
priority: 优先级 1-4数字越大优先级越高默认2
due_date: 截止日期,格式 YYYY-MM-DD
Returns:
创建结果
"""
uid = get_current_user()
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=title,
description=description,
priority=priority,
due_date=due_date,
status="todo",
)
db.add(task)
await db.commit()
await db.refresh(task)
return f"任务创建成功: [{task.id[:8]}] {title}"
try:
return _run_async(_create())
except Exception as e:
return f"创建任务失败: {str(e)}"
@tool
def update_task_status(task_id: str, status: str) -> str:
"""
更新任务状态。
Args:
task_id: 任务ID完整ID或前8位
status: 新状态 (todo/in_progress/done/blocked)
Returns:
更新结果
"""
uid = get_current_user()
async def _update():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if len(task_id) == 8:
query = query.where(Task.id.like(f"{task_id}%"))
else:
query = query.where(Task.id == task_id)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return f"任务不存在: {task_id}"
task.status = status
await db.commit()
return f"任务状态已更新: {task.title} -> {status}"
try:
return _run_async(_update())
except Exception as e:
return f"更新任务失败: {str(e)}"
__all__ = ["get_tasks", "create_task", "update_task_status"]

69
backend/app/config.py Normal file
View File

@@ -0,0 +1,69 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Literal
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
# === 应用基础 ===
APP_NAME: str = "Jarvis"
APP_VERSION: str = "0.1.0"
DEBUG: bool = False
# === 安全 ===
SECRET_KEY: str = "change-me-in-production"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7
# === 数据库 ===
DATABASE_URL: str = "sqlite+aiosqlite:///./data/jarvis.db"
DATA_DIR: str = "./data"
# === ChromaDB ===
CHROMA_PERSIST_DIR: str = "./data/chroma"
# === LLM 配置 ===
# 支持: openai / claude / ollama / deepseek / custom
LLM_PROVIDER: Literal["openai", "claude", "ollama", "deepseek", "custom"] = "openai"
# OpenAI (默认)
OPENAI_API_KEY: str = ""
OPENAI_MODEL: str = "gpt-4o"
OPENAI_BASE_URL: str = "https://api.openai.com/v1"
# Claude
ANTHROPIC_API_KEY: str = ""
CLAUDE_MODEL: str = "claude-sonnet-4-20250514"
CLAUDE_MAX_TOKENS: int = 8192
# Ollama (本地模型)
OLLAMA_BASE_URL: str = "http://localhost:11434"
OLLAMA_MODEL: str = "llama3"
# === 定时任务 ===
SCHEDULER_ENABLED: bool = True
DAILY_PLAN_TIME: str = "00:00"
FORUM_SCAN_INTERVAL_MINUTES: int = 30
# === CORS ===
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://localhost:3000"]
# === 文件上传 ===
UPLOAD_DIR: str = "./data/uploads"
MAX_UPLOAD_SIZE: int = 50 * 1024 * 1024
# === 向量化 ===
EMBEDDING_MODEL: str = "text-embedding-3-small"
CHUNK_SIZE: int = 500
CHUNK_OVERLAP: int = 50
# === LangSmith 可观测性 ===
LANGSMITH_TRACING: bool = False
LANGSMITH_API_KEY: str = ""
LANGSMITH_PROJECT: str = "jarvis-agent"
# === NAS 部署 ===
NAS_DATA_ROOT: str = "/data/jarvis"
settings = Settings()

View File

@@ -0,0 +1,26 @@
"""
LangSmith Tracing 配置
提供 Callback 工厂函数,用于 LangGraph 追踪
"""
from langchain_core.tracers import LangChainTracer
from app.config import settings
def get_langsmith_callbacks() -> list:
"""
根据配置返回 LangSmith Callback 列表
未启用时返回空列表
"""
if not settings.LANGSMITH_TRACING:
return []
if not settings.LANGSMITH_API_KEY:
return []
return [
LangChainTracer(
project_name=settings.LANGSMITH_PROJECT,
)
]

35
backend/app/database.py Normal file
View File

@@ -0,0 +1,35 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
import os
os.makedirs(settings.DATA_DIR, exist_ok=True)
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
)
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncSession:
async with async_session() as session:
try:
yield session
finally:
await session.close()
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

72
backend/app/main.py Normal file
View File

@@ -0,0 +1,72 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.database import init_db
from app.routers import (
auth_router,
conversation_router,
document_router,
task_router,
forum_router,
graph_router,
agent_router,
todo_router,
settings_router,
folder_router,
)
from app.routers.scheduler import router as scheduler_router
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
from app.config import settings
import os
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
await init_db()
start_scheduler()
yield
# 关闭
stop_scheduler()
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(auth_router)
app.include_router(conversation_router)
app.include_router(document_router)
app.include_router(task_router)
app.include_router(forum_router)
app.include_router(graph_router)
app.include_router(agent_router)
app.include_router(todo_router)
app.include_router(settings_router)
app.include_router(folder_router)
app.include_router(scheduler_router)
@app.get("/api/health")
async def health():
return {
"status": "ok",
"version": settings.APP_VERSION,
"llm_provider": settings.LLM_PROVIDER,
"scheduler": get_scheduler_status(),
}

View File

@@ -0,0 +1,31 @@
from app.models.base import Base
from app.models.user import User
from app.models.document import Document, DocumentChunk
from app.models.task import Task, TaskHistory
from app.models.forum import ForumPost, ForumReply
from app.models.agent import Agent, AgentMessage
from app.models.conversation import Conversation, Message
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.memory import MemorySummary, UserMemory
from app.models.todo import DailyTodo, TodoSource
__all__ = [
"Base",
"User",
"Document",
"DocumentChunk",
"Task",
"TaskHistory",
"ForumPost",
"ForumReply",
"Agent",
"AgentMessage",
"Conversation",
"Message",
"KGNode",
"KGEdge",
"MemorySummary",
"UserMemory",
"DailyTodo",
"TodoSource",
]

View File

@@ -0,0 +1,28 @@
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class Agent(BaseModel):
__tablename__ = "agents"
name = Column(String(100), nullable=False)
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
description = Column(Text, nullable=True)
system_prompt = Column(Text, nullable=False)
is_active = Column(Boolean, default=True)
is_default = Column(Boolean, default=False)
messages = relationship("AgentMessage", back_populates="agent", cascade="all, delete-orphan")
replies = relationship("ForumReply", back_populates="agent")
class AgentMessage(BaseModel):
__tablename__ = "agent_messages"
agent_id = Column(String(36), ForeignKey("agents.id"), nullable=False, index=True)
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
role = Column(String(20), nullable=False) # system, user, assistant
content = Column(Text, nullable=False)
agent = relationship("Agent", back_populates="messages")

View File

@@ -0,0 +1,12 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, DateTime
from app.database import Base
class BaseModel(Base):
__abstract__ = True
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)

View File

@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Text, Integer, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class Conversation(BaseModel):
__tablename__ = "conversations"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(500), nullable=True)
message_count = Column(Integer, default=0)
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
class Message(BaseModel):
__tablename__ = "messages"
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
role = Column(String(20), nullable=False) # user, assistant, system
content = Column(Text, nullable=False)
model = Column(String(100), nullable=True)
tokens_used = Column(Integer, nullable=True)
attachments = Column(JSON, nullable=True) # 新增: [{file_id, filename, file_type, file_size}]
conversation = relationship("Conversation", back_populates="messages")

View File

@@ -0,0 +1,33 @@
from sqlalchemy import Column, String, Integer, Text, ForeignKey, Boolean
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class Document(BaseModel):
__tablename__ = "documents"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(500), nullable=False)
filename = Column(String(500), nullable=False)
file_type = Column(String(50), nullable=False) # pdf, md, txt, docx
file_size = Column(Integer, nullable=False)
file_path = Column(String(1000), nullable=False)
folder_id = Column(String(36), ForeignKey("folders.id"), nullable=True) # 新增
summary = Column(Text, nullable=True)
chunk_count = Column(Integer, default=0)
is_indexed = Column(Boolean, default=False)
chunks = relationship("DocumentChunk", back_populates="document", cascade="all, delete-orphan")
class DocumentChunk(BaseModel):
__tablename__ = "document_chunks"
document_id = Column(String(36), ForeignKey("documents.id"), nullable=False, index=True)
chunk_index = Column(Integer, nullable=False)
content = Column(Text, nullable=False)
metadata_ = Column(String(2000), nullable=True) # JSON 存储元数据
chroma_collection = Column(String(255), nullable=True)
chroma_id = Column(String(255), nullable=True)
document = relationship("Document", back_populates="chunks")

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, String, ForeignKey, UniqueConstraint
from app.models.base import BaseModel
class Folder(BaseModel):
__tablename__ = "folders"
__table_args__ = (
UniqueConstraint('user_id', 'parent_id', 'name', name='uq_user_parent_name'),
)
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(255), nullable=False)
parent_id = Column(String(36), ForeignKey("folders.id"), nullable=True)

View File

@@ -0,0 +1,30 @@
from sqlalchemy import Column, String, Text, Integer, ForeignKey, Boolean
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class ForumPost(BaseModel):
__tablename__ = "forum_posts"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(500), nullable=False)
content = Column(Text, nullable=False)
category = Column(String(100), nullable=True) # instruction, discussion, question
is_executed = Column(Boolean, default=False)
execution_result = Column(Text, nullable=True)
reply_count = Column(Integer, default=0)
replies = relationship("ForumReply", back_populates="post", cascade="all, delete-orphan")
class ForumReply(BaseModel):
__tablename__ = "forum_replies"
post_id = Column(String(36), ForeignKey("forum_posts.id"), nullable=False, index=True)
user_id = Column(String(36), ForeignKey("users.id"), nullable=True)
agent_id = Column(String(36), ForeignKey("agents.id"), nullable=True)
content = Column(Text, nullable=False)
is_ai_reply = Column(Boolean, default=False)
post = relationship("ForumPost", back_populates="replies")
agent = relationship("Agent", back_populates="replies")

View File

@@ -0,0 +1,32 @@
from sqlalchemy import Column, String, Text, Integer, Float, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class KGNode(BaseModel):
__tablename__ = "kg_nodes"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
name = Column(String(500), nullable=False)
entity_type = Column(String(100), nullable=False) # person, concept, task, document, chunk, tag
description = Column(Text, nullable=True)
properties_ = Column(JSON, nullable=True) # 额外属性
source_document_id = Column(String(36), ForeignKey("documents.id"), nullable=True)
importance = Column(Float, default=0.5) # 重要性 0-1
last_updated_by = Column(String(36), nullable=True) # 哪个 agent 更新过
outgoing_edges = relationship("KGEdge", foreign_keys="KGEdge.source_id", back_populates="source_node", cascade="all, delete-orphan")
incoming_edges = relationship("KGEdge", foreign_keys="KGEdge.target_id", back_populates="target_node", cascade="all, delete-orphan")
class KGEdge(BaseModel):
__tablename__ = "kg_edges"
source_id = Column(String(36), ForeignKey("kg_nodes.id"), nullable=False, index=True)
target_id = Column(String(36), ForeignKey("kg_nodes.id"), nullable=False, index=True)
relation_type = Column(String(100), nullable=False) # related_to, part_of, caused_by, depends_on, etc.
weight = Column(Float, default=0.5) # 关系强度 0-1
properties_ = Column(JSON, nullable=True)
source_node = relationship("KGNode", foreign_keys=[source_id], back_populates="outgoing_edges")
target_node = relationship("KGNode", foreign_keys=[target_id], back_populates="incoming_edges")

View File

@@ -0,0 +1,35 @@
from sqlalchemy import Column, String, Text, Integer, ForeignKey, Boolean, DateTime, Enum as SQLEnum
from datetime import datetime
from app.models.base import BaseModel
class MemorySummary(BaseModel):
"""
对话摘要 — 中期记忆
当一段对话超过阈值轮数时,自动生成摘要存入此表
"""
__tablename__ = "memory_summaries"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
conversation_id = Column(String(36), ForeignKey("conversations.id"), nullable=False, index=True)
summary_text = Column(Text, nullable=False) # 摘要内容
turn_count = Column(Integer, default=0) # 摘要时累计轮数
summary_at = Column(DateTime, default=datetime.utcnow, nullable=False)
class UserMemory(BaseModel):
"""
用户画像记忆 — 长期记忆
从对话中提取的用户事实、偏好、目标
"""
__tablename__ = "user_memories"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
memory_type = Column(String(50), nullable=False) # fact | preference | goal | habit | other
content = Column(Text, nullable=False) # 记忆内容
importance = Column(Integer, default=5) # 重要程度 1-10
is_recalled = Column(Boolean, default=False) # 是否在当前对话中被召回
recall_count = Column(Integer, default=0) # 被召回次数
source_conversation_id = Column(String(36), nullable=True) # 来源对话
extracted_at = Column(DateTime, default=datetime.utcnow, nullable=False)
last_recalled_at = Column(DateTime, nullable=True)

View File

@@ -0,0 +1,45 @@
from sqlalchemy import Column, String, Text, Integer, ForeignKey, DateTime, Enum
from sqlalchemy.orm import relationship
from datetime import datetime
from enum import Enum as PyEnum
from app.models.base import BaseModel
class TaskStatus(str, PyEnum):
TODO = "todo"
IN_PROGRESS = "in_progress"
DONE = "done"
CANCELLED = "cancelled"
class TaskPriority(str, PyEnum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
URGENT = "urgent"
class Task(BaseModel):
__tablename__ = "tasks"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(500), nullable=False)
description = Column(Text, nullable=True)
status = Column(Enum(TaskStatus), default=TaskStatus.TODO, nullable=False, index=True)
priority = Column(Enum(TaskPriority), default=TaskPriority.MEDIUM, nullable=False)
due_date = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
tags = Column(String(1000), nullable=True) # JSON 数组
history = relationship("TaskHistory", back_populates="task", cascade="all, delete-orphan")
class TaskHistory(BaseModel):
__tablename__ = "task_histories"
task_id = Column(String(36), ForeignKey("tasks.id"), nullable=False, index=True)
action = Column(String(100), nullable=False) # created, status_changed, updated, deleted
old_value = Column(Text, nullable=True)
new_value = Column(Text, nullable=True)
task = relationship("Task", back_populates="history")

View File

@@ -0,0 +1,8 @@
import pytest
from app.models.folder import Folder
def test_folder_model_creation():
folder = Folder(user_id="test-user", name="Test Folder")
assert folder.name == "Test Folder"
assert folder.parent_id is None

View File

@@ -0,0 +1,24 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Enum
from enum import Enum as PyEnum
from app.models.base import BaseModel
class TodoSource(str, PyEnum):
AI_KANBAN = "ai_kanban"
AI_CHAT = "ai_chat"
MANUAL = "manual"
class DailyTodo(BaseModel):
__tablename__ = "daily_todos"
user_id = Column(String(36), nullable=False, index=True)
title = Column(String(500), nullable=False)
is_completed = Column(Boolean, default=False, nullable=False)
source = Column(Enum(TodoSource), default=TodoSource.MANUAL, nullable=False)
source_detail = Column(String(500), nullable=True)
source_ref_id = Column(String(36), nullable=True)
todo_date = Column(String(10), nullable=False) # YYYY-MM-DD
completed_at = Column(DateTime, nullable=True)

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, Boolean, JSON
from app.models.base import BaseModel
class User(BaseModel):
__tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
full_name = Column(String(255), nullable=True)
is_active = Column(Boolean, default=True, nullable=False)
is_superuser = Column(Boolean, default=False, nullable=False)
# 用户级配置
llm_config = Column(JSON, nullable=True) # LLM 模型配置
scheduler_config = Column(JSON, nullable=True) # 定时任务配置

View File

@@ -0,0 +1,10 @@
from app.routers.auth import router as auth_router
from app.routers.conversation import router as conversation_router
from app.routers.document import router as document_router
from app.routers.task import router as task_router
from app.routers.forum import router as forum_router
from app.routers.graph import router as graph_router
from app.routers.agent import router as agent_router
from app.routers.todo import router as todo_router
from app.routers.settings import router as settings_router
from app.routers.folder import router as folder_router

View File

@@ -0,0 +1,240 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.agent import Agent
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
router = APIRouter(prefix="/api/agents", tags=["Agent"])
# 运行时调用统计(内存中,非持久化)
_agent_call_counts: dict[str, int] = {}
_agent_current_tasks: dict[str, str | None] = {}
_agent_statuses: dict[str, str] = {}
# 默认 Agent 角色列表
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
def record_agent_call(agent_id: str):
_agent_call_counts[agent_id] = _agent_call_counts.get(agent_id, 0) + 1
def set_agent_task(agent_id: str, task: str | None):
_agent_current_tasks[agent_id] = task
_agent_statuses[agent_id] = "active" if task else "idle"
def set_agent_status(agent_id: str, status: str):
_agent_statuses[agent_id] = status
@router.get("", response_model=list[AgentOut])
async def list_agents(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Agent).where(Agent.is_active == True).order_by(Agent.role)
)
return result.scalars().all()
# ———— 运行时统计(必须在 /{agent_id} 之前)————
@router.get("/stats", response_model=list[AgentStats])
async def get_agent_stats(
current_user: User = Depends(get_current_user),
):
"""
获取各 Agent 的运行时统计(调用次数、当前任务、状态)
"""
stats = []
for role in DEFAULT_AGENT_ROLES:
stats.append(AgentStats(
agent_id=role,
call_count=_agent_call_counts.get(role, 0),
current_task=_agent_current_tasks.get(role),
status=_agent_statuses.get(role, "idle"),
))
return stats
# ———— 配置管理(必须在 /{agent_id} 之前)————
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
async def get_agent_config(
agent_id: str,
db: AsyncSession = Depends(get_db),
):
"""获取单个 Agent 完整配置"""
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
defaults = {
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
}
if agent_id not in defaults:
raise HTTPException(status_code=404, detail="Agent 不存在")
name, desc, prompt = defaults[agent_id]
return AgentConfigOut(
id=agent_id, name=name, role=agent_id,
description=desc, system_prompt=prompt, enabled=True, is_active=True,
)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
)
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
async def update_agent_config(
agent_id: str,
data: AgentConfigUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""更新 Agent 配置(名称、描述、提示词、启用状态)"""
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
if data.name is not None:
agent.name = data.name
if data.description is not None:
agent.description = data.description
if data.system_prompt is not None:
agent.system_prompt = data.system_prompt
if data.enabled is not None:
agent.is_active = data.enabled
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
await db.commit()
await db.refresh(agent)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
)
@router.post("", response_model=AgentOut, status_code=201)
async def create_agent(
data: AgentCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
agent = Agent(
name=data.name,
role=data.role,
description=data.description,
system_prompt=data.system_prompt,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return agent
@router.get("/{agent_id}", response_model=AgentOut)
async def get_agent(
agent_id: str,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
return agent
# ———— 配置管理 ————
@router.get("/config/{agent_id}", response_model=AgentConfigOut)
async def get_agent_config(
agent_id: str,
db: AsyncSession = Depends(get_db),
):
"""获取单个 Agent 完整配置"""
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
# 如果数据库中没有,返回默认配置
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
defaults = {
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
}
if agent_id not in defaults:
raise HTTPException(status_code=404, detail="Agent 不存在")
name, desc, prompt = defaults[agent_id]
return AgentConfigOut(
id=agent_id, name=name, role=agent_id,
description=desc, system_prompt=prompt, enabled=True, is_active=True,
)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
)
@router.put("/config/{agent_id}", response_model=AgentConfigOut)
async def update_agent_config(
agent_id: str,
data: AgentConfigUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""更新 Agent 配置(名称、描述、提示词、启用状态)"""
result = await db.execute(select(Agent).where(Agent.role == agent_id))
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
if data.name is not None:
agent.name = data.name
if data.description is not None:
agent.description = data.description
if data.system_prompt is not None:
agent.system_prompt = data.system_prompt
if data.enabled is not None:
agent.is_active = data.enabled
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
await db.commit()
await db.refresh(agent)
return AgentConfigOut(
id=agent.role,
name=agent.name,
role=agent.role,
description=agent.description,
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
)

View File

@@ -0,0 +1,83 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
from app.config import settings
router = APIRouter(prefix="/api/auth", tags=["认证"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db),
) -> User:
payload = decode_token(token)
if payload is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证令牌")
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已禁用")
return user
@router.post("/register", response_model=UserOut, status_code=status.HTTP_201_CREATED)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
# 检查邮箱是否已存在
result = await db.execute(select(User).where(User.email == user_data.email))
if result.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="邮箱已被注册")
# 创建用户
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
full_name=user_data.full_name,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
identifier = form_data.username.strip()
# 支持:邮箱 / UUID / 用户名前缀
user = None
# 1. 尝试 UUID
import re
if re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', identifier, re.I):
result = await db.execute(select(User).where(User.id == identifier))
user = result.scalar_one_or_none()
# 2. 尝试邮箱
if not user:
result = await db.execute(select(User).where(User.email == identifier))
user = result.scalar_one_or_none()
# 3. 尝试用户名前缀email@ 前面的部分)
if not user and '@' not in identifier:
result = await db.execute(select(User).where(User.email.like(f"{identifier}@%")))
user = result.scalar_one_or_none()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
access_token = create_access_token(data={"sub": user.id})
return Token(access_token=access_token)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(get_current_user)):
return current_user

View File

@@ -0,0 +1,217 @@
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.database import get_db
from app.models.conversation import Conversation, Message
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.conversation import ConversationCreate, ConversationOut, ChatRequest, ChatResponse
from app.services.agent_service import AgentService
import json
router = APIRouter(prefix="/api/conversations", tags=["对话"])
@router.get("", response_model=list[ConversationOut])
async def list_conversations(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Conversation)
.where(Conversation.user_id == current_user.id)
.order_by(desc(Conversation.updated_at))
.limit(50)
)
return result.scalars().all()
@router.post("", response_model=ConversationOut, status_code=201)
async def create_conversation(
data: ConversationCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
conv = Conversation(user_id=current_user.id, title=data.title or "新对话")
db.add(conv)
await db.commit()
await db.refresh(conv)
return conv
@router.get("/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取对话的所有消息"""
result = await db.execute(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
)
)
conv = result.scalar_one_or_none()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
msgs = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
)
return msgs.scalars().all()
@router.delete("/{conversation_id}", status_code=204)
async def delete_conversation(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == current_user.id,
)
)
conv = result.scalar_one_or_none()
if not conv:
raise HTTPException(status_code=404, detail="对话不存在")
await db.delete(conv)
await db.commit()
@router.post("/chat", response_model=ChatResponse)
async def chat(
data: ChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""简单版对话(非流式)"""
agent_svc = AgentService(db)
conv_id, msg_id, content = await agent_svc.chat_simple(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
)
# 更新对话消息计数
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
conv = result.scalar_one_or_none()
if conv:
conv.message_count += 2
await db.commit()
return ChatResponse(
conversation_id=conv_id,
message_id=msg_id,
content=content,
agent_name="jarvis",
)
@router.post("/chat/stream")
async def chat_stream(
data: ChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""流式对话"""
agent_svc = AgentService(db)
async def stream_generator():
conv_id, msg_id, stream = await agent_svc.chat(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
)
# 先发送元数据
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
# 流式发送内容
collected = ""
try:
async for chunk in stream:
if chunk:
collected += chunk
yield f"event: chunk\ndata: {json.dumps({'content': chunk})}\n\n"
# 更新数据库中的消息
await agent_svc.save_response(msg_id, collected)
except Exception as e:
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
finally:
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.websocket("/ws/{user_id}/{conversation_id}")
async def websocket_chat(
websocket: WebSocket,
user_id: str,
conversation_id: str,
):
"""WebSocket 流式对话"""
await websocket.accept()
agent_svc = None
try:
async for message in websocket.iter_text():
data = json.loads(message)
user_message = data.get("message", "")
# 每个连接创建新的数据库会话
from app.database import async_session
async with async_session() as db:
agent_svc = AgentService(db)
if conversation_id == "new":
# 新对话
conv_id, msg_id, stream = await agent_svc.chat(
user_id=user_id,
message=user_message,
conversation_id=None,
)
await websocket.send_json({
"type": "metadata",
"conversation_id": conv_id,
"message_id": msg_id,
})
else:
conv_id, msg_id, stream = await agent_svc.chat(
user_id=user_id,
message=user_message,
conversation_id=conversation_id,
)
collected = ""
async for chunk in stream:
if chunk:
collected += chunk
await websocket.send_json({"type": "chunk", "content": chunk})
await agent_svc.save_response(msg_id, collected)
await websocket.send_json({"type": "done", "message_id": msg_id})
except WebSocketDisconnect:
pass
except Exception as e:
try:
await websocket.send_json({"type": "error", "error": str(e)})
except Exception:
pass

View File

@@ -0,0 +1,154 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, BackgroundTasks, Form
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.document import Document, DocumentChunk
from app.models.user import User
from app.routers.auth import get_current_user
from app.services.document_service import DocumentService
from app.services.knowledge_service import KnowledgeService
from dataclasses import asdict
router = APIRouter(prefix="/api/documents", tags=["知识库"])
@router.get("", response_model=list)
async def list_documents(
folder_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
query = select(Document).where(Document.user_id == current_user.id)
if folder_id:
query = query.where(Document.folder_id == folder_id)
result = await db.execute(query.order_by(Document.created_at.desc()))
return result.scalars().all()
@router.post("/upload", status_code=201)
async def upload_document(
background: BackgroundTasks,
file: UploadFile = File(...),
folder_id: Optional[str] = Form(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""上传文档,自动分块并向量化"""
doc_svc = DocumentService(db)
doc = await doc_svc.upload_document(current_user.id, file, folder_id=folder_id)
# 后台索引到 ChromaDB
def index_task():
import asyncio
from app.database import async_session
from app.services.knowledge_service import KnowledgeService
async def _index():
async with async_session() as session:
kb_svc = KnowledgeService(session, user_id=current_user.id)
await kb_svc.index_document(doc.id, user_id=current_user.id)
asyncio.run(_index())
background.add_task(index_task)
return {"id": doc.id, "title": doc.title, "chunk_count": doc.chunk_count, "status": "上传成功,正在索引..."}
@router.get("/{document_id}")
async def get_document(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == current_user.id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise HTTPException(status_code=404, detail="文档不存在")
return doc
@router.get("/{document_id}/chunks")
async def get_document_chunks(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取文档的所有 chunks"""
result = await db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == current_user.id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise HTTPException(status_code=404, detail="文档不存在")
chunks_result = await db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return chunks_result.scalars().all()
@router.delete("/{document_id}", status_code=204)
async def delete_document(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""删除文档"""
doc_svc = DocumentService(db)
await doc_svc.delete_document(current_user.id, document_id)
@router.post("/search")
async def search_documents(
query: str,
top_k: int = 5,
mode: str = "hybrid",
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
搜索知识库
- query: 搜索查询
- top_k: 返回数量默认5
- mode: hybrid混合/ semantic语义/ keyword关键词
"""
kb_svc = KnowledgeService(db, user_id=current_user.id)
if mode == "keyword":
results = await kb_svc._keyword_search(query, current_user.id, top_k)
elif mode == "semantic":
results = await kb_svc.retrieve(query, current_user.id, top_k, use_rerank=True)
else:
results = await kb_svc.hybrid_search(query, current_user.id, top_k)
return [asdict(r) for r in results]
@router.get("/{document_id}/content")
async def get_document_content(
document_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取文档的文本内容用于AI理解"""
from app.services.document_service import DocumentService
doc_svc = DocumentService(db)
content = await doc_svc.get_document_content(current_user.id, document_id)
if content is None:
raise HTTPException(status_code=404, detail="文档不存在或无内容")
return {"content": content}

View File

@@ -0,0 +1,143 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from typing import List
from app.database import get_db
from app.models.folder import Folder
from app.models.user import User
from app.schemas.folder import FolderCreate, FolderUpdate, FolderOut, FolderTreeOut
from app.services.auth_service import get_current_user
router = APIRouter(prefix="/api/folders", tags=["文件夹"])
def build_folder_tree(folders: list[Folder], parent_id: str = None) -> List[FolderTreeOut]:
"""递归构建文件夹树"""
tree = []
for folder in folders:
if folder.parent_id == parent_id:
children = build_folder_tree(folders, folder.id)
tree.append(FolderTreeOut(
id=folder.id,
name=folder.name,
parent_id=folder.parent_id,
children=children
))
return tree
@router.get("", response_model=List[FolderTreeOut])
async def get_folders(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取用户的完整文件夹树"""
result = await db.execute(
select(Folder).where(Folder.user_id == current_user.id)
)
folders = result.scalars().all()
return build_folder_tree(list(folders))
@router.post("", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
async def create_folder(
folder_data: FolderCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建文件夹"""
# 验证父文件夹存在且属于当前用户
if folder_data.parent_id:
result = await db.execute(
select(Folder).where(
and_(Folder.id == folder_data.parent_id, Folder.user_id == current_user.id)
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="父文件夹不存在")
# 检查同名文件夹
result = await db.execute(
select(Folder).where(
and_(
Folder.user_id == current_user.id,
Folder.parent_id == folder_data.parent_id,
Folder.name == folder_data.name
)
)
)
if result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="同名文件夹已存在")
folder = Folder(
user_id=current_user.id,
name=folder_data.name,
parent_id=folder_data.parent_id
)
db.add(folder)
await db.commit()
await db.refresh(folder)
return folder
@router.put("/{folder_id}", response_model=FolderOut)
async def rename_folder(
folder_id: str,
folder_data: FolderUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""重命名文件夹"""
result = await db.execute(
select(Folder).where(
and_(Folder.id == folder_id, Folder.user_id == current_user.id)
)
)
folder = result.scalar_one_or_none()
if not folder:
raise HTTPException(status_code=404, detail="文件夹不存在")
folder.name = folder_data.name
await db.commit()
await db.refresh(folder)
return folder
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_folder(
folder_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除文件夹(级联删除文档)"""
from app.models.document import Document
from app.services.knowledge_service import KnowledgeService
result = await db.execute(
select(Folder).where(
and_(Folder.id == folder_id, Folder.user_id == current_user.id)
)
)
folder = result.scalar_one_or_none()
if not folder:
raise HTTPException(status_code=404, detail="文件夹不存在")
async def delete_recursive(fid: str):
# 删除子文件夹(先递归)
children = await db.execute(
select(Folder).where(Folder.parent_id == fid)
)
for child in children.scalars():
await delete_recursive(child.id)
# 删除文档
docs = await db.execute(
select(Document).where(Document.folder_id == fid)
)
for doc in docs.scalars():
knowledge_service = KnowledgeService(db, current_user.id)
await knowledge_service.delete_from_vectorstore(current_user.id, doc.id)
await db.delete(doc)
# 删除文件夹本身
folder_to_delete = await db.get(Folder, fid)
if folder_to_delete:
await db.delete(folder_to_delete)
await delete_recursive(folder_id)
await db.commit()

View File

@@ -0,0 +1,111 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.database import get_db
from app.models.forum import ForumPost, ForumReply
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.forum import ForumPostCreate, ForumPostOut, ForumReplyCreate, ForumReplyOut
router = APIRouter(prefix="/api/forum", tags=["论坛"])
@router.get("/posts", response_model=list[ForumPostOut])
async def list_posts(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(ForumPost)
.where(ForumPost.user_id == current_user.id)
.order_by(desc(ForumPost.created_at))
)
return result.scalars().all()
@router.post("/posts", response_model=ForumPostOut, status_code=201)
async def create_post(
data: ForumPostCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
post = ForumPost(
user_id=current_user.id,
title=data.title,
content=data.content,
category=data.category,
)
db.add(post)
await db.commit()
await db.refresh(post)
return post
@router.get("/posts/{post_id}", response_model=ForumPostOut)
async def get_post(
post_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(ForumPost).where(ForumPost.id == post_id)
)
post = result.scalar_one_or_none()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
return post
@router.delete("/posts/{post_id}", status_code=204)
async def delete_post(
post_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(ForumPost).where(
ForumPost.id == post_id,
ForumPost.user_id == current_user.id,
)
)
post = result.scalar_one_or_none()
if not post:
raise HTTPException(status_code=404, detail="帖子不存在")
await db.delete(post)
await db.commit()
@router.get("/posts/{post_id}/replies", response_model=list[ForumReplyOut])
async def list_replies(
post_id: str,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(ForumReply)
.where(ForumReply.post_id == post_id)
.order_by(ForumReply.created_at)
)
return result.scalars().all()
@router.post("/posts/{post_id}/replies", response_model=ForumReplyOut, status_code=201)
async def create_reply(
post_id: str,
data: ForumReplyCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
reply = ForumReply(
post_id=post_id,
user_id=current_user.id,
content=data.content,
)
db.add(reply)
# 更新帖子回复数
result = await db.execute(select(ForumPost).where(ForumPost.id == post_id))
post = result.scalar_one_or_none()
if post:
post.reply_count += 1
await db.commit()
await db.refresh(reply)
return reply

View File

@@ -0,0 +1,240 @@
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.database import get_db
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.user import User
from app.routers.auth import get_current_user
from app.services.graph_service import GraphService
from app.schemas.graph import KGNodeOut, TagProperties, TagExtractRequest, TagExtractResponse, RelatedContentRequest
router = APIRouter(prefix="/api/graph", tags=["知识图谱"])
@router.get("")
async def get_graph(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取用户知识图谱"""
nodes_result = await db.execute(
select(KGNode)
.where(KGNode.user_id == current_user.id)
.order_by(KGNode.importance.desc())
.limit(200)
)
nodes = list(nodes_result.scalars().all())
node_ids = {n.id for n in nodes}
edges_result = await db.execute(select(KGEdge))
edges = [e for e in edges_result.scalars().all()
if e.source_id in node_ids or e.target_id in node_ids]
return {
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
"description": n.description, "importance": n.importance,
"created_at": str(n.created_at)}
for n in nodes],
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
"relation": e.relation_type, "weight": e.weight}
for e in edges],
"stats": {
"node_count": len(nodes),
"edge_count": len(edges),
}
}
@router.post("/build")
async def build_graph(
background: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""构建/重建知识图谱(后台异步执行)"""
def build_task():
import asyncio
from app.database import async_session
from app.services.graph_service import GraphService
async def _build():
async with async_session() as session:
svc = GraphService(session)
await svc.build_graph(user_id=current_user.id, document_ids=None)
asyncio.run(_build())
background.add_task(build_task)
return {"status": "started", "message": "图谱构建任务已启动,请稍后刷新查看"}
@router.get("/entity/{entity}")
async def get_entity_context(
entity: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取某个实体的详细上下文"""
svc = GraphService(db)
context = await svc.get_entity_context(entity, current_user.id)
return {"context": context}
@router.get("/summary")
async def get_graph_summary(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取图谱摘要"""
svc = GraphService(db)
summary = await svc.get_graph_summary(current_user.id)
return {"summary": summary}
@router.get("/neighbors/{node_id}")
async def get_node_neighbors(
node_id: str,
depth: int = 1,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取节点的邻居节点(用于可视化点击展开)"""
svc = GraphService(db)
data = await svc.get_neighbors(node_id, depth)
return {
"nodes": [{"id": n.id, "name": n.name, "type": n.entity_type,
"description": n.description} for n in data["nodes"]],
"edges": [{"id": e.id, "source": e.source_id, "target": e.target_id,
"relation": e.relation_type} for e in data["edges"]],
}
@router.delete("/nodes/{node_id}", status_code=204)
async def delete_node(
node_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""删除图谱节点"""
result = await db.execute(
select(KGNode).where(KGNode.id == node_id, KGNode.user_id == current_user.id)
)
node = result.scalar_one_or_none()
if not node:
raise HTTPException(status_code=404, detail="节点不存在")
await db.delete(node)
await db.commit()
@router.post("/tags/extract", response_model=TagExtractResponse)
async def extract_tags(request: TagExtractRequest, db: AsyncSession = Depends(get_db)):
"""从内容中提取标签(不保存到节点)"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
tag_infos = tag_service.extract_tags_from_content(request.content, request.user_id)
tags = []
for t in tag_infos:
short_name, level, parent_path = tag_service.parse_tag_path(t["path"])
tags.append(TagProperties(
tag_path=t["path"],
short_name=short_name,
level=level,
parent_path=parent_path,
description=t.get("description")
))
return TagExtractResponse(tags=tags, tag_count=len(tags))
@router.post("/tags/content/{node_id}", response_model=TagExtractResponse)
async def tag_content_node(
node_id: str,
request: TagExtractRequest,
db: AsyncSession = Depends(get_db)
):
"""为内容节点打标签"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
result = await db.execute(select(KGNode).where(KGNode.id == node_id))
node = result.scalar_one_or_none()
if not node:
raise HTTPException(status_code=404, detail="Node not found")
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
tag_nodes = tag_service.tag_content(request.content, request.user_id, node)
tags = []
for n in tag_nodes:
props = n.properties_ or {}
tags.append(TagProperties(
tag_path=props.get("tag_path", n.name),
short_name=n.name,
level=props.get("level", 1),
parent_path=props.get("parent_path"),
description=n.description
))
return TagExtractResponse(tags=tags, tag_count=len(tags))
@router.get("/tags/{tag_id}/related", response_model=list[KGNodeOut])
async def get_related_tags(tag_id: str, db: AsyncSession = Depends(get_db)):
"""获取标签的关联标签"""
result = await db.execute(
select(KGEdge).where(
or_(KGEdge.source_id == tag_id, KGEdge.target_id == tag_id),
KGEdge.relation_type.in_(["related_to", "synonym_of"])
)
)
edges = list(result.scalars().all())
related_ids = set()
for e in edges:
if e.source_id == tag_id:
related_ids.add(e.target_id)
else:
related_ids.add(e.source_id)
if not related_ids:
return []
result = await db.execute(select(KGNode).where(KGNode.id.in_(related_ids)))
nodes = list(result.scalars().all())
return nodes
@router.get("/tags/{user_id}", response_model=list[KGNodeOut])
async def get_user_tags(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户的所有标签"""
result = await db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.entity_type == "tag"
).order_by(KGNode.properties_["level"].astext)
)
nodes = list(result.scalars().all())
return nodes
@router.post("/content/related", response_model=list[KGNodeOut])
async def get_related_content(
request: RelatedContentRequest,
db: AsyncSession = Depends(get_db)
):
"""通过标签找相关内容"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
results = tag_service.get_related_content(request.tag_ids, request.user_id, request.limit)
nodes = [r[0] for r in results]
return nodes

View File

@@ -0,0 +1,42 @@
from fastapi import APIRouter, Depends
from app.models.user import User
from app.routers.auth import get_current_user
from app.services.scheduler_service import (
get_scheduler_status,
scheduler,
daily_task_analysis,
forum_scan_task,
graph_rebuild_task,
tag_generation_task,
)
import logging
router = APIRouter(prefix="/api/scheduler", tags=["定时任务"])
logger = logging.getLogger(__name__)
@router.get("/status")
async def get_status(current_user: User = Depends(get_current_user)):
"""获取调度器状态"""
return get_scheduler_status()
@router.post("/trigger/{job_id}")
async def trigger_job(job_id: str, current_user: User = Depends(get_current_user)):
"""手动触发某个定时任务"""
job_map = {
"daily_task_analysis": daily_task_analysis,
"forum_scan": forum_scan_task,
"graph_rebuild": graph_rebuild_task,
"tag_generation": tag_generation_task,
}
if job_id not in job_map:
return {"error": f"未知任务: {job_id}"}
try:
await job_map[job_id]()
return {"status": "ok", "job": job_id, "message": "任务已触发执行"}
except Exception as e:
logger.error(f"手动触发任务失败 {job_id}: {e}")
return {"status": "error", "job": job_id, "error": str(e)}

View File

@@ -0,0 +1,87 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.settings import (
SettingsOut, ProfileUpdateIn, LLMConfigIn, SchedulerConfigIn, LLMTestIn
)
from app.services.settings_service import (
get_user_settings, update_user_profile, update_llm_config,
update_scheduler_config, test_llm_connection
)
router = APIRouter(prefix="/api/settings", tags=["设置"])
@router.get("", response_model=SettingsOut)
async def get_settings(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
settings = await get_user_settings(current_user.id, db)
if not settings:
raise HTTPException(status_code=404, detail="用户不存在")
return settings
@router.put("/profile")
async def update_profile(
data: ProfileUpdateIn,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
try:
user = await update_user_profile(
current_user.id, db,
full_name=data.full_name,
password=data.password,
current_password=data.current_password
)
return user
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/llm")
async def update_llm(
data: LLMConfigIn,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
try:
config = await update_llm_config(current_user.id, data.model_dump(exclude_none=True), db)
return {"llm_config": config}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/llm/test")
async def test_llm(
data: LLMTestIn,
current_user: User = Depends(get_current_user),
):
result = await test_llm_connection(
provider=data.provider,
model=data.model,
base_url=data.base_url,
api_key=data.api_key
)
return result
@router.put("/scheduler")
async def update_scheduler(
data: SchedulerConfigIn,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
try:
config = await update_scheduler_config(
current_user.id,
data.model_dump(exclude_none=True),
db
)
return {"scheduler_config": config}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -0,0 +1,77 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.stats import (
SystemHealth,
ConversationStats,
KnowledgeStats,
KanbanStats,
CommunityStats,
PersonalInsights,
)
from app.services.stats_service import StatsService
router = APIRouter(prefix="/api/stats", tags=["统计"])
@router.get("/system", response_model=SystemHealth)
async def get_system_health(db: Session = Depends(get_db)):
"""获取系统健康指标"""
svc = StatsService(db)
return svc.get_system_health()
@router.get("/conversations", response_model=ConversationStats)
async def get_conversation_stats(
days: int = 30,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取对话统计数据"""
svc = StatsService(db)
return svc.get_conversation_stats(user_id=current_user.id, days=days)
@router.get("/knowledge", response_model=KnowledgeStats)
async def get_knowledge_stats(
days: int = 30,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取知识库统计数据"""
svc = StatsService(db)
return svc.get_knowledge_stats(user_id=current_user.id, days=days)
@router.get("/kanban", response_model=KanbanStats)
async def get_kanban_stats(
days: int = 30,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取看板统计数据"""
svc = StatsService(db)
return svc.get_kanban_stats(user_id=current_user.id, days=days)
@router.get("/community", response_model=CommunityStats)
async def get_community_stats(
days: int = 30,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取社区统计数据"""
svc = StatsService(db)
return svc.get_community_stats(user_id=current_user.id, days=days)
@router.get("/insights", response_model=PersonalInsights)
async def get_personal_insights(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取个人洞察"""
svc = StatsService(db)
return svc.get_personal_insights(user_id=current_user.id)

View File

@@ -0,0 +1,91 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.database import get_db
from app.models.task import Task, TaskStatus
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.task import TaskCreate, TaskUpdate, TaskOut
router = APIRouter(prefix="/api/tasks", tags=["看板"])
@router.get("", response_model=list[TaskOut])
async def list_tasks(
status: TaskStatus | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
query = select(Task).where(Task.user_id == current_user.id)
if status:
query = query.where(Task.status == status)
query = query.order_by(desc(Task.created_at))
result = await db.execute(query)
return result.scalars().all()
@router.post("", response_model=TaskOut, status_code=201)
async def create_task(
data: TaskCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
import json
task = Task(
user_id=current_user.id,
title=data.title,
description=data.description,
priority=data.priority,
due_date=data.due_date,
tags=json.dumps(data.tags) if data.tags else None,
)
db.add(task)
await db.commit()
await db.refresh(task)
return task
@router.patch("/{task_id}", response_model=TaskOut)
async def update_task(
task_id: str,
data: TaskUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
import json
result = await db.execute(
select(Task).where(Task.id == task_id, Task.user_id == current_user.id)
)
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
for field, value in data.model_dump(exclude_none=True).items():
if field == "tags":
setattr(task, field, json.dumps(value))
elif field == "status" and value == TaskStatus.DONE:
from datetime import datetime
task.completed_at = datetime.utcnow()
setattr(task, field, value)
else:
setattr(task, field, value)
await db.commit()
await db.refresh(task)
return task
@router.delete("/{task_id}", status_code=204)
async def delete_task(
task_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Task).where(Task.id == task_id, Task.user_id == current_user.id)
)
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
await db.delete(task)
await db.commit()

154
backend/app/routers/todo.py Normal file
View File

@@ -0,0 +1,154 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from datetime import date
from app.database import get_db
from app.models.todo import DailyTodo, TodoSource
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.todo import (
TodoCreate, TodoUpdate, TodoOut, TodoListOut, TodoSummaryOut
)
from app.services.todo_service import generate_daily_todos
router = APIRouter(prefix="/api/todos", tags=["待办"])
@router.get("", response_model=TodoListOut)
async def list_todos(
date_str: str = Query(default=None), # YYYY-MM-DD默认当天
page: int = Query(default=1, ge=1),
page_size: int = Query(default=50, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date_str or date.today().isoformat()
offset = (page - 1) * page_size
# 查询总数
count_q = select(func.count()).select_from(DailyTodo).where(
DailyTodo.user_id == current_user.id,
DailyTodo.todo_date == target_date,
)
total = (await db.execute(count_q)).scalar()
# 查询列表
q = select(DailyTodo).where(
DailyTodo.user_id == current_user.id,
DailyTodo.todo_date == target_date,
).order_by(DailyTodo.created_at.desc()).offset(offset).limit(page_size)
items = (await db.execute(q)).scalars().all()
return TodoListOut(items=items, total=total, page=page, page_size=page_size)
@router.post("", response_model=TodoOut, status_code=201)
async def create_todo(
data: TodoCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
todo = DailyTodo(
user_id=current_user.id,
title=data.title,
source=TodoSource.MANUAL,
todo_date=date.today().isoformat(),
)
db.add(todo)
await db.commit()
await db.refresh(todo)
return todo
@router.patch("/{todo_id}", response_model=TodoOut)
async def update_todo(
todo_id: str,
data: TodoUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(DailyTodo).where(DailyTodo.id == todo_id, DailyTodo.user_id == current_user.id)
)
todo = result.scalar_one_or_none()
if not todo:
raise HTTPException(status_code=404, detail="待办不存在")
# 历史日期不允许修改
if todo.todo_date != date.today().isoformat():
raise HTTPException(status_code=403, detail="历史待办不可修改")
if data.title is not None:
todo.title = data.title
if data.is_completed is not None:
from datetime import datetime
todo.is_completed = data.is_completed
todo.completed_at = datetime.utcnow() if data.is_completed else None
await db.commit()
await db.refresh(todo)
return todo
@router.delete("/{todo_id}", status_code=204)
async def delete_todo(
todo_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(DailyTodo).where(DailyTodo.id == todo_id, DailyTodo.user_id == current_user.id)
)
todo = result.scalar_one_or_none()
if not todo:
raise HTTPException(status_code=404, detail="待办不存在")
if todo.todo_date != date.today().isoformat():
raise HTTPException(status_code=403, detail="历史待办不可删除")
await db.delete(todo)
await db.commit()
@router.post("/ai-generate", response_model=TodoListOut)
async def ai_generate_todos(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date.today().isoformat()
# 幂等检查是否已有AI生成记录
check_q = select(func.count()).select_from(DailyTodo).where(
DailyTodo.user_id == current_user.id,
DailyTodo.todo_date == target_date,
DailyTodo.source.in_([TodoSource.AI_KANBAN, TodoSource.AI_CHAT]),
)
count = (await db.execute(check_q)).scalar()
if count > 0:
# 已生成,返回现有记录
q = select(DailyTodo).where(
DailyTodo.user_id == current_user.id,
DailyTodo.todo_date == target_date,
).order_by(DailyTodo.created_at.desc())
items = (await db.execute(q)).scalars().all()
return TodoListOut(items=items, total=len(items), page=1, page_size=50)
# 执行AI生成
todos = await generate_daily_todos(current_user.id, db)
return TodoListOut(items=todos, total=len(todos), page=1, page_size=50)
@router.get("/summary", response_model=TodoSummaryOut)
async def get_summary(
date_str: str = Query(default=None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date_str or date.today().isoformat()
q = select(DailyTodo).where(
DailyTodo.user_id == current_user.id,
DailyTodo.todo_date == target_date,
)
todos = (await db.execute(q)).scalars().all()
completed = sum(1 for t in todos if t.is_completed)
return TodoSummaryOut(date=target_date, total=len(todos), completed=completed, pending=len(todos) - completed)

View File

@@ -0,0 +1,2 @@
# Schemas package - import directly from submodules
# e.g.: from app.schemas.auth import UserCreate

View File

@@ -0,0 +1,55 @@
from pydantic import BaseModel
class AgentCreate(BaseModel):
name: str
role: str
description: str | None = None
system_prompt: str
class AgentOut(BaseModel):
id: str
name: str
role: str
description: str | None
is_active: bool
is_default: bool
model_config = {"from_attributes": True}
class AgentMessageOut(BaseModel):
id: str
agent_id: str
conversation_id: str
role: str
content: str
model_config = {"from_attributes": True}
class AgentStats(BaseModel):
agent_id: str
call_count: int
current_task: str | None
status: str # active | idle | disabled
class AgentConfigUpdate(BaseModel):
name: str | None = None
description: str | None = None
system_prompt: str | None = None
enabled: bool | None = None
class AgentConfigOut(BaseModel):
id: str
name: str
role: str
description: str | None
system_prompt: str
enabled: bool
is_active: bool
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel, EmailStr
class UserCreate(BaseModel):
email: EmailStr
password: str
full_name: str | None = None
class UserOut(BaseModel):
id: str
email: str
full_name: str | None
is_active: bool
is_superuser: bool
model_config = {"from_attributes": True}
class Token(BaseModel):
access_token: str
token_type: str = "bearer"
class TokenPayload(BaseModel):
sub: str

View File

@@ -0,0 +1,45 @@
from pydantic import BaseModel
from datetime import datetime
class MessageCreate(BaseModel):
content: str
class MessageOut(BaseModel):
id: str
role: str
content: str
model: str | None
tokens_used: int | None
created_at: datetime
model_config = {"from_attributes": True}
class ConversationCreate(BaseModel):
title: str | None = None
class ConversationOut(BaseModel):
id: str
title: str | None
message_count: int
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ChatRequest(BaseModel):
message: str
conversation_id: str | None = None
agent_id: str | None = None
file_ids: list[str] = [] # 新增
class ChatResponse(BaseModel):
conversation_id: str
message_id: str
content: str
agent_name: str

View File

@@ -0,0 +1,40 @@
from pydantic import BaseModel
from datetime import datetime
class DocumentOut(BaseModel):
id: str
title: str
filename: str
file_type: str
file_size: int
summary: str | None
chunk_count: int
is_indexed: bool
created_at: datetime
model_config = {"from_attributes": True}
class DocumentChunkOut(BaseModel):
id: str
chunk_index: int
content: str
metadata_: str | None
model_config = {"from_attributes": True}
class SearchRequest(BaseModel):
query: str
top_k: int = 5
user_id: str
class SearchResult(BaseModel):
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None

View File

@@ -0,0 +1,39 @@
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
from datetime import datetime
class FolderCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
parent_id: Optional[str] = None
@field_validator('name')
@classmethod
def validate_name(cls, v):
forbidden = '/\\*?:'
for c in forbidden:
if c in v:
raise ValueError(f'Folder name cannot contain: {forbidden}')
return v
class FolderUpdate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
class FolderOut(BaseModel):
id: str
name: str
parent_id: Optional[str]
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class FolderTreeOut(BaseModel):
id: str
name: str
parent_id: Optional[str]
children: List["FolderTreeOut"] = []
model_config = {"from_attributes": True}
# 递归模型需要 forward ref
FolderTreeOut.model_rebuild()

View File

@@ -0,0 +1,37 @@
from pydantic import BaseModel
class ForumPostCreate(BaseModel):
title: str
content: str
category: str | None = "discussion"
class ForumPostOut(BaseModel):
id: str
user_id: str
title: str
content: str
category: str | None
is_executed: bool
execution_result: str | None
reply_count: int
created_at: str
model_config = {"from_attributes": True}
class ForumReplyCreate(BaseModel):
content: str
class ForumReplyOut(BaseModel):
id: str
post_id: str
user_id: str | None
agent_id: str | None
content: str
is_ai_reply: bool
created_at: str
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
from typing import Literal
class TagProperties(BaseModel):
tag_path: str = Field(..., description="完整标签路径,如 '编程语言/Python/异步'")
short_name: str = Field(..., description="显示名称,如 '异步'")
level: int = Field(..., ge=1, description="层级深度1为顶级")
parent_path: str | None = Field(None, description="父路径,如 '编程语言/Python'")
description: str | None = Field(None, description="AI生成的标签描述")
color: str | None = Field(None, description="标签颜色,如 '#FF5733'")
class KGNodeOut(BaseModel):
id: str
name: str
entity_type: str
description: str | None
properties_: dict | None
importance: float
created_at: str
# 新增:如果是 tag 节点,返回 tag 属性
tag_properties: TagProperties | None = None
model_config = {"from_attributes": True}
def model_post_init(self, __context):
if self.entity_type == "tag" and self.properties_:
self.tag_properties = TagProperties(**self.properties_)
class KGEdgeOut(BaseModel):
id: str
source_id: str
target_id: str
relation_type: str
weight: float
properties_: dict | None
model_config = {"from_attributes": True}
class GraphOut(BaseModel):
nodes: list[KGNodeOut]
edges: list[KGEdgeOut]
class KGBuildRequest(BaseModel):
user_id: str
document_ids: list[str] | None = None # None = 全量重建
class TagExtractRequest(BaseModel):
content: str = Field(..., min_length=10)
user_id: str
class TagExtractResponse(BaseModel):
tags: list[TagProperties]
tag_count: int
class RelatedContentRequest(BaseModel):
tag_ids: list[str]
user_id: str
limit: int = 10

View File

@@ -0,0 +1,58 @@
from pydantic import BaseModel, Field
from typing import Optional, Literal, List
from app.schemas.auth import UserOut
# LLM Provider 类型
LLMProviderType = Literal["openai", "claude", "ollama", "deepseek", "custom"]
LLMType = Literal["chat", "vlm", "embedding", "rerank"]
# 单个模型配置
class LLMModelConfig(BaseModel):
name: str = "" # 模型名称/别名,用于标识
provider: LLMProviderType = "openai"
model: str = ""
base_url: str = ""
api_key: str = ""
enabled: bool = True # 是否启用
# LLM 配置输入 - 每种类型支持多个模型
class LLMConfigIn(BaseModel):
chat: Optional[List[LLMModelConfig]] = []
vlm: Optional[List[LLMModelConfig]] = []
embedding: Optional[List[LLMModelConfig]] = []
rerank: Optional[List[LLMModelConfig]] = []
# 定时任务配置
class SchedulerConfigIn(BaseModel):
daily_plan_time: Optional[str] = "08:00"
forum_scan_interval_minutes: Optional[int] = 30
todo_ai_generate_time: Optional[str] = "08:00"
enabled: Optional[bool] = True
# 用户资料更新
class ProfileUpdateIn(BaseModel):
full_name: Optional[str] = Field(None, min_length=2, max_length=50)
password: Optional[str] = Field(None, min_length=8)
current_password: Optional[str] = None
# 完整设置输出
class SettingsOut(BaseModel):
profile: UserOut
llm_config: Optional[dict] = None
scheduler_config: Optional[dict] = None
model_config = {"from_attributes": True}
# 测试 LLM 连接请求
class LLMTestIn(BaseModel):
type: LLMType
provider: LLMProviderType
model: str
base_url: str
api_key: str

View File

@@ -0,0 +1,82 @@
from pydantic import BaseModel
from typing import Optional
# ===== System Health =====
class SystemHealth(BaseModel):
uptime_seconds: int
cpu_percent: float
memory_used_mb: float
memory_total_mb: float
memory_percent: float
disk_used_gb: float
disk_total_gb: float
disk_percent: float
active_users_24h: int
# ===== Daily Stats Base =====
class DailyStatItem(BaseModel):
date: str
count: int
class DailyTokenStatItem(BaseModel):
date: str
input_tokens: int
output_tokens: int
# ===== Conversation Stats =====
class ConversationStats(BaseModel):
daily_conversations: list[DailyStatItem]
daily_messages: list[DailyStatItem]
daily_input_tokens: list[DailyTokenStatItem]
daily_output_tokens: list[DailyTokenStatItem]
totals: dict
# ===== Knowledge Stats =====
class KnowledgeStats(BaseModel):
daily_new_tags: list[DailyStatItem]
daily_documents: list[DailyStatItem]
daily_knowledge_queries: list[DailyStatItem]
daily_tag_relations: list[DailyStatItem]
totals: dict
# ===== Kanban Stats =====
class KanbanStats(BaseModel):
daily_new_tasks: list[DailyStatItem]
daily_completed_tasks: list[DailyStatItem]
daily_completion_rate: list[DailyStatItem]
current_pending_tasks: int
totals: dict
# ===== Community Stats =====
class CommunityStats(BaseModel):
daily_posts: list[DailyStatItem]
daily_replies: list[DailyStatItem]
daily_ai_executions: list[DailyStatItem]
daily_agent_calls: list[DailyStatItem]
totals: dict
# ===== Personal Insights =====
class HourlyActivity(BaseModel):
hour: int
count: int
class TagUsage(BaseModel):
tag_path: str
usage_count: int
class PersonalInsights(BaseModel):
hourly_activity: list[HourlyActivity]
top_tags: list[TagUsage]
token_trend_percent: float
this_month_tokens: int
last_month_tokens: int

View File

@@ -0,0 +1,39 @@
from pydantic import BaseModel
from datetime import datetime
from app.models.task import TaskStatus, TaskPriority
class TaskCreate(BaseModel):
title: str
description: str | None = None
priority: TaskPriority = TaskPriority.MEDIUM
due_date: datetime | None = None
tags: list[str] | None = None
class TaskUpdate(BaseModel):
title: str | None = None
description: str | None = None
status: TaskStatus | None = None
priority: TaskPriority | None = None
due_date: datetime | None = None
tags: list[str] | None = None
class TaskOut(BaseModel):
id: str
title: str
description: str | None
status: TaskStatus
priority: TaskPriority
due_date: datetime | None
completed_at: datetime | None
tags: str | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class DailyPlanRequest(BaseModel):
user_id: str

View File

@@ -0,0 +1,40 @@
from pydantic import BaseModel
from datetime import datetime
from app.models.todo import TodoSource
class TodoCreate(BaseModel):
title: str
class TodoUpdate(BaseModel):
title: str | None = None
is_completed: bool | None = None
class TodoOut(BaseModel):
id: str
title: str
is_completed: bool
source: TodoSource
source_detail: str | None
todo_date: str
completed_at: datetime | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class TodoListOut(BaseModel):
items: list[TodoOut]
total: int
page: int
page_size: int
class TodoSummaryOut(BaseModel):
date: str
total: int
completed: int
pending: int

View File

@@ -0,0 +1,2 @@
# Services - import specific classes directly when needed
# e.g.: from app.services.agent_service import AgentService

View File

@@ -0,0 +1,261 @@
"""
Jarvis Agent 服务层
负责 LangGraph Agent 的调用、流式输出、对话历史管理
"""
import json
import uuid
from datetime import datetime
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_core.messages import HumanMessage, AIMessage
from app.models.conversation import Conversation, Message
from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.services import memory_service
class AgentService:
"""对话 Agent 服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def chat(
self,
user_id: str,
message: str,
conversation_id: str | None = None,
) -> tuple[str, str, AsyncGenerator[str, None]]:
"""
处理对话请求(流式)
Returns:
(conversation_id, message_id, response_stream)
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
# 预创建助手消息(后续更新内容)
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content="",
model="jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 调用 LangGraph Agent
async def run_agent():
set_current_user(user_id)
try:
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
}
collected = ""
async for event in graph.astream_events(langgraph_state, version="v2"):
kind = event.get("event")
if kind == "on_chat_model_end":
content = event.get("data", {}).get("output", {})
if isinstance(content, dict):
content = content.get("content", "")
if content:
delta = content[len(collected):]
if delta:
collected += delta
yield delta
elif kind == "on_tool_end":
name = event.get("name", "")
yield f"\n[工具执行: {name}]\n"
except Exception as e:
yield f"\n执行出错: {str(e)}"
finally:
clear_current_user()
# 异步触发自动摘要和记忆提取(不阻塞响应)
import asyncio
try:
loop = asyncio.get_running_loop()
loop.create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
)
except Exception:
pass
# 最终更新数据库中的消息内容
if collected:
try:
result2 = await self.db.execute(
select(Message).where(Message.id == assistant_msg.id)
)
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await self.db.commit()
except Exception:
pass
return conversation_id, assistant_msg.id, run_agent()
async def chat_simple(
self,
user_id: str,
message: str,
conversation_id: str | None = None,
file_ids: list[str] | None = None,
) -> tuple[str, str, str]:
"""
简单同步版对话(无流式)
Returns:
(conversation_id, message_id, response_content)
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
if content:
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
# 将文件上下文添加到消息
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
attachments=[{"file_ids": file_ids}] if file_ids else None,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 调用 LangGraph Agent
set_current_user(user_id)
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
}
try:
result_state = await graph.ainvoke(langgraph_state)
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
except Exception as e:
response_content = f"抱歉,发生错误: {str(e)}"
finally:
clear_current_user()
# 异步触发自动摘要
import asyncio
try:
asyncio.get_running_loop().create_task(
memory_service.try_auto_summarize(self.db, user_id, conversation_id)
)
except Exception:
pass
# 保存助手消息
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content=response_content,
model="jarvis",
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
return conversation_id, assistant_msg.id, response_content

View File

@@ -0,0 +1,29 @@
from datetime import datetime, timedelta
from passlib.context import CryptContext
from jose import jwt, JWTError
from app.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def decode_token(token: str) -> dict | None:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None

View File

@@ -0,0 +1,256 @@
"""
文档服务 - 上传、解析、分块、存储
支持多种文档格式 + LlamaIndex 智能分块
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import UploadFile
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import os
import aiofiles
import uuid
ALLOWED_EXTENSIONS = {".pdf", ".md", ".txt", ".docx", ".doc"}
class DocumentService:
def __init__(self, db: AsyncSession, user_id: str = None):
self.db = db
self.user_id = user_id
async def upload_document(self, user_id: str, file: UploadFile, folder_id: str | None = None) -> Document:
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
raise ValueError(f"不支持的文件类型: {ext}")
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
file_id = str(uuid.uuid4())
file_path = os.path.join(settings.UPLOAD_DIR, f"{file_id}{ext}")
content = await file.read()
file_size = len(content)
if file_size > settings.MAX_UPLOAD_SIZE:
raise ValueError(f"文件大小超过限制: {settings.MAX_UPLOAD_SIZE // 1024 // 1024}MB")
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
text_content = await self._extract_text(file_path, ext)
doc = Document(
user_id=user_id,
title=file.filename.rsplit('.', 1)[0],
filename=file.filename,
file_type=ext[1:],
file_size=file_size,
file_path=file_path,
summary=text_content[:500] if len(text_content) > 500 else text_content,
folder_id=folder_id,
)
self.db.add(doc)
await self.db.commit()
await self.db.refresh(doc)
chunks = self._chunk_text(text_content)
for i, chunk_text in enumerate(chunks):
chunk = DocumentChunk(
document_id=doc.id,
chunk_index=i,
content=chunk_text,
)
self.db.add(chunk)
doc.chunk_count = len(chunks)
await self.db.commit()
return doc
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
folders = await self.db.execute(
select(Folder).where(Folder.user_id == self.user_id)
)
folder_map = {f.id: f for f in folders.scalars().all()}
path_parts = []
current_id = folder_id
while current_id:
folder = folder_map.get(current_id)
if not folder:
break
path_parts.insert(0, folder.name)
current_id = folder.parent_id
return "/" + "/".join(path_parts) if path_parts else None
async def delete_document(self, user_id: str, document_id: str):
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
raise ValueError("文档不存在")
if os.path.exists(doc.file_path):
os.remove(doc.file_path)
await self.db.delete(doc)
await self.db.commit()
async def _extract_text(self, file_path: str, ext: str) -> str:
if ext == ".pdf":
try:
import pymupdf
doc = pymupdf.open(file_path)
text = "".join(page.get_text() for page in doc)
doc.close()
return text
except ImportError:
return "[PDF 内容需要安装 pymupdf: uv pip install pymupdf]"
elif ext in (".md", ".txt"):
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read()
elif ext in (".docx", ".doc"):
try:
from docx import Document as DocxDocument
doc = DocxDocument(file_path)
return "\n".join([p.text for p in doc.paragraphs])
except ImportError:
return "[Word 内容需要安装 python-docx: uv pip install python-docx]"
return "[暂不支持此格式]"
def _chunk_text(self, text: str) -> list[str]:
"""
智能文档分块策略
1. 先按 Markdown 标题层级H1/H2/H3切分
2. 每个大段落内部按固定长度切分
3. 保留上下文prev_summary / next_summary
"""
import re
chunks = []
# 策略1: Markdown 标题切分(优先)
header_pattern = re.compile(r"^(#{1,3})\s+(.+)$", re.MULTILINE)
headers = list(header_pattern.finditer(text))
if headers:
# 按标题段落切分
for i, match in enumerate(headers):
start = match.start()
end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
section = text[start:end].strip()
if len(section) > settings.CHUNK_SIZE:
# 大段落内部再切分
sub_chunks = self._split_large_chunk(section, match.group(2))
chunks.extend(sub_chunks)
elif section:
chunks.append(section)
else:
# 策略2: 按段落切分
chunks = self._chunk_by_paragraphs(text)
# 过滤空 chunk
chunks = [c.strip() for c in chunks if c.strip()]
return chunks if chunks else [text[: settings.CHUNK_SIZE]]
def _chunk_by_paragraphs(self, text: str) -> list[str]:
"""按段落分块,带上下文"""
paragraphs = text.split("\n\n")
chunks = []
current = ""
prev_summary = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current) + len(para) < settings.CHUNK_SIZE:
current += "\n\n" + para
else:
if current:
# 添加上下文摘要
enriched = current.strip()
chunks.append(enriched)
current = para
if current.strip():
chunks.append(current.strip())
return chunks
def _split_large_chunk(self, text: str, title: str) -> list[str]:
"""将大段落拆分为固定大小的子块"""
chunks = []
sentences = text.split("")
current = title + "\n\n"
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
full_sentence = sentence if sentence.endswith("") else sentence + ""
if len(current) + len(full_sentence) < settings.CHUNK_SIZE:
current += full_sentence + " "
else:
if current.strip():
chunks.append(current.strip())
current = title + "\n\n" + full_sentence + " "
if current.strip():
chunks.append(current.strip())
return chunks
async def get_document_chunks(self, document_id: str) -> list[DocumentChunk]:
result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
return list(result.scalars().all())
async def get_document_content(self, user_id: str, document_id: str) -> str | None:
"""获取文档的文本内容"""
import os
result = await self.db.execute(
select(Document).where(
Document.id == document_id,
Document.user_id == user_id,
)
)
doc = result.scalar_one_or_none()
if not doc:
return None
file_path = doc.file_path
if not os.path.exists(file_path):
return None
# 根据文件类型读取内容
ext = doc.filename.split('.')[-1].lower()
try:
if ext == 'txt':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'md':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == 'pdf':
# 简单文本提取(生产环境应使用专业库)
return f"[PDF文档] {doc.filename}"
else:
return f"[文档] {doc.filename}"
except Exception:
return f"[文档] {doc.filename}"

View File

@@ -0,0 +1,342 @@
"""
知识图谱服务 - 实体识别、关系抽取、图谱查询
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.document import Document, DocumentChunk
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage
import json
import logging
logger = logging.getLogger(__name__)
ENTITY_EXTRACTION_PROMPT = """从以下文本中提取实体和关系,返回 JSON 格式。
实体类型:
- person人物人名、角色
- concept概念抽象概念、理论、方法
- topic主题话题、领域
- task任务要做的事情
- event事件发生的事件
- document文档文件、资料
关系类型:
- related_to相关于
- part_of隶属于
- caused_by由...导致)
- depends_on取决于
- contains包含
- located_in位于
- works_on从事
要求:
1. 识别文本中所有有意义的实体不超过10个
2. 识别实体之间的关系(每个实体至少一条关系)
3. 每个实体要有 name、type、description1-2句话
4. 关系要有 source、target、relation_type
文本内容:
{text}
请只返回 JSON不要有其他内容
{{
"entities": [
{{"name": "实体名", "type": "类型", "description": "描述"}}
],
"relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型"}}
]
}}
"""
RELATION_INFERENCE_PROMPT = """根据以下实体列表和用户的问题,推断相关实体之间的关系。
用户问题:{question}
已知实体:
{entities}
请推断这些实体之间的隐含关系,返回 JSON
{{
"inferred_relations": [
{{"source": "实体A", "target": "实体B", "relation_type": "关系类型", "confidence": 0.9}}
]
}}
关系类型related_to / part_of / caused_by / depends_on / contains / works_on / located_in
confidence: 0.0-1.0,表示推断置信度
"""
class GraphService:
def __init__(self, db: AsyncSession):
self.db = db
self.llm = get_llm()
async def build_graph(self, user_id: str, document_ids: list[str] | None = None):
"""
从文档构建/更新知识图谱
- 遍历所有 chunk
- LLM 实体识别
- LLM 关系抽取
- 去重合并
"""
query = (
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(Document.is_indexed == True)
)
if document_ids:
query = query.where(DocumentChunk.document_id.in_(document_ids))
result = await self.db.execute(query)
chunks = list(result.scalars().all())
logger.info(f"[GraphService] 开始构建图谱,共 {len(chunks)} 个 chunks")
for chunk in chunks:
try:
await self._process_chunk(chunk, user_id)
except Exception as e:
logger.error(f"[GraphService] 处理 chunk {chunk.id} 失败: {e}")
continue
logger.info(f"[GraphService] 图谱构建完成")
async def _process_chunk(self, chunk: DocumentChunk, user_id: str):
"""处理单个 chunk提取实体和关系"""
prompt = ENTITY_EXTRACTION_PROMPT.format(text=chunk.content[:2000])
response = await self.llm.invoke([HumanMessage(content=prompt)])
try:
data = json.loads(response.content)
except json.JSONDecodeError:
return
entities = data.get("entities", [])
relations = data.get("relations", [])
if not entities:
return
# 先查找已存在的节点
existing_nodes = {}
for entity_data in entities:
name = entity_data["name"]
result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.where(KGNode.name == name)
)
node = result.scalar_one_or_none()
if node:
existing_nodes[name] = node
# 插入新节点
entity_map = {}
for entity_data in entities:
name = entity_data["name"]
if name in existing_nodes:
entity_map[name] = existing_nodes[name].id
else:
node = KGNode(
user_id=user_id,
name=name,
entity_type=entity_data["type"],
description=entity_data.get("description", ""),
source_document_id=chunk.document_id,
)
self.db.add(node)
await self.db.flush()
entity_map[name] = node.id
# 插入关系(去重)
for rel in relations:
src, tgt = rel["source"], rel["target"]
if src not in entity_map or tgt not in entity_map:
continue
# 检查关系是否已存在
result = await self.db.execute(
select(KGEdge).where(
KGEdge.source_id == entity_map[src],
KGEdge.target_id == entity_map[tgt],
KGEdge.relation_type == rel["relation_type"],
)
)
existing = result.scalar_one_or_none()
if not existing:
edge = KGEdge(
source_id=entity_map[src],
target_id=entity_map[tgt],
relation_type=rel["relation_type"],
)
self.db.add(edge)
await self.db.commit()
async def get_graph_summary(self, user_id: str) -> str:
"""获取用户图谱的整体摘要"""
# 统计
node_count = await self.db.execute(
select(func.count()).select_from(KGNode).where(KGNode.user_id == user_id)
)
edge_count = await self.db.execute(
select(func.count()).select_from(KGEdge)
.select_from(KGEdge)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
)
node_total = node_count.scalar() or 0
edge_total = edge_count.scalar() or 0
if node_total == 0:
return "知识图谱为空,请先上传文档并构建图谱。"
# 按类型统计节点
type_result = await self.db.execute(
select(KGNode.entity_type, func.count())
.where(KGNode.user_id == user_id)
.group_by(KGNode.entity_type)
)
type_stats = type_result.all()
# 关系类型统计
rel_result = await self.db.execute(
select(KGEdge.relation_type, func.count())
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGNode.user_id == user_id)
.group_by(KGEdge.relation_type)
)
rel_stats = rel_result.all()
# 列出最重要的节点(按 importance
top_nodes_result = await self.db.execute(
select(KGNode)
.where(KGNode.user_id == user_id)
.order_by(KGNode.importance.desc())
.limit(10)
)
top_nodes = list(top_nodes_result.scalars().all())
lines = [
f"## 知识图谱摘要",
f"",
f"**总节点数**: {node_total}",
f"**总关系数**: {edge_total}",
f"",
f"### 节点类型分布",
]
for etype, count in type_stats:
lines.append(f"- {etype}: {count}")
lines.append(f"\n### 关系类型分布")
for rtype, count in rel_stats:
lines.append(f"- {rtype}: {count}")
lines.append(f"\n### 核心实体 (Top 10)")
for node in top_nodes:
lines.append(f"- [{node.entity_type}] {node.name}: {node.description[:50]}...")
return "\n".join(lines)
async def get_entity_context(self, entity: str, user_id: str) -> str:
"""获取某个实体的详细上下文"""
# 查找节点
result = await self.db.execute(
select(KGNode).where(
KGNode.user_id == user_id,
KGNode.name.contains(entity),
).limit(5)
)
nodes = list(result.scalars().all())
if not nodes:
return f"未找到实体: {entity}"
lines = []
for node in nodes:
lines.append(f"### {node.name} [{node.entity_type}]")
lines.append(f"描述: {node.description or '无描述'}")
# 获取该节点的关系
edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.target_id)
.where(KGEdge.source_id == node.id)
.limit(10)
)
out_edges = list(edges_result.all())
in_edges_result = await self.db.execute(
select(KGEdge, KGNode)
.join(KGNode, KGNode.id == KGEdge.source_id)
.where(KGEdge.target_id == node.id)
.limit(10)
)
in_edges = list(in_edges_result.all())
if out_edges:
lines.append("**关联到**:")
for edge, target in out_edges:
lines.append(f" - {node.name} --[{edge.relation_type}]--> {target.name}")
if in_edges:
lines.append("**被关联于**:")
for edge, source in in_edges:
lines.append(f" - {source.name} --[{edge.relation_type}]--> {node.name}")
lines.append("")
return "\n".join(lines)
async def get_neighbors(self, node_id: str, depth: int = 1) -> dict:
"""获取节点的邻居节点(用于图谱可视化)"""
visited = set()
current_level = {node_id}
all_nodes = []
all_edges = []
for _ in range(depth):
if not current_level:
break
next_level = set()
for nid in current_level:
if nid in visited:
continue
visited.add(nid)
# 获取节点
node_result = await self.db.execute(
select(KGNode).where(KGNode.id == nid)
)
node = node_result.scalar_one_or_none()
if node:
all_nodes.append(node)
# 获取出边
out_result = await self.db.execute(
select(KGEdge).where(KGEdge.source_id == nid)
)
for edge in out_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.target_id)
# 获取入边
in_result = await self.db.execute(
select(KGEdge).where(KGEdge.target_id == nid)
)
for edge in in_result.scalars().all():
all_edges.append(edge)
next_level.add(edge.source_id)
current_level = next_level
return {"nodes": all_nodes, "edges": all_edges}

View File

@@ -0,0 +1,308 @@
"""
知识库服务 - ChromaDB 向量检索 + 混合检索 + Rerank
检索策略:
1. 语义检索 (dense) - ChromaDB 向量相似度
2. 关键词检索 (sparse) - SQL LIKE
3. 混合检索 - 语义 + 关键词 加权融合
4. Rerank - 二次排序优化结果
5. 上下文丰富 - 自动获取前/后 chunk 提供完整语境
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.config import settings
import chromadb
from chromadb.config import Settings as ChromaSettings
from dataclasses import dataclass
@dataclass
class SearchResult:
chunk_id: str
document_id: str
document_title: str
content: str
score: float
metadata_: str | None = None
prev_chunk: str | None = None
next_chunk: str | None = None
class KnowledgeService:
"""向量知识库检索服务"""
def __init__(self, db: AsyncSession, user_id: str | None = None):
self.db = db
self.user_id = user_id
self._chroma_client = None
@property
def chroma_client(self):
if self._chroma_client is None:
self._chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_PERSIST_DIR,
settings=ChromaSettings(allow_reset=True),
)
return self._chroma_client
def get_collection(self, user_id: str):
return self.chroma_client.get_or_create_collection(
name=f"user_{user_id}",
metadata={"user_id": user_id},
)
async def index_document(self, document_id: str, user_id: str, folder_path: str | None = None):
"""将文档 chunks 向量化存入 ChromaDB"""
result = await self.db.execute(
select(Document).where(Document.id == document_id)
)
doc = result.scalar_one_or_none()
if not doc:
return
chunks_result = await self.db.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document_id)
.order_by(DocumentChunk.chunk_index)
)
chunks = list(chunks_result.scalars().all())
if not chunks:
return
collection = self.get_collection(user_id)
ids = [chunk.id for chunk in chunks]
documents = [chunk.content for chunk in chunks]
metadatas = [
{
"document_id": doc.id,
"document_title": doc.title,
"chunk_index": chunk.chunk_index,
"file_type": doc.file_type,
"folder_path": folder_path or "",
}
for chunk in chunks
]
collection.add(ids=ids, documents=documents, metadatas=metadatas)
doc.is_indexed = True
await self.db.commit()
async def retrieve(
self,
query: str,
user_id: str,
folder_id: str | None = None,
top_k: int = 5,
use_rerank: bool = True,
) -> list[SearchResult]:
"""
混合检索 + Rerank支持按文件夹过滤
流程:
1. ChromaDB 向量检索 (扩大候选集)
2. 提取父 chunk完整上下文
3. Rerank 二次排序
4. 返回 top_k 结果
"""
collection = self.get_collection(user_id)
# 构建过滤条件
where = None
if folder_id:
folder_path = await self._get_folder_path(folder_id)
if folder_path:
where = {"folder_path": {"$starts_with": folder_path}}
try:
results = collection.query(
query_texts=[query],
n_results=top_k * 3,
where=where,
include=["documents", "metadatas", "distances"],
)
except Exception:
return []
if not results or not results.get("ids"):
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
search_results: list[SearchResult] = []
for i, chunk_id in enumerate(ids):
meta = metadatas[i] if i < len(metadatas) else {}
score = 1.0 - (distances[i] if i < len(distances) else 0.0)
prev_chunk, next_chunk = await self._get_sibling_chunks(
chunk_id=chunk_id,
chunk_index=meta.get("chunk_index", 0),
document_id=meta.get("document_id", ""),
)
search_results.append(SearchResult(
chunk_id=chunk_id,
document_id=meta.get("document_id", ""),
document_title=meta.get("document_title", ""),
content=documents[i] if i < len(documents) else "",
score=score,
metadata_=str(meta),
prev_chunk=prev_chunk,
next_chunk=next_chunk,
))
if use_rerank:
search_results = self._rerank(query, search_results, top_k)
else:
search_results = search_results[:top_k]
return search_results
def _rerank(
self,
query: str,
results: list[SearchResult],
top_k: int,
) -> list[SearchResult]:
"""Rerank: 语义分 * 0.7 + 关键词匹配 * 0.2 + 标题匹配 * 0.1"""
import re
query_words = set(re.findall(r"\w+", query.lower()))
scored = []
for r in results:
score = r.score * 0.7
content_words = set(re.findall(r"\w+", r.content.lower()))
keyword_overlap = len(query_words & content_words) / max(len(query_words), 1)
score += keyword_overlap * 0.2
if r.document_title:
title_words = set(re.findall(r"\w+", r.document_title.lower()))
title_overlap = len(query_words & title_words) / max(len(query_words), 1)
score += title_overlap * 0.1
scored.append((score, r))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
async def _get_sibling_chunks(
self,
chunk_id: str,
chunk_index: int,
document_id: str,
) -> tuple[str | None, str | None]:
"""获取前一个和后一个 chunk完整上下文"""
prev_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index - 1,
)
)
next_result = await self.db.execute(
select(DocumentChunk).where(
DocumentChunk.document_id == document_id,
DocumentChunk.chunk_index == chunk_index + 1,
)
)
prev_chunk = prev_result.scalar_one_or_none()
next_chunk = next_result.scalar_one_or_none()
return (
prev_chunk.content if prev_chunk else None,
next_chunk.content if next_chunk else None,
)
async def _get_folder_path(self, folder_id: str) -> str | None:
"""获取文件夹的完整路径"""
result = await self.db.execute(
select(Folder).where(Folder.id == folder_id)
)
folder = result.scalar_one_or_none()
if not folder:
return None
path_parts = [folder.name]
current_parent_id = folder.parent_id
while current_parent_id:
parent_result = await self.db.execute(
select(Folder).where(Folder.id == current_parent_id)
)
parent = parent_result.scalar_one_or_none()
if not parent:
break
path_parts.insert(0, parent.name)
current_parent_id = parent.parent_id
return "/" + "/".join(path_parts)
async def hybrid_search(
self,
query: str,
user_id: str,
top_k: int = 5,
) -> list[SearchResult]:
"""混合检索: 向量 + 关键词 + Rerank"""
vector_results = await self.retrieve(query, user_id, top_k=top_k * 2, use_rerank=False)
keyword_results = await self._keyword_search(query, user_id, top_k)
seen = set()
merged: list[SearchResult] = []
for r in vector_results + keyword_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
merged.append(r)
return self._rerank(query, merged, top_k)
async def _keyword_search(
self,
query: str,
user_id: str,
top_k: int,
) -> list[SearchResult]:
"""SQL 关键词搜索"""
result = await self.db.execute(
select(DocumentChunk)
.join(Document)
.where(Document.user_id == user_id)
.where(
or_(
DocumentChunk.content.contains(query),
Document.title.contains(query),
)
)
.limit(top_k)
)
chunks = result.scalars().all()
results = []
for chunk in chunks:
doc_result = await self.db.execute(
select(Document).where(Document.id == chunk.document_id)
)
doc = doc_result.scalar_one_or_none()
results.append(SearchResult(
chunk_id=chunk.id,
document_id=chunk.document_id,
document_title=doc.title if doc else "",
content=chunk.content,
score=0.5,
metadata_=None,
))
return results
async def delete_from_vectorstore(self, user_id: str, document_id: str):
"""从向量库删除文档"""
collection = self.get_collection(user_id)
try:
collection.delete(where={"document_id": document_id})
except Exception:
pass

View File

@@ -0,0 +1,145 @@
"""
LLM 服务 - 支持多种 LLM 提供商
OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
"""
from abc import ABC, abstractmethod
from typing import AsyncIterator
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from app.config import settings
import httpx
import os
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
class LLMService(ABC):
@abstractmethod
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
raise NotImplementedError
@abstractmethod
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
raise NotImplementedError
@abstractmethod
def get_model_name(self) -> str:
raise NotImplementedError
class OpenAICompatibleService(LLMService):
"""
OpenAI 兼容接口
支持 OpenAI、DeepSeek、硅基流动、任意 OpenAI API 兼容服务
"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
):
self.api_key = api_key or settings.OPENAI_API_KEY
self.model = model or settings.OPENAI_MODEL
self.base_url = base_url or settings.OPENAI_BASE_URL
self._llm = ChatOpenAI(
api_key=self.api_key,
model=self.model,
base_url=self.base_url,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class ClaudeService(LLMService):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_tokens: int = 8192,
):
self.api_key = api_key or settings.ANTHROPIC_API_KEY
self.model = model or settings.CLAUDE_MODEL
self._llm = ChatAnthropic(
api_key=self.api_key,
model=self.model,
max_tokens=max_tokens,
timeout=httpx.Timeout(60.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
class OllamaService(LLMService):
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
):
self.base_url = base_url or settings.OLLAMA_BASE_URL
self.model = model or settings.OLLAMA_MODEL
self._llm = ChatOllama(
base_url=self.base_url,
model=self.model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
async def invoke(self, messages: list[BaseMessage]) -> AIMessage:
return await self._llm.ainvoke(messages)
async def stream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
async for chunk in self._llm.astream(messages):
if chunk.content:
yield chunk.content
def get_model_name(self) -> str:
return self.model
# 单例缓存
_llm_instance: LLMService | None = None
def get_llm() -> LLMService:
"""根据配置获取 LLM 实例"""
global _llm_instance
if _llm_instance is None:
provider = settings.LLM_PROVIDER
if provider == "openai":
_llm_instance = OpenAICompatibleService()
elif provider == "deepseek":
_llm_instance = OpenAICompatibleService(
base_url="https://api.deepseek.com/v1",
model="deepseek-chat",
)
elif provider == "custom":
_llm_instance = OpenAICompatibleService()
elif provider == "claude":
_llm_instance = ClaudeService()
elif provider == "ollama":
_llm_instance = OllamaService()
else:
raise ValueError(f"Unknown LLM provider: {provider}")
return _llm_instance

View File

@@ -0,0 +1,304 @@
"""
Jarvis 记忆系统
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
"""
import json
import re
from datetime import datetime
from typing import Optional
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import MemorySummary, UserMemory
from app.models.conversation import Conversation, Message
from app.services.llm_service import get_llm
from app.agents.context import get_current_user
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
limit: int = 20,
) -> list[Message]:
"""加载指定对话的历史消息"""
result = await db.execute(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
.limit(limit)
)
return list(result.scalars().all())
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id))
.where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
)
return result.scalar() or 0
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
turn_count = await get_conversation_turn_count(db, conversation_id)
# 检查是否已有摘要覆盖到当前轮数
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(desc(MemorySummary.turn_count))
.limit(1)
)
latest_summary = result.scalar_one_or_none()
if latest_summary:
return turn_count - latest_summary.turn_count >= SUMMARIZE_THRESHOLD
return turn_count >= SUMMARIZE_THRESHOLD
async def generate_summary(
db: AsyncSession,
conversation_id: str,
messages: list[Message],
) -> str:
"""调用 LLM 生成对话摘要"""
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
HumanMessage(content=history_text),
])
return response.content.strip()
async def save_summary(
db: AsyncSession,
user_id: str,
conversation_id: str,
summary_text: str,
turn_count: int,
) -> MemorySummary:
"""保存对话摘要"""
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
summary_text=summary_text,
turn_count=turn_count,
)
db.add(summary)
await db.commit()
await db.refresh(summary)
return summary
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[MemorySummary]:
"""获取某对话的所有历史摘要"""
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
.order_by(MemorySummary.summary_at)
)
return list(result.scalars().all())
# ———— 长期记忆: 用户画像 ————
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
只提取事实性的、可能对未来对话有帮助的信息,如:
- 用户的身份/职业/背景
- 用户的偏好和习惯
- 用户的目标和计划
- 重要的事件和日期
- 用户的观点和态度
每条记忆格式: [类型] 内容
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
如果没有提取到任何记忆,回复""
"""
FACT_TYPES = {"fact", "preference", "goal", "habit"}
def _parse_fact_line(line: str) -> tuple[str, str] | None:
"""解析一行记忆: [fact] 内容 -> (type, content)"""
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
if m and m.group(1) in FACT_TYPES:
return m.group(1), m.group(2).strip()
return None
async def extract_user_memories(
db: AsyncSession,
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[UserMemory]:
"""从对话中提取用户记忆并保存"""
if len(messages) < 2:
return []
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages[-10:]
)
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content=EXTRACTION_PROMPT),
HumanMessage(content=history_text),
])
text = response.content.strip()
if text == "" or not text:
return []
memories = []
for line in text.split("\n"):
parsed = _parse_fact_line(line)
if not parsed:
continue
mem_type, content = parsed
# 检查是否已有完全相同的记忆
existing = await db.execute(
select(UserMemory).where(
UserMemory.user_id == user_id,
UserMemory.content == content,
)
)
if existing.scalar_one_or_none():
continue
mem = UserMemory(
user_id=user_id,
memory_type=mem_type,
content=content,
importance=5,
source_conversation_id=conversation_id,
)
db.add(mem)
memories.append(mem)
if memories:
await db.commit()
return memories
async def recall_user_memories(
db: AsyncSession,
user_id: str,
query: str,
top_k: int = 5,
) -> list[UserMemory]:
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
# 先尝试语义相似(通过 LLM 判断)
# 降级: 直接从数据库取最近的重要记忆
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
.limit(top_k)
)
memories = list(result.scalars().all())
# 重置召回标记
for m in memories:
m.is_recalled = False
await db.commit()
return memories
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
"""标记记忆已被召回使用"""
result = await db.execute(
select(UserMemory).where(UserMemory.id == memory_id)
)
mem = result.scalar_one_or_none()
if mem:
mem.is_recalled = True
mem.recall_count = (mem.recall_count or 0) + 1
mem.last_recalled_at = datetime.utcnow()
await db.commit()
# ———— 记忆组装: 供 Agent 使用的上下文 ————
async def build_memory_context(
db: AsyncSession,
user_id: str,
conversation_id: str,
current_query: str,
) -> str:
"""
构建完整的记忆上下文字符串,
供注入到 Agent system prompt 中使用。
"""
parts = []
# 1. 用户画像(长期记忆)
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if user_memories:
lines = []
for m in user_memories:
tag = f"[{m.memory_type}]"
lines.append(f" {tag} {m.content}")
await mark_memory_recalled(db, m.id)
parts.append("【用户记忆】\n" + "\n".join(lines))
# 2. 对话摘要(中期记忆)
summaries = await get_summaries(db, conversation_id)
if summaries:
# 只取最近2条
recent = summaries[-2:]
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
if not parts:
return ""
return "\n\n".join(parts)
async def try_auto_summarize(
db: AsyncSession,
user_id: str,
conversation_id: str,
) -> bool:
"""
检查是否需要摘要,如果需要则生成并保存。
返回是否执行了摘要。
"""
if not await should_summarize(db, conversation_id):
return False
messages = await load_conversation_history(db, conversation_id, limit=30)
if len(messages) < 3:
return False
try:
summary_text = await generate_summary(db, conversation_id, messages)
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
# 同时提取用户记忆
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception:
return False

View File

@@ -0,0 +1,291 @@
"""
定时任务服务 - APScheduler 调度器
"""
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy import select, and_
from app.database import async_session
from app.models.task import Task
from app.models.forum import ForumPost
from app.models.knowledge_graph import KGNode
from app.services.agent_service import AgentService
from app.services.graph_service import GraphService
from app.config import settings
import logging
logger = logging.getLogger(__name__)
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
# ===================== 定时任务函数 =====================
async def daily_task_analysis():
"""
每日凌晨任务分析
- 分析前一天完成的任务
- 生成每日报告
- 创建次日计划建议
"""
logger.info("[Scheduler] 开始执行每日任务分析...")
async with async_session() as db:
from datetime import datetime, timedelta
yesterday = datetime.utcnow().date() - timedelta(days=1)
# 统计昨日任务完成情况
result = await db.execute(
select(Task).where(Task.updated_at >= yesterday)
)
tasks = result.scalars().all()
completed = [t for t in tasks if t.status == "done"]
pending = [t for t in tasks if t.status != "done"]
report = f"""## 每日任务报告 - {yesterday.strftime('%Y-%m-%d')}
### 完成情况
- 总任务数: {len(tasks)}
- 已完成: {len(completed)}
- 未完成: {len(pending)}
### 已完成任务
{chr(10).join([f"- {t.title}" for t in completed]) or ""}
### 未完成任务
{chr(10).join([f"- {t.title} (优先级: {t.priority})" for t in pending]) or ""}
### 建议
根据未完成任务,建议明天优先处理:
{chr(10).join([f"{i+1}. {t.title}" for i, t in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5])]) or "无待处理任务"}
"""
# 发布到论坛
from app.models.forum import ForumPost
post = ForumPost(
title=f"每日报告 - {yesterday.strftime('%Y-%m-%d')}",
content=report,
category="discussion",
)
db.add(post)
# 创建明日计划建议任务
for i, task in enumerate(sorted(pending, key=lambda x: x.priority, reverse=True)[:5]):
suggestion = Task(
title=f"继续: {task.title}",
description=f"昨日未完成任务,优先级: {task.priority}",
priority=task.priority,
status="todo",
)
db.add(suggestion)
await db.commit()
logger.info(f"[Scheduler] 每日任务分析完成,完成 {len(completed)} 个任务")
async def forum_scan_task():
"""
论坛扫描任务
- 扫描所有指令类帖子
- 识别可执行指令
- AI自动执行
"""
logger.info("[Scheduler] 开始扫描论坛指令...")
async with async_session() as db:
from sqlalchemy import select
result = await db.execute(
select(ForumPost).where(
ForumPost.category == "instruction",
ForumPost.is_executed == False,
).limit(5)
)
posts = result.scalars().all()
if not posts:
logger.info("[Scheduler] 暂无待执行指令")
return
agent_svc = AgentService(db)
executed_count = 0
for post in posts:
try:
# 让 Agent 分析并执行指令
conv_id, msg_id, response = await agent_svc.chat_simple(
user_id=post.user_id,
message=f"请执行以下论坛指令: {post.title}{post.content}",
conversation_id=None,
)
post.is_executed = True
post.executed_response = response
executed_count += 1
logger.info(f"[Scheduler] 执行指令: {post.title}")
except Exception as e:
logger.error(f"[Scheduler] 执行指令失败 {post.title}: {e}")
await db.commit()
logger.info(f"[Scheduler] 论坛扫描完成,执行了 {executed_count} 个指令")
async def graph_rebuild_task():
"""
知识图谱增量重建任务
- 扫描新增/更新的文档
- 更新图谱节点和边
"""
logger.info("[Scheduler] 开始重建知识图谱...")
async with async_session() as db:
try:
graph_svc = GraphService(db)
# 只处理最近7天有活动的文档
await graph_svc.build_graph(user_id="default", document_ids=None)
logger.info("[Scheduler] 知识图谱重建完成")
except Exception as e:
logger.error(f"[Scheduler] 知识图谱重建失败: {e}")
async def tag_generation_task():
"""
每日凌晨 00:00 增量标签生成任务
"""
from app.services.tag_service import TagService
from app.core.llm import get_llm_client
from sqlalchemy import select
logger.info("[Scheduler] 开始执行每日标签生成...")
async with async_session() as db:
try:
llm_client = get_llm_client()
tag_service = TagService(db, llm_client)
result = await db.execute(
select(KGNode.user_id).distinct().where(
KGNode.entity_type.in_(["conversation", "document", "chunk"])
)
)
user_ids = result.scalars().all()
total_tagged = 0
for user_id in user_ids:
sync_tag_service = TagService(db, llm_client)
result = sync_tag_service.tag_incremental_content(user_id, days=1)
total_tagged += result["tagged"]
logger.info(f"[Scheduler] 每日标签生成完成,共标签化 {total_tagged} 个内容节点")
except Exception as e:
logger.error(f"[Scheduler] 每日标签生成失败: {e}")
async def daily_todo_generation():
"""
每天早上 08:00 为所有活跃用户生成待办
- 来自前一天未完成的看板任务
- 来自前一天对话记录分析
"""
from app.models.user import User
from app.services.todo_service import generate_daily_todos
from sqlalchemy import select
logger.info("[Scheduler] 开始执行每日待办生成...")
async with async_session() as db:
try:
result = await db.execute(select(User).where(User.is_active == True))
users = result.scalars().all()
for user in users:
try:
await generate_daily_todos(user.id, db)
logger.info(f"[Scheduler] 为用户 {user.id} 生成今日待办完成")
except Exception as e:
logger.error(f"[Scheduler] 用户 {user.id} 定时生成待办失败: {e}")
logger.info(f"[Scheduler] 每日待办生成完成,共处理 {len(users)} 个用户")
except Exception as e:
logger.error(f"[Scheduler] 每日待办生成失败: {e}")
# ===================== 调度器管理 =====================
def start_scheduler():
"""启动调度器,注册所有定时任务"""
if scheduler.running:
logger.warning("[Scheduler] 调度器已在运行")
return
# 每日凌晨 00:30 执行任务分析
scheduler.add_job(
daily_task_analysis,
CronTrigger(hour=0, minute=30, timezone="Asia/Shanghai"),
id="daily_task_analysis",
name="每日任务分析",
replace_existing=True,
)
# 每小时扫描论坛指令
scheduler.add_job(
forum_scan_task,
IntervalTrigger(hours=1),
id="forum_scan",
name="论坛指令扫描",
replace_existing=True,
)
# 每天凌晨 3:00 重建图谱
scheduler.add_job(
graph_rebuild_task,
CronTrigger(hour=3, minute=0, timezone="Asia/Shanghai"),
id="graph_rebuild",
name="知识图谱重建",
replace_existing=True,
)
# 每天凌晨 00:00 生成标签
scheduler.add_job(
tag_generation_task,
CronTrigger(hour=0, minute=0, timezone="Asia/Shanghai"),
id="tag_generation",
name="每日标签生成",
replace_existing=True,
)
# 每天早上 08:00 生成今日待办
scheduler.add_job(
daily_todo_generation,
CronTrigger(hour=8, minute=0, timezone="Asia/Shanghai"),
id="daily_todo_generation",
name="每日待办生成",
replace_existing=True,
)
scheduler.start()
logger.info("[Scheduler] 定时任务调度器已启动")
def stop_scheduler():
"""停止调度器"""
if scheduler.running:
scheduler.shutdown(wait=False)
logger.info("[Scheduler] 定时任务调度器已停止")
def get_scheduler_status() -> dict:
"""获取调度器状态"""
if not scheduler.running:
return {"status": "stopped", "jobs": []}
jobs = []
for job in scheduler.get_jobs():
jobs.append({
"id": job.id,
"name": job.name,
"next_run": str(job.next_run_time) if job.next_run_time else None,
})
return {"status": "running", "jobs": jobs}

View File

@@ -0,0 +1,140 @@
import logging
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.user import User
from app.services.auth_service import verify_password, get_password_hash
logger = logging.getLogger(__name__)
async def get_user_settings(user_id: str, db: AsyncSession) -> dict:
"""获取用户完整设置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
return None
return {
"profile": user,
"llm_config": user.llm_config or {},
"scheduler_config": user.scheduler_config or {}
}
async def update_user_profile(
user_id: str,
db: AsyncSession,
full_name: Optional[str] = None,
password: Optional[str] = None,
current_password: Optional[str] = None
) -> User:
"""更新用户资料"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
if password:
if not current_password or not verify_password(current_password, user.hashed_password):
raise ValueError("当前密码错误")
user.hashed_password = get_password_hash(password)
if full_name:
user.full_name = full_name
await db.commit()
await db.refresh(user)
return user
async def update_llm_config(user_id: str, config: dict, db: AsyncSession) -> dict:
"""更新 LLM 配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
current = user.llm_config or {}
# 合并配置 - 直接替换整个类型配置列表
for key, value in config.items():
if value is not None:
if isinstance(value, list):
# 列表直接替换
current[key] = value
elif isinstance(value, dict):
# 字典合并
if key in current and isinstance(current[key], dict):
current[key] = {**current[key], **value}
else:
current[key] = value
else:
current[key] = value
user.llm_config = current
await db.commit()
return current
async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession) -> dict:
"""更新定时任务配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise ValueError("用户不存在")
current = user.scheduler_config or {}
for key, value in config.items():
if value is not None:
current[key] = value
user.scheduler_config = current
await db.commit()
return current
async def test_llm_connection(
provider: str,
model: str,
base_url: str,
api_key: str
) -> dict:
"""测试 LLM 连接"""
try:
# 根据不同 provider 创建临时 LLM 实例并测试
if provider == "openai":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=30
)
elif provider == "claude":
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=30
)
elif provider == "ollama":
from langchain_ollama import ChatOllama
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=30
)
elif provider == "deepseek":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or "https://api.deepseek.com/v1",
timeout=30
)
else:
return {"success": False, "error": f"不支持的 provider: {provider}"}
# 简单测试调用
from langchain_core.messages import HumanMessage
response = await llm.ainvoke([HumanMessage(content="Hi")])
return {"success": True, "message": f"连接成功,模型响应: {response.content[:50]}..."}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -0,0 +1,278 @@
import psutil
import time
from datetime import datetime, timedelta
from sqlalchemy import select, func, and_
from sqlalchemy.orm import Session
from app.models.conversation import Conversation, Message
from app.models.knowledge_graph import KGNode, KGEdge
from app.models.task import Task, TaskStatus
from app.models.forum import ForumPost, ForumReply
from app.models.document import Document
class StatsService:
def __init__(self, db: Session):
self.db = db
def get_system_health(self) -> dict:
"""获取系统健康指标"""
uptime_seconds = int(time.time() - psutil.boot_time())
cpu_percent = psutil.cpu_percent(interval=0.1)
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"uptime_seconds": uptime_seconds,
"cpu_percent": cpu_percent,
"memory_used_mb": round(mem.used / (1024 * 1024), 1),
"memory_total_mb": round(mem.total / (1024 * 1024), 1),
"memory_percent": mem.percent,
"disk_used_gb": round(disk.used / (1024 * 1024 * 1024), 1),
"disk_total_gb": round(disk.total / (1024 * 1024 * 1024), 1),
"disk_percent": disk.percent,
"active_users_24h": 0, # 需要 User 表的 updated_at
}
def _get_daily_stats(self, model, date_column, user_id=None, days=30) -> list:
"""通用每日统计查询"""
cutoff = datetime.utcnow() - timedelta(days=days)
query = self.db.query(
func.date(date_column).label('date'),
func.count().label('count')
).filter(date_column >= cutoff)
if user_id and hasattr(model, 'user_id'):
query = query.filter(model.user_id == user_id)
query = query.group_by(func.date(date_column)).order_by(func.date(date_column))
results = query.all()
return [{"date": str(r.date), "count": r.count} for r in results]
def get_conversation_stats(self, user_id: str = None, days=30) -> dict:
"""获取对话统计数据"""
cutoff = datetime.utcnow() - timedelta(days=days)
daily_conversations = self._get_daily_stats(
Conversation, Conversation.created_at, user_id, days
)
daily_messages = self._get_daily_stats(
Message, Message.created_at, user_id, days
)
# Daily tokens
input_query = self.db.query(
func.date(Message.created_at).label('date'),
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
).filter(
Message.created_at >= cutoff,
Message.role == 'user'
)
if user_id:
input_query = input_query.join(Conversation).filter(Conversation.user_id == user_id)
input_results = input_query.group_by(func.date(Message.created_at)).all()
output_query = self.db.query(
func.date(Message.created_at).label('date'),
func.coalesce(func.sum(Message.tokens_used), 0).label('tokens')
).filter(
Message.created_at >= cutoff,
Message.role == 'assistant'
)
if user_id:
output_query = output_query.join(Conversation).filter(Conversation.user_id == user_id)
output_results = output_query.group_by(func.date(Message.created_at)).all()
daily_input_tokens = [{"date": str(r.date), "input_tokens": r.tokens} for r in input_results]
daily_output_tokens = [{"date": str(r.date), "output_tokens": r.tokens} for r in output_results]
return {
"daily_conversations": daily_conversations,
"daily_messages": daily_messages,
"daily_input_tokens": daily_input_tokens,
"daily_output_tokens": daily_output_tokens,
"totals": {
"conversations": sum(c["count"] for c in daily_conversations),
"messages": sum(m["count"] for m in daily_messages),
"input_tokens": sum(t["input_tokens"] for t in daily_input_tokens),
"output_tokens": sum(t["output_tokens"] for t in daily_output_tokens),
}
}
def get_knowledge_stats(self, user_id: str = None, days=30) -> dict:
"""获取知识库统计数据"""
cutoff = datetime.utcnow() - timedelta(days=days)
# New tags
tag_query = self.db.query(
func.date(KGNode.created_at).label('date'),
func.count().label('count')
).filter(
KGNode.created_at >= cutoff,
KGNode.entity_type == 'tag'
)
if user_id:
tag_query = tag_query.filter(KGNode.user_id == user_id)
tag_results = tag_query.group_by(func.date(KGNode.created_at)).all()
daily_new_tags = [{"date": str(r.date), "count": r.count} for r in tag_results]
daily_documents = self._get_daily_stats(
Document, Document.created_at, user_id, days
)
daily_tag_relations = self._get_daily_stats(
KGEdge, KGEdge.created_at, user_id, days
)
return {
"daily_new_tags": daily_new_tags,
"daily_documents": daily_documents,
"daily_knowledge_queries": [],
"daily_tag_relations": daily_tag_relations,
"totals": {
"new_tags": sum(t["count"] for t in daily_new_tags),
"documents": sum(d["count"] for d in daily_documents),
"tag_relations": sum(r["count"] for r in daily_tag_relations),
}
}
def get_kanban_stats(self, user_id: str = None, days=30) -> dict:
"""获取看板统计数据"""
daily_new_tasks = self._get_daily_stats(
Task, Task.created_at, user_id, days
)
# Completed tasks
completed_query = self.db.query(
func.date(Task.completed_at).label('date'),
func.count().label('count')
).filter(
Task.completed_at >= datetime.utcnow() - timedelta(days=days),
Task.status == TaskStatus.DONE
)
if user_id:
completed_query = completed_query.filter(Task.user_id == user_id)
completed_results = completed_query.group_by(func.date(Task.completed_at)).all()
daily_completed_tasks = [{"date": str(r.date), "count": r.count} for r in completed_results]
# Current pending
pending_query = self.db.query(func.count(Task.id)).filter(Task.status == TaskStatus.TODO)
if user_id:
pending_query = pending_query.filter(Task.user_id == user_id)
current_pending_tasks = pending_query.scalar() or 0
# Completion rate
daily_new_dict = {d["date"]: d["count"] for d in daily_new_tasks}
daily_completed_dict = {d["date"]: d["count"] for d in daily_completed_tasks}
all_dates = set(daily_new_dict.keys()) | set(daily_completed_dict.keys())
daily_completion_rate = []
for date in sorted(all_dates):
new = daily_new_dict.get(date, 0)
completed = daily_completed_dict.get(date, 0)
rate = (completed / new * 100) if new > 0 else 0
daily_completion_rate.append({"date": date, "rate": round(rate, 1)})
return {
"daily_new_tasks": daily_new_tasks,
"daily_completed_tasks": daily_completed_tasks,
"daily_completion_rate": daily_completion_rate,
"current_pending_tasks": current_pending_tasks,
"totals": {
"new_tasks": sum(t["count"] for t in daily_new_tasks),
"completed_tasks": sum(c["count"] for c in daily_completed_tasks),
}
}
def get_community_stats(self, user_id: str = None, days=30) -> dict:
"""获取社区统计数据"""
daily_posts = self._get_daily_stats(
ForumPost, ForumPost.created_at, user_id, days
)
daily_replies = self._get_daily_stats(
ForumReply, ForumReply.created_at, user_id, days
)
# AI executions
ai_query = self.db.query(
func.date(ForumPost.updated_at).label('date'),
func.count().label('count')
).filter(
ForumPost.updated_at >= datetime.utcnow() - timedelta(days=days),
ForumPost.is_executed == True
)
if user_id:
ai_query = ai_query.filter(ForumPost.user_id == user_id)
ai_results = ai_query.group_by(func.date(ForumPost.updated_at)).all()
daily_ai_executions = [{"date": str(r.date), "count": r.count} for r in ai_results]
return {
"daily_posts": daily_posts,
"daily_replies": daily_replies,
"daily_ai_executions": daily_ai_executions,
"daily_agent_calls": [],
"totals": {
"posts": sum(p["count"] for p in daily_posts),
"replies": sum(r["count"] for r in daily_replies),
"ai_executions": sum(a["count"] for a in daily_ai_executions),
}
}
def get_personal_insights(self, user_id: str) -> dict:
"""获取个人洞察"""
# Hourly activity
hourly_query = self.db.query(
func.extract('hour', Conversation.created_at).label('hour'),
func.count().label('count')
).filter(Conversation.user_id == user_id).group_by(
func.extract('hour', Conversation.created_at)
)
hourly_results = hourly_query.all()
hourly_activity = [{"hour": int(r.hour), "count": r.count} for r in hourly_results]
# Top tags
tag_query = self.db.query(
KGNode.properties_["tag_path"].astext.label('tag_path'),
func.count(KGEdge.id).label('usage_count')
).join(
KGEdge, KGEdge.target_id == KGNode.id
).filter(
KGNode.user_id == user_id,
KGNode.entity_type == 'tag',
KGEdge.relation_type == 'has_tag'
).group_by(
KGNode.properties_["tag_path"].astext
).order_by(func.count(KGEdge.id).desc()).limit(5)
top_tags = [{"tag_path": r.tag_path, "usage_count": r.usage_count} for r in tag_query.all()]
# Token trend
now = datetime.utcnow()
this_month_start = datetime(now.year, now.month, 1)
last_month_end = this_month_start - timedelta(days=1)
last_month_start = datetime(last_month_end.year, last_month_end.month, 1)
this_month_tokens = self.db.query(
func.coalesce(func.sum(Message.tokens_used), 0)
).join(Conversation).filter(
Conversation.user_id == user_id,
Message.created_at >= this_month_start,
Message.role == 'assistant'
).scalar() or 0
last_month_tokens = self.db.query(
func.coalesce(func.sum(Message.tokens_used), 0)
).join(Conversation).filter(
Conversation.user_id == user_id,
Message.created_at >= last_month_start,
Message.created_at < this_month_start,
Message.role == 'assistant'
).scalar() or 0
token_trend_percent = 0
if last_month_tokens > 0:
token_trend_percent = round((this_month_tokens - last_month_tokens) / last_month_tokens * 100, 1)
return {
"hourly_activity": hourly_activity,
"top_tags": top_tags,
"token_trend_percent": token_trend_percent,
"this_month_tokens": this_month_tokens,
"last_month_tokens": last_month_tokens,
}

View File

@@ -0,0 +1,239 @@
import json
from sqlalchemy.orm import Session
from app.models.knowledge_graph import KGNode, KGEdge
TAG_EXTRACTION_PROMPT = """你是一个知识分类专家。从给定内容中提取标签。
要求:
1. 标签采用层级路径格式,如 "编程语言/Python""后端/框架/FastAPI"
2. 层级深度 1-4 层,避免过深
3. 每个内容提取 3-8 个标签
4. 标签应覆盖:主题、技术栈、领域、任务类型等维度
输出格式JSON数组
[
{"path": "编程语言/Python", "description": "Python编程语言相关"},
{"path": "后端/框架/FastAPI", "description": "FastAPI框架相关"}
]
内容:
{content}
"""
TAG_RELATION_PROMPT = """分析以下标签之间的关系,输出 JSON 数组:
关系类型:
- parent_of: 父子关系(上级包含下级)
- related_to: 语义相关(但不是父子)
- synonym_of: 同义词
标签列表:
{tag_paths}
输出格式:
[
{"source": "标签1", "target": "标签2", "relation": "related_to", "weight": 0.8},
{"source": "标签1", "target": "标签3", "relation": "parent_of", "weight": 1.0}
]
"""
class TagService:
def __init__(self, db: Session, llm_client):
self.db = db
self.llm_client = llm_client
def extract_tags_from_content(self, content: str, user_id: str) -> list[dict]:
"""从内容中提取标签"""
response = self.llm_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个知识分类专家。"},
{"role": "user", "content": TAG_EXTRACTION_PROMPT.format(content=content)}
],
response_format={"type": "json_object"}
)
result = json.loads(response.choices[0].message.content)
return result.get("tags", [])
def parse_tag_path(self, path: str) -> tuple[str, int, str | None]:
"""解析标签路径,返回 (short_name, level, parent_path)"""
parts = path.strip("/").split("/")
short_name = parts[-1]
level = len(parts)
parent_path = "/".join(parts[:-1]) if level > 1 else None
return short_name, level, parent_path
def get_or_create_tag_node(self, tag_info: dict, user_id: str) -> KGNode:
"""获取或创建标签节点"""
path = tag_info["path"]
existing = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.properties_["tag_path"].astext == path
).first()
if existing:
return existing
short_name, level, parent_path = self.parse_tag_path(path)
node = KGNode(
user_id=user_id,
name=short_name,
entity_type="tag",
description=tag_info.get("description"),
properties_={
"tag_path": path,
"short_name": short_name,
"level": level,
"parent_path": parent_path,
"description": tag_info.get("description"),
"color": tag_info.get("color"),
},
importance=0.5
)
self.db.add(node)
self.db.flush()
return node
def ensure_parent_tags(self, path: str, user_id: str) -> list[KGNode]:
"""确保父路径标签存在"""
parts = path.strip("/").split("/")
nodes = []
for i in range(1, len(parts)):
parent_path = "/".join(parts[:i])
tag_info = {"path": parent_path, "description": None}
node = self.get_or_create_tag_node(tag_info, user_id)
nodes.append(node)
return nodes
def create_tag_relations(self, tag_paths: list[str], user_id: str) -> list[KGEdge]:
"""分析并创建标签之间的关系边"""
path_to_node = {}
for path in tag_paths:
node = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.properties_["tag_path"].astext == path,
KGNode.entity_type == "tag"
).first()
if node:
path_to_node[path] = node
response = self.llm_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个知识图谱专家。"},
{"role": "user", "content": TAG_RELATION_PROMPT.format(tag_paths=json.dumps(tag_paths))}
],
response_format={"type": "json_object"}
)
result = json.loads(response.choices[0].message.content)
relations = result.get("relations", [])
edges = []
for rel in relations:
source_node = path_to_node.get(rel["source"])
target_node = path_to_node.get(rel["target"])
if source_node and target_node:
existing = self.db.query(KGEdge).filter(
KGEdge.source_id == source_node.id,
KGEdge.target_id == target_node.id
).first()
if not existing:
edge = KGEdge(
source_id=source_node.id,
target_id=target_node.id,
relation_type=rel["relation"],
weight=rel.get("weight", 0.5)
)
self.db.add(edge)
edges.append(edge)
self.db.flush()
return edges
def tag_content(self, content: str, user_id: str, content_node: KGNode) -> list[KGNode]:
"""为内容节点打标签"""
tag_infos = self.extract_tags_from_content(content, user_id)
tag_paths = [t["path"] for t in tag_infos]
tag_nodes = []
for tag_info in tag_infos:
node = self.get_or_create_tag_node(tag_info, user_id)
tag_nodes.append(node)
self.ensure_parent_tags(tag_info["path"], user_id)
# 创建 has_tag 边
for tag_node in tag_nodes:
existing_edge = self.db.query(KGEdge).filter(
KGEdge.source_id == content_node.id,
KGEdge.target_id == tag_node.id,
KGEdge.relation_type == "has_tag"
).first()
if not existing_edge:
edge = KGEdge(
source_id=content_node.id,
target_id=tag_node.id,
relation_type="has_tag",
weight=1.0
)
self.db.add(edge)
tag_node_ids = [n.id for n in tag_nodes]
current_tag_ids = content_node.properties_.get("tag_node_ids", []) if content_node.properties_ else []
content_node.properties_["tag_node_ids"] = list(set(current_tag_ids + tag_node_ids))
if len(tag_paths) >= 2:
self.create_tag_relations(tag_paths, user_id)
self.db.commit()
return tag_nodes
def tag_incremental_content(self, user_id: str, days: int = 1) -> dict:
"""
增量打标签 - 只对最近新增/更新的内容节点打标签
"""
from datetime import datetime, timedelta
cutoff_date = datetime.utcnow() - timedelta(days=days)
content_nodes = self.db.query(KGNode).filter(
KGNode.user_id == user_id,
KGNode.entity_type.in_(["conversation", "document", "chunk"]),
KGNode.updated_at >= cutoff_date
).all()
untagged = [
n for n in content_nodes
if not n.properties_.get("tag_node_ids")
]
tagged_count = 0
for node in untagged:
content = node.description or ""
try:
self.tag_content(content, user_id, node)
tagged_count += 1
except Exception as e:
pass
return {"total": len(untagged), "tagged": tagged_count}
def get_related_content(self, tag_node_ids: list[str], user_id: str, limit: int = 10) -> list[tuple[KGNode, float]]:
"""通过标签找相关内容"""
edges = self.db.query(KGEdge).filter(
KGEdge.target_id.in_(tag_node_ids),
KGEdge.relation_type == "has_tag"
).all()
content_weights: dict[str, float] = {}
for edge in edges:
content_weights[edge.source_id] = content_weights.get(edge.source_id, 0) + edge.weight
content_ids = list(content_weights.keys())
content_nodes = self.db.query(KGNode).filter(
KGNode.id.in_(content_ids),
KGNode.entity_type.in_(["conversation", "document", "chunk"])
).all()
return [(node, content_weights[node.id]) for node in content_nodes]

View File

@@ -0,0 +1,165 @@
import json
import logging
from datetime import date, datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.todo import DailyTodo, TodoSource
from app.models.task import Task, TaskStatus
from app.models.conversation import Conversation, Message
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
logger = logging.getLogger(__name__)
async def generate_daily_todos(user_id: str, db: AsyncSession) -> list[DailyTodo]:
"""
为用户生成今日待办:
1. 来自前一天未完成的看板任务最多20条
2. 来自前一天对话记录分析最多3条
"""
today = date.today()
yesterday = (today - timedelta(days=1)).isoformat()
todos: list[DailyTodo] = []
# 1. 从看板任务导入
kanban_todos = await _import_kanban_tasks(user_id, yesterday, db)
todos.extend(kanban_todos)
# 2. 从对话记录分析
chat_todos = await _analyze_chat_history(user_id, yesterday, db)
todos.extend(chat_todos)
return todos
async def _import_kanban_tasks(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
"""导入前一天创建的、未完成的看板任务"""
q = select(Task).where(
Task.user_id == user_id,
Task.status != TaskStatus.DONE,
).order_by(Task.created_at.desc()).limit(20)
tasks = (await db.execute(q)).scalars().all()
todos = []
for task in tasks:
todo = DailyTodo(
user_id=user_id,
title=task.title,
source=TodoSource.AI_KANBAN,
source_detail=f"看板:{task.title}",
source_ref_id=task.id,
todo_date=date.today().isoformat(),
)
db.add(todo)
todos.append(todo)
if todos:
await db.commit()
for todo in todos:
await db.refresh(todo)
return todos
async def _analyze_chat_history(user_id: str, date_str: str, db: AsyncSession) -> list[DailyTodo]:
"""分析前一天对话,提取待办事项"""
try:
# 查询前一天创建的对话
conv_q = select(Conversation).where(
Conversation.user_id == user_id,
).order_by(Conversation.created_at.desc()).limit(10)
convs = (await db.execute(conv_q)).scalars().all()
# 过滤出昨天的对话
yesterday_convs = []
for conv in convs:
created = conv.created_at
if hasattr(created, 'date'):
created_date = created.date() if hasattr(created, 'date') else created
else:
created_date = datetime.fromisoformat(str(created)).date()
if str(created_date) == date_str or (created + timedelta(hours=8)).strftime('%Y-%m-%d') == date_str:
yesterday_convs.append(conv)
if not yesterday_convs:
return []
# 收集消息内容限制2000字
messages_content = []
for conv in yesterday_convs:
msg_q = select(Message).where(
Message.conversation_id == conv.id
).order_by(Message.created_at.asc()).limit(50)
msgs = (await db.execute(msg_q)).scalars().all()
for msg in msgs:
if msg.content:
messages_content.append(f"[{msg.role}]: {msg.content[:500]}")
if not messages_content:
return []
full_text = "\n".join(messages_content)[:2000]
# 调用 LLM 分析
prompt = f"""你是一个任务规划助手。请分析以下对话记录,提取其中用户想要完成但尚未明确完成的事项。
要求:
- 最多提取 3 条
- 每条格式:{{"title": "事项描述50字以内", "reason": "来源说明60字以内"}}
- 只提取用户明确表达过需求但还未完成的事项
- 如果没有可提取的内容,返回空数组 []
对话记录:
{full_text}
返回 JSON 数组:"""
llm = get_llm()
response = await llm.invoke([
SystemMessage(content="你是一个任务规划助手。"),
HumanMessage(content=prompt),
])
content = response.content if hasattr(response, 'content') else str(response)
# 尝试解析 JSON
try:
# 提取 JSON 数组
start = content.find('[')
end = content.rfind(']') + 1
if start != -1 and end > start:
items = json.loads(content[start:end])
else:
items = []
except (json.JSONDecodeError, ValueError):
logger.warning(f"LLM 返回格式异常,跳过对话分析: {content[:200]}")
items = []
if not items:
return []
todos = []
for item in items[:3]:
todo = DailyTodo(
user_id=user_id,
title=item.get("title", "")[:500],
source=TodoSource.AI_CHAT,
source_detail=f"对话:{item.get('reason', '')[:60]}",
todo_date=date.today().isoformat(),
)
db.add(todo)
todos.append(todo)
if todos:
await db.commit()
for todo in todos:
await db.refresh(todo)
return todos
except Exception as e:
logger.error(f"对话分析失败: {e}")
return []

Binary file not shown.

BIN
backend/data/jarvis.db Normal file

Binary file not shown.

View File

@@ -0,0 +1,119 @@
远光软件股份有限公司科技项目可行性研究报告
项目名称:大模型微调技术研究与应用
申请部门:
起止时间:年至年
项目负责人:
联系电话:
申请日期:年 月
大模型微调技术可行性研究报告
远光软件股份有限公司科技项目可行性研究报告
项目名称: 大模型微调技术研究与应用
申请部门:
起止时间: 年 月至 年 月
项目负责人:
联系电话:
申请日期: 年 月
一、目的和意义
1.1 项目背景与需求
近年来以深度学习为基础的大型预训练语言模型Large Language Models,
LLMs如GPT系列、BERT、LLaMA等在自然语言处理领域取得了突破性进展通过海量数据的预训练和超大规模参数量这些模型展现出强大的通用语言理解与生成能力在机器翻译、文本摘要、问答系统、内容创作等众多任务中表现出色引领了人工智能技术的新浪潮。然而这些通用大模型在面对特定专业领域任务时往往存在知识覆盖不足、专业术语理解偏差、领域特定逻辑推理能力欠缺、输出风格不符合行业特点等问题难以直接满足垂直场景的应用需求。
模型微调Fine-tuning技术作为将通用大模型适配到特定场景的关键手段通过在领域相关数据上进一步训练模型参数使模型能够吸收领域知识、适应特定任务要求从而显著提升模型在目标任务上的性能表现。随着大模型参数规模的不断扩大传统的全参数微调方式面临着计算资源消耗大、存储成本高、容易产生灾难性遗忘等挑战因此参数高效微调Parameter-Efficient
Fine-Tuning,
PEFT方法如LoRA、Adapter、Prefix-tuning等技术应运而生为低成本、高效率的大模型领域适配提供了新的技术路径。
本项目旨在探索适合特定领域特点的高效微调策略,解决数据稀缺性、专业术语理解、领域知识融合等关键技术问题,提升模型在特定场景下的准确性、可靠性和实用性。
项目成果将对该现状和技术发展的作用主要体现在技术推动作用和应用落地支撑两方面。
二、国内外研究水平综述
2.1 技术发展历史简要回顾
大模型微调技术的发展历程分为四个阶段:
第一阶段2018年前传统迁移学习与微调雏形阶段。模型适配多采用传统迁移学习思路将通用数据集上训练的基础模型迁移至特定任务场景。
第二阶段2018-2020年预训练-微调范式确立阶段。2018年谷歌提出BERT模型首次构建"预训练通用知识+下游任务微调"的技术框架。
第三阶段2020-2022年高效微调技术爆发阶段。LoRA、QLoRA、Adapter等参数高效微调技术相继出现将微调参数规模大幅降低。
第四阶段2022年至今垂直领域深化与协同优化阶段。"基座模型+领域微调"的架构成为主流,微调技术与知识图谱进一步融合。
2.2 国内外研究水平现状和发展趋势
国际层面Hugging
Face、DeepSpeed等开源社区为参数高效微调技术的普及提供了重要支撑。国内层面阿里云基于通义千问进行财税领域定制微调验证了微调技术在财务领域的应用价值。
三、项目的理论和实践依据
3.1 项目研究内容原理简述
本项目采用"基座模型+领域适配"分层微调架构选取开源基座模型针对财务问答场景特性采用LoRA参数高效微调策略。
3.2 项目研究内容理论和实践依据
理论依据包括国家战略层面的政策支持和成熟的技术理论体系。实践依据包括大模型微调技术在财务等垂直领域的成功案例。
3.3 项目研究的关键和难点
关键点包括高质量数据集构建、高效微调策略适配、知识精准注入与幻觉抑制、效果评估体系建设。难点集中在数据处理、微调策略、知识注入和评估体系四个方面。
四、项目研究内容和实施方案
4.1 项目研究内容详细说明
本项目研究内容包括数据格式研究、微调框架研究、模型微调后评估体系研究三个方面。
4.2 理论研究步骤和试验计划
包括数据处理流程、训练数据生成流程、数据验证流程三个主要环节。
4.3 项目组织方式和协作分工
本项目由项目负责人统筹协调,下设数据组、算法组、应用组三个工作小组。
五、预期目标和成果形式
5.1 项目研究预期达到的目标
技术目标问答准确率达到85%以上。应用目标开发财务智能知识问答原型系统。效益目标替代财务专家70%以上的重复性咨询工作。
5.2 明确叙述提高研究成果的形式
包括技术方案文档、原型系统、训练数据集、微调模型、技术论文/报告等成果形式。
六、项目承担团队的条件
项目团队具备人工智能、大数据等领域的技术背景具备财务信息系统开发经验具备充足的GPU计算资源和完善的开发测试环境。
七、项目进度安排
第1-2月项目启动、需求分析第3-4月数据收集、清洗第5-7月数据集生成第8-10月模型训练第11-12月系统开发第13-14月优化整理第15-16月验收转化。
八、项目经费预算
本项目经费预算根据实际研究工作需要编制,包括人工费、设备使用费、业务费、场地使用费、专家咨询费等科目。
分管领导审核意见:
(对经费预算是否合理,有无其他经费来源,能否保证研究计划实施所需的人力,工作时间等基本条件提出具体意见)
分管领导(签字): 年 月 日

77
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,77 @@
[project]
name = "jarvis-backend"
version = "0.1.0"
description = "Jarvis Personal AI Assistant - Backend"
readme = "README.md"
requires-python = ">=3.12"
license = { text = "MIT" }
dependencies = [
# Web 框架
"fastapi>=0.115.0",
"uvicorn[standard]>=0.30.0",
"python-multipart>=0.0.12",
"websockets>=12.0",
"aiofiles>=24.0.0",
# Agent 框架
"langgraph>=0.2.36",
"langchain-anthropic>=0.3.14",
"langchain-openai>=0.3.18",
"langchain-core>=0.3.52",
"langchain-ollama>=0.4.0",
"langsmith>=0.1.0",
# 知识库框架
"llama-index>=0.12.0",
"llama-index-vector-stores-chroma>=0.3.0",
"chromadb>=0.5.0",
# 数据库
"sqlalchemy>=2.0.0",
"aiosqlite>=0.20.0",
"alembic>=1.13.0",
# 认证 & 安全
"python-jose[cryptography]>=3.3.0",
"passlib[bcrypt]>=1.7.4",
"bcrypt>=4.0.0,<5.0.0",
# 配置 & 验证
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"email-validator>=2.0.0",
# 定时任务
"APScheduler>=3.10.0",
# 工具
"python-dotenv>=1.0.0",
"httpx>=0.27.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-cov>=4.1.0",
"ruff>=0.5.0",
"mypy>=1.10.0",
"pre-commit>=3.7.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["app"]
[tool.ruff]
target-version = "py312"
line-length = 100
select = ["E", "F", "I", "N", "W", "UP"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View File

View File

View File

View File

@@ -0,0 +1,120 @@
import pytest
from unittest.mock import MagicMock, patch
from app.services.tag_service import TagService, TAG_EXTRACTION_PROMPT, TAG_RELATION_PROMPT
class TestTagService:
"""TagService 单元测试"""
def test_parse_tag_path_single_level(self):
"""测试单层标签路径解析"""
service = TagService(db=MagicMock(), llm_client=MagicMock())
short_name, level, parent_path = service.parse_tag_path("Python")
assert short_name == "Python"
assert level == 1
assert parent_path is None
def test_parse_tag_path_nested(self):
"""测试多层标签路径解析"""
service = TagService(db=MagicMock(), llm_client=MagicMock())
short_name, level, parent_path = service.parse_tag_path("编程语言/Python/异步")
assert short_name == "异步"
assert level == 3
assert parent_path == "编程语言/Python"
def test_parse_tag_path_strips_slashes(self):
"""测试标签路径斜杠处理"""
service = TagService(db=MagicMock(), llm_client=MagicMock())
short_name, level, parent_path = service.parse_tag_path("/后端/框架/")
assert short_name == "框架"
assert level == 2
assert parent_path == "后端"
def test_parse_tag_path_empty_parts(self):
"""测试空路径部分处理"""
service = TagService(db=MagicMock(), llm_client=MagicMock())
short_name, level, parent_path = service.parse_tag_path("a/b/c/d")
assert short_name == "d"
assert level == 4
assert parent_path == "a/b/c"
@patch('app.services.tag_service.KGNode')
@patch('app.services.tag_service.KGEdge')
def test_get_or_create_tag_node_creates_new(self, mock_edge, mock_node):
"""测试创建新标签节点"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
service = TagService(db=mock_db, llm_client=MagicMock())
tag_info = {"path": "Python", "description": "Python语言"}
result = service.get_or_create_tag_node(tag_info, "user_123")
assert result is not None
mock_db.add.assert_called_once()
mock_db.flush.assert_called_once()
@patch('app.services.tag_service.KGNode')
def test_get_or_create_tag_node_returns_existing(self, mock_node):
"""测试返回已存在的标签节点"""
mock_db = MagicMock()
mock_existing = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_existing
service = TagService(db=mock_db, llm_client=MagicMock())
tag_info = {"path": "Python", "description": "Python语言"}
result = service.get_or_create_tag_node(tag_info, "user_123")
assert result == mock_existing
mock_db.add.assert_not_called()
def test_ensure_parent_tags_creates_parents(self):
"""测试自动创建父标签"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
service = TagService(db=mock_db, llm_client=MagicMock())
with patch.object(service, 'get_or_create_tag_node') as mock_create:
mock_create.return_value = MagicMock()
result = service.ensure_parent_tags("a/b/c", "user_123")
assert mock_create.call_count == 2
def test_ensure_parent_tags_single_level(self):
"""测试单层标签不创建父标签"""
mock_db = MagicMock()
service = TagService(db=mock_db, llm_client=MagicMock())
with patch.object(service, 'get_or_create_tag_node') as mock_create:
mock_create.return_value = MagicMock()
result = service.ensure_parent_tags("Python", "user_123")
assert mock_create.call_count == 0
@patch('app.services.tag_service.KGNode')
def test_get_related_content_empty_tags(self, mock_kg_node):
"""测试空标签列表返回空结果"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.all.return_value = []
service = TagService(db=mock_db, llm_client=MagicMock())
result = service.get_related_content([], "user_123")
assert result == []
def test_tag_extraction_prompt_format(self):
"""测试标签提取提示词格式"""
assert "层级路径格式" in TAG_EXTRACTION_PROMPT
assert "3-8 个标签" in TAG_EXTRACTION_PROMPT
assert "{content}" in TAG_EXTRACTION_PROMPT
def test_tag_relation_prompt_format(self):
"""测试标签关系提示词格式"""
assert "parent_of" in TAG_RELATION_PROMPT
assert "related_to" in TAG_RELATION_PROMPT
assert "synonym_of" in TAG_RELATION_PROMPT
assert "{tag_paths}" in TAG_RELATION_PROMPT

4261
backend/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff