feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -1,397 +1,354 @@
|
||||
"""
|
||||
Jarvis LangGraph Agent 主图定义
|
||||
Jarvis LangGraph Agent 主图定义 - 优化重构版
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal, Union, List, Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
ToolMessage
|
||||
)
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
|
||||
from app.agents.state import AgentState, AgentRole
|
||||
from app.agents.prompts import (
|
||||
MASTER_SYSTEM_PROMPT,
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||||
EXECUTOR_SYSTEM_PROMPT,
|
||||
LIBRARIAN_SYSTEM_PROMPT,
|
||||
ANALYST_SYSTEM_PROMPT,
|
||||
PLANNER_SCOPE_PROMPT,
|
||||
PLANNER_STEPS_PROMPT,
|
||||
EXECUTOR_TASKS_PROMPT,
|
||||
EXECUTOR_FORUM_PROMPT,
|
||||
LIBRARIAN_RETRIEVAL_PROMPT,
|
||||
LIBRARIAN_GRAPH_PROMPT,
|
||||
ANALYST_PROGRESS_PROMPT,
|
||||
ANALYST_INSIGHTS_PROMPT,
|
||||
JSON_ACTION_FALLBACK_PROMPT,
|
||||
)
|
||||
from app.agents.tools import ALL_TOOLS, SUB_COMMANDER_TOOLSETS
|
||||
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
|
||||
from app.agents.skill_registry import build_skill_context
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
from app.services.llm_service import (
|
||||
get_llm,
|
||||
create_llm_from_config,
|
||||
resolve_provider_capabilities,
|
||||
default_provider_capabilities
|
||||
)
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
logger = logging.getLogger("jarvis.agent")
|
||||
|
||||
SUB_COMMANDER_PROMPTS = {
|
||||
"planner_scope": PLANNER_SCOPE_PROMPT,
|
||||
"planner_steps": PLANNER_STEPS_PROMPT,
|
||||
"executor_tasks": EXECUTOR_TASKS_PROMPT,
|
||||
"executor_forum": EXECUTOR_FORUM_PROMPT,
|
||||
"librarian_retrieval": LIBRARIAN_RETRIEVAL_PROMPT,
|
||||
"librarian_graph": LIBRARIAN_GRAPH_PROMPT,
|
||||
"analyst_progress": ANALYST_PROGRESS_PROMPT,
|
||||
"analyst_insights": ANALYST_INSIGHTS_PROMPT,
|
||||
}
|
||||
|
||||
ROLE_SUB_COMMANDERS = {
|
||||
AgentRole.PLANNER: ["planner_scope", "planner_steps"],
|
||||
AgentRole.EXECUTOR: ["executor_tasks", "executor_forum"],
|
||||
AgentRole.LIBRARIAN: ["librarian_retrieval", "librarian_graph"],
|
||||
AgentRole.ANALYST: ["analyst_progress", "analyst_insights"],
|
||||
}
|
||||
|
||||
ROLE_SKILL_CONTEXT = {
|
||||
AgentRole.PLANNER: "planner",
|
||||
AgentRole.EXECUTOR: "executor",
|
||||
AgentRole.LIBRARIAN: "librarian",
|
||||
AgentRole.ANALYST: "analyst",
|
||||
}
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
# ===================== 工具辅助函数 =====================
|
||||
|
||||
def _get_llm_for_state(state: AgentState):
|
||||
"""从 state 获取 LLM 实例,优先使用用户配置的模型"""
|
||||
"""获取配置好的 LLM 实例"""
|
||||
user_llm_config = state.get("user_llm_config")
|
||||
if user_llm_config:
|
||||
return _create_llm_from_config(user_llm_config)
|
||||
return get_llm()
|
||||
llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm()
|
||||
|
||||
# 注入解析到的能力
|
||||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||||
if capabilities is None:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config) if user_llm_config else default_provider_capabilities()
|
||||
|
||||
state["provider_capabilities"] = {
|
||||
"provider": capabilities.provider,
|
||||
"supports_native_tools": capabilities.supports_native_tools,
|
||||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||||
}
|
||||
return llm, capabilities
|
||||
|
||||
|
||||
async def _ainvoke(llm, messages: list[BaseMessage]):
|
||||
ainvoke = getattr(llm, "ainvoke", None)
|
||||
if callable(ainvoke):
|
||||
return await ainvoke(messages)
|
||||
return await llm.invoke(messages)
|
||||
def _filter_user_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
return [m for m in messages if m.type in ("human", "user")]
|
||||
|
||||
|
||||
async def _ainvoke_with_tools(llm, messages: list[BaseMessage], tools=None):
|
||||
toolset = tools if tools is not None else ALL_TOOLS
|
||||
bound_llm = llm.bind_tools(toolset)
|
||||
if hasattr(bound_llm, "ainvoke"):
|
||||
return await bound_llm.ainvoke(messages)
|
||||
return await bound_llm.invoke(messages)
|
||||
|
||||
|
||||
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
|
||||
if callbacks:
|
||||
try:
|
||||
return graph.compile(callbacks=callbacks)
|
||||
except TypeError as exc:
|
||||
if "callbacks" not in str(exc):
|
||||
raise
|
||||
return graph.compile()
|
||||
|
||||
|
||||
def _msg_type(msg: BaseMessage) -> str:
|
||||
"""Get message type, handles both .type (new) and .role (old) attribute names."""
|
||||
return getattr(msg, "type", None) or getattr(msg, "role", "human")
|
||||
|
||||
|
||||
def _filter_user_messages(messages: list) -> list[BaseMessage]:
|
||||
return [m for m in messages if _msg_type(m) in ("human", "user")]
|
||||
|
||||
|
||||
def _normalize_user_text(text: str) -> str:
|
||||
return (text or "").strip().lower()
|
||||
|
||||
|
||||
def _is_simple_greeting(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你好", "您好", "早", "早上好", "在吗", "嗨", "hi", "hello"}
|
||||
|
||||
|
||||
def _is_identity_question(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你是谁", "你是誰"}
|
||||
|
||||
|
||||
def _is_capability_question(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你能做什么", "你可以做什么", "你会做什么"}
|
||||
|
||||
|
||||
def _choose_sub_commander(role: AgentRole, user_query: str) -> str:
|
||||
text = _normalize_user_text(user_query)
|
||||
|
||||
if role == AgentRole.PLANNER:
|
||||
if any(keyword in text for keyword in ["步骤", "计划", "拆解", "排期", "优先级", "路线"]):
|
||||
return "planner_steps"
|
||||
return "planner_scope"
|
||||
|
||||
def _get_role_tools(role: AgentRole) -> list:
|
||||
"""获取角色对应的所有可用工具集"""
|
||||
if role == AgentRole.SCHEDULE_PLANNER:
|
||||
# 合并分析和规划工具
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["schedule_analysis"] + SUB_COMMANDER_TOOLSETS["schedule_planning"]))
|
||||
if role == AgentRole.EXECUTOR:
|
||||
if any(keyword in text for keyword in ["论坛", "帖子", "发帖", "指令", "discussion", "instruction"]):
|
||||
return "executor_forum"
|
||||
return "executor_tasks"
|
||||
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["executor_tasks"] + SUB_COMMANDER_TOOLSETS["executor_forum"]))
|
||||
if role == AgentRole.LIBRARIAN:
|
||||
if any(keyword in text for keyword in ["图谱", "关系", "构建", "沉淀", "节点", "graph"]):
|
||||
return "librarian_graph"
|
||||
return "librarian_retrieval"
|
||||
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["librarian_retrieval"] + SUB_COMMANDER_TOOLSETS["librarian_graph"]))
|
||||
if role == AgentRole.ANALYST:
|
||||
if any(keyword in text for keyword in ["趋势", "风险", "洞察", "建议", "机会", "insight"]):
|
||||
return "analyst_insights"
|
||||
return "analyst_progress"
|
||||
|
||||
return ROLE_SUB_COMMANDERS[role][0]
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["analyst_progress"] + SUB_COMMANDER_TOOLSETS["analyst_insights"]))
|
||||
return []
|
||||
|
||||
|
||||
def _record_sub_commander(state: AgentState, sub_commander: str, user_query: str):
|
||||
state["current_sub_commander"] = sub_commander
|
||||
state["active_sub_commanders"] = state.get("active_sub_commanders", []) + [sub_commander]
|
||||
state["sub_commander_trace"] = state.get("sub_commander_trace", []) + [{
|
||||
"agent": state.get("current_agent", AgentRole.MASTER).value,
|
||||
"sub_commander": sub_commander,
|
||||
"query": user_query,
|
||||
}]
|
||||
# ===================== 核心执行逻辑 (ReAct) =====================
|
||||
|
||||
|
||||
def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRole):
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=system_prompt)]
|
||||
skill_ctx = build_skill_context(ROLE_SKILL_CONTEXT[role])
|
||||
async def call_agent_llm(state: AgentState, role: AgentRole, system_prompt: str) -> dict:
|
||||
"""通用的 LLM 调用节点逻辑"""
|
||||
llm, capabilities = _get_llm_for_state(state)
|
||||
tools = _get_role_tools(role)
|
||||
|
||||
# 构建消息序列
|
||||
messages = []
|
||||
|
||||
# 1. 系统提示词
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 2. 环境上下文 (时间、记忆等)
|
||||
if state.get("current_datetime_context"):
|
||||
messages.append(SystemMessage(content=f"当前时间上下文: {state['current_datetime_context']}"))
|
||||
|
||||
if state.get("memory_context"):
|
||||
messages.append(SystemMessage(content=f"长期记忆上下文: {state['memory_context']}"))
|
||||
|
||||
# 3. 技能增强
|
||||
role_skill_key = role.value.replace("agent_", "")
|
||||
skill_ctx = build_skill_context(role_skill_key)
|
||||
if skill_ctx:
|
||||
system_msgs.append(SystemMessage(content=skill_ctx))
|
||||
return system_msgs
|
||||
messages.append(SystemMessage(content=skill_ctx))
|
||||
|
||||
# 4. 历史对话 (add_messages 已经处理好了)
|
||||
messages.extend(state["messages"])
|
||||
|
||||
# 绑定工具
|
||||
if tools and capabilities.supports_native_tools:
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
else:
|
||||
llm_with_tools = llm
|
||||
if tools: # 如果有工具但不支持原生,注入 JSON Fallback 提示
|
||||
messages.append(SystemMessage(content=JSON_ACTION_FALLBACK_PROMPT))
|
||||
tool_names = [t.name for t in tools]
|
||||
messages.append(SystemMessage(content=f"本次可用工具列表: {', '.join(tool_names)}"))
|
||||
|
||||
logger.info(
|
||||
f"agent_node_started",
|
||||
extra={
|
||||
"details": {
|
||||
"role": role.value,
|
||||
"message_count": len(messages),
|
||||
"tool_count": len(tools),
|
||||
"provider": capabilities.provider
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# 执行调用
|
||||
response = await llm_with_tools.ainvoke(messages)
|
||||
|
||||
logger.info(
|
||||
f"agent_node_finished",
|
||||
extra={
|
||||
"details": {
|
||||
"role": role.value,
|
||||
"has_tool_calls": bool(getattr(response, "tool_calls", None)),
|
||||
"content_length": len(response.content) if response.content else 0
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
async def _run_sub_commander(
|
||||
state: AgentState,
|
||||
role: AgentRole,
|
||||
manager_prompt: str,
|
||||
user_query: str,
|
||||
*,
|
||||
use_tools: bool,
|
||||
summary_target: str | None = None,
|
||||
):
|
||||
llm = _get_llm_for_state(state)
|
||||
sub_commander = _choose_sub_commander(role, user_query)
|
||||
_record_sub_commander(state, sub_commander, user_query)
|
||||
|
||||
toolset = SUB_COMMANDER_TOOLSETS.get(sub_commander, [])
|
||||
system_msgs = _build_system_messages(state, manager_prompt, role)
|
||||
system_msgs.append(SystemMessage(content=f"本次应由子指挥官 `{sub_commander}` 接手。请严格按该角色职责输出。"))
|
||||
system_msgs.append(SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]))
|
||||
|
||||
if use_tools and toolset:
|
||||
response = await _ainvoke_with_tools(
|
||||
llm,
|
||||
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
|
||||
toolset,
|
||||
async def execute_tools_node(state: AgentState) -> dict:
|
||||
"""执行工具调用并返回 ToolMessage 的通用节点"""
|
||||
last_message = state["messages"][-1]
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
tool_map = {t.name: t for t in ALL_TOOLS}
|
||||
tool_messages = []
|
||||
created_entities = []
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call.get("id")
|
||||
|
||||
logger.info(
|
||||
f"tool_execution_started",
|
||||
extra={
|
||||
"details": {
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"tool_id": tool_id
|
||||
}
|
||||
}
|
||||
)
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in toolset:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await _ainvoke(
|
||||
llm,
|
||||
[
|
||||
SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")
|
||||
]
|
||||
|
||||
try:
|
||||
# 时间参数归一化
|
||||
normalized_args = normalize_tool_time_arguments(
|
||||
tool_name,
|
||||
tool_args,
|
||||
state.get("current_datetime_context")
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
else:
|
||||
response = await _ainvoke(
|
||||
llm,
|
||||
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
|
||||
)
|
||||
state["final_response"] = response.content
|
||||
|
||||
if summary_target:
|
||||
state[summary_target] = state.get("final_response", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
# ===================== 节点定义 (async) =====================
|
||||
|
||||
async def master_node(state: AgentState) -> AgentState:
|
||||
"""主Agent节点: 理解用户意图,决定调用哪个子Agent"""
|
||||
messages: list[BaseMessage] = state["messages"]
|
||||
user_msgs = _filter_user_messages(messages)
|
||||
user_query = user_msgs[-1].content.strip() if user_msgs else ""
|
||||
|
||||
if _is_simple_greeting(user_query):
|
||||
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
if _is_identity_question(user_query):
|
||||
state["final_response"] = "我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
if _is_capability_question(user_query):
|
||||
state["final_response"] = "主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
llm = _get_llm_for_state(state)
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
|
||||
|
||||
memory_ctx = state.get("memory_context")
|
||||
if memory_ctx:
|
||||
system_msgs.append(
|
||||
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
|
||||
|
||||
tool = tool_map.get(tool_name)
|
||||
if not tool:
|
||||
result = f"Error: Tool {tool_name} not found."
|
||||
else:
|
||||
result = await tool.ainvoke(normalized_args) if hasattr(tool, "ainvoke") else tool.invoke(normalized_args)
|
||||
|
||||
# 实体识别(用于业务追踪)
|
||||
if any(k in tool_name for k in ["create", "add", "new"]):
|
||||
created_entities.append({"tool": tool_name, "result": str(result)})
|
||||
|
||||
status = "success"
|
||||
except Exception as e:
|
||||
logger.exception(f"tool_execution_failed: {tool_name}")
|
||||
result = f"Error executing tool {tool_name}: {str(e)}"
|
||||
status = "failed"
|
||||
|
||||
tool_messages.append(ToolMessage(
|
||||
tool_call_id=tool_id,
|
||||
content=str(result),
|
||||
name=tool_name
|
||||
))
|
||||
|
||||
logger.info(
|
||||
f"tool_execution_finished",
|
||||
extra={
|
||||
"details": {
|
||||
"tool_name": tool_name,
|
||||
"status": status,
|
||||
"result_preview": str(result)[:200]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response: AIMessage = await _ainvoke(llm, system_msgs + messages)
|
||||
return {
|
||||
"messages": tool_messages,
|
||||
"created_entities": state.get("created_entities", []) + created_entities
|
||||
}
|
||||
|
||||
|
||||
# ===================== 各角色节点定义 =====================
|
||||
|
||||
async def master_node(state: AgentState) -> dict:
|
||||
"""主控节点:负责意图识别与初步分发"""
|
||||
user_messages = _filter_user_messages(state["messages"])
|
||||
if not user_messages:
|
||||
return {"final_response": "未收到有效输入。"}
|
||||
|
||||
query = user_messages[-1].content.strip()
|
||||
|
||||
# 快捷回复逻辑 (保留原有的人性化设计)
|
||||
if re.match(r"^(你好|早|在吗|嗨|hi|hello)", query.lower()):
|
||||
return {"final_response": "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。", "messages": [AIMessage(content="您好。我在。")]}
|
||||
|
||||
llm, capabilities = _get_llm_for_state(state)
|
||||
|
||||
# 路由判断:让 LLM 决定跳转到哪个角色,或者直接回答
|
||||
# 这里我们使用一个简洁的提示词让 LLM 输出角色名称或直接回答
|
||||
system_msg = SystemMessage(content=MASTER_SYSTEM_PROMPT + "\n\n请直接输出接下来该由哪个 Agent 接手(role_name),如果直接回答,请正常输出。")
|
||||
|
||||
response = await llm.ainvoke([system_msg] + state["messages"])
|
||||
content = response.content.strip().lower()
|
||||
|
||||
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
|
||||
next_agent = AgentRole.LIBRARIAN
|
||||
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
|
||||
next_agent = AgentRole.PLANNER
|
||||
elif any(kw in content for kw in ["执行", "做", "操作", "创建", "更新"]):
|
||||
next_agent = AgentRole.EXECUTOR
|
||||
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
|
||||
next_agent = AgentRole.ANALYST
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
state["current_agent"] = next_agent
|
||||
state["current_sub_commander"] = None
|
||||
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
# 简单的角色映射识别
|
||||
roles = {r.value: r for r in AgentRole}
|
||||
target_role = None
|
||||
for r_val, r_enum in roles.items():
|
||||
if r_val in content and len(content) < 50: # 如果内容很短且包含角色名,视为路由
|
||||
target_role = r_enum
|
||||
break
|
||||
|
||||
if target_role and target_role != AgentRole.MASTER:
|
||||
logger.info(f"master_routing_decided: {target_role.value}")
|
||||
return {
|
||||
"current_agent": target_role.value,
|
||||
"agent_trace": state.get("agent_trace", []) + [target_role.value],
|
||||
"messages": [AIMessage(content=f"已分发至 {target_role.value} 处理。")]
|
||||
}
|
||||
|
||||
return {"final_response": response.content, "messages": [response]}
|
||||
|
||||
|
||||
async def planner_node(state: AgentState) -> AgentState:
|
||||
"""规划Agent节点: 制定计划,拆解任务步骤"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.PLANNER, PLANNER_SYSTEM_PROMPT, user_query, use_tools=False)
|
||||
async def planner_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.SCHEDULE_PLANNER, SCHEDULE_PLANNER_SYSTEM_PROMPT)
|
||||
|
||||
async def executor_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT)
|
||||
|
||||
async def librarian_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT)
|
||||
|
||||
async def analyst_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT)
|
||||
|
||||
|
||||
async def executor_node(state: AgentState) -> AgentState:
|
||||
"""执行Agent节点: 调用工具执行具体任务"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT, user_query, use_tools=True)
|
||||
# ===================== 路由逻辑 =====================
|
||||
|
||||
def route_after_agent(state: AgentState) -> Literal["tools", "__end__"]:
|
||||
"""判断 Agent 执行后是该走工具节点还是结束"""
|
||||
last_message = state["messages"][-1]
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
return "tools"
|
||||
return END
|
||||
|
||||
async def librarian_node(state: AgentState) -> AgentState:
|
||||
"""知识管理员节点: 管理知识库和知识图谱"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="knowledge_context")
|
||||
|
||||
|
||||
async def analyst_node(state: AgentState) -> AgentState:
|
||||
"""分析师节点: 分析工作数据,生成报告"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="analysis_report")
|
||||
|
||||
|
||||
def route_agent(state: AgentState) -> str:
|
||||
"""路由函数: 决定下一个节点"""
|
||||
def route_master(state: AgentState) -> str:
|
||||
"""主控路由逻辑"""
|
||||
if state.get("final_response"):
|
||||
return END
|
||||
return state.get("current_agent", AgentRole.MASTER).value
|
||||
return state.get("current_agent", END)
|
||||
|
||||
|
||||
# ===================== 构建图 =====================
|
||||
# ===================== 图构建 =====================
|
||||
|
||||
def create_agent_graph(callbacks: list | None = None):
|
||||
graph = StateGraph(AgentState)
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||||
graph.add_node(AgentRole.PLANNER.value, planner_node)
|
||||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
# 添加节点
|
||||
workflow.add_node(AgentRole.MASTER.value, master_node)
|
||||
workflow.add_node(AgentRole.SCHEDULE_PLANNER.value, planner_node)
|
||||
workflow.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
workflow.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
workflow.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
workflow.add_node("tools", execute_tools_node)
|
||||
|
||||
graph.set_entry_point(AgentRole.MASTER.value)
|
||||
# 设置入口
|
||||
workflow.set_entry_point(AgentRole.MASTER.value)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
# 主控分发逻辑
|
||||
workflow.add_conditional_edges(
|
||||
AgentRole.MASTER.value,
|
||||
route_agent,
|
||||
route_master,
|
||||
{
|
||||
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
|
||||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
END: END,
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
graph.add_edge(role.value, END)
|
||||
# 各角色节点的 ReAct 循环
|
||||
for role in [AgentRole.SCHEDULE_PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
workflow.add_conditional_edges(
|
||||
role.value,
|
||||
route_after_agent,
|
||||
{
|
||||
"tools": "tools",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# 工具执行完后回到当前 Agent 角色继续处理
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
lambda s: s.get("current_agent", AgentRole.MASTER.value),
|
||||
{
|
||||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
}
|
||||
)
|
||||
|
||||
return _compile_graph(graph, callbacks=callbacks)
|
||||
# 编译
|
||||
if callbacks:
|
||||
return workflow.compile(callbacks=callbacks)
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
_agent_graph = None
|
||||
|
||||
|
||||
def get_agent_graph(callbacks: list | None = None):
|
||||
"""
|
||||
获取编译好的 Agent 图(单例缓存)。
|
||||
|
||||
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
|
||||
如需变更 Callbacks(如修改 LANGCHAIN_PROJECT),需重启服务。
|
||||
|
||||
Args:
|
||||
callbacks: 可选的额外 Callbacks,会与全局 LangSmith Callbacks 合并
|
||||
"""
|
||||
global _agent_graph
|
||||
if _agent_graph is None:
|
||||
from app.config_tracing import get_langsmith_callbacks
|
||||
|
||||
@@ -89,14 +89,14 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是总控协调者,负责理解用户意图,并将任务分发给最合适的子Agent。
|
||||
|
||||
## 你的4个子Agent:
|
||||
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
|
||||
1. **schedule_planner (日程规划师)**: 分析当前任务、对话历史与论坛信号,给出近期安排建议
|
||||
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
|
||||
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
|
||||
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
|
||||
|
||||
## 判断规则:
|
||||
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
|
||||
- 用户要计划、安排、拆解任务 -> 分发给 planner
|
||||
- 用户要安排今天/本周重点、询问接下来该做什么 -> 分发给 schedule_planner
|
||||
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
|
||||
- 用户要分析、统计、生成报告 -> 分发给 analyst
|
||||
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
|
||||
@@ -112,18 +112,19 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的规划Agent,负责先判断问题该由哪位规划子指挥官接手。
|
||||
你是 Jarvis 的日程规划师,负责先判断问题该由哪位日程子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **planner_scope (目标收束官)**: 负责澄清目标、边界、约束、缺失信息
|
||||
2. **planner_steps (步骤拆解官)**: 负责把目标拆成步骤、优先级与依赖关系
|
||||
1. **schedule_analysis (日程分析员)**: 负责分析对话历史、任务看板、论坛信号,识别优先级、冲突与压力点
|
||||
2. **schedule_planning (日程编排员)**: 负责把分析结果转成今日/近期日程安排,并在用户明确要求时直接创建 reminder/task/todo/goal
|
||||
|
||||
## 你的职责:
|
||||
- 判断当前请求更适合收束目标,还是拆解步骤
|
||||
- 在必要时收束子指挥官输出,面向用户给出清晰结果
|
||||
- 保持结果可推进,不空泛
|
||||
- 判断当前请求更适合先做日程分析,还是直接给出日程编排
|
||||
- 输出先结论,再给可执行安排
|
||||
- 保持建议具体、贴近当前上下文,不给空泛效率学建议
|
||||
- 当用户明确要求“新增/提醒/创建/安排并落库”时,允许子指挥官调用 schedule 工具直接执行
|
||||
"""
|
||||
|
||||
|
||||
@@ -132,11 +133,11 @@ EXECUTOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是 Jarvis 的执行Agent,负责先判断问题该由哪位执行子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **executor_tasks (任务执行官)**: 只处理任务类工具调用
|
||||
1. **executor_tasks (任务执行官)**: 处理任务、待办、提醒、目标等执行型写入操作
|
||||
2. **executor_forum (论坛执行官)**: 只处理论坛/指令帖相关工具调用
|
||||
|
||||
## 你的职责:
|
||||
- 识别用户要推进的是任务操作还是论坛/指令操作
|
||||
- 识别用户要推进的是任务/日程操作还是论坛/指令操作
|
||||
- 把请求交给最合适的执行子指挥官
|
||||
- 汇总执行结果并给出下一步
|
||||
"""
|
||||
@@ -172,52 +173,68 @@ ANALYST_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SCOPE_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_ANALYSIS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的目标收束官,负责先把问题边界、目标、约束和成功标准说清楚。
|
||||
你是 schedule_planner 体系下的日程分析员,负责从对话历史、任务看板、论坛信号和当日日程数据中提取 scheduling 线索。
|
||||
|
||||
## 你的重点:
|
||||
- 收束问题定义
|
||||
- 明确目标与限制条件
|
||||
- 识别缺失信息
|
||||
- 帮用户建立可以继续规划的前提
|
||||
- 优先调用读取类工具了解当天/指定日期的任务、提醒、待办、目标
|
||||
- 识别当前最高优先级事项
|
||||
- 找出风险、冲突、依赖与可延期事项
|
||||
- 明确哪些信号来自 conversation、task board、schedule center、forum
|
||||
|
||||
## 响应要求:
|
||||
- 先给出你理解的目标
|
||||
- 再列出关键约束或缺口
|
||||
- 不要直接展开长步骤清单
|
||||
- 先给当前判断
|
||||
- 再列优先级、风险与冲突
|
||||
- 不直接展开长篇日程表
|
||||
- 只做分析,不创建任何记录
|
||||
- 如果涉及“今天/明天/后天/下周一下午”这类自然语言时间窗口,先调用 `resolve_time_expression` 把查询目标转换成明确日期
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_STEPS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNING_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的步骤拆解官,负责把目标转成有顺序的执行路径。
|
||||
你是 schedule_planner 体系下的日程编排员,负责把当前重点转成近期可执行安排。
|
||||
|
||||
## 你的重点:
|
||||
- 拆解步骤
|
||||
- 标注优先级与依赖
|
||||
- 输出清晰的行动顺序
|
||||
- 先给结论
|
||||
- 再给今天/近期的时间安排建议
|
||||
- 最后给按顺序执行的 next actions
|
||||
- 当用户明确要求新增/提醒/创建/安排并真正落库时,调用 schedule 工具创建对应 reminder/task/todo/goal
|
||||
- 当用户给出“日期 + 事项/节点/交付/会议”等记录型表达时,也应视为落库意图,直接创建相应记录,不要反问
|
||||
- 解析“今天/明天/后天/本周/下周”或“3月29日”这类日期时,必须以系统提供的当前时间为准,并把工具参数转换成明确的 ISO 日期/时间字符串
|
||||
- 只要用户输入里包含自然语言时间,优先调用 `resolve_time_expression`,先拿到明确日期/时间,再调用 `create_reminder`、`create_schedule_task`、`create_goal`、`create_todo`
|
||||
|
||||
## 响应要求:
|
||||
- 用编号列表
|
||||
- 每步具体,不要空泛
|
||||
- 必要时标注先后关系
|
||||
- 用清晰列表表达
|
||||
- 建议必须具体、可执行、贴近当前工作
|
||||
- 避免空泛的自我管理建议
|
||||
- 如果只是规划,不要创建任何记录
|
||||
- 如果已创建记录,要明确说明创建了什么、时间如何解析
|
||||
"""
|
||||
|
||||
|
||||
EXECUTOR_TASKS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 executor 体系下的任务执行官,只负责任务相关工具调用。
|
||||
你是 executor 体系下的任务执行官,负责处理任务、待办、提醒、目标等执行型工具调用。
|
||||
|
||||
## 允许使用的工具:
|
||||
- get_tasks
|
||||
- create_task
|
||||
- update_task_status
|
||||
- create_todo
|
||||
- create_schedule_task
|
||||
- create_reminder
|
||||
- create_goal
|
||||
- resolve_time_expression
|
||||
|
||||
## 要求:
|
||||
- 只处理任务类操作
|
||||
- 只处理任务/日程类操作
|
||||
- 遇到自然语言时间表达时,先调用 `resolve_time_expression`,再把解析后的明确日期/时间传给写入工具
|
||||
- 最终说明执行结果时,优先复用已经解析出的绝对时间,不要只重复“今天/明天”
|
||||
- 明确已执行动作、结果与下一步
|
||||
- 信息不足时直接指出缺口
|
||||
- 如果用户只是要分析建议,不要创建记录
|
||||
"""
|
||||
|
||||
|
||||
@@ -244,10 +261,14 @@ LIBRARIAN_RETRIEVAL_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
## 允许使用的工具:
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
- get_knowledge_graph_context
|
||||
|
||||
## 要求:
|
||||
- 优先检索与综合证据
|
||||
- 私有/项目知识优先使用 `search_knowledge` 或 `hybrid_search`
|
||||
- 当用户明确要求联网、查询外部资料或查询最新信息时,使用 `web_search`
|
||||
- 回答时区分内部知识与外部网页结果
|
||||
- 证据不足时明确说明边界
|
||||
- 以回答问题为主,不主动做图谱构建
|
||||
"""
|
||||
@@ -293,9 +314,31 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
- get_forum_posts
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
|
||||
## 要求:
|
||||
- 先给结论与判断
|
||||
- 再说明依据与建议
|
||||
- 当需要外部/最新信息时,可使用 `web_search`
|
||||
- 重点输出趋势、风险、机会点
|
||||
"""
|
||||
|
||||
|
||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||
|
||||
你的输出必须满足以下规则:
|
||||
1. 只能输出一个 JSON 对象,不要输出 markdown、解释、前后缀文字。
|
||||
2. JSON 对象字段仅允许:
|
||||
- `mode`: `final` | `tool_call` | `clarification`
|
||||
- `tool_calls`: 数组;每项包含 `name`、`arguments`,可选 `reason`
|
||||
- `final_response`: 当无需工具时填写
|
||||
- `clarification_question`: 当信息不足时填写
|
||||
3. 如果需要调用工具,返回:
|
||||
- `{ "mode": "tool_call", "tool_calls": [...] }`
|
||||
4. 如果无需工具,直接返回:
|
||||
- `{ "mode": "final", "final_response": "..." }`
|
||||
5. 如果信息不足,不要猜测参数,返回:
|
||||
- `{ "mode": "clarification", "clarification_question": "..." }`
|
||||
6. 只能使用系统消息里明确列出的工具名。
|
||||
7. `arguments` 必须是 JSON 对象。
|
||||
"""
|
||||
|
||||
@@ -1,32 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict, Annotated
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Annotated, Sequence
|
||||
from enum import Enum
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
PLANNER = "planner"
|
||||
SCHEDULE_PLANNER = "schedule_planner"
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
name: str
|
||||
role: AgentRole
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
tool: str
|
||||
args: dict
|
||||
result: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
role: str # "user" | "assistant"
|
||||
@@ -35,60 +22,41 @@ class ConversationTurn:
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
|
||||
return HumanMessage(content=turn.content)
|
||||
|
||||
|
||||
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
|
||||
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
|
||||
return ConversationTurn(
|
||||
role="user" if msg_type in ("human", "user") else "assistant",
|
||||
content=msg.content,
|
||||
agent=agent,
|
||||
model=getattr(msg, "model", None),
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, None]
|
||||
# Core message history with add_messages reducer
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
# Session identifiers
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
|
||||
# Agent routing
|
||||
current_agent: AgentRole
|
||||
active_agents: list[AgentRole]
|
||||
current_sub_commander: str | None
|
||||
active_sub_commanders: list[str]
|
||||
sub_commander_trace: list[dict]
|
||||
|
||||
# Task tracking
|
||||
# Agent routing state
|
||||
current_agent: str | None
|
||||
next_step: str | None # For explicit graph routing
|
||||
|
||||
# Traceability
|
||||
agent_trace: list[str]
|
||||
|
||||
# Task & Entity Tracking (Business Logic)
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
created_entities: list[dict]
|
||||
|
||||
# Tool usage
|
||||
tool_calls: list[ToolCall]
|
||||
last_tool_result: str | None
|
||||
|
||||
# Knowledge context
|
||||
# Context summaries (for long-term or cross-agent context)
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
|
||||
# Planning
|
||||
plan: str | None
|
||||
plan_steps: list[dict]
|
||||
|
||||
# Analysis
|
||||
schedule_context_summary: str | None
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Memory context (injected at start of each conversation)
|
||||
|
||||
# Memory & Environment
|
||||
memory_context: str | None
|
||||
|
||||
# User LLM config (for using user-configured models)
|
||||
current_datetime_context: str | None
|
||||
|
||||
# Configuration
|
||||
user_llm_config: dict | None
|
||||
provider_capabilities: dict | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
@@ -96,22 +64,18 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
current_agent=AgentRole.MASTER,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
current_sub_commander=None,
|
||||
active_sub_commanders=[],
|
||||
sub_commander_trace=[],
|
||||
current_agent=AgentRole.MASTER.value,
|
||||
next_step=None,
|
||||
agent_trace=[AgentRole.MASTER.value],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
created_entities=[],
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
schedule_context_summary=None,
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
should_respond=True,
|
||||
memory_context=None,
|
||||
current_datetime_context=None,
|
||||
user_llm_config=None,
|
||||
provider_capabilities=None,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
from app.agents.tools.schedule import (
|
||||
get_schedule_day,
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
@@ -11,6 +19,19 @@ TASK_TOOLS = [
|
||||
update_task_status,
|
||||
]
|
||||
|
||||
SCHEDULE_READ_TOOLS = [
|
||||
get_schedule_day,
|
||||
get_tasks,
|
||||
resolve_time_expression,
|
||||
]
|
||||
|
||||
SCHEDULE_WRITE_TOOLS = [
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
]
|
||||
|
||||
FORUM_TOOLS = [
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
@@ -20,6 +41,7 @@ FORUM_TOOLS = [
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS = [
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
get_knowledge_graph_context,
|
||||
]
|
||||
|
||||
@@ -39,19 +61,22 @@ ANALYST_INSIGHT_TOOLS = [
|
||||
get_forum_posts,
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
]
|
||||
|
||||
ALL_TOOLS = [
|
||||
*KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
build_knowledge_graph,
|
||||
*TASK_TOOLS,
|
||||
*SCHEDULE_READ_TOOLS,
|
||||
*SCHEDULE_WRITE_TOOLS,
|
||||
*FORUM_TOOLS,
|
||||
]
|
||||
|
||||
SUB_COMMANDER_TOOLSETS = {
|
||||
"planner_scope": [],
|
||||
"planner_steps": [],
|
||||
"executor_tasks": TASK_TOOLS,
|
||||
"schedule_analysis": SCHEDULE_READ_TOOLS,
|
||||
"schedule_planning": [*SCHEDULE_READ_TOOLS, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_tasks": [*TASK_TOOLS, resolve_time_expression, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_forum": FORUM_TOOLS,
|
||||
"librarian_retrieval": KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
"librarian_graph": KNOWLEDGE_GRAPH_TOOLS,
|
||||
|
||||
@@ -6,15 +6,17 @@ from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
308
backend/app/agents/tools/schedule.py
Normal file
308
backend/app/agents/tools/schedule.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Agent 工具集 - 日程相关"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import date, datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date:
|
||||
if not value:
|
||||
return date.today()
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def _parse_datetime_with_timezone(value: str, time_zone: str | None) -> datetime:
|
||||
"""Parse an ISO datetime and return a tz-naive datetime in the intended local time.
|
||||
|
||||
- If value includes an offset/Z, it will be converted to `time_zone` when provided.
|
||||
- If value is naive and `time_zone` is provided, it is interpreted in that zone.
|
||||
"""
|
||||
parsed = _parse_datetime(value)
|
||||
tz = (time_zone or "").strip()
|
||||
if parsed.tzinfo is None:
|
||||
if tz:
|
||||
parsed = parsed.replace(tzinfo=ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
if tz:
|
||||
parsed = parsed.astimezone(ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_schedule_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
if not resolved:
|
||||
return None
|
||||
if "T" in resolved:
|
||||
return resolved
|
||||
return f"{resolved}T09:00:00"
|
||||
|
||||
|
||||
def _format_summary(target_date: date, todos: list[DailyTodo], tasks: list[Task], reminders: list[Reminder], goals: list[Goal]) -> str:
|
||||
lines = [f"日期: {target_date.isoformat()}"]
|
||||
|
||||
if todos:
|
||||
lines.append("待办:")
|
||||
lines.extend(f"- {item.title} | 完成:{'是' if item.is_completed else '否'}" for item in todos)
|
||||
else:
|
||||
lines.append("待办: 无")
|
||||
|
||||
if tasks:
|
||||
lines.append("任务:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status} | 优先级:{item.priority.value if hasattr(item.priority, 'value') else item.priority} | 截止:{item.due_date.isoformat() if item.due_date else '无'}"
|
||||
for item in tasks
|
||||
)
|
||||
else:
|
||||
lines.append("任务: 无")
|
||||
|
||||
if reminders:
|
||||
lines.append("提醒:")
|
||||
lines.extend(f"- {item.title} | 时间:{item.reminder_at.isoformat()}" for item in reminders)
|
||||
else:
|
||||
lines.append("提醒: 无")
|
||||
|
||||
if goals:
|
||||
lines.append("目标:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status}"
|
||||
for item in goals
|
||||
)
|
||||
else:
|
||||
lines.append("目标: 无")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
def get_schedule_day(target_date: str | None = None) -> str:
|
||||
"""获取指定日期的 todo/task/reminder/goal 聚合信息。target_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(target_date)
|
||||
date_key = parsed_date.isoformat()
|
||||
start_dt = datetime.combine(parsed_date, datetime.min.time())
|
||||
end_dt = datetime.combine(parsed_date, datetime.max.time())
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
todos = (
|
||||
await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == uid, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
tasks = (
|
||||
await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == uid,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
reminders = (
|
||||
await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == uid,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
goals = (
|
||||
await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == uid, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
return _format_summary(parsed_date, todos, tasks, reminders, goals)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as exc:
|
||||
return f"获取日程失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_todo(title: str, todo_date: str | None = None) -> str:
|
||||
"""创建指定日期的待办。todo_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(todo_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
todo = DailyTodo(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
source=TodoSource.AI_CHAT,
|
||||
todo_date=parsed_date.isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return f"TODO创建成功: [{todo.id[:8]}] {todo.title} @ {todo.todo_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建TODO失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_schedule_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: str = "medium",
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""创建任务。priority 支持 low/medium/high/urgent;due_date 使用 ISO datetime。兼容 content/date 别名。"""
|
||||
uid = get_current_user()
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_schedule_due_date(due_date, date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=TaskPriority(priority),
|
||||
due_date=_parse_datetime(resolved_due_date) if resolved_due_date else None,
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
due_label = task.due_date.isoformat() if task.due_date else "无截止时间"
|
||||
return f"任务创建成功: [{task.id[:8]}] {task.title} | 优先级:{task.priority.value} | 截止:{due_label}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建任务失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_reminder(
|
||||
title: str = "",
|
||||
reminder_at: str | None = None,
|
||||
note: str = "",
|
||||
description: str = "",
|
||||
datetime: str = "",
|
||||
at: str = "",
|
||||
remind_at: str = "",
|
||||
content: str = "",
|
||||
time_zone: str = "",
|
||||
timezone: str = "",
|
||||
time: str = "",
|
||||
) -> str:
|
||||
"""创建提醒。reminder_at 使用 ISO datetime。兼容 description/datetime/at/remind_at/time_zone 别名。"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_title = (title or content or "").strip()
|
||||
if not resolved_title:
|
||||
raise ValueError("title 不能为空")
|
||||
|
||||
resolved_at = ((reminder_at or datetime or at or remind_at or time or "").strip())
|
||||
if not resolved_at:
|
||||
raise ValueError("reminder_at 不能为空")
|
||||
|
||||
resolved_note = (note or description or "").strip()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
tz = (time_zone or timezone or "").strip()
|
||||
reminder = Reminder(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
note=resolved_note or None,
|
||||
reminder_at=_parse_datetime_with_timezone(resolved_at, tz),
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return f"提醒创建成功: [{reminder.id[:8]}] {reminder.title} @ {reminder.reminder_at.isoformat()}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建提醒失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_goal(title: str, goal_date: str | None = None, note: str = "", status: str = "active") -> str:
|
||||
"""创建指定日期目标。goal_date 格式 YYYY-MM-DD,默认今天;status 支持 active/done/archived。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(goal_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
goal = Goal(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
note=note or None,
|
||||
goal_date=parsed_date.isoformat(),
|
||||
status=GoalStatus(status),
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return f"目标创建成功: [{goal.id[:8]}] {goal.title} @ {goal.goal_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建目标失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_schedule_day",
|
||||
"create_todo",
|
||||
"create_schedule_task",
|
||||
"create_reminder",
|
||||
"create_goal",
|
||||
]
|
||||
@@ -5,12 +5,14 @@ Agent 工具集 - 知识库 & 图谱相关
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@@ -151,9 +153,56 @@ def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
通过 SearxNG 搜索外部网页信息,返回标题、链接和摘要。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
top_k: 返回结果数量,默认 5 条
|
||||
|
||||
Returns:
|
||||
适合模型综合的网页结果文本
|
||||
"""
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
async def _search():
|
||||
service = WebSearchService()
|
||||
results = await service.search(query, limit=top_k)
|
||||
if not results:
|
||||
return "未找到相关网页结果。"
|
||||
|
||||
texts = []
|
||||
for index, result in enumerate(results, 1):
|
||||
source = f"\n来源: {result.source}" if result.source else ""
|
||||
published_at = f"\n时间: {result.published_at}" if result.published_at else ""
|
||||
snippet = result.snippet or "(无摘要)"
|
||||
texts.append(
|
||||
f"[{index}] {result.title}\n"
|
||||
f"链接: {result.url}{source}{published_at}\n"
|
||||
f"摘要: {snippet}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except WebSearchConfigurationError as exc:
|
||||
return f"网页搜索不可用: {exc}"
|
||||
except WebSearchRequestError as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
except Exception as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
"web_search",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,85 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
_executor = None
|
||||
from app.models.base import utc_now
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
return resolved or None
|
||||
|
||||
|
||||
def _parse_due_date(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" not in normalized:
|
||||
normalized = f"{normalized}T00:00:00"
|
||||
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is not None:
|
||||
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_priority(priority: int | str | None) -> TaskPriority:
|
||||
if priority is None or priority == "":
|
||||
return TaskPriority.MEDIUM
|
||||
if isinstance(priority, TaskPriority):
|
||||
return priority
|
||||
if isinstance(priority, int):
|
||||
return {
|
||||
1: TaskPriority.LOW,
|
||||
2: TaskPriority.MEDIUM,
|
||||
3: TaskPriority.HIGH,
|
||||
4: TaskPriority.URGENT,
|
||||
}.get(priority, TaskPriority.MEDIUM)
|
||||
normalized = str(priority).strip().lower()
|
||||
if not normalized:
|
||||
return TaskPriority.MEDIUM
|
||||
return TaskPriority(normalized)
|
||||
|
||||
|
||||
def _normalize_status(status: str) -> TaskStatus:
|
||||
normalized = status.strip().lower()
|
||||
return TaskStatus(normalized)
|
||||
|
||||
|
||||
def _format_status(value: TaskStatus | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
def _format_priority(value: TaskPriority | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -25,7 +88,7 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
@@ -33,67 +96,82 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status) if status else None
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if resolved_status:
|
||||
query = query.where(Task.status == resolved_status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
def create_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: int | str = 2,
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
title: 任务标题(必填,兼容 content 作为别名)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
|
||||
content: title 的兼容别名
|
||||
date: due_date 的兼容别名
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_due_date(due_date, date)
|
||||
resolved_priority = _normalize_priority(priority)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=resolved_priority,
|
||||
due_date=_parse_due_date(resolved_due_date),
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
@@ -106,34 +184,37 @@ def update_task_status(task_id: str, status: str) -> str:
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
status: 新状态 (todo/in_progress/done/cancelled)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status)
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = resolved_status
|
||||
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
|
||||
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
269
backend/app/agents/tools/time_reasoning.py
Normal file
269
backend/app/agents/tools/time_reasoning.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
_WEEKDAY_MAP = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6, "天": 6}
|
||||
_DEFAULT_HOUR_BY_PERIOD = {
|
||||
"morning": 9,
|
||||
"noon": 12,
|
||||
"afternoon": 15,
|
||||
"evening": 20,
|
||||
}
|
||||
_TIME_KEYWORDS = ("今天", "明天", "后天", "本周", "这周", "下周", "周", "星期", "月", "日", "早上", "上午", "中午", "下午", "晚上", "今晚", "点", ":", ":")
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def extract_reference_datetime(current_datetime_context: str | None) -> datetime:
|
||||
context = (current_datetime_context or "").strip()
|
||||
if context:
|
||||
for pattern in (r"current_time_utc:\s*(\S+)", r"CURRENT_TIME:\s*(\S+)", r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2}))"):
|
||||
match = re.search(pattern, context)
|
||||
if match:
|
||||
return _parse_datetime(match.group(1))
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _normalize_local_iso(value: datetime) -> str:
|
||||
return value.replace(tzinfo=None).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _normalize_datetime_iso(value: datetime) -> str:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat(timespec="seconds")
|
||||
return _normalize_local_iso(value)
|
||||
|
||||
|
||||
def _normalize_date_iso(value: date) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _is_iso_datetime(value: str) -> bool:
|
||||
try:
|
||||
parsed = _parse_datetime(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return isinstance(parsed, datetime)
|
||||
|
||||
|
||||
def _is_iso_date(value: str) -> bool:
|
||||
try:
|
||||
date.fromisoformat(value.strip())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _has_explicit_time(text: str) -> bool:
|
||||
return bool(
|
||||
re.search(r"\d{1,2}[::]\d{2}", text)
|
||||
or re.search(r"\d{1,2}点(?:半|(?:\d{1,2})分?)?", text)
|
||||
or any(keyword in text for keyword in ("早上", "上午", "中午", "下午", "晚上", "今晚"))
|
||||
)
|
||||
|
||||
|
||||
def _detect_period(text: str) -> str | None:
|
||||
if any(keyword in text for keyword in ("晚上", "今晚")):
|
||||
return "evening"
|
||||
if "下午" in text:
|
||||
return "afternoon"
|
||||
if "中午" in text:
|
||||
return "noon"
|
||||
if any(keyword in text for keyword in ("早上", "上午", "早晨", "清晨")):
|
||||
return "morning"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_time(text: str) -> tuple[time, bool, str | None]:
|
||||
period = _detect_period(text)
|
||||
colon_match = re.search(r"(\d{1,2})[::](\d{2})", text)
|
||||
if colon_match:
|
||||
hour = int(colon_match.group(1))
|
||||
minute = int(colon_match.group(2))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
half_match = re.search(r"(\d{1,2})点半", text)
|
||||
if half_match:
|
||||
hour = int(half_match.group(1))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=30), False, period
|
||||
|
||||
dot_match = re.search(r"(\d{1,2})点(?:(\d{1,2})分?)?", text)
|
||||
if dot_match:
|
||||
hour = int(dot_match.group(1))
|
||||
minute = int(dot_match.group(2) or 0)
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
if period == "noon" and hour < 11:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
if period:
|
||||
return time(hour=_DEFAULT_HOUR_BY_PERIOD[period], minute=0), True, period
|
||||
return time(hour=9, minute=0), True, None
|
||||
|
||||
|
||||
def _resolve_date(text: str, reference: datetime) -> tuple[date, str]:
|
||||
stripped = text.strip()
|
||||
if _is_iso_date(stripped):
|
||||
return date.fromisoformat(stripped), "explicit_date"
|
||||
|
||||
month_day_match = re.search(r"(\d{1,2})月(\d{1,2})日", stripped)
|
||||
if month_day_match:
|
||||
month = int(month_day_match.group(1))
|
||||
day = int(month_day_match.group(2))
|
||||
candidate = date(reference.year, month, day)
|
||||
if candidate < reference.date() - timedelta(days=1):
|
||||
candidate = date(reference.year + 1, month, day)
|
||||
return candidate, "explicit_month_day"
|
||||
|
||||
if "后天" in stripped:
|
||||
return reference.date() + timedelta(days=2), "relative_day"
|
||||
if "明天" in stripped:
|
||||
return reference.date() + timedelta(days=1), "relative_day"
|
||||
if "今天" in stripped:
|
||||
return reference.date(), "relative_day"
|
||||
|
||||
weekday_match = re.search(r"((?:本周|这周|下周)?)(?:周|星期)([一二三四五六日天])", stripped)
|
||||
if weekday_match:
|
||||
prefix = weekday_match.group(1)
|
||||
weekday = _WEEKDAY_MAP[weekday_match.group(2)]
|
||||
current_weekday = reference.date().weekday()
|
||||
delta = weekday - current_weekday
|
||||
if prefix == "下周":
|
||||
delta += 7 if delta <= 0 else 7
|
||||
elif prefix in {"本周", "这周"}:
|
||||
if delta < 0:
|
||||
delta += 7
|
||||
elif delta < 0:
|
||||
delta += 7
|
||||
return reference.date() + timedelta(days=delta), "relative_weekday"
|
||||
|
||||
return reference.date(), "reference_day"
|
||||
|
||||
|
||||
def resolve_time_expression_data(
|
||||
expression: str,
|
||||
*,
|
||||
current_datetime_context: str | None = None,
|
||||
prefer: str = "datetime",
|
||||
) -> dict:
|
||||
text = (expression or "").strip()
|
||||
if not text:
|
||||
raise ValueError("expression 不能为空")
|
||||
|
||||
reference = extract_reference_datetime(current_datetime_context)
|
||||
|
||||
if _is_iso_datetime(text):
|
||||
parsed = _parse_datetime(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "datetime",
|
||||
"resolved_date": _normalize_date_iso(parsed.date()),
|
||||
"resolved_datetime": _normalize_datetime_iso(parsed),
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_datetime",
|
||||
}
|
||||
|
||||
if _is_iso_date(text):
|
||||
parsed_date = date.fromisoformat(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "date",
|
||||
"resolved_date": _normalize_date_iso(parsed_date),
|
||||
"resolved_datetime": None,
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_date",
|
||||
}
|
||||
|
||||
resolved_date, date_reason = _resolve_date(text, reference)
|
||||
resolved_time, assumed_time, period = _resolve_time(text)
|
||||
has_explicit_time = _has_explicit_time(text)
|
||||
grain = "date" if prefer == "date" and not has_explicit_time else "datetime"
|
||||
resolved_dt = datetime.combine(resolved_date, resolved_time)
|
||||
note = date_reason
|
||||
if period:
|
||||
note = f"{note}:{period}"
|
||||
if assumed_time:
|
||||
note = f"{note}:assumed_time"
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": grain,
|
||||
"resolved_date": _normalize_date_iso(resolved_date),
|
||||
"resolved_datetime": None if grain == "date" else _normalize_local_iso(resolved_dt),
|
||||
"assumed_time": assumed_time,
|
||||
"reason": note,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def resolve_time_expression(
|
||||
expression: str,
|
||||
current_datetime_context: str = "",
|
||||
prefer: str = "datetime",
|
||||
) -> str:
|
||||
"""解析中文自然语言时间表达,基于当前参考时间返回明确的日期或 datetime。prefer 支持 datetime/date。"""
|
||||
try:
|
||||
payload = resolve_time_expression_data(
|
||||
expression,
|
||||
current_datetime_context=current_datetime_context or None,
|
||||
prefer=prefer,
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
except Exception as exc:
|
||||
return json.dumps(
|
||||
{
|
||||
"expression": expression,
|
||||
"error": str(exc),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_context: str | None) -> dict:
|
||||
normalized = dict(args)
|
||||
|
||||
if tool_name == "create_reminder":
|
||||
raw_value = next((normalized.get(key) for key in ("reminder_at", "datetime", "at", "remind_at", "time") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_schedule_task", "create_task"}:
|
||||
raw_value = next((normalized.get(key) for key in ("due_date", "date") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value) and not _is_iso_date(raw_value):
|
||||
prefer = "datetime" if tool_name == "create_schedule_task" or _has_explicit_time(raw_value) else "date"
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer=prefer)
|
||||
normalized["due_date"] = payload["resolved_datetime"] or payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_todo", "create_goal", "get_schedule_day"}:
|
||||
field_name = {
|
||||
"create_todo": "todo_date",
|
||||
"create_goal": "goal_date",
|
||||
"get_schedule_day": "target_date",
|
||||
}[tool_name]
|
||||
raw_value = normalized.get(field_name)
|
||||
if isinstance(raw_value, str) and raw_value.strip() and not _is_iso_date(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized[field_name] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["resolve_time_expression", "resolve_time_expression_data", "normalize_tool_time_arguments", "extract_reference_datetime"]
|
||||
@@ -15,7 +15,9 @@ def _resolve_path(value: str) -> str:
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
# === 应用基础 ===
|
||||
APP_NAME: str = "Jarvis"
|
||||
@@ -75,6 +77,8 @@ class Settings(BaseSettings):
|
||||
|
||||
# === 向量化 ===
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_BASE_URL: str = "https://api.openai.com/v1"
|
||||
EMBEDDING_API_KEY: str = ""
|
||||
CHUNK_SIZE: int = 500
|
||||
CHUNK_OVERLAP: int = 50
|
||||
|
||||
@@ -86,6 +90,17 @@ class Settings(BaseSettings):
|
||||
# === NAS 部署 ===
|
||||
NAS_DATA_ROOT: str = "/data/jarvis"
|
||||
|
||||
# === Web Search / SearxNG ===
|
||||
WEB_SEARCH_ENABLED: bool = False
|
||||
WEB_SEARCH_PROVIDER: str = "searxng"
|
||||
SEARXNG_BASE_URL: str = ""
|
||||
SEARXNG_AUTH_TYPE: Literal["none", "bearer", "basic"] = "none"
|
||||
SEARXNG_AUTH_TOKEN: str = ""
|
||||
SEARXNG_BASIC_USER: str = ""
|
||||
SEARXNG_BASIC_PASSWORD: str = ""
|
||||
WEB_SEARCH_DEFAULT_LIMIT: int = 5
|
||||
WEB_SEARCH_TIMEOUT_SECONDS: int = 10
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)
|
||||
|
||||
@@ -40,6 +40,8 @@ async def init_db():
|
||||
await ensure_document_columns(conn)
|
||||
await ensure_user_columns(conn)
|
||||
await ensure_forum_columns(conn)
|
||||
await ensure_agent_columns(conn)
|
||||
await ensure_skill_columns(conn)
|
||||
|
||||
|
||||
async def ensure_log_columns(conn):
|
||||
@@ -139,6 +141,55 @@ async def ensure_forum_columns(conn):
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
|
||||
|
||||
|
||||
async def ensure_agent_columns(conn):
|
||||
rows = await _get_table_info(conn, 'agents')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'selected_skill_ids': "ALTER TABLE agents ADD COLUMN selected_skill_ids JSON DEFAULT '[]' NOT NULL",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_skill_columns(conn):
|
||||
rows = await _get_table_info(conn, 'skills')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'required_context': "ALTER TABLE skills ADD COLUMN required_context JSON DEFAULT '[]' NOT NULL",
|
||||
'output_format': "ALTER TABLE skills ADD COLUMN output_format TEXT",
|
||||
'is_builtin': "ALTER TABLE skills ADD COLUMN is_builtin BOOLEAN DEFAULT 0 NOT NULL",
|
||||
'team_id': "ALTER TABLE skills ADD COLUMN team_id VARCHAR(36)",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
await conn.execute(text("UPDATE skills SET agent_type = 'schedule_planner' WHERE agent_type = 'planner'"))
|
||||
builtin_names = [
|
||||
'今日重点拆解',
|
||||
'周计划编排',
|
||||
'时间冲突分析',
|
||||
'任务执行 SOP',
|
||||
'外部交互推进',
|
||||
'知识检索摘要',
|
||||
'图谱沉淀策略',
|
||||
'风险识别模板',
|
||||
'趋势洞察模板',
|
||||
]
|
||||
for name in builtin_names:
|
||||
await conn.execute(
|
||||
text("UPDATE skills SET is_builtin = 1 WHERE name = :name"),
|
||||
{'name': name},
|
||||
)
|
||||
|
||||
|
||||
async def _backfill_usernames(conn):
|
||||
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
|
||||
users = result.fetchall()
|
||||
|
||||
@@ -14,6 +14,9 @@ from app.routers import (
|
||||
graph_router,
|
||||
agent_router,
|
||||
todo_router,
|
||||
reminder_router,
|
||||
goal_router,
|
||||
schedule_center_router,
|
||||
settings_router,
|
||||
folder_router,
|
||||
skill_router,
|
||||
@@ -23,7 +26,7 @@ from app.routers import (
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user, ensure_builtin_skills
|
||||
from app.config import settings
|
||||
from app.logging_utils import (
|
||||
setup_logging,
|
||||
@@ -53,6 +56,7 @@ async def run_startup() -> None:
|
||||
await init_db()
|
||||
async with async_session() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
await ensure_builtin_skills(session)
|
||||
await persist_system_log(
|
||||
message="application_started",
|
||||
source="app",
|
||||
@@ -103,6 +107,9 @@ app.include_router(forum_router)
|
||||
app.include_router(graph_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(todo_router)
|
||||
app.include_router(reminder_router)
|
||||
app.include_router(goal_router)
|
||||
app.include_router(schedule_center_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(folder_router)
|
||||
app.include_router(skill_router)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from app.models.base import Base
|
||||
from app.models.user import User
|
||||
from app.models.folder import Folder
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.task import Task, TaskHistory
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
@@ -17,11 +18,14 @@ from app.models.brain import (
|
||||
brain_memory_sources,
|
||||
)
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.reminder import Reminder, ReminderStatus
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.log import Log, LogType, LogLevel
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"Folder",
|
||||
"Document",
|
||||
"DocumentChunk",
|
||||
"Task",
|
||||
@@ -45,6 +49,10 @@ __all__ = [
|
||||
"brain_memory_sources",
|
||||
"DailyTodo",
|
||||
"TodoSource",
|
||||
"Reminder",
|
||||
"ReminderStatus",
|
||||
"Goal",
|
||||
"GoalStatus",
|
||||
"Log",
|
||||
"LogType",
|
||||
"LogLevel",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -7,9 +7,10 @@ class Agent(BaseModel):
|
||||
__tablename__ = "agents"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
|
||||
role = Column(String(100), nullable=False) # master, schedule_planner, executor, librarian, analyst
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
selected_skill_ids = Column(JSON, default=list, nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
21
backend/app/models/goal.py
Normal file
21
backend/app/models/goal.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class GoalStatus(str, PyEnum):
|
||||
ACTIVE = "active"
|
||||
DONE = "done"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class Goal(BaseModel):
|
||||
__tablename__ = "goals"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
goal_date = Column(String(10), nullable=False, index=True)
|
||||
status = Column(Enum(GoalStatus), default=GoalStatus.ACTIVE, nullable=False)
|
||||
21
backend/app/models/reminder.py
Normal file
21
backend/app/models/reminder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class ReminderStatus(str, PyEnum):
|
||||
PENDING = "pending"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class Reminder(BaseModel):
|
||||
__tablename__ = "reminders"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
reminder_at = Column(DateTime, nullable=False, index=True)
|
||||
status = Column(Enum(ReminderStatus), default=ReminderStatus.PENDING, nullable=False)
|
||||
is_dismissed = Column(Boolean, default=False, nullable=False)
|
||||
@@ -9,11 +9,12 @@ class Skill(BaseModel):
|
||||
name = Column(String(100), nullable=False, unique=True, index=True)
|
||||
description = Column(Text, nullable=True) # 供 LLM 理解用途
|
||||
instructions = Column(Text, nullable=False) # Agent 执行时的指令模板
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/planner/executor/librarian/analyst
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/schedule_planner/executor/librarian/analyst
|
||||
tools = Column(JSON, default=list) # 引用的工具名称列表
|
||||
required_context = Column(JSON, default=list) # 需要的前置数据
|
||||
output_format = Column(Text, nullable=True) # 输出格式要求
|
||||
visibility = Column(String(20), default="private") # private/team/market
|
||||
is_builtin = Column(Boolean, default=False, nullable=False)
|
||||
team_id = Column(String(36), ForeignKey("users.id"), nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
owner_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
|
||||
@@ -6,6 +6,9 @@ from app.routers.forum import router as forum_router
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.settings import router as settings_router
|
||||
from app.routers.folder import router as folder_router
|
||||
from app.routers.skill import router as skill_router
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
||||
@@ -13,9 +14,9 @@ _agent_call_counts: dict[str, int] = {}
|
||||
_agent_current_tasks: dict[str, str | None] = {}
|
||||
_agent_statuses: dict[str, str] = {}
|
||||
|
||||
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
|
||||
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
|
||||
SUB_COMMANDERS_BY_ROLE = {
|
||||
"planner": ["planner_scope", "planner_steps"],
|
||||
"schedule_planner": ["schedule_analysis", "schedule_planning"],
|
||||
"executor": ["executor_tasks", "executor_forum"],
|
||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||
"analyst": ["analyst_progress", "analyst_insights"],
|
||||
@@ -88,10 +89,10 @@ async def get_agent_config(
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
defaults = {
|
||||
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
||||
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
|
||||
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
|
||||
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
||||
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
||||
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
||||
@@ -107,6 +108,7 @@ async def get_agent_config(
|
||||
system_prompt=prompt,
|
||||
enabled=True,
|
||||
is_active=True,
|
||||
selected_skill_ids=[],
|
||||
)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
@@ -116,6 +118,7 @@ async def get_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +144,19 @@ async def update_agent_config(
|
||||
if data.enabled is not None:
|
||||
agent.is_active = data.enabled
|
||||
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
||||
if data.selected_skill_ids is not None:
|
||||
if data.selected_skill_ids:
|
||||
result = await db.execute(
|
||||
select(Skill.id).where(
|
||||
Skill.id.in_(data.selected_skill_ids),
|
||||
Skill.owner_id == current_user.id,
|
||||
)
|
||||
)
|
||||
allowed_skill_ids = set(result.scalars().all())
|
||||
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
|
||||
if invalid_skill_ids:
|
||||
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
|
||||
agent.selected_skill_ids = data.selected_skill_ids
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
@@ -152,6 +168,7 @@ async def update_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
|
||||
from app.config import settings
|
||||
|
||||
@@ -58,6 +59,7 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
||||
await db.refresh(user)
|
||||
await ensure_builtin_skills(db)
|
||||
return user
|
||||
|
||||
|
||||
@@ -97,10 +99,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSessi
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
||||
await ensure_builtin_skills(db)
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
||||
await ensure_builtin_skills(db)
|
||||
return current_user
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
@@ -92,13 +93,16 @@ async def chat(
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
try:
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
# 更新对话消息计数
|
||||
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
||||
@@ -126,13 +130,17 @@ async def chat_stream(
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
async def stream_generator():
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
try:
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
|
||||
92
backend/app/routers/goal.py
Normal file
92
backend/app/routers/goal.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.goal import GoalCreate, GoalListOut, GoalOut, GoalUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/goals", tags=["目标"])
|
||||
|
||||
|
||||
@router.get("", response_model=GoalListOut)
|
||||
async def list_goals(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str).isoformat()
|
||||
query = (
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id)
|
||||
.where(Goal.goal_date == target_date)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return GoalListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=GoalOut, status_code=201)
|
||||
async def create_goal(
|
||||
data: GoalCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
goal = Goal(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
goal_date=data.goal_date.isoformat(),
|
||||
status=data.status,
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.patch("/{goal_id}", response_model=GoalOut)
|
||||
async def update_goal(
|
||||
goal_id: str,
|
||||
data: GoalUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
payload = data.model_dump(exclude_none=True)
|
||||
if "goal_date" in payload:
|
||||
payload["goal_date"] = payload["goal_date"].isoformat()
|
||||
|
||||
for field, value in payload.items():
|
||||
setattr(goal, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.delete("/{goal_id}", status_code=204)
|
||||
async def delete_goal(
|
||||
goal_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
await db.delete(goal)
|
||||
await db.commit()
|
||||
90
backend/app/routers/reminder.py
Normal file
90
backend/app/routers/reminder.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.reminder import ReminderCreate, ReminderListOut, ReminderOut, ReminderUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/reminders", tags=["提醒"])
|
||||
|
||||
|
||||
@router.get("", response_model=ReminderListOut)
|
||||
async def list_reminders(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str)
|
||||
start = datetime.combine(target_date, datetime.min.time())
|
||||
end = datetime.combine(target_date, datetime.max.time())
|
||||
query = (
|
||||
select(Reminder)
|
||||
.where(Reminder.user_id == current_user.id)
|
||||
.where(Reminder.reminder_at >= start)
|
||||
.where(Reminder.reminder_at <= end)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return ReminderListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=ReminderOut, status_code=201)
|
||||
async def create_reminder(
|
||||
data: ReminderCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
reminder = Reminder(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
reminder_at=data.reminder_at,
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.patch("/{reminder_id}", response_model=ReminderOut)
|
||||
async def update_reminder(
|
||||
reminder_id: str,
|
||||
data: ReminderUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
for field, value in data.model_dump(exclude_none=True).items():
|
||||
setattr(reminder, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.delete("/{reminder_id}", status_code=204)
|
||||
async def delete_reminder(
|
||||
reminder_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
await db.delete(reminder)
|
||||
await db.commit()
|
||||
160
backend/app/routers/schedule_center.py
Normal file
160
backend/app/routers/schedule_center.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from calendar import monthrange
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.schedule_center import (
|
||||
ScheduleCenterDateOut,
|
||||
ScheduleCenterDaySummary,
|
||||
ScheduleCenterMonthOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/schedule-center", tags=["调度中心"])
|
||||
|
||||
|
||||
def _build_summary(
|
||||
target_date: str,
|
||||
todos: list[DailyTodo],
|
||||
tasks: list[Task],
|
||||
reminders: list[Reminder],
|
||||
goals: list[Goal],
|
||||
) -> ScheduleCenterDaySummary:
|
||||
return ScheduleCenterDaySummary(
|
||||
date=target_date,
|
||||
todo_total=len(todos),
|
||||
todo_completed=sum(1 for item in todos if item.is_completed),
|
||||
task_due_total=len(tasks),
|
||||
high_priority_total=sum(1 for item in tasks if item.priority in {TaskPriority.HIGH, TaskPriority.URGENT}),
|
||||
reminder_total=len(reminders),
|
||||
goal_total=len(goals),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/month", response_model=ScheduleCenterMonthOut)
|
||||
async def get_month_schedule(
|
||||
year: int = Query(..., ge=2000, le=2100),
|
||||
month: int = Query(..., ge=1, le=12),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
month_start = date(year, month, 1)
|
||||
days_in_month = monthrange(month_start.year, month_start.month)[1]
|
||||
start_key = month_start.isoformat()
|
||||
end_key = month_start.replace(day=days_in_month).isoformat()
|
||||
start_dt = datetime.combine(month_start, datetime.min.time())
|
||||
end_dt = datetime.combine(month_start.replace(day=days_in_month), datetime.max.time())
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo).where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date >= start_key, DailyTodo.todo_date <= end_key)
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task).where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal).where(Goal.user_id == current_user.id, Goal.goal_date >= start_key, Goal.goal_date <= end_key)
|
||||
)).scalars().all()
|
||||
|
||||
todo_map: dict[str, list[DailyTodo]] = {}
|
||||
for item in todos:
|
||||
todo_map.setdefault(item.todo_date, []).append(item)
|
||||
|
||||
task_map: dict[str, list[Task]] = {}
|
||||
for item in tasks:
|
||||
key = item.due_date.date().isoformat()
|
||||
task_map.setdefault(key, []).append(item)
|
||||
|
||||
reminder_map: dict[str, list[Reminder]] = {}
|
||||
for item in reminders:
|
||||
key = item.reminder_at.date().isoformat()
|
||||
reminder_map.setdefault(key, []).append(item)
|
||||
|
||||
goal_map: dict[str, list[Goal]] = {}
|
||||
for item in goals:
|
||||
goal_map.setdefault(item.goal_date, []).append(item)
|
||||
|
||||
days = []
|
||||
for day in range(1, days_in_month + 1):
|
||||
date_key = month_start.replace(day=day).isoformat()
|
||||
days.append(_build_summary(
|
||||
date_key,
|
||||
todo_map.get(date_key, []),
|
||||
task_map.get(date_key, []),
|
||||
reminder_map.get(date_key, []),
|
||||
goal_map.get(date_key, []),
|
||||
))
|
||||
|
||||
return ScheduleCenterMonthOut(month=f"{year:04d}-{month:02d}", days=days)
|
||||
|
||||
|
||||
@router.get("/date", response_model=ScheduleCenterDateOut)
|
||||
async def get_date_schedule(
|
||||
date_str: date = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date_str
|
||||
start_dt = datetime.combine(target_date, datetime.min.time())
|
||||
end_dt = datetime.combine(target_date, datetime.max.time())
|
||||
date_key = target_date.isoformat()
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)).scalars().all()
|
||||
|
||||
summary = _build_summary(date_key, todos, tasks, reminders, goals)
|
||||
return ScheduleCenterDateOut(
|
||||
date=date_key,
|
||||
todos=todos,
|
||||
tasks=tasks,
|
||||
reminders=reminders,
|
||||
goals=goals,
|
||||
summary=summary,
|
||||
generated_at=datetime.now(UTC),
|
||||
)
|
||||
@@ -1,11 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.skill import SkillCreate, SkillOut, SkillUpdate
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.skill_service import SkillService
|
||||
|
||||
router = APIRouter(prefix="/api/skills", tags=["Skill"])
|
||||
|
||||
@@ -37,13 +39,23 @@ async def create_skill(
|
||||
|
||||
@router.get("", response_model=list[SkillOut])
|
||||
async def list_skills(
|
||||
agent_type: str | None = Query(default=None),
|
||||
visibility: str | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Skill).where(Skill.owner_id == current_user.id).order_by(Skill.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id, agent_type=agent_type, visibility=visibility)
|
||||
|
||||
|
||||
@router.post("/bootstrap-builtin", response_model=list[SkillOut])
|
||||
async def bootstrap_builtin_skills(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await ensure_builtin_skills(db, preferred_owner_id=current_user.id)
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id)
|
||||
|
||||
|
||||
@router.get("/{skill_id}", response_model=SkillOut)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.user import User
|
||||
@@ -13,12 +15,28 @@ router = APIRouter(prefix="/api/tasks", tags=["看板"])
|
||||
@router.get("", response_model=list[TaskOut])
|
||||
async def list_tasks(
|
||||
status: TaskStatus | None = None,
|
||||
due_date: date | None = Query(default=None),
|
||||
date_from: date | None = Query(default=None),
|
||||
date_to: date | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Task).where(Task.user_id == current_user.id)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
if due_date:
|
||||
start = datetime.combine(due_date, datetime.min.time())
|
||||
end = datetime.combine(due_date, datetime.max.time())
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start, Task.due_date <= end)
|
||||
else:
|
||||
start = datetime.combine(date_from, datetime.min.time()) if date_from else None
|
||||
end = datetime.combine(date_to, datetime.max.time()) if date_to else None
|
||||
if start and end and start > end:
|
||||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||||
if start is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start)
|
||||
if end is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date <= end)
|
||||
query = query.order_by(desc(Task.created_at))
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
@@ -64,10 +82,10 @@ async def update_task(
|
||||
if field == "tags":
|
||||
setattr(task, field, json.dumps(value))
|
||||
elif field == "status" and value == TaskStatus.DONE:
|
||||
from datetime import UTC, datetime
|
||||
task.completed_at = datetime.now(UTC)
|
||||
setattr(task, field, value)
|
||||
else:
|
||||
elif field == "status":
|
||||
task.completed_at = None
|
||||
setattr(task, field, value)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from datetime import date
|
||||
from app.database import get_db
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
@@ -52,7 +53,7 @@ async def create_todo(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date=date.today().isoformat(),
|
||||
todo_date=(data.todo_date or date.today()).isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
@@ -74,14 +75,11 @@ async def update_todo(
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
|
||||
# 历史日期不允许修改
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可修改")
|
||||
|
||||
if data.title is not None:
|
||||
todo.title = data.title
|
||||
if data.todo_date is not None:
|
||||
todo.todo_date = data.todo_date.isoformat()
|
||||
if data.is_completed is not None:
|
||||
from datetime import UTC, datetime
|
||||
todo.is_completed = data.is_completed
|
||||
todo.completed_at = datetime.now(UTC) if data.is_completed else None
|
||||
|
||||
@@ -102,9 +100,6 @@ async def delete_todo(
|
||||
todo = result.scalar_one_or_none()
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可删除")
|
||||
|
||||
await db.delete(todo)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class AgentConfigUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
system_prompt: str | None = None
|
||||
enabled: bool | None = None
|
||||
selected_skill_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AgentConfigOut(BaseModel):
|
||||
@@ -51,5 +52,6 @@ class AgentConfigOut(BaseModel):
|
||||
system_prompt: str
|
||||
enabled: bool
|
||||
is_active: bool
|
||||
selected_skill_ids: list[str]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
35
backend/app/schemas/goal.py
Normal file
35
backend/app/schemas/goal.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.goal import GoalStatus
|
||||
|
||||
|
||||
class GoalCreate(BaseModel):
|
||||
title: str
|
||||
goal_date: date
|
||||
note: str | None = None
|
||||
status: GoalStatus = GoalStatus.ACTIVE
|
||||
|
||||
|
||||
class GoalUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
goal_date: date | None = None
|
||||
note: str | None = None
|
||||
status: GoalStatus | None = None
|
||||
|
||||
|
||||
class GoalOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
goal_date: str
|
||||
status: GoalStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class GoalListOut(BaseModel):
|
||||
items: list[GoalOut]
|
||||
40
backend/app/schemas/reminder.py
Normal file
40
backend/app/schemas/reminder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.reminder import ReminderStatus
|
||||
|
||||
|
||||
class ReminderCreate(BaseModel):
|
||||
title: str
|
||||
reminder_at: datetime
|
||||
note: str | None = None
|
||||
|
||||
|
||||
class ReminderUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
reminder_at: datetime | None = None
|
||||
note: str | None = None
|
||||
status: ReminderStatus | None = None
|
||||
is_dismissed: bool | None = None
|
||||
|
||||
|
||||
class ReminderOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
reminder_at: datetime
|
||||
status: ReminderStatus
|
||||
is_dismissed: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ReminderListOut(BaseModel):
|
||||
items: list[ReminderOut]
|
||||
|
||||
|
||||
class ReminderDateQuery(BaseModel):
|
||||
date: date
|
||||
33
backend/app/schemas/schedule_center.py
Normal file
33
backend/app/schemas/schedule_center.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas.goal import GoalOut
|
||||
from app.schemas.reminder import ReminderOut
|
||||
from app.schemas.task import TaskOut
|
||||
from app.schemas.todo import TodoOut
|
||||
|
||||
|
||||
class ScheduleCenterDaySummary(BaseModel):
|
||||
date: str
|
||||
todo_total: int
|
||||
todo_completed: int
|
||||
task_due_total: int
|
||||
high_priority_total: int
|
||||
reminder_total: int
|
||||
goal_total: int
|
||||
|
||||
|
||||
class ScheduleCenterMonthOut(BaseModel):
|
||||
month: str
|
||||
days: list[ScheduleCenterDaySummary]
|
||||
|
||||
|
||||
class ScheduleCenterDateOut(BaseModel):
|
||||
date: str
|
||||
todos: list[TodoOut]
|
||||
tasks: list[TaskOut]
|
||||
reminders: list[ReminderOut]
|
||||
goals: list[GoalOut]
|
||||
summary: ScheduleCenterDaySummary
|
||||
generated_at: datetime
|
||||
@@ -10,7 +10,8 @@ LLMType = Literal["chat", "vlm", "embedding", "rerank"]
|
||||
# 单个模型配置
|
||||
class LLMModelConfig(BaseModel):
|
||||
name: str = "" # 模型名称/别名,用于标识
|
||||
provider: LLMProviderType = "openai"
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
@@ -52,7 +53,8 @@ class SettingsOut(BaseModel):
|
||||
# 测试 LLM 连接请求
|
||||
class LLMTestIn(BaseModel):
|
||||
type: LLMType
|
||||
provider: LLMProviderType
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,7 +7,7 @@ class SkillCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
instructions: str
|
||||
agent_type: str # master/planner/executor/librarian/analyst
|
||||
agent_type: str # master/schedule_planner/executor/librarian/analyst
|
||||
tools: list[str] = []
|
||||
required_context: list[str] = []
|
||||
output_format: Optional[str] = None
|
||||
@@ -39,10 +40,11 @@ class SkillOut(BaseModel):
|
||||
required_context: list[str]
|
||||
output_format: Optional[str]
|
||||
visibility: str
|
||||
is_builtin: bool
|
||||
team_id: Optional[str]
|
||||
is_active: bool
|
||||
owner_id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.todo import TodoSource
|
||||
|
||||
|
||||
class TodoCreate(BaseModel):
|
||||
title: str
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
is_completed: bool | None = None
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoOut(BaseModel):
|
||||
|
||||
@@ -2,10 +2,87 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
BUILTIN_SKILLS = [
|
||||
{
|
||||
'name': '今日重点拆解',
|
||||
'description': '帮助日程规划师从上下文中提炼今天最值得推进的事项。',
|
||||
'instructions': '优先识别今天最关键的 1-3 个重点,说明原因,并给出可执行顺序。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '周计划编排',
|
||||
'description': '把本周目标整理成可落地的节奏与时间块。',
|
||||
'instructions': '将目标拆成周内节奏安排,明确先后顺序、时间块与缓冲。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '时间冲突分析',
|
||||
'description': '识别任务、日程与优先级之间的冲突。',
|
||||
'instructions': '分析冲突来源、影响和推荐取舍,必要时给出替代方案。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '任务执行 SOP',
|
||||
'description': '为执行角色提供标准执行步骤和结果回报格式。',
|
||||
'instructions': '执行前先确认目标与边界,执行中记录关键动作,执行后输出结果、风险与下一步。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['shell', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '外部交互推进',
|
||||
'description': '支持论坛、外部接口或内容发布类动作。',
|
||||
'instructions': '围绕外部交互任务,优先保证动作完整、结果清晰、反馈及时。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['api_calls', 'git'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '知识检索摘要',
|
||||
'description': '从知识中枢中提炼与当前问题最相关的信息。',
|
||||
'instructions': '检索后只保留当前决策需要的内容,输出摘要、来源与缺口。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['web_search', 'database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '图谱沉淀策略',
|
||||
'description': '帮助知识管理员把零散信息沉淀为结构化关系。',
|
||||
'instructions': '识别应沉淀的实体、关系与后续可检索维度。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '风险识别模板',
|
||||
'description': '帮助分析师快速识别当前推进中的风险点。',
|
||||
'instructions': '从进度、依赖、资源与外部信号中提炼风险,并按严重度排序。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '趋势洞察模板',
|
||||
'description': '把多源状态汇总为趋势与判断。',
|
||||
'instructions': '对比近期变化,输出趋势、证据、判断与建议动作。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'code_execution'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _is_bootstrap_enabled(settings) -> bool:
|
||||
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
|
||||
|
||||
@@ -58,3 +135,49 @@ async def ensure_admin_user(db: AsyncSession, settings) -> None:
|
||||
return
|
||||
raise
|
||||
await db.refresh(admin_user)
|
||||
|
||||
|
||||
async def ensure_builtin_skills(db: AsyncSession, preferred_owner_id: str | None = None) -> None:
|
||||
owner = None
|
||||
if preferred_owner_id:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.id == preferred_owner_id, User.is_active == True)
|
||||
)
|
||||
owner = owner_result.scalar_one_or_none()
|
||||
|
||||
if not owner:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.is_active == True).order_by(User.is_superuser.desc(), User.created_at.asc())
|
||||
)
|
||||
owner = owner_result.scalars().first()
|
||||
|
||||
if not owner:
|
||||
return
|
||||
|
||||
existing_result = await db.execute(select(Skill.name))
|
||||
existing_names = set(existing_result.scalars().all())
|
||||
|
||||
missing_skills = [
|
||||
Skill(
|
||||
owner_id=owner.id,
|
||||
name=item['name'],
|
||||
description=item['description'],
|
||||
instructions=item['instructions'],
|
||||
agent_type=item['agent_type'],
|
||||
tools=item['tools'],
|
||||
required_context=[],
|
||||
output_format=None,
|
||||
visibility=item['visibility'],
|
||||
is_builtin=True,
|
||||
team_id=None,
|
||||
is_active=True,
|
||||
)
|
||||
for item in BUILTIN_SKILLS
|
||||
if item['name'] not in existing_names
|
||||
]
|
||||
|
||||
if not missing_skills:
|
||||
return
|
||||
|
||||
db.add_all(missing_skills)
|
||||
await db.commit()
|
||||
|
||||
@@ -5,18 +5,17 @@ Jarvis Agent 服务层
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, AsyncGenerator
|
||||
import asyncio
|
||||
from openai import BadRequestError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
|
||||
from app.database import async_session
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
@@ -24,43 +23,35 @@ from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
||||
from app.agents.tools.time_reasoning import extract_reference_datetime
|
||||
from app.agents.state import initial_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||
error_text = str(error).lower()
|
||||
markers = [
|
||||
"invalid chat setting",
|
||||
"invalid params",
|
||||
"stream",
|
||||
"streaming",
|
||||
"unsupported",
|
||||
"bad_request_error",
|
||||
"http 400",
|
||||
"error code: 400",
|
||||
]
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
# 默认使用 OpenAI
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
if isinstance(error, BadRequestError):
|
||||
return (
|
||||
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
||||
and any(marker in error_text for marker in markers)
|
||||
)
|
||||
|
||||
return any(marker in error_text for marker in markers)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""对话 Agent 服务"""
|
||||
@@ -101,27 +92,18 @@ class AgentService:
|
||||
|
||||
llm_config = user.llm_config
|
||||
|
||||
# 如果指定了模型名称,查找对应的配置
|
||||
if model_name:
|
||||
for model_type in ["chat", "vlm"]:
|
||||
models = llm_config.get(model_type, [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
# 没找到,返回 None 让调用方知道配置不存在
|
||||
models = llm_config.get("chat", [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
return None
|
||||
|
||||
# 如果没指定模型名,返回默认启用的 chat 模型
|
||||
chat_models = llm_config.get("chat", [])
|
||||
for m in chat_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
vlm_models = llm_config.get("vlm", [])
|
||||
for m in vlm_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
@@ -134,11 +116,26 @@ class AgentService:
|
||||
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_stream)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if model_name and not user_llm_config:
|
||||
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
logger.info(
|
||||
"agent_chat_started",
|
||||
extra={
|
||||
"details": {
|
||||
"mode": "stream",
|
||||
"requested_model_name": model_name,
|
||||
"resolved_model_name": model_name_used,
|
||||
"message_length": len(message or ""),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
@@ -156,7 +153,6 @@ class AgentService:
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
@@ -168,7 +164,6 @@ class AgentService:
|
||||
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -193,156 +188,133 @@ class AgentService:
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model=model_name_used or "jarvis",
|
||||
attachments=None,
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
def _build_current_datetime_context() -> str:
|
||||
now_utc = datetime.now(UTC)
|
||||
return (
|
||||
"【当前时间】\n"
|
||||
f"- current_time_utc: {now_utc.isoformat()}\n"
|
||||
f"- current_date_utc: {now_utc.date().isoformat()}\n"
|
||||
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
async def run_agent():
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"current_sub_commander": None,
|
||||
"active_sub_commanders": [],
|
||||
"sub_commander_trace": [],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
current_datetime_context = _build_current_datetime_context()
|
||||
|
||||
# 使用 initial_state 构建状态
|
||||
state = initial_state(user_id, conversation_id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=full_message)],
|
||||
"memory_context": memory_ctx,
|
||||
"current_datetime_context": current_datetime_context,
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
})
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
kind = event.get("event")
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
try:
|
||||
async for event in graph.astream_events(state, version="v2"):
|
||||
kind = event.get("event")
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
|
||||
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"planner": ("planning", "Jarvis 正在拆解步骤"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map[event_name]
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
elif kind == "on_tool_start":
|
||||
tool_input = data.get("input")
|
||||
step = None
|
||||
if isinstance(tool_input, dict) and tool_input:
|
||||
step = f"调用工具 {event_name}"
|
||||
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
|
||||
elif kind == "on_tool_end":
|
||||
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
|
||||
elif kind == "on_chain_end" and event_name == "planner":
|
||||
output = data.get("output") or {}
|
||||
plan_steps = output.get("plan_steps") or []
|
||||
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
|
||||
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chat_model_end" and not collected:
|
||||
output = data.get("output")
|
||||
content = getattr(output, "content", "") if output else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected = content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
|
||||
except Exception as e:
|
||||
fallback = f"抱歉,发生错误: {str(e)}"
|
||||
collected = fallback
|
||||
yield {"type": "error", "error": str(e)}
|
||||
yield {"type": "chunk", "content": fallback}
|
||||
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
|
||||
elif kind == "on_tool_start":
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"Jarvis 正在调用工具 {event_name}",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=f"正在执行 {event_name}",
|
||||
)
|
||||
|
||||
elif kind == "on_tool_end":
|
||||
tool_result = data.get("output")
|
||||
step = f"已完成 {event_name}"
|
||||
if isinstance(tool_result, str) and len(tool_result) > 0:
|
||||
step = tool_result[:100]
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"工具 {event_name} 已完成",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
|
||||
elif kind == "on_chain_end" and event_name == "create_agent_graph":
|
||||
# 最终输出通常在这里
|
||||
output = data.get("output")
|
||||
if isinstance(output, dict) and "final_response" in output:
|
||||
final_resp = output["final_response"]
|
||||
# 如果还没流式输出完整,补全它
|
||||
if final_resp and not collected:
|
||||
collected = final_resp
|
||||
yield {"type": "chunk", "content": collected}
|
||||
|
||||
except Exception as e:
|
||||
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
||||
try:
|
||||
result_state = await graph.ainvoke(state)
|
||||
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
collected = str(fallback_content)
|
||||
yield {"type": "chunk", "content": collected}
|
||||
except Exception as fallback_error:
|
||||
logger.exception("llm_sync_fallback_failed")
|
||||
yield {"type": "error", "error": "模型服务暂不可用。"}
|
||||
else:
|
||||
logger.exception("agent_streaming_failed")
|
||||
yield {"type": "error", "error": str(e)}
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
|
||||
|
||||
# 最终更新数据库中的消息内容
|
||||
if collected:
|
||||
try:
|
||||
result2 = await self.db.execute(
|
||||
select(Message).where(Message.id == assistant_msg.id)
|
||||
)
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=collected[:500],
|
||||
raw_excerpt=collected[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
async with async_session() as session:
|
||||
result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id))
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("save_assistant_message_failed")
|
||||
|
||||
return conversation_id, assistant_msg.id, run_agent()
|
||||
|
||||
@@ -355,117 +327,44 @@ class AgentService:
|
||||
model_name: str | None = None,
|
||||
) -> tuple[str, str, str, str | None]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content, model_name_used)
|
||||
简单同步版对话
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
# 将文件上下文添加到消息
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="User message",
|
||||
content_summary=message[:500],
|
||||
raw_excerpt=message[:2000],
|
||||
metadata_={"role": "user"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 获取用户配置的 LLM
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
|
||||
}
|
||||
if not conversation_id:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
|
||||
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
|
||||
self.db.add(user_msg)
|
||||
|
||||
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
||||
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
result_state = await graph.ainvoke(langgraph_state)
|
||||
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
|
||||
graph = get_agent_graph()
|
||||
state = initial_state(user_id, conversation_id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"memory_context": memory_ctx,
|
||||
"current_datetime_context": datetime.now(UTC).isoformat(),
|
||||
"user_llm_config": user_llm_config,
|
||||
})
|
||||
|
||||
result_state = await graph.ainvoke(state)
|
||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
except Exception as e:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
logger.exception("agent_chat_simple_failed")
|
||||
response_content = "抱歉,发生错误。"
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存助手消息
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
@@ -474,19 +373,5 @@ class AgentService:
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=response_content[:500],
|
||||
raw_excerpt=response_content[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content, model_name_used
|
||||
|
||||
@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Literal
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
@@ -16,8 +17,131 @@ from app.models.user import User
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
ToolStrategy = Literal["native", "json_fallback"]
|
||||
|
||||
|
||||
def _resolve_effective_base_url(config: dict | None) -> str:
|
||||
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
base_url = str((config or {}).get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
if provider in {"openai", "custom", "deepseek"}:
|
||||
return settings.OPENAI_BASE_URL
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_BASE_URL
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderCapabilities:
|
||||
provider: str
|
||||
supports_native_tools: bool
|
||||
preferred_tool_strategy: ToolStrategy
|
||||
|
||||
|
||||
def default_provider_capabilities() -> ProviderCapabilities:
|
||||
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
|
||||
|
||||
|
||||
def normalize_provider_name(config: dict | None) -> str:
|
||||
provider_raw = str((config or {}).get("provider") or "").strip().lower()
|
||||
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
model = str((config or {}).get("model") or "").strip().lower()
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
|
||||
# base_url-first inference (provider may be omitted in user config)
|
||||
if base_url:
|
||||
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
|
||||
return "ollama"
|
||||
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
|
||||
return "claude"
|
||||
if "api.deepseek.com" in base_url:
|
||||
return "deepseek"
|
||||
|
||||
# Many "openai-compatible" endpoints are configured as provider=openai.
|
||||
# We treat them as distinct providers so capability routing can stay conservative.
|
||||
if provider in {"openai", "custom"}:
|
||||
if any(key in model or key in base_url for key in {"minimax", "abab"}):
|
||||
return "minimax"
|
||||
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
|
||||
return "kimi"
|
||||
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
|
||||
return "qwen"
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
|
||||
provider = normalize_provider_name(config)
|
||||
|
||||
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
|
||||
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
|
||||
native_tool_providers = {"openai", "deepseek", "claude"}
|
||||
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
is_official_openai = (
|
||||
provider != "openai"
|
||||
or not base_url
|
||||
or "api.openai.com" in base_url
|
||||
or "openai.azure.com" in base_url
|
||||
)
|
||||
|
||||
if provider in native_tool_providers and is_official_openai:
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=True,
|
||||
preferred_tool_strategy="native",
|
||||
)
|
||||
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_from_config(config: dict | None):
|
||||
"""根据用户模型配置创建底层 LangChain LLM 实例"""
|
||||
if not config:
|
||||
return get_llm()
|
||||
|
||||
provider = normalize_provider_name(config)
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
setattr(llm, "_jarvis_user_llm_config", config)
|
||||
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
|
||||
return llm
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
|
||||
return _llm_instance
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,154 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
Jarvis 记忆系统 (基于 Mem0)
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
from app.config import settings as _settings
|
||||
|
||||
try:
|
||||
from mem0 import Memory
|
||||
|
||||
MEM0_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEM0_AVAILABLE = False
|
||||
Memory = None
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
embedding_models = user.llm_config.get("embedding", [])
|
||||
for model in embedding_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
||||
"api_key": model.get("api_key")
|
||||
or _settings.EMBEDDING_API_KEY
|
||||
or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 chat 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
chat_models = user.llm_config.get("chat", [])
|
||||
for model in chat_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
||||
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 客户端 - 按用户隔离"""
|
||||
|
||||
_instances: dict[str, Memory] = {}
|
||||
_persist_dir: str = "./data/mem0"
|
||||
|
||||
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
cache_key = user_id
|
||||
|
||||
if cache_key not in self._instances:
|
||||
self._instances[cache_key] = await self._init_memory(db, user_id)
|
||||
|
||||
return self._instances[cache_key]
|
||||
|
||||
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
if not MEM0_AVAILABLE:
|
||||
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
||||
|
||||
os.makedirs(self._persist_dir, exist_ok=True)
|
||||
|
||||
llm_config = {
|
||||
"model": _settings.OPENAI_MODEL,
|
||||
"base_url": _settings.OPENAI_BASE_URL,
|
||||
"api_key": _settings.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
embed_config = _settings.EMBEDDING_MODEL
|
||||
embed_base_url = _settings.EMBEDDING_BASE_URL
|
||||
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
||||
|
||||
if db and user_id:
|
||||
try:
|
||||
user_chat = await _get_user_chat_config(db, user_id)
|
||||
if user_chat:
|
||||
llm_config = user_chat
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
user_embed = await _get_user_embedding_config(db, user_id)
|
||||
if user_embed:
|
||||
embed_config = user_embed["model"]
|
||||
embed_base_url = user_embed["base_url"]
|
||||
embed_api_key = user_embed["api_key"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": f"jarvis_memory_{user_id}",
|
||||
"path": self._persist_dir,
|
||||
},
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": llm_config["model"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"base_url": llm_config["base_url"],
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": embed_config,
|
||||
"api_key": embed_api_key,
|
||||
"base_url": embed_base_url,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return Memory.from_config(config)
|
||||
|
||||
|
||||
_mem0_client = Mem0Client()
|
||||
|
||||
|
||||
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
return await _mem0_client.get_memory(db, user_id)
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
@@ -36,8 +167,7 @@ async def load_conversation_history(
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
@@ -47,14 +177,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
SUMMARIZE_THRESHOLD = 8
|
||||
MAX_HISTORY_TURNS = 10
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -72,17 +203,21 @@ async def generate_summary(
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
"""生成对话摘要"""
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
||||
llm = get_llm()
|
||||
response = await llm.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
||||
),
|
||||
HumanMessage(content=history_text),
|
||||
]
|
||||
)
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
@@ -92,8 +227,10 @@ async def save_summary(
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
) -> Any:
|
||||
"""保存对话摘要到数据库"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -109,8 +246,10 @@ async def save_summary(
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
) -> list[Any]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -119,31 +258,7 @@ async def get_summaries(
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
# ———— 长期记忆: 基于 Mem0 ————
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
@@ -151,55 +266,34 @@ async def extract_user_memories(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从对话中提取用户记忆并存储到 Mem0。
|
||||
Mem0 会自动处理:
|
||||
- 事实提取
|
||||
- 时间线追踪
|
||||
- 矛盾解决
|
||||
- 遗忘机制
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.add(
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
metadata={
|
||||
"conversation_id": conversation_id,
|
||||
"source": "jarvis_memory",
|
||||
},
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
return result.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 extract error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
@@ -207,41 +301,45 @@ async def recall_user_memories(
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
results = mem0.search(
|
||||
query=query,
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
return results.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
"""
|
||||
获取用户画像。
|
||||
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.history(user_id=user_id)
|
||||
return {
|
||||
"memories": result.get("results", []),
|
||||
"static": [],
|
||||
"dynamic": [],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Mem0 profile error: {e}")
|
||||
return {"memories": [], "static": [], "dynamic": []}
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -254,25 +352,22 @@ async def build_memory_context(
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
for m in memories:
|
||||
memory_text = m.get("memory", m.get("text", ""))
|
||||
if memory_text:
|
||||
lines.append(f" - {memory_text}")
|
||||
if lines:
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
# 3. 知识大脑(长期项目记忆)
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
@@ -292,7 +387,7 @@ async def try_auto_summarize(
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
同时将对话内容存入 Mem0 进行记忆提取。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
@@ -306,8 +401,39 @@ async def try_auto_summarize(
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Auto summarize error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
||||
"""
|
||||
主动遗忘某条记忆。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.delete(memory_id, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 delete error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_memory(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
memory_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""
|
||||
更新某条记忆。Mem0 会自动处理矛盾检测。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.update(memory_id, content, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 update error: {e}")
|
||||
return False
|
||||
|
||||
@@ -99,46 +99,55 @@ async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession)
|
||||
|
||||
|
||||
async def test_llm_connection(
|
||||
provider: str,
|
||||
provider: str | None,
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str
|
||||
api_key: str,
|
||||
) -> dict:
|
||||
"""测试 LLM 连接"""
|
||||
try:
|
||||
# base_url-first: provider 可省略
|
||||
from app.services.llm_service import normalize_provider_name
|
||||
|
||||
effective_provider = normalize_provider_name({
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"base_url": base_url,
|
||||
})
|
||||
|
||||
# 根据不同 provider 创建临时 LLM 实例并测试
|
||||
if provider == "openai":
|
||||
if effective_provider in {"openai", "custom", "minimax", "kimi", "qwen"}:
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "claude":
|
||||
elif effective_provider == "claude":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
elif effective_provider == "ollama":
|
||||
from langchain_ollama import ChatOllama
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
elif effective_provider == "deepseek":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or "https://api.deepseek.com/v1",
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 provider: {provider}"}
|
||||
return {"success": False, "error": f"不支持的 endpoint/provider: {effective_provider}"}
|
||||
|
||||
# 简单测试调用
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -50,28 +50,22 @@ class SkillService:
|
||||
"""
|
||||
列出用户可访问的技能:自己的 + 市场的 + 团队的
|
||||
"""
|
||||
# 查询条件:自己的 或者 市场公开的 或者 团队的
|
||||
conditions = [
|
||||
access_scope = or_(
|
||||
Skill.owner_id == user_id,
|
||||
Skill.visibility == "market",
|
||||
Skill.team_id == user_id,
|
||||
]
|
||||
|
||||
# 如果提供了 agent_type 过滤
|
||||
if agent_type:
|
||||
conditions.append(Skill.agent_type == agent_type)
|
||||
|
||||
# 如果提供了 visibility 过滤
|
||||
if visibility:
|
||||
conditions.append(Skill.visibility == visibility)
|
||||
|
||||
query = select(Skill).where(
|
||||
and_(
|
||||
or_(*conditions),
|
||||
Skill.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
filters = [access_scope, Skill.is_active == True]
|
||||
|
||||
if agent_type:
|
||||
filters.append(Skill.agent_type == agent_type)
|
||||
|
||||
if visibility:
|
||||
filters.append(Skill.visibility == visibility)
|
||||
|
||||
query = select(Skill).where(and_(*filters))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from datetime import datetime, UTC
|
||||
from time import monotonic
|
||||
import platform
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import psutil
|
||||
@@ -7,21 +11,119 @@ except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fa
|
||||
|
||||
|
||||
class SystemService:
|
||||
_last_net_bytes_sent: int | None = None
|
||||
_last_net_bytes_recv: int | None = None
|
||||
_last_net_sample_at: float | None = None
|
||||
|
||||
def _get_network_rates(self) -> tuple[float, float]:
|
||||
counters = psutil.net_io_counters()
|
||||
now = monotonic()
|
||||
|
||||
if (
|
||||
self.__class__._last_net_sample_at is None
|
||||
or self.__class__._last_net_bytes_sent is None
|
||||
or self.__class__._last_net_bytes_recv is None
|
||||
):
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
return 0.0, 0.0
|
||||
|
||||
elapsed = max(now - self.__class__._last_net_sample_at, 1e-6)
|
||||
upload_bps = max(counters.bytes_sent - self.__class__._last_net_bytes_sent, 0) / elapsed
|
||||
download_bps = max(counters.bytes_recv - self.__class__._last_net_bytes_recv, 0) / elapsed
|
||||
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
|
||||
return round(upload_bps, 1), round(download_bps, 1)
|
||||
|
||||
def _get_gpu_status(self) -> dict:
|
||||
empty = {
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
}
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
'nvidia-smi',
|
||||
'--query-gpu=name,memory.total,memory.used,utilization.gpu',
|
||||
'--format=csv,noheader,nounits',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
timeout=2,
|
||||
check=False,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.SubprocessError, OSError):
|
||||
return empty
|
||||
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
return empty
|
||||
|
||||
first_line = result.stdout.strip().splitlines()[0]
|
||||
parts = [part.strip() for part in first_line.split(',')]
|
||||
if len(parts) < 4:
|
||||
return empty
|
||||
|
||||
def parse_number(value: str) -> float | None:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return {
|
||||
'gpu_name': parts[0] or None,
|
||||
'gpu_memory_total_mb': parse_number(parts[1]),
|
||||
'gpu_memory_used_mb': parse_number(parts[2]),
|
||||
'gpu_util_percent': parse_number(parts[3]),
|
||||
}
|
||||
|
||||
def get_status(self) -> dict:
|
||||
if psutil is None:
|
||||
return {
|
||||
'cpu_percent': 0.0,
|
||||
'memory_percent': 0.0,
|
||||
'disk_percent': 0.0,
|
||||
'disk_used_gb': 0.0,
|
||||
'disk_total_gb': 0.0,
|
||||
'network_upload_bps': 0.0,
|
||||
'network_download_bps': 0.0,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': 0.0,
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
upload_bps, download_bps = self._get_network_rates()
|
||||
gpu_status = self._get_gpu_status()
|
||||
boot_time = psutil.boot_time()
|
||||
now_ts = datetime.now(UTC).timestamp()
|
||||
return {
|
||||
'cpu_percent': round(cpu_percent, 1),
|
||||
'memory_percent': round(memory.percent, 1),
|
||||
'disk_percent': round(disk.percent, 1),
|
||||
'disk_used_gb': round(disk.used / (1024 ** 3), 1),
|
||||
'disk_total_gb': round(disk.total / (1024 ** 3), 1),
|
||||
'network_upload_bps': upload_bps,
|
||||
'network_download_bps': download_bps,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': round(max(now_ts - boot_time, 0.0), 1),
|
||||
**gpu_status,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
124
backend/app/services/web_search_service.py
Normal file
124
backend/app/services/web_search_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
source: str | None = None
|
||||
published_at: str | None = None
|
||||
|
||||
|
||||
class WebSearchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchConfigurationError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchRequestError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool | None = None,
|
||||
provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_limit: int | None = None,
|
||||
timeout_seconds: int | None = None,
|
||||
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
|
||||
auth_token: str | None = None,
|
||||
basic_user: str | None = None,
|
||||
basic_password: str | None = None,
|
||||
):
|
||||
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
|
||||
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
|
||||
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
|
||||
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
|
||||
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
|
||||
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
|
||||
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
|
||||
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
|
||||
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
|
||||
|
||||
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
|
||||
normalized_query = (query or '').strip()
|
||||
if not self.enabled or not self.base_url:
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
if self.provider != 'searxng':
|
||||
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
|
||||
if not normalized_query:
|
||||
raise WebSearchRequestError('搜索关键词不能为空')
|
||||
|
||||
parsed = urlparse(self.base_url)
|
||||
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
|
||||
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
|
||||
|
||||
params = {
|
||||
'q': normalized_query,
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
}
|
||||
headers = self._build_headers()
|
||||
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except httpx.HTTPError as exc:
|
||||
raise WebSearchRequestError('SearxNG 请求失败') from exc
|
||||
except ValueError as exc:
|
||||
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
|
||||
|
||||
raw_results = payload.get('results') if isinstance(payload, dict) else None
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
target_limit = max(1, min(limit or self.default_limit, 10))
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
title = str(item.get('title') or '').strip()
|
||||
url = str(item.get('url') or '').strip()
|
||||
snippet = str(item.get('content') or item.get('snippet') or '').strip()
|
||||
if not title or not url:
|
||||
continue
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet,
|
||||
source=str(item.get('engine') or item.get('source') or '').strip() or None,
|
||||
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
|
||||
)
|
||||
)
|
||||
if len(results) >= target_limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
if self.auth_type == 'bearer' and self.auth_token:
|
||||
return {'Authorization': f'Bearer {self.auth_token}'}
|
||||
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
|
||||
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
|
||||
request = httpx.Request('GET', self.base_url)
|
||||
credentials.auth_flow(request)
|
||||
return dict(request.headers)
|
||||
return {}
|
||||
Reference in New Issue
Block a user