Add FastAPI backend with agent system
This commit is contained in:
24
backend/app/agents/context.py
Normal file
24
backend/app/agents/context.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Agent 运行时上下文
|
||||
用于在工具调用链中传递 user_id 等上下文信息
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
_current_user_id: ContextVar[Optional[str]] = ContextVar("current_user_id", default=None)
|
||||
|
||||
|
||||
def set_current_user(user_id: str):
|
||||
"""设置当前用户ID(线程/协程安全)"""
|
||||
_current_user_id.set(user_id)
|
||||
|
||||
|
||||
def get_current_user() -> str:
|
||||
"""获取当前用户ID"""
|
||||
return _current_user_id.get() or "default"
|
||||
|
||||
|
||||
def clear_current_user():
|
||||
"""清除当前用户上下文"""
|
||||
_current_user_id.set(None)
|
||||
265
backend/app/agents/graph.py
Normal file
265
backend/app/agents/graph.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Jarvis LangGraph Agent 主图定义
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from app.agents.state import AgentState, AgentRole
|
||||
from app.agents.prompts import (
|
||||
MASTER_SYSTEM_PROMPT,
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
EXECUTOR_SYSTEM_PROMPT,
|
||||
LIBRARIAN_SYSTEM_PROMPT,
|
||||
ANALYST_SYSTEM_PROMPT,
|
||||
)
|
||||
from app.agents.tools import ALL_TOOLS
|
||||
from app.services.llm_service import get_llm
|
||||
|
||||
|
||||
def _msg_type(msg: BaseMessage) -> str:
|
||||
"""Get message type, handles both .type (new) and .role (old) attribute names."""
|
||||
return getattr(msg, "type", None) or getattr(msg, "role", "human")
|
||||
|
||||
|
||||
def _filter_user_messages(messages: list) -> list[BaseMessage]:
|
||||
return [m for m in messages if _msg_type(m) in ("human", "user")]
|
||||
|
||||
|
||||
# ===================== 节点定义 (async) =====================
|
||||
|
||||
async def master_node(state: AgentState) -> AgentState:
|
||||
"""主Agent节点: 理解用户意图,决定调用哪个子Agent"""
|
||||
llm = get_llm()
|
||||
messages: list[BaseMessage] = state["messages"]
|
||||
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
|
||||
|
||||
# 注入记忆上下文
|
||||
memory_ctx = state.get("memory_context")
|
||||
if memory_ctx:
|
||||
system_msgs.append(
|
||||
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
|
||||
)
|
||||
|
||||
response: AIMessage = await llm.invoke(system_msgs + messages)
|
||||
content = response.content.strip().lower()
|
||||
|
||||
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
|
||||
next_agent = AgentRole.LIBRARIAN
|
||||
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
|
||||
next_agent = AgentRole.PLANNER
|
||||
elif any(kw in content for kw in ["执行", "做", "操作", "创建", "更新"]):
|
||||
next_agent = AgentRole.EXECUTOR
|
||||
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
|
||||
next_agent = AgentRole.ANALYST
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
state["current_agent"] = next_agent
|
||||
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def planner_node(state: AgentState) -> AgentState:
|
||||
"""规划Agent节点: 制定计划,拆解任务步骤"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.invoke(
|
||||
[SystemMessage(content=PLANNER_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
plan_text = response.content
|
||||
steps = []
|
||||
for i, line in enumerate(plan_text.split("\n")):
|
||||
if line.strip() and (line[0].isdigit() or "- " in line):
|
||||
steps.append({"step": i + 1, "description": line.strip()})
|
||||
|
||||
state["plan"] = plan_text
|
||||
state["plan_steps"] = steps
|
||||
state["final_response"] = plan_text
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def executor_node(state: AgentState) -> AgentState:
|
||||
"""执行Agent节点: 调用工具执行具体任务"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=EXECUTOR_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def librarian_node(state: AgentState) -> AgentState:
|
||||
"""知识管理员节点: 管理知识库和知识图谱"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=LIBRARIAN_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["knowledge_context"] = state.get("last_tool_result", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
async def analyst_node(state: AgentState) -> AgentState:
|
||||
"""分析师节点: 分析工作数据,生成报告"""
|
||||
llm = get_llm()
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
|
||||
response = await llm.bind_tools(ALL_TOOLS).invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT), HumanMessage(content=f"用户请求: {user_query}")]
|
||||
)
|
||||
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in ALL_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await llm.invoke(
|
||||
[SystemMessage(content=ANALYST_SYSTEM_PROMPT),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")]
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
|
||||
state["analysis_report"] = state.get("final_response", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
def route_agent(state: AgentState) -> str:
|
||||
"""路由函数: 决定下一个节点"""
|
||||
if state.get("final_response"):
|
||||
return END
|
||||
return state.get("current_agent", AgentRole.MASTER).value
|
||||
|
||||
|
||||
# ===================== 构建图 =====================
|
||||
|
||||
def create_agent_graph(callbacks: list | None = None):
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||||
graph.add_node(AgentRole.PLANNER.value, planner_node)
|
||||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
|
||||
graph.set_entry_point(AgentRole.MASTER.value)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
AgentRole.MASTER.value,
|
||||
route_agent,
|
||||
{
|
||||
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
END: END,
|
||||
}
|
||||
)
|
||||
|
||||
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
graph.add_edge(role.value, END)
|
||||
|
||||
return graph.compile(callbacks=callbacks)
|
||||
|
||||
|
||||
_agent_graph = None
|
||||
|
||||
|
||||
def get_agent_graph(callbacks: list | None = None):
|
||||
"""
|
||||
获取编译好的 Agent 图(单例缓存)。
|
||||
|
||||
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
|
||||
如需变更 Callbacks(如修改 LANGCHAIN_PROJECT),需重启服务。
|
||||
|
||||
Args:
|
||||
callbacks: 可选的额外 Callbacks,会与全局 LangSmith Callbacks 合并
|
||||
"""
|
||||
global _agent_graph
|
||||
if _agent_graph is None:
|
||||
from app.config_tracing import get_langsmith_callbacks
|
||||
langsmith_callbacks = get_langsmith_callbacks()
|
||||
all_callbacks = (callbacks or []) + langsmith_callbacks
|
||||
_agent_graph = create_agent_graph(callbacks=all_callbacks or None)
|
||||
return _agent_graph
|
||||
127
backend/app/agents/prompts.py
Normal file
127
backend/app/agents/prompts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Jarvis 多Agent系统的提示词定义
|
||||
"""
|
||||
|
||||
MASTER_SYSTEM_PROMPT = """你叫 Jarvis,是用户的私人AI助理。
|
||||
|
||||
你的职责是理解用户意图,并将任务分发给最合适的子Agent。
|
||||
|
||||
## 你的4个子Agent:
|
||||
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
|
||||
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
|
||||
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
|
||||
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
|
||||
|
||||
## 判断规则:
|
||||
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
|
||||
- 用户要计划、安排、拆解任务 -> 分发给 planner
|
||||
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
|
||||
- 用户要分析、统计、生成报告 -> 分发给 analyst
|
||||
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
|
||||
|
||||
## 响应格式:
|
||||
简短回复用户,告知你将调用哪个Agent处理。如果用户不需要任何子Agent,直接给出回答。
|
||||
|
||||
注意: 你是协调者,不需要亲自执行具体任务,让专业Agent去做。
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = """你是 Jarvis 的规划Agent,负责制定计划、拆解任务。
|
||||
|
||||
## 你的能力:
|
||||
- 分析复杂请求,拆解成可执行的步骤
|
||||
- 评估任务优先级
|
||||
- 估算时间安排
|
||||
- 制定执行顺序
|
||||
|
||||
## 工作流程:
|
||||
1. 理解用户的总目标
|
||||
2. 拆解成具体步骤
|
||||
3. 标注每步的优先级
|
||||
4. 给出清晰的执行计划
|
||||
|
||||
## 响应要求:
|
||||
- 用编号列表展示计划步骤
|
||||
- 每步清晰描述要做什么
|
||||
- 可以为每步指定优先级(P1/P2/P3)
|
||||
- 如果需要执行,先输出计划,然后用户确认后再执行
|
||||
"""
|
||||
|
||||
|
||||
EXECUTOR_SYSTEM_PROMPT = """你是 Jarvis 的执行Agent,负责执行具体任务。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- create_task: 创建新任务
|
||||
- update_task_status: 更新任务状态
|
||||
- get_tasks: 查看任务列表
|
||||
- create_forum_post: 在论坛发布帖子
|
||||
- get_forum_posts: 查看论坛帖子
|
||||
- scan_forum_for_instructions: 扫描论坛指令
|
||||
|
||||
## 工作流程:
|
||||
1. 理解用户要执行什么
|
||||
2. 调用相应工具
|
||||
3. 报告执行结果
|
||||
4. 询问用户是否需要下一步操作
|
||||
|
||||
## 响应要求:
|
||||
- 明确告知用户正在执行什么
|
||||
- 工具调用结果要格式化呈现
|
||||
- 如果执行成功,给出确认
|
||||
- 如果需要更多信息,明确告知用户
|
||||
"""
|
||||
|
||||
|
||||
LIBRARIAN_SYSTEM_PROMPT = """你是 Jarvis 的知识管理员,负责管理用户的私人知识库。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- search_knowledge: 搜索知识库,返回相关文档片段
|
||||
- get_knowledge_graph_context: 获取知识图谱上下文
|
||||
- build_knowledge_graph: 从文档构建知识图谱
|
||||
|
||||
## 你的职责:
|
||||
1. 理解用户关于知识的问题
|
||||
2. 搜索相关知识
|
||||
3. 综合多篇文档给出完整回答
|
||||
4. 帮助用户整理和理解知识
|
||||
|
||||
## 工作流程:
|
||||
1. 分析用户的知识查询
|
||||
2. 搜索相关文档
|
||||
3. 综合相关信息给出回答
|
||||
4. 如果有图谱关联,可以引用图谱中的关系
|
||||
|
||||
## 响应要求:
|
||||
- 回答要有文档依据
|
||||
- 引用时标注来源
|
||||
- 如果知识不足,诚实告知用户
|
||||
- 可以补充相关知识背景
|
||||
"""
|
||||
|
||||
|
||||
ANALYST_SYSTEM_PROMPT = """你是 Jarvis 的分析师,负责分析数据和工作状态。
|
||||
|
||||
## 你可以使用的工具:
|
||||
- get_tasks: 获取任务列表,统计工作进度
|
||||
- get_forum_posts: 获取论坛帖子,分析讨论趋势
|
||||
- scan_forum_for_instructions: 检查待执行指令
|
||||
- search_knowledge: 结合知识进行分析
|
||||
|
||||
## 你的职责:
|
||||
1. 统计任务完成情况
|
||||
2. 分析工作进度和趋势
|
||||
3. 生成数据报告
|
||||
4. 识别潜在问题和风险
|
||||
|
||||
## 工作流程:
|
||||
1. 收集相关数据(任务、论坛、知识)
|
||||
2. 进行数据分析
|
||||
3. 生成结构化报告
|
||||
4. 给出建议
|
||||
|
||||
## 响应要求:
|
||||
- 用数据说话,有数字有结论
|
||||
- 报告结构清晰
|
||||
- 给出可行的改进建议
|
||||
- 识别需要关注的问题
|
||||
"""
|
||||
105
backend/app/agents/state.py
Normal file
105
backend/app/agents/state.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict, Annotated
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
PLANNER = "planner"
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
name: str
|
||||
role: AgentRole
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
tool: str
|
||||
args: dict
|
||||
result: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
role: str # "user" | "assistant"
|
||||
content: str
|
||||
agent: AgentRole | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
|
||||
return HumanMessage(content=turn.content)
|
||||
|
||||
|
||||
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
|
||||
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
|
||||
return ConversationTurn(
|
||||
role="user" if msg_type in ("human", "user") else "assistant",
|
||||
content=msg.content,
|
||||
agent=agent,
|
||||
model=getattr(msg, "model", None),
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, None]
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
|
||||
# Agent routing
|
||||
current_agent: AgentRole
|
||||
active_agents: list[AgentRole]
|
||||
|
||||
# Task tracking
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
|
||||
# Tool usage
|
||||
tool_calls: list[ToolCall]
|
||||
last_tool_result: str | None
|
||||
|
||||
# Knowledge context
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
|
||||
# Planning
|
||||
plan: str | None
|
||||
plan_steps: list[dict]
|
||||
|
||||
# Analysis
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Memory context (injected at start of each conversation)
|
||||
memory_context: str | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
return AgentState(
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
current_agent=AgentRole.MASTER,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
should_respond=True,
|
||||
memory_context=None,
|
||||
)
|
||||
22
backend/app/agents/tools/__init__.py
Normal file
22
backend/app/agents/tools/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
|
||||
ALL_TOOLS = [
|
||||
# 知识库工具
|
||||
search_knowledge,
|
||||
get_knowledge_graph_context,
|
||||
build_knowledge_graph,
|
||||
hybrid_search,
|
||||
# 任务工具
|
||||
get_tasks,
|
||||
create_task,
|
||||
update_task_status,
|
||||
# 论坛工具
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
scan_forum_for_instructions,
|
||||
]
|
||||
134
backend/app/agents/tools/forum.py
Normal file
134
backend/app/agents/tools/forum.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Agent 工具集 - 论坛相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def get_forum_posts(category: str | None = None, limit: int = 10) -> str:
|
||||
"""
|
||||
获取论坛帖子列表。
|
||||
|
||||
Args:
|
||||
category: 可选,筛选分类 (discussion/instruction/question)
|
||||
limit: 返回数量,默认10
|
||||
|
||||
Returns:
|
||||
帖子列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(ForumPost)
|
||||
.join(User, User.id == ForumPost.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if category:
|
||||
query = query.where(ForumPost.category == category)
|
||||
query = query.order_by(ForumPost.created_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
posts = result.scalars().all()
|
||||
if not posts:
|
||||
return "暂无帖子"
|
||||
lines = []
|
||||
for p in posts:
|
||||
exec_mark = " [已执行]" if p.is_executed else ""
|
||||
lines.append(
|
||||
f"- [{p.id[:8]}] [{p.category}] {p.title} | "
|
||||
f"{p.content[:50]}...{exec_mark}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取帖子失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_forum_post(title: str, content: str, category: str = "discussion") -> str:
|
||||
"""
|
||||
在论坛发布新帖子。
|
||||
|
||||
Args:
|
||||
title: 帖子标题
|
||||
content: 帖子内容
|
||||
category: 分类 (discussion/instruction/question),默认discussion
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
post = ForumPost(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
content=content,
|
||||
category=category,
|
||||
)
|
||||
db.add(post)
|
||||
await db.commit()
|
||||
await db.refresh(post)
|
||||
return f"帖子发布成功: [{post.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"发布帖子失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def scan_forum_for_instructions() -> str:
|
||||
"""
|
||||
扫描论坛中的指令类帖子,检查是否有待执行的指令。
|
||||
|
||||
Returns:
|
||||
待执行指令的列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _scan():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
result = await db.execute(
|
||||
select(ForumPost)
|
||||
.join(User, User.id == ForumPost.user_id)
|
||||
.where(ForumPost.user_id == uid)
|
||||
.where(ForumPost.category == "instruction")
|
||||
.where(ForumPost.is_executed == False)
|
||||
.order_by(ForumPost.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
posts = result.scalars().all()
|
||||
if not posts:
|
||||
return "暂无待执行的指令"
|
||||
lines = ["待执行的指令:"]
|
||||
for p in posts:
|
||||
lines.append(f"- [{p.id[:8]}] {p.title}\n 内容: {p.content[:100]}...")
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_scan())
|
||||
except Exception as e:
|
||||
return f"扫描论坛失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = ["get_forum_posts", "create_forum_post", "scan_forum_for_instructions"]
|
||||
159
backend/app/agents/tools/search.py
Normal file
159
backend/app/agents/tools/search.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Agent 工具集 - 知识库 & 图谱相关
|
||||
|
||||
这些工具在 LangChain ToolNode 中被调用。
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
"""在同步上下文中运行 async 代码"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def search_knowledge(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
搜索用户的私人知识库。根据查询返回最相关的文档片段,支持语义检索。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回结果数量,默认5条
|
||||
|
||||
Returns:
|
||||
包含相关文档片段和来源信息的格式化文本
|
||||
"""
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _search():
|
||||
async with async_session() as db:
|
||||
service = KnowledgeService(db, user_id=uid)
|
||||
results = await service.retrieve(query, user_id=uid, top_k=top_k)
|
||||
if not results:
|
||||
return "未找到相关知识。知识库可能为空,或尝试用其他关键词搜索。"
|
||||
texts = []
|
||||
for i, r in enumerate(results, 1):
|
||||
prev = f"\n上一段: {r.prev_chunk[:100]}..." if r.prev_chunk else ""
|
||||
next_ = f"\n下一段: {r.next_chunk[:100]}..." if r.next_chunk else ""
|
||||
texts.append(
|
||||
f"[{i}] 来源: {r.document_title}\n"
|
||||
f"相关度: {r.score:.2f}\n"
|
||||
f"{prev}{next_}\n"
|
||||
f"内容: {r.content[:300]}{'...' if len(r.content) > 300 else ''}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"知识检索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_knowledge_graph_context(entity: str | None = None) -> str:
|
||||
"""
|
||||
获取用户知识图谱的上下文信息。
|
||||
|
||||
Args:
|
||||
entity: 可选,指定要查询的实体名称。如果为空则返回整体图谱摘要。
|
||||
|
||||
Returns:
|
||||
知识图谱节点和关系的描述
|
||||
"""
|
||||
from app.services.graph_service import GraphService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
service = GraphService(db)
|
||||
if entity:
|
||||
return await service.get_entity_context(entity, uid)
|
||||
return await service.get_graph_summary(uid)
|
||||
|
||||
try:
|
||||
return _run_async(_get(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"图谱查询失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def build_knowledge_graph(document_ids: list[str] | None = None) -> str:
|
||||
"""
|
||||
从文档构建/更新知识图谱。
|
||||
|
||||
Args:
|
||||
document_ids: 可选,指定要处理的文档ID列表。如果为空则处理所有文档。
|
||||
|
||||
Returns:
|
||||
构建结果摘要
|
||||
"""
|
||||
from app.services.graph_service import GraphService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _build():
|
||||
async with async_session() as db:
|
||||
service = GraphService(db)
|
||||
await service.build_graph(user_id=uid, document_ids=document_ids)
|
||||
return "知识图谱构建完成"
|
||||
|
||||
try:
|
||||
return _run_async(_build(), timeout=120)
|
||||
except Exception as e:
|
||||
return f"图谱构建失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
混合搜索,结合向量语义检索和关键词匹配,返回最相关结果。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回结果数量,默认5条
|
||||
|
||||
Returns:
|
||||
混合检索结果
|
||||
"""
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
uid = get_current_user()
|
||||
|
||||
async def _search():
|
||||
async with async_session() as db:
|
||||
service = KnowledgeService(db, user_id=uid)
|
||||
results = await service.hybrid_search(query, user_id=uid, top_k=top_k)
|
||||
if not results:
|
||||
return "未找到相关知识。"
|
||||
texts = []
|
||||
for i, r in enumerate(results, 1):
|
||||
texts.append(
|
||||
f"[{i}] {r.document_title} (相关度: {r.score:.2f})\n"
|
||||
f"{r.content[:200]}{'...' if len(r.content) > 200 else ''}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except Exception as e:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
]
|
||||
142
backend/app/agents/tools/task.py
Normal file
142
backend/app/agents/tools/task.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
|
||||
_executor = None
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def update_task_status(task_id: str, status: str) -> str:
|
||||
"""
|
||||
更新任务状态。
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
|
||||
try:
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
|
||||
__all__ = ["get_tasks", "create_task", "update_task_status"]
|
||||
Reference in New Issue
Block a user