feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -1,34 +0,0 @@
|
||||
# =============================================
|
||||
# Jarvis 后端服务配置
|
||||
# 复制此文件为 .env 后按需修改
|
||||
# =============================================
|
||||
|
||||
# === 应用基础 ===
|
||||
DEBUG=false
|
||||
HOST=127.0.0.1
|
||||
PORT=3337
|
||||
SECRET_KEY=change-me-to-a-random-secret-key
|
||||
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
|
||||
|
||||
# === 数据存储 ===
|
||||
DATABASE_URL=sqlite+aiosqlite:///./data/jarvis.db
|
||||
DATA_DIR=./data
|
||||
CHROMA_PERSIST_DIR=./data/chroma
|
||||
UPLOAD_DIR=./data/uploads
|
||||
MAX_UPLOAD_SIZE=52428800
|
||||
# Supported values: ch | en
|
||||
MINERU_LANGUAGE=ch
|
||||
|
||||
# === JWT ===
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
||||
|
||||
# === 管理员账号 Bootstrap ===
|
||||
ADMIN=admin
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=
|
||||
ADMIN_FULL_NAME=Administrator
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED=true
|
||||
DAILY_PLAN_TIME=00:00
|
||||
FORUM_SCAN_INTERVAL_MINUTES=30
|
||||
@@ -1,397 +1,354 @@
|
||||
"""
|
||||
Jarvis LangGraph Agent 主图定义
|
||||
Jarvis LangGraph Agent 主图定义 - 优化重构版
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal, Union, List, Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
ToolMessage
|
||||
)
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
|
||||
from app.agents.state import AgentState, AgentRole
|
||||
from app.agents.prompts import (
|
||||
MASTER_SYSTEM_PROMPT,
|
||||
PLANNER_SYSTEM_PROMPT,
|
||||
SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||||
EXECUTOR_SYSTEM_PROMPT,
|
||||
LIBRARIAN_SYSTEM_PROMPT,
|
||||
ANALYST_SYSTEM_PROMPT,
|
||||
PLANNER_SCOPE_PROMPT,
|
||||
PLANNER_STEPS_PROMPT,
|
||||
EXECUTOR_TASKS_PROMPT,
|
||||
EXECUTOR_FORUM_PROMPT,
|
||||
LIBRARIAN_RETRIEVAL_PROMPT,
|
||||
LIBRARIAN_GRAPH_PROMPT,
|
||||
ANALYST_PROGRESS_PROMPT,
|
||||
ANALYST_INSIGHTS_PROMPT,
|
||||
JSON_ACTION_FALLBACK_PROMPT,
|
||||
)
|
||||
from app.agents.tools import ALL_TOOLS, SUB_COMMANDER_TOOLSETS
|
||||
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
|
||||
from app.agents.skill_registry import build_skill_context
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
from app.services.llm_service import (
|
||||
get_llm,
|
||||
create_llm_from_config,
|
||||
resolve_provider_capabilities,
|
||||
default_provider_capabilities
|
||||
)
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
logger = logging.getLogger("jarvis.agent")
|
||||
|
||||
SUB_COMMANDER_PROMPTS = {
|
||||
"planner_scope": PLANNER_SCOPE_PROMPT,
|
||||
"planner_steps": PLANNER_STEPS_PROMPT,
|
||||
"executor_tasks": EXECUTOR_TASKS_PROMPT,
|
||||
"executor_forum": EXECUTOR_FORUM_PROMPT,
|
||||
"librarian_retrieval": LIBRARIAN_RETRIEVAL_PROMPT,
|
||||
"librarian_graph": LIBRARIAN_GRAPH_PROMPT,
|
||||
"analyst_progress": ANALYST_PROGRESS_PROMPT,
|
||||
"analyst_insights": ANALYST_INSIGHTS_PROMPT,
|
||||
}
|
||||
|
||||
ROLE_SUB_COMMANDERS = {
|
||||
AgentRole.PLANNER: ["planner_scope", "planner_steps"],
|
||||
AgentRole.EXECUTOR: ["executor_tasks", "executor_forum"],
|
||||
AgentRole.LIBRARIAN: ["librarian_retrieval", "librarian_graph"],
|
||||
AgentRole.ANALYST: ["analyst_progress", "analyst_insights"],
|
||||
}
|
||||
|
||||
ROLE_SKILL_CONTEXT = {
|
||||
AgentRole.PLANNER: "planner",
|
||||
AgentRole.EXECUTOR: "executor",
|
||||
AgentRole.LIBRARIAN: "librarian",
|
||||
AgentRole.ANALYST: "analyst",
|
||||
}
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
# ===================== 工具辅助函数 =====================
|
||||
|
||||
def _get_llm_for_state(state: AgentState):
|
||||
"""从 state 获取 LLM 实例,优先使用用户配置的模型"""
|
||||
"""获取配置好的 LLM 实例"""
|
||||
user_llm_config = state.get("user_llm_config")
|
||||
if user_llm_config:
|
||||
return _create_llm_from_config(user_llm_config)
|
||||
return get_llm()
|
||||
llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm()
|
||||
|
||||
# 注入解析到的能力
|
||||
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
|
||||
if capabilities is None:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config) if user_llm_config else default_provider_capabilities()
|
||||
|
||||
state["provider_capabilities"] = {
|
||||
"provider": capabilities.provider,
|
||||
"supports_native_tools": capabilities.supports_native_tools,
|
||||
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
|
||||
}
|
||||
return llm, capabilities
|
||||
|
||||
|
||||
async def _ainvoke(llm, messages: list[BaseMessage]):
|
||||
ainvoke = getattr(llm, "ainvoke", None)
|
||||
if callable(ainvoke):
|
||||
return await ainvoke(messages)
|
||||
return await llm.invoke(messages)
|
||||
def _filter_user_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
return [m for m in messages if m.type in ("human", "user")]
|
||||
|
||||
|
||||
async def _ainvoke_with_tools(llm, messages: list[BaseMessage], tools=None):
|
||||
toolset = tools if tools is not None else ALL_TOOLS
|
||||
bound_llm = llm.bind_tools(toolset)
|
||||
if hasattr(bound_llm, "ainvoke"):
|
||||
return await bound_llm.ainvoke(messages)
|
||||
return await bound_llm.invoke(messages)
|
||||
|
||||
|
||||
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
|
||||
if callbacks:
|
||||
try:
|
||||
return graph.compile(callbacks=callbacks)
|
||||
except TypeError as exc:
|
||||
if "callbacks" not in str(exc):
|
||||
raise
|
||||
return graph.compile()
|
||||
|
||||
|
||||
def _msg_type(msg: BaseMessage) -> str:
|
||||
"""Get message type, handles both .type (new) and .role (old) attribute names."""
|
||||
return getattr(msg, "type", None) or getattr(msg, "role", "human")
|
||||
|
||||
|
||||
def _filter_user_messages(messages: list) -> list[BaseMessage]:
|
||||
return [m for m in messages if _msg_type(m) in ("human", "user")]
|
||||
|
||||
|
||||
def _normalize_user_text(text: str) -> str:
|
||||
return (text or "").strip().lower()
|
||||
|
||||
|
||||
def _is_simple_greeting(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你好", "您好", "早", "早上好", "在吗", "嗨", "hi", "hello"}
|
||||
|
||||
|
||||
def _is_identity_question(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你是谁", "你是誰"}
|
||||
|
||||
|
||||
def _is_capability_question(text: str) -> bool:
|
||||
normalized = _normalize_user_text(text)
|
||||
return normalized in {"你能做什么", "你可以做什么", "你会做什么"}
|
||||
|
||||
|
||||
def _choose_sub_commander(role: AgentRole, user_query: str) -> str:
|
||||
text = _normalize_user_text(user_query)
|
||||
|
||||
if role == AgentRole.PLANNER:
|
||||
if any(keyword in text for keyword in ["步骤", "计划", "拆解", "排期", "优先级", "路线"]):
|
||||
return "planner_steps"
|
||||
return "planner_scope"
|
||||
|
||||
def _get_role_tools(role: AgentRole) -> list:
|
||||
"""获取角色对应的所有可用工具集"""
|
||||
if role == AgentRole.SCHEDULE_PLANNER:
|
||||
# 合并分析和规划工具
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["schedule_analysis"] + SUB_COMMANDER_TOOLSETS["schedule_planning"]))
|
||||
if role == AgentRole.EXECUTOR:
|
||||
if any(keyword in text for keyword in ["论坛", "帖子", "发帖", "指令", "discussion", "instruction"]):
|
||||
return "executor_forum"
|
||||
return "executor_tasks"
|
||||
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["executor_tasks"] + SUB_COMMANDER_TOOLSETS["executor_forum"]))
|
||||
if role == AgentRole.LIBRARIAN:
|
||||
if any(keyword in text for keyword in ["图谱", "关系", "构建", "沉淀", "节点", "graph"]):
|
||||
return "librarian_graph"
|
||||
return "librarian_retrieval"
|
||||
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["librarian_retrieval"] + SUB_COMMANDER_TOOLSETS["librarian_graph"]))
|
||||
if role == AgentRole.ANALYST:
|
||||
if any(keyword in text for keyword in ["趋势", "风险", "洞察", "建议", "机会", "insight"]):
|
||||
return "analyst_insights"
|
||||
return "analyst_progress"
|
||||
|
||||
return ROLE_SUB_COMMANDERS[role][0]
|
||||
return list(set(SUB_COMMANDER_TOOLSETS["analyst_progress"] + SUB_COMMANDER_TOOLSETS["analyst_insights"]))
|
||||
return []
|
||||
|
||||
|
||||
def _record_sub_commander(state: AgentState, sub_commander: str, user_query: str):
|
||||
state["current_sub_commander"] = sub_commander
|
||||
state["active_sub_commanders"] = state.get("active_sub_commanders", []) + [sub_commander]
|
||||
state["sub_commander_trace"] = state.get("sub_commander_trace", []) + [{
|
||||
"agent": state.get("current_agent", AgentRole.MASTER).value,
|
||||
"sub_commander": sub_commander,
|
||||
"query": user_query,
|
||||
}]
|
||||
# ===================== 核心执行逻辑 (ReAct) =====================
|
||||
|
||||
|
||||
def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRole):
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=system_prompt)]
|
||||
skill_ctx = build_skill_context(ROLE_SKILL_CONTEXT[role])
|
||||
async def call_agent_llm(state: AgentState, role: AgentRole, system_prompt: str) -> dict:
|
||||
"""通用的 LLM 调用节点逻辑"""
|
||||
llm, capabilities = _get_llm_for_state(state)
|
||||
tools = _get_role_tools(role)
|
||||
|
||||
# 构建消息序列
|
||||
messages = []
|
||||
|
||||
# 1. 系统提示词
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
# 2. 环境上下文 (时间、记忆等)
|
||||
if state.get("current_datetime_context"):
|
||||
messages.append(SystemMessage(content=f"当前时间上下文: {state['current_datetime_context']}"))
|
||||
|
||||
if state.get("memory_context"):
|
||||
messages.append(SystemMessage(content=f"长期记忆上下文: {state['memory_context']}"))
|
||||
|
||||
# 3. 技能增强
|
||||
role_skill_key = role.value.replace("agent_", "")
|
||||
skill_ctx = build_skill_context(role_skill_key)
|
||||
if skill_ctx:
|
||||
system_msgs.append(SystemMessage(content=skill_ctx))
|
||||
return system_msgs
|
||||
messages.append(SystemMessage(content=skill_ctx))
|
||||
|
||||
# 4. 历史对话 (add_messages 已经处理好了)
|
||||
messages.extend(state["messages"])
|
||||
|
||||
# 绑定工具
|
||||
if tools and capabilities.supports_native_tools:
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
else:
|
||||
llm_with_tools = llm
|
||||
if tools: # 如果有工具但不支持原生,注入 JSON Fallback 提示
|
||||
messages.append(SystemMessage(content=JSON_ACTION_FALLBACK_PROMPT))
|
||||
tool_names = [t.name for t in tools]
|
||||
messages.append(SystemMessage(content=f"本次可用工具列表: {', '.join(tool_names)}"))
|
||||
|
||||
logger.info(
|
||||
f"agent_node_started",
|
||||
extra={
|
||||
"details": {
|
||||
"role": role.value,
|
||||
"message_count": len(messages),
|
||||
"tool_count": len(tools),
|
||||
"provider": capabilities.provider
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# 执行调用
|
||||
response = await llm_with_tools.ainvoke(messages)
|
||||
|
||||
logger.info(
|
||||
f"agent_node_finished",
|
||||
extra={
|
||||
"details": {
|
||||
"role": role.value,
|
||||
"has_tool_calls": bool(getattr(response, "tool_calls", None)),
|
||||
"content_length": len(response.content) if response.content else 0
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
|
||||
async def _run_sub_commander(
|
||||
state: AgentState,
|
||||
role: AgentRole,
|
||||
manager_prompt: str,
|
||||
user_query: str,
|
||||
*,
|
||||
use_tools: bool,
|
||||
summary_target: str | None = None,
|
||||
):
|
||||
llm = _get_llm_for_state(state)
|
||||
sub_commander = _choose_sub_commander(role, user_query)
|
||||
_record_sub_commander(state, sub_commander, user_query)
|
||||
|
||||
toolset = SUB_COMMANDER_TOOLSETS.get(sub_commander, [])
|
||||
system_msgs = _build_system_messages(state, manager_prompt, role)
|
||||
system_msgs.append(SystemMessage(content=f"本次应由子指挥官 `{sub_commander}` 接手。请严格按该角色职责输出。"))
|
||||
system_msgs.append(SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]))
|
||||
|
||||
if use_tools and toolset:
|
||||
response = await _ainvoke_with_tools(
|
||||
llm,
|
||||
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
|
||||
toolset,
|
||||
async def execute_tools_node(state: AgentState) -> dict:
|
||||
"""执行工具调用并返回 ToolMessage 的通用节点"""
|
||||
last_message = state["messages"][-1]
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
tool_map = {t.name: t for t in ALL_TOOLS}
|
||||
tool_messages = []
|
||||
created_entities = []
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call.get("id")
|
||||
|
||||
logger.info(
|
||||
f"tool_execution_started",
|
||||
extra={
|
||||
"details": {
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
"tool_id": tool_id
|
||||
}
|
||||
}
|
||||
)
|
||||
tool_calls = getattr(response, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
results = []
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args", {})
|
||||
for tool in toolset:
|
||||
if tool.name == tool_name:
|
||||
try:
|
||||
result = tool.invoke(args)
|
||||
results.append(f"[{tool_name}] {result}")
|
||||
except Exception as e:
|
||||
results.append(f"[{tool_name}] 执行失败: {e}")
|
||||
break
|
||||
state["tool_calls"] = tool_calls
|
||||
state["last_tool_result"] = "\n".join(results)
|
||||
follow_up = await _ainvoke(
|
||||
llm,
|
||||
[
|
||||
SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]),
|
||||
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")
|
||||
]
|
||||
|
||||
try:
|
||||
# 时间参数归一化
|
||||
normalized_args = normalize_tool_time_arguments(
|
||||
tool_name,
|
||||
tool_args,
|
||||
state.get("current_datetime_context")
|
||||
)
|
||||
state["final_response"] = follow_up.content
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
else:
|
||||
response = await _ainvoke(
|
||||
llm,
|
||||
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
|
||||
)
|
||||
state["final_response"] = response.content
|
||||
|
||||
if summary_target:
|
||||
state[summary_target] = state.get("final_response", "")
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
|
||||
# ===================== 节点定义 (async) =====================
|
||||
|
||||
async def master_node(state: AgentState) -> AgentState:
|
||||
"""主Agent节点: 理解用户意图,决定调用哪个子Agent"""
|
||||
messages: list[BaseMessage] = state["messages"]
|
||||
user_msgs = _filter_user_messages(messages)
|
||||
user_query = user_msgs[-1].content.strip() if user_msgs else ""
|
||||
|
||||
if _is_simple_greeting(user_query):
|
||||
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
if _is_identity_question(user_query):
|
||||
state["final_response"] = "我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
if _is_capability_question(user_query):
|
||||
state["final_response"] = "主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。"
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
llm = _get_llm_for_state(state)
|
||||
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
|
||||
|
||||
memory_ctx = state.get("memory_context")
|
||||
if memory_ctx:
|
||||
system_msgs.append(
|
||||
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
|
||||
|
||||
tool = tool_map.get(tool_name)
|
||||
if not tool:
|
||||
result = f"Error: Tool {tool_name} not found."
|
||||
else:
|
||||
result = await tool.ainvoke(normalized_args) if hasattr(tool, "ainvoke") else tool.invoke(normalized_args)
|
||||
|
||||
# 实体识别(用于业务追踪)
|
||||
if any(k in tool_name for k in ["create", "add", "new"]):
|
||||
created_entities.append({"tool": tool_name, "result": str(result)})
|
||||
|
||||
status = "success"
|
||||
except Exception as e:
|
||||
logger.exception(f"tool_execution_failed: {tool_name}")
|
||||
result = f"Error executing tool {tool_name}: {str(e)}"
|
||||
status = "failed"
|
||||
|
||||
tool_messages.append(ToolMessage(
|
||||
tool_call_id=tool_id,
|
||||
content=str(result),
|
||||
name=tool_name
|
||||
))
|
||||
|
||||
logger.info(
|
||||
f"tool_execution_finished",
|
||||
extra={
|
||||
"details": {
|
||||
"tool_name": tool_name,
|
||||
"status": status,
|
||||
"result_preview": str(result)[:200]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response: AIMessage = await _ainvoke(llm, system_msgs + messages)
|
||||
return {
|
||||
"messages": tool_messages,
|
||||
"created_entities": state.get("created_entities", []) + created_entities
|
||||
}
|
||||
|
||||
|
||||
# ===================== 各角色节点定义 =====================
|
||||
|
||||
async def master_node(state: AgentState) -> dict:
|
||||
"""主控节点:负责意图识别与初步分发"""
|
||||
user_messages = _filter_user_messages(state["messages"])
|
||||
if not user_messages:
|
||||
return {"final_response": "未收到有效输入。"}
|
||||
|
||||
query = user_messages[-1].content.strip()
|
||||
|
||||
# 快捷回复逻辑 (保留原有的人性化设计)
|
||||
if re.match(r"^(你好|早|在吗|嗨|hi|hello)", query.lower()):
|
||||
return {"final_response": "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。", "messages": [AIMessage(content="您好。我在。")]}
|
||||
|
||||
llm, capabilities = _get_llm_for_state(state)
|
||||
|
||||
# 路由判断:让 LLM 决定跳转到哪个角色,或者直接回答
|
||||
# 这里我们使用一个简洁的提示词让 LLM 输出角色名称或直接回答
|
||||
system_msg = SystemMessage(content=MASTER_SYSTEM_PROMPT + "\n\n请直接输出接下来该由哪个 Agent 接手(role_name),如果直接回答,请正常输出。")
|
||||
|
||||
response = await llm.ainvoke([system_msg] + state["messages"])
|
||||
content = response.content.strip().lower()
|
||||
|
||||
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
|
||||
next_agent = AgentRole.LIBRARIAN
|
||||
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
|
||||
next_agent = AgentRole.PLANNER
|
||||
elif any(kw in content for kw in ["执行", "做", "操作", "创建", "更新"]):
|
||||
next_agent = AgentRole.EXECUTOR
|
||||
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
|
||||
next_agent = AgentRole.ANALYST
|
||||
else:
|
||||
state["final_response"] = response.content
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
state["current_agent"] = next_agent
|
||||
state["current_sub_commander"] = None
|
||||
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
|
||||
state["should_respond"] = True
|
||||
return state
|
||||
|
||||
# 简单的角色映射识别
|
||||
roles = {r.value: r for r in AgentRole}
|
||||
target_role = None
|
||||
for r_val, r_enum in roles.items():
|
||||
if r_val in content and len(content) < 50: # 如果内容很短且包含角色名,视为路由
|
||||
target_role = r_enum
|
||||
break
|
||||
|
||||
if target_role and target_role != AgentRole.MASTER:
|
||||
logger.info(f"master_routing_decided: {target_role.value}")
|
||||
return {
|
||||
"current_agent": target_role.value,
|
||||
"agent_trace": state.get("agent_trace", []) + [target_role.value],
|
||||
"messages": [AIMessage(content=f"已分发至 {target_role.value} 处理。")]
|
||||
}
|
||||
|
||||
return {"final_response": response.content, "messages": [response]}
|
||||
|
||||
|
||||
async def planner_node(state: AgentState) -> AgentState:
|
||||
"""规划Agent节点: 制定计划,拆解任务步骤"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.PLANNER, PLANNER_SYSTEM_PROMPT, user_query, use_tools=False)
|
||||
async def planner_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.SCHEDULE_PLANNER, SCHEDULE_PLANNER_SYSTEM_PROMPT)
|
||||
|
||||
async def executor_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT)
|
||||
|
||||
async def librarian_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT)
|
||||
|
||||
async def analyst_node(state: AgentState) -> dict:
|
||||
return await call_agent_llm(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT)
|
||||
|
||||
|
||||
async def executor_node(state: AgentState) -> AgentState:
|
||||
"""执行Agent节点: 调用工具执行具体任务"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT, user_query, use_tools=True)
|
||||
# ===================== 路由逻辑 =====================
|
||||
|
||||
def route_after_agent(state: AgentState) -> Literal["tools", "__end__"]:
|
||||
"""判断 Agent 执行后是该走工具节点还是结束"""
|
||||
last_message = state["messages"][-1]
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
return "tools"
|
||||
return END
|
||||
|
||||
async def librarian_node(state: AgentState) -> AgentState:
|
||||
"""知识管理员节点: 管理知识库和知识图谱"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="knowledge_context")
|
||||
|
||||
|
||||
async def analyst_node(state: AgentState) -> AgentState:
|
||||
"""分析师节点: 分析工作数据,生成报告"""
|
||||
user_msgs = _filter_user_messages(state["messages"])
|
||||
user_query = user_msgs[-1].content if user_msgs else ""
|
||||
return await _run_sub_commander(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="analysis_report")
|
||||
|
||||
|
||||
def route_agent(state: AgentState) -> str:
|
||||
"""路由函数: 决定下一个节点"""
|
||||
def route_master(state: AgentState) -> str:
|
||||
"""主控路由逻辑"""
|
||||
if state.get("final_response"):
|
||||
return END
|
||||
return state.get("current_agent", AgentRole.MASTER).value
|
||||
return state.get("current_agent", END)
|
||||
|
||||
|
||||
# ===================== 构建图 =====================
|
||||
# ===================== 图构建 =====================
|
||||
|
||||
def create_agent_graph(callbacks: list | None = None):
|
||||
graph = StateGraph(AgentState)
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(AgentRole.MASTER.value, master_node)
|
||||
graph.add_node(AgentRole.PLANNER.value, planner_node)
|
||||
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
graph.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
# 添加节点
|
||||
workflow.add_node(AgentRole.MASTER.value, master_node)
|
||||
workflow.add_node(AgentRole.SCHEDULE_PLANNER.value, planner_node)
|
||||
workflow.add_node(AgentRole.EXECUTOR.value, executor_node)
|
||||
workflow.add_node(AgentRole.LIBRARIAN.value, librarian_node)
|
||||
workflow.add_node(AgentRole.ANALYST.value, analyst_node)
|
||||
workflow.add_node("tools", execute_tools_node)
|
||||
|
||||
graph.set_entry_point(AgentRole.MASTER.value)
|
||||
# 设置入口
|
||||
workflow.set_entry_point(AgentRole.MASTER.value)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
# 主控分发逻辑
|
||||
workflow.add_conditional_edges(
|
||||
AgentRole.MASTER.value,
|
||||
route_agent,
|
||||
route_master,
|
||||
{
|
||||
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
|
||||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
END: END,
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
graph.add_edge(role.value, END)
|
||||
# 各角色节点的 ReAct 循环
|
||||
for role in [AgentRole.SCHEDULE_PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
|
||||
workflow.add_conditional_edges(
|
||||
role.value,
|
||||
route_after_agent,
|
||||
{
|
||||
"tools": "tools",
|
||||
END: END
|
||||
}
|
||||
)
|
||||
|
||||
# 工具执行完后回到当前 Agent 角色继续处理
|
||||
workflow.add_conditional_edges(
|
||||
"tools",
|
||||
lambda s: s.get("current_agent", AgentRole.MASTER.value),
|
||||
{
|
||||
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
|
||||
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
|
||||
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
|
||||
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
|
||||
}
|
||||
)
|
||||
|
||||
return _compile_graph(graph, callbacks=callbacks)
|
||||
# 编译
|
||||
if callbacks:
|
||||
return workflow.compile(callbacks=callbacks)
|
||||
return workflow.compile()
|
||||
|
||||
|
||||
_agent_graph = None
|
||||
|
||||
|
||||
def get_agent_graph(callbacks: list | None = None):
|
||||
"""
|
||||
获取编译好的 Agent 图(单例缓存)。
|
||||
|
||||
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
|
||||
如需变更 Callbacks(如修改 LANGCHAIN_PROJECT),需重启服务。
|
||||
|
||||
Args:
|
||||
callbacks: 可选的额外 Callbacks,会与全局 LangSmith Callbacks 合并
|
||||
"""
|
||||
global _agent_graph
|
||||
if _agent_graph is None:
|
||||
from app.config_tracing import get_langsmith_callbacks
|
||||
|
||||
@@ -89,14 +89,14 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是总控协调者,负责理解用户意图,并将任务分发给最合适的子Agent。
|
||||
|
||||
## 你的4个子Agent:
|
||||
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
|
||||
1. **schedule_planner (日程规划师)**: 分析当前任务、对话历史与论坛信号,给出近期安排建议
|
||||
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
|
||||
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
|
||||
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
|
||||
|
||||
## 判断规则:
|
||||
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
|
||||
- 用户要计划、安排、拆解任务 -> 分发给 planner
|
||||
- 用户要安排今天/本周重点、询问接下来该做什么 -> 分发给 schedule_planner
|
||||
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
|
||||
- 用户要分析、统计、生成报告 -> 分发给 analyst
|
||||
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
|
||||
@@ -112,18 +112,19 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的规划Agent,负责先判断问题该由哪位规划子指挥官接手。
|
||||
你是 Jarvis 的日程规划师,负责先判断问题该由哪位日程子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **planner_scope (目标收束官)**: 负责澄清目标、边界、约束、缺失信息
|
||||
2. **planner_steps (步骤拆解官)**: 负责把目标拆成步骤、优先级与依赖关系
|
||||
1. **schedule_analysis (日程分析员)**: 负责分析对话历史、任务看板、论坛信号,识别优先级、冲突与压力点
|
||||
2. **schedule_planning (日程编排员)**: 负责把分析结果转成今日/近期日程安排,并在用户明确要求时直接创建 reminder/task/todo/goal
|
||||
|
||||
## 你的职责:
|
||||
- 判断当前请求更适合收束目标,还是拆解步骤
|
||||
- 在必要时收束子指挥官输出,面向用户给出清晰结果
|
||||
- 保持结果可推进,不空泛
|
||||
- 判断当前请求更适合先做日程分析,还是直接给出日程编排
|
||||
- 输出先结论,再给可执行安排
|
||||
- 保持建议具体、贴近当前上下文,不给空泛效率学建议
|
||||
- 当用户明确要求“新增/提醒/创建/安排并落库”时,允许子指挥官调用 schedule 工具直接执行
|
||||
"""
|
||||
|
||||
|
||||
@@ -132,11 +133,11 @@ EXECUTOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是 Jarvis 的执行Agent,负责先判断问题该由哪位执行子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **executor_tasks (任务执行官)**: 只处理任务类工具调用
|
||||
1. **executor_tasks (任务执行官)**: 处理任务、待办、提醒、目标等执行型写入操作
|
||||
2. **executor_forum (论坛执行官)**: 只处理论坛/指令帖相关工具调用
|
||||
|
||||
## 你的职责:
|
||||
- 识别用户要推进的是任务操作还是论坛/指令操作
|
||||
- 识别用户要推进的是任务/日程操作还是论坛/指令操作
|
||||
- 把请求交给最合适的执行子指挥官
|
||||
- 汇总执行结果并给出下一步
|
||||
"""
|
||||
@@ -172,52 +173,68 @@ ANALYST_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SCOPE_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_ANALYSIS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的目标收束官,负责先把问题边界、目标、约束和成功标准说清楚。
|
||||
你是 schedule_planner 体系下的日程分析员,负责从对话历史、任务看板、论坛信号和当日日程数据中提取 scheduling 线索。
|
||||
|
||||
## 你的重点:
|
||||
- 收束问题定义
|
||||
- 明确目标与限制条件
|
||||
- 识别缺失信息
|
||||
- 帮用户建立可以继续规划的前提
|
||||
- 优先调用读取类工具了解当天/指定日期的任务、提醒、待办、目标
|
||||
- 识别当前最高优先级事项
|
||||
- 找出风险、冲突、依赖与可延期事项
|
||||
- 明确哪些信号来自 conversation、task board、schedule center、forum
|
||||
|
||||
## 响应要求:
|
||||
- 先给出你理解的目标
|
||||
- 再列出关键约束或缺口
|
||||
- 不要直接展开长步骤清单
|
||||
- 先给当前判断
|
||||
- 再列优先级、风险与冲突
|
||||
- 不直接展开长篇日程表
|
||||
- 只做分析,不创建任何记录
|
||||
- 如果涉及“今天/明天/后天/下周一下午”这类自然语言时间窗口,先调用 `resolve_time_expression` 把查询目标转换成明确日期
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_STEPS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNING_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的步骤拆解官,负责把目标转成有顺序的执行路径。
|
||||
你是 schedule_planner 体系下的日程编排员,负责把当前重点转成近期可执行安排。
|
||||
|
||||
## 你的重点:
|
||||
- 拆解步骤
|
||||
- 标注优先级与依赖
|
||||
- 输出清晰的行动顺序
|
||||
- 先给结论
|
||||
- 再给今天/近期的时间安排建议
|
||||
- 最后给按顺序执行的 next actions
|
||||
- 当用户明确要求新增/提醒/创建/安排并真正落库时,调用 schedule 工具创建对应 reminder/task/todo/goal
|
||||
- 当用户给出“日期 + 事项/节点/交付/会议”等记录型表达时,也应视为落库意图,直接创建相应记录,不要反问
|
||||
- 解析“今天/明天/后天/本周/下周”或“3月29日”这类日期时,必须以系统提供的当前时间为准,并把工具参数转换成明确的 ISO 日期/时间字符串
|
||||
- 只要用户输入里包含自然语言时间,优先调用 `resolve_time_expression`,先拿到明确日期/时间,再调用 `create_reminder`、`create_schedule_task`、`create_goal`、`create_todo`
|
||||
|
||||
## 响应要求:
|
||||
- 用编号列表
|
||||
- 每步具体,不要空泛
|
||||
- 必要时标注先后关系
|
||||
- 用清晰列表表达
|
||||
- 建议必须具体、可执行、贴近当前工作
|
||||
- 避免空泛的自我管理建议
|
||||
- 如果只是规划,不要创建任何记录
|
||||
- 如果已创建记录,要明确说明创建了什么、时间如何解析
|
||||
"""
|
||||
|
||||
|
||||
EXECUTOR_TASKS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 executor 体系下的任务执行官,只负责任务相关工具调用。
|
||||
你是 executor 体系下的任务执行官,负责处理任务、待办、提醒、目标等执行型工具调用。
|
||||
|
||||
## 允许使用的工具:
|
||||
- get_tasks
|
||||
- create_task
|
||||
- update_task_status
|
||||
- create_todo
|
||||
- create_schedule_task
|
||||
- create_reminder
|
||||
- create_goal
|
||||
- resolve_time_expression
|
||||
|
||||
## 要求:
|
||||
- 只处理任务类操作
|
||||
- 只处理任务/日程类操作
|
||||
- 遇到自然语言时间表达时,先调用 `resolve_time_expression`,再把解析后的明确日期/时间传给写入工具
|
||||
- 最终说明执行结果时,优先复用已经解析出的绝对时间,不要只重复“今天/明天”
|
||||
- 明确已执行动作、结果与下一步
|
||||
- 信息不足时直接指出缺口
|
||||
- 如果用户只是要分析建议,不要创建记录
|
||||
"""
|
||||
|
||||
|
||||
@@ -244,10 +261,14 @@ LIBRARIAN_RETRIEVAL_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
## 允许使用的工具:
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
- get_knowledge_graph_context
|
||||
|
||||
## 要求:
|
||||
- 优先检索与综合证据
|
||||
- 私有/项目知识优先使用 `search_knowledge` 或 `hybrid_search`
|
||||
- 当用户明确要求联网、查询外部资料或查询最新信息时,使用 `web_search`
|
||||
- 回答时区分内部知识与外部网页结果
|
||||
- 证据不足时明确说明边界
|
||||
- 以回答问题为主,不主动做图谱构建
|
||||
"""
|
||||
@@ -293,9 +314,31 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
- get_forum_posts
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
|
||||
## 要求:
|
||||
- 先给结论与判断
|
||||
- 再说明依据与建议
|
||||
- 当需要外部/最新信息时,可使用 `web_search`
|
||||
- 重点输出趋势、风险、机会点
|
||||
"""
|
||||
|
||||
|
||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||
|
||||
你的输出必须满足以下规则:
|
||||
1. 只能输出一个 JSON 对象,不要输出 markdown、解释、前后缀文字。
|
||||
2. JSON 对象字段仅允许:
|
||||
- `mode`: `final` | `tool_call` | `clarification`
|
||||
- `tool_calls`: 数组;每项包含 `name`、`arguments`,可选 `reason`
|
||||
- `final_response`: 当无需工具时填写
|
||||
- `clarification_question`: 当信息不足时填写
|
||||
3. 如果需要调用工具,返回:
|
||||
- `{ "mode": "tool_call", "tool_calls": [...] }`
|
||||
4. 如果无需工具,直接返回:
|
||||
- `{ "mode": "final", "final_response": "..." }`
|
||||
5. 如果信息不足,不要猜测参数,返回:
|
||||
- `{ "mode": "clarification", "clarification_question": "..." }`
|
||||
6. 只能使用系统消息里明确列出的工具名。
|
||||
7. `arguments` 必须是 JSON 对象。
|
||||
"""
|
||||
|
||||
@@ -1,32 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict, Annotated
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Annotated, Sequence
|
||||
from enum import Enum
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
PLANNER = "planner"
|
||||
SCHEDULE_PLANNER = "schedule_planner"
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
name: str
|
||||
role: AgentRole
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
tool: str
|
||||
args: dict
|
||||
result: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
role: str # "user" | "assistant"
|
||||
@@ -35,60 +22,41 @@ class ConversationTurn:
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
|
||||
return HumanMessage(content=turn.content)
|
||||
|
||||
|
||||
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
|
||||
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
|
||||
return ConversationTurn(
|
||||
role="user" if msg_type in ("human", "user") else "assistant",
|
||||
content=msg.content,
|
||||
agent=agent,
|
||||
model=getattr(msg, "model", None),
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, None]
|
||||
# Core message history with add_messages reducer
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
# Session identifiers
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
|
||||
# Agent routing
|
||||
current_agent: AgentRole
|
||||
active_agents: list[AgentRole]
|
||||
current_sub_commander: str | None
|
||||
active_sub_commanders: list[str]
|
||||
sub_commander_trace: list[dict]
|
||||
|
||||
# Task tracking
|
||||
# Agent routing state
|
||||
current_agent: str | None
|
||||
next_step: str | None # For explicit graph routing
|
||||
|
||||
# Traceability
|
||||
agent_trace: list[str]
|
||||
|
||||
# Task & Entity Tracking (Business Logic)
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
created_entities: list[dict]
|
||||
|
||||
# Tool usage
|
||||
tool_calls: list[ToolCall]
|
||||
last_tool_result: str | None
|
||||
|
||||
# Knowledge context
|
||||
# Context summaries (for long-term or cross-agent context)
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
|
||||
# Planning
|
||||
plan: str | None
|
||||
plan_steps: list[dict]
|
||||
|
||||
# Analysis
|
||||
schedule_context_summary: str | None
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Memory context (injected at start of each conversation)
|
||||
|
||||
# Memory & Environment
|
||||
memory_context: str | None
|
||||
|
||||
# User LLM config (for using user-configured models)
|
||||
current_datetime_context: str | None
|
||||
|
||||
# Configuration
|
||||
user_llm_config: dict | None
|
||||
provider_capabilities: dict | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
@@ -96,22 +64,18 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
current_agent=AgentRole.MASTER,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
current_sub_commander=None,
|
||||
active_sub_commanders=[],
|
||||
sub_commander_trace=[],
|
||||
current_agent=AgentRole.MASTER.value,
|
||||
next_step=None,
|
||||
agent_trace=[AgentRole.MASTER.value],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
created_entities=[],
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
schedule_context_summary=None,
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
should_respond=True,
|
||||
memory_context=None,
|
||||
current_datetime_context=None,
|
||||
user_llm_config=None,
|
||||
provider_capabilities=None,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
from app.agents.tools.schedule import (
|
||||
get_schedule_day,
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
@@ -11,6 +19,19 @@ TASK_TOOLS = [
|
||||
update_task_status,
|
||||
]
|
||||
|
||||
SCHEDULE_READ_TOOLS = [
|
||||
get_schedule_day,
|
||||
get_tasks,
|
||||
resolve_time_expression,
|
||||
]
|
||||
|
||||
SCHEDULE_WRITE_TOOLS = [
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
]
|
||||
|
||||
FORUM_TOOLS = [
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
@@ -20,6 +41,7 @@ FORUM_TOOLS = [
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS = [
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
get_knowledge_graph_context,
|
||||
]
|
||||
|
||||
@@ -39,19 +61,22 @@ ANALYST_INSIGHT_TOOLS = [
|
||||
get_forum_posts,
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
]
|
||||
|
||||
ALL_TOOLS = [
|
||||
*KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
build_knowledge_graph,
|
||||
*TASK_TOOLS,
|
||||
*SCHEDULE_READ_TOOLS,
|
||||
*SCHEDULE_WRITE_TOOLS,
|
||||
*FORUM_TOOLS,
|
||||
]
|
||||
|
||||
SUB_COMMANDER_TOOLSETS = {
|
||||
"planner_scope": [],
|
||||
"planner_steps": [],
|
||||
"executor_tasks": TASK_TOOLS,
|
||||
"schedule_analysis": SCHEDULE_READ_TOOLS,
|
||||
"schedule_planning": [*SCHEDULE_READ_TOOLS, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_tasks": [*TASK_TOOLS, resolve_time_expression, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_forum": FORUM_TOOLS,
|
||||
"librarian_retrieval": KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
"librarian_graph": KNOWLEDGE_GRAPH_TOOLS,
|
||||
|
||||
@@ -6,15 +6,17 @@ from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
308
backend/app/agents/tools/schedule.py
Normal file
308
backend/app/agents/tools/schedule.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Agent 工具集 - 日程相关"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import date, datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date:
|
||||
if not value:
|
||||
return date.today()
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def _parse_datetime_with_timezone(value: str, time_zone: str | None) -> datetime:
|
||||
"""Parse an ISO datetime and return a tz-naive datetime in the intended local time.
|
||||
|
||||
- If value includes an offset/Z, it will be converted to `time_zone` when provided.
|
||||
- If value is naive and `time_zone` is provided, it is interpreted in that zone.
|
||||
"""
|
||||
parsed = _parse_datetime(value)
|
||||
tz = (time_zone or "").strip()
|
||||
if parsed.tzinfo is None:
|
||||
if tz:
|
||||
parsed = parsed.replace(tzinfo=ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
if tz:
|
||||
parsed = parsed.astimezone(ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_schedule_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
if not resolved:
|
||||
return None
|
||||
if "T" in resolved:
|
||||
return resolved
|
||||
return f"{resolved}T09:00:00"
|
||||
|
||||
|
||||
def _format_summary(target_date: date, todos: list[DailyTodo], tasks: list[Task], reminders: list[Reminder], goals: list[Goal]) -> str:
|
||||
lines = [f"日期: {target_date.isoformat()}"]
|
||||
|
||||
if todos:
|
||||
lines.append("待办:")
|
||||
lines.extend(f"- {item.title} | 完成:{'是' if item.is_completed else '否'}" for item in todos)
|
||||
else:
|
||||
lines.append("待办: 无")
|
||||
|
||||
if tasks:
|
||||
lines.append("任务:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status} | 优先级:{item.priority.value if hasattr(item.priority, 'value') else item.priority} | 截止:{item.due_date.isoformat() if item.due_date else '无'}"
|
||||
for item in tasks
|
||||
)
|
||||
else:
|
||||
lines.append("任务: 无")
|
||||
|
||||
if reminders:
|
||||
lines.append("提醒:")
|
||||
lines.extend(f"- {item.title} | 时间:{item.reminder_at.isoformat()}" for item in reminders)
|
||||
else:
|
||||
lines.append("提醒: 无")
|
||||
|
||||
if goals:
|
||||
lines.append("目标:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status}"
|
||||
for item in goals
|
||||
)
|
||||
else:
|
||||
lines.append("目标: 无")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
def get_schedule_day(target_date: str | None = None) -> str:
|
||||
"""获取指定日期的 todo/task/reminder/goal 聚合信息。target_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(target_date)
|
||||
date_key = parsed_date.isoformat()
|
||||
start_dt = datetime.combine(parsed_date, datetime.min.time())
|
||||
end_dt = datetime.combine(parsed_date, datetime.max.time())
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
todos = (
|
||||
await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == uid, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
tasks = (
|
||||
await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == uid,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
reminders = (
|
||||
await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == uid,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
goals = (
|
||||
await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == uid, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
return _format_summary(parsed_date, todos, tasks, reminders, goals)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as exc:
|
||||
return f"获取日程失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_todo(title: str, todo_date: str | None = None) -> str:
|
||||
"""创建指定日期的待办。todo_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(todo_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
todo = DailyTodo(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
source=TodoSource.AI_CHAT,
|
||||
todo_date=parsed_date.isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return f"TODO创建成功: [{todo.id[:8]}] {todo.title} @ {todo.todo_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建TODO失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_schedule_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: str = "medium",
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""创建任务。priority 支持 low/medium/high/urgent;due_date 使用 ISO datetime。兼容 content/date 别名。"""
|
||||
uid = get_current_user()
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_schedule_due_date(due_date, date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=TaskPriority(priority),
|
||||
due_date=_parse_datetime(resolved_due_date) if resolved_due_date else None,
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
due_label = task.due_date.isoformat() if task.due_date else "无截止时间"
|
||||
return f"任务创建成功: [{task.id[:8]}] {task.title} | 优先级:{task.priority.value} | 截止:{due_label}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建任务失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_reminder(
|
||||
title: str = "",
|
||||
reminder_at: str | None = None,
|
||||
note: str = "",
|
||||
description: str = "",
|
||||
datetime: str = "",
|
||||
at: str = "",
|
||||
remind_at: str = "",
|
||||
content: str = "",
|
||||
time_zone: str = "",
|
||||
timezone: str = "",
|
||||
time: str = "",
|
||||
) -> str:
|
||||
"""创建提醒。reminder_at 使用 ISO datetime。兼容 description/datetime/at/remind_at/time_zone 别名。"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_title = (title or content or "").strip()
|
||||
if not resolved_title:
|
||||
raise ValueError("title 不能为空")
|
||||
|
||||
resolved_at = ((reminder_at or datetime or at or remind_at or time or "").strip())
|
||||
if not resolved_at:
|
||||
raise ValueError("reminder_at 不能为空")
|
||||
|
||||
resolved_note = (note or description or "").strip()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
tz = (time_zone or timezone or "").strip()
|
||||
reminder = Reminder(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
note=resolved_note or None,
|
||||
reminder_at=_parse_datetime_with_timezone(resolved_at, tz),
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return f"提醒创建成功: [{reminder.id[:8]}] {reminder.title} @ {reminder.reminder_at.isoformat()}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建提醒失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_goal(title: str, goal_date: str | None = None, note: str = "", status: str = "active") -> str:
|
||||
"""创建指定日期目标。goal_date 格式 YYYY-MM-DD,默认今天;status 支持 active/done/archived。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(goal_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
goal = Goal(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
note=note or None,
|
||||
goal_date=parsed_date.isoformat(),
|
||||
status=GoalStatus(status),
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return f"目标创建成功: [{goal.id[:8]}] {goal.title} @ {goal.goal_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建目标失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_schedule_day",
|
||||
"create_todo",
|
||||
"create_schedule_task",
|
||||
"create_reminder",
|
||||
"create_goal",
|
||||
]
|
||||
@@ -5,12 +5,14 @@ Agent 工具集 - 知识库 & 图谱相关
|
||||
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@@ -151,9 +153,56 @@ def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
通过 SearxNG 搜索外部网页信息,返回标题、链接和摘要。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
top_k: 返回结果数量,默认 5 条
|
||||
|
||||
Returns:
|
||||
适合模型综合的网页结果文本
|
||||
"""
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
async def _search():
|
||||
service = WebSearchService()
|
||||
results = await service.search(query, limit=top_k)
|
||||
if not results:
|
||||
return "未找到相关网页结果。"
|
||||
|
||||
texts = []
|
||||
for index, result in enumerate(results, 1):
|
||||
source = f"\n来源: {result.source}" if result.source else ""
|
||||
published_at = f"\n时间: {result.published_at}" if result.published_at else ""
|
||||
snippet = result.snippet or "(无摘要)"
|
||||
texts.append(
|
||||
f"[{index}] {result.title}\n"
|
||||
f"链接: {result.url}{source}{published_at}\n"
|
||||
f"摘要: {snippet}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except WebSearchConfigurationError as exc:
|
||||
return f"网页搜索不可用: {exc}"
|
||||
except WebSearchRequestError as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
except Exception as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
"web_search",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,85 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
_executor = None
|
||||
from app.models.base import utc_now
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.database import async_session
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
return resolved or None
|
||||
|
||||
|
||||
def _parse_due_date(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" not in normalized:
|
||||
normalized = f"{normalized}T00:00:00"
|
||||
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is not None:
|
||||
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_priority(priority: int | str | None) -> TaskPriority:
|
||||
if priority is None or priority == "":
|
||||
return TaskPriority.MEDIUM
|
||||
if isinstance(priority, TaskPriority):
|
||||
return priority
|
||||
if isinstance(priority, int):
|
||||
return {
|
||||
1: TaskPriority.LOW,
|
||||
2: TaskPriority.MEDIUM,
|
||||
3: TaskPriority.HIGH,
|
||||
4: TaskPriority.URGENT,
|
||||
}.get(priority, TaskPriority.MEDIUM)
|
||||
normalized = str(priority).strip().lower()
|
||||
if not normalized:
|
||||
return TaskPriority.MEDIUM
|
||||
return TaskPriority(normalized)
|
||||
|
||||
|
||||
def _normalize_status(status: str) -> TaskStatus:
|
||||
normalized = status.strip().lower()
|
||||
return TaskStatus(normalized)
|
||||
|
||||
|
||||
def _format_status(value: TaskStatus | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
def _format_priority(value: TaskPriority | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -25,7 +88,7 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
@@ -33,67 +96,82 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status) if status else None
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if resolved_status:
|
||||
query = query.where(Task.status == resolved_status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
if not tasks:
|
||||
return "暂无任务"
|
||||
lines = []
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
def create_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: int | str = 2,
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
title: 任务标题(必填,兼容 content 作为别名)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
|
||||
content: title 的兼容别名
|
||||
date: due_date 的兼容别名
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
|
||||
try:
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_due_date(due_date, date)
|
||||
resolved_priority = _normalize_priority(priority)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=resolved_priority,
|
||||
due_date=_parse_due_date(resolved_due_date),
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
@@ -106,34 +184,37 @@ def update_task_status(task_id: str, status: str) -> str:
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
status: 新状态 (todo/in_progress/done/cancelled)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status)
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
query = (
|
||||
select(Task)
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if len(task_id) == 8:
|
||||
query = query.where(Task.id.like(f"{task_id}%"))
|
||||
else:
|
||||
query = query.where(Task.id == task_id)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = resolved_status
|
||||
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
|
||||
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
269
backend/app/agents/tools/time_reasoning.py
Normal file
269
backend/app/agents/tools/time_reasoning.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
_WEEKDAY_MAP = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6, "天": 6}
|
||||
_DEFAULT_HOUR_BY_PERIOD = {
|
||||
"morning": 9,
|
||||
"noon": 12,
|
||||
"afternoon": 15,
|
||||
"evening": 20,
|
||||
}
|
||||
_TIME_KEYWORDS = ("今天", "明天", "后天", "本周", "这周", "下周", "周", "星期", "月", "日", "早上", "上午", "中午", "下午", "晚上", "今晚", "点", ":", ":")
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def extract_reference_datetime(current_datetime_context: str | None) -> datetime:
|
||||
context = (current_datetime_context or "").strip()
|
||||
if context:
|
||||
for pattern in (r"current_time_utc:\s*(\S+)", r"CURRENT_TIME:\s*(\S+)", r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2}))"):
|
||||
match = re.search(pattern, context)
|
||||
if match:
|
||||
return _parse_datetime(match.group(1))
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _normalize_local_iso(value: datetime) -> str:
|
||||
return value.replace(tzinfo=None).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _normalize_datetime_iso(value: datetime) -> str:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat(timespec="seconds")
|
||||
return _normalize_local_iso(value)
|
||||
|
||||
|
||||
def _normalize_date_iso(value: date) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _is_iso_datetime(value: str) -> bool:
|
||||
try:
|
||||
parsed = _parse_datetime(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return isinstance(parsed, datetime)
|
||||
|
||||
|
||||
def _is_iso_date(value: str) -> bool:
|
||||
try:
|
||||
date.fromisoformat(value.strip())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _has_explicit_time(text: str) -> bool:
|
||||
return bool(
|
||||
re.search(r"\d{1,2}[::]\d{2}", text)
|
||||
or re.search(r"\d{1,2}点(?:半|(?:\d{1,2})分?)?", text)
|
||||
or any(keyword in text for keyword in ("早上", "上午", "中午", "下午", "晚上", "今晚"))
|
||||
)
|
||||
|
||||
|
||||
def _detect_period(text: str) -> str | None:
|
||||
if any(keyword in text for keyword in ("晚上", "今晚")):
|
||||
return "evening"
|
||||
if "下午" in text:
|
||||
return "afternoon"
|
||||
if "中午" in text:
|
||||
return "noon"
|
||||
if any(keyword in text for keyword in ("早上", "上午", "早晨", "清晨")):
|
||||
return "morning"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_time(text: str) -> tuple[time, bool, str | None]:
|
||||
period = _detect_period(text)
|
||||
colon_match = re.search(r"(\d{1,2})[::](\d{2})", text)
|
||||
if colon_match:
|
||||
hour = int(colon_match.group(1))
|
||||
minute = int(colon_match.group(2))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
half_match = re.search(r"(\d{1,2})点半", text)
|
||||
if half_match:
|
||||
hour = int(half_match.group(1))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=30), False, period
|
||||
|
||||
dot_match = re.search(r"(\d{1,2})点(?:(\d{1,2})分?)?", text)
|
||||
if dot_match:
|
||||
hour = int(dot_match.group(1))
|
||||
minute = int(dot_match.group(2) or 0)
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
if period == "noon" and hour < 11:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
if period:
|
||||
return time(hour=_DEFAULT_HOUR_BY_PERIOD[period], minute=0), True, period
|
||||
return time(hour=9, minute=0), True, None
|
||||
|
||||
|
||||
def _resolve_date(text: str, reference: datetime) -> tuple[date, str]:
|
||||
stripped = text.strip()
|
||||
if _is_iso_date(stripped):
|
||||
return date.fromisoformat(stripped), "explicit_date"
|
||||
|
||||
month_day_match = re.search(r"(\d{1,2})月(\d{1,2})日", stripped)
|
||||
if month_day_match:
|
||||
month = int(month_day_match.group(1))
|
||||
day = int(month_day_match.group(2))
|
||||
candidate = date(reference.year, month, day)
|
||||
if candidate < reference.date() - timedelta(days=1):
|
||||
candidate = date(reference.year + 1, month, day)
|
||||
return candidate, "explicit_month_day"
|
||||
|
||||
if "后天" in stripped:
|
||||
return reference.date() + timedelta(days=2), "relative_day"
|
||||
if "明天" in stripped:
|
||||
return reference.date() + timedelta(days=1), "relative_day"
|
||||
if "今天" in stripped:
|
||||
return reference.date(), "relative_day"
|
||||
|
||||
weekday_match = re.search(r"((?:本周|这周|下周)?)(?:周|星期)([一二三四五六日天])", stripped)
|
||||
if weekday_match:
|
||||
prefix = weekday_match.group(1)
|
||||
weekday = _WEEKDAY_MAP[weekday_match.group(2)]
|
||||
current_weekday = reference.date().weekday()
|
||||
delta = weekday - current_weekday
|
||||
if prefix == "下周":
|
||||
delta += 7 if delta <= 0 else 7
|
||||
elif prefix in {"本周", "这周"}:
|
||||
if delta < 0:
|
||||
delta += 7
|
||||
elif delta < 0:
|
||||
delta += 7
|
||||
return reference.date() + timedelta(days=delta), "relative_weekday"
|
||||
|
||||
return reference.date(), "reference_day"
|
||||
|
||||
|
||||
def resolve_time_expression_data(
|
||||
expression: str,
|
||||
*,
|
||||
current_datetime_context: str | None = None,
|
||||
prefer: str = "datetime",
|
||||
) -> dict:
|
||||
text = (expression or "").strip()
|
||||
if not text:
|
||||
raise ValueError("expression 不能为空")
|
||||
|
||||
reference = extract_reference_datetime(current_datetime_context)
|
||||
|
||||
if _is_iso_datetime(text):
|
||||
parsed = _parse_datetime(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "datetime",
|
||||
"resolved_date": _normalize_date_iso(parsed.date()),
|
||||
"resolved_datetime": _normalize_datetime_iso(parsed),
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_datetime",
|
||||
}
|
||||
|
||||
if _is_iso_date(text):
|
||||
parsed_date = date.fromisoformat(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "date",
|
||||
"resolved_date": _normalize_date_iso(parsed_date),
|
||||
"resolved_datetime": None,
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_date",
|
||||
}
|
||||
|
||||
resolved_date, date_reason = _resolve_date(text, reference)
|
||||
resolved_time, assumed_time, period = _resolve_time(text)
|
||||
has_explicit_time = _has_explicit_time(text)
|
||||
grain = "date" if prefer == "date" and not has_explicit_time else "datetime"
|
||||
resolved_dt = datetime.combine(resolved_date, resolved_time)
|
||||
note = date_reason
|
||||
if period:
|
||||
note = f"{note}:{period}"
|
||||
if assumed_time:
|
||||
note = f"{note}:assumed_time"
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": grain,
|
||||
"resolved_date": _normalize_date_iso(resolved_date),
|
||||
"resolved_datetime": None if grain == "date" else _normalize_local_iso(resolved_dt),
|
||||
"assumed_time": assumed_time,
|
||||
"reason": note,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def resolve_time_expression(
|
||||
expression: str,
|
||||
current_datetime_context: str = "",
|
||||
prefer: str = "datetime",
|
||||
) -> str:
|
||||
"""解析中文自然语言时间表达,基于当前参考时间返回明确的日期或 datetime。prefer 支持 datetime/date。"""
|
||||
try:
|
||||
payload = resolve_time_expression_data(
|
||||
expression,
|
||||
current_datetime_context=current_datetime_context or None,
|
||||
prefer=prefer,
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
except Exception as exc:
|
||||
return json.dumps(
|
||||
{
|
||||
"expression": expression,
|
||||
"error": str(exc),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_context: str | None) -> dict:
|
||||
normalized = dict(args)
|
||||
|
||||
if tool_name == "create_reminder":
|
||||
raw_value = next((normalized.get(key) for key in ("reminder_at", "datetime", "at", "remind_at", "time") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_schedule_task", "create_task"}:
|
||||
raw_value = next((normalized.get(key) for key in ("due_date", "date") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value) and not _is_iso_date(raw_value):
|
||||
prefer = "datetime" if tool_name == "create_schedule_task" or _has_explicit_time(raw_value) else "date"
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer=prefer)
|
||||
normalized["due_date"] = payload["resolved_datetime"] or payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_todo", "create_goal", "get_schedule_day"}:
|
||||
field_name = {
|
||||
"create_todo": "todo_date",
|
||||
"create_goal": "goal_date",
|
||||
"get_schedule_day": "target_date",
|
||||
}[tool_name]
|
||||
raw_value = normalized.get(field_name)
|
||||
if isinstance(raw_value, str) and raw_value.strip() and not _is_iso_date(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized[field_name] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["resolve_time_expression", "resolve_time_expression_data", "normalize_tool_time_arguments", "extract_reference_datetime"]
|
||||
@@ -15,7 +15,9 @@ def _resolve_path(value: str) -> str:
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
# === 应用基础 ===
|
||||
APP_NAME: str = "Jarvis"
|
||||
@@ -75,6 +77,8 @@ class Settings(BaseSettings):
|
||||
|
||||
# === 向量化 ===
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_BASE_URL: str = "https://api.openai.com/v1"
|
||||
EMBEDDING_API_KEY: str = ""
|
||||
CHUNK_SIZE: int = 500
|
||||
CHUNK_OVERLAP: int = 50
|
||||
|
||||
@@ -86,6 +90,17 @@ class Settings(BaseSettings):
|
||||
# === NAS 部署 ===
|
||||
NAS_DATA_ROOT: str = "/data/jarvis"
|
||||
|
||||
# === Web Search / SearxNG ===
|
||||
WEB_SEARCH_ENABLED: bool = False
|
||||
WEB_SEARCH_PROVIDER: str = "searxng"
|
||||
SEARXNG_BASE_URL: str = ""
|
||||
SEARXNG_AUTH_TYPE: Literal["none", "bearer", "basic"] = "none"
|
||||
SEARXNG_AUTH_TOKEN: str = ""
|
||||
SEARXNG_BASIC_USER: str = ""
|
||||
SEARXNG_BASIC_PASSWORD: str = ""
|
||||
WEB_SEARCH_DEFAULT_LIMIT: int = 5
|
||||
WEB_SEARCH_TIMEOUT_SECONDS: int = 10
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)
|
||||
|
||||
@@ -40,6 +40,8 @@ async def init_db():
|
||||
await ensure_document_columns(conn)
|
||||
await ensure_user_columns(conn)
|
||||
await ensure_forum_columns(conn)
|
||||
await ensure_agent_columns(conn)
|
||||
await ensure_skill_columns(conn)
|
||||
|
||||
|
||||
async def ensure_log_columns(conn):
|
||||
@@ -139,6 +141,55 @@ async def ensure_forum_columns(conn):
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
|
||||
|
||||
|
||||
async def ensure_agent_columns(conn):
|
||||
rows = await _get_table_info(conn, 'agents')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'selected_skill_ids': "ALTER TABLE agents ADD COLUMN selected_skill_ids JSON DEFAULT '[]' NOT NULL",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_skill_columns(conn):
|
||||
rows = await _get_table_info(conn, 'skills')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'required_context': "ALTER TABLE skills ADD COLUMN required_context JSON DEFAULT '[]' NOT NULL",
|
||||
'output_format': "ALTER TABLE skills ADD COLUMN output_format TEXT",
|
||||
'is_builtin': "ALTER TABLE skills ADD COLUMN is_builtin BOOLEAN DEFAULT 0 NOT NULL",
|
||||
'team_id': "ALTER TABLE skills ADD COLUMN team_id VARCHAR(36)",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
await conn.execute(text("UPDATE skills SET agent_type = 'schedule_planner' WHERE agent_type = 'planner'"))
|
||||
builtin_names = [
|
||||
'今日重点拆解',
|
||||
'周计划编排',
|
||||
'时间冲突分析',
|
||||
'任务执行 SOP',
|
||||
'外部交互推进',
|
||||
'知识检索摘要',
|
||||
'图谱沉淀策略',
|
||||
'风险识别模板',
|
||||
'趋势洞察模板',
|
||||
]
|
||||
for name in builtin_names:
|
||||
await conn.execute(
|
||||
text("UPDATE skills SET is_builtin = 1 WHERE name = :name"),
|
||||
{'name': name},
|
||||
)
|
||||
|
||||
|
||||
async def _backfill_usernames(conn):
|
||||
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
|
||||
users = result.fetchall()
|
||||
|
||||
@@ -14,6 +14,9 @@ from app.routers import (
|
||||
graph_router,
|
||||
agent_router,
|
||||
todo_router,
|
||||
reminder_router,
|
||||
goal_router,
|
||||
schedule_center_router,
|
||||
settings_router,
|
||||
folder_router,
|
||||
skill_router,
|
||||
@@ -23,7 +26,7 @@ from app.routers import (
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user, ensure_builtin_skills
|
||||
from app.config import settings
|
||||
from app.logging_utils import (
|
||||
setup_logging,
|
||||
@@ -53,6 +56,7 @@ async def run_startup() -> None:
|
||||
await init_db()
|
||||
async with async_session() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
await ensure_builtin_skills(session)
|
||||
await persist_system_log(
|
||||
message="application_started",
|
||||
source="app",
|
||||
@@ -103,6 +107,9 @@ app.include_router(forum_router)
|
||||
app.include_router(graph_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(todo_router)
|
||||
app.include_router(reminder_router)
|
||||
app.include_router(goal_router)
|
||||
app.include_router(schedule_center_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(folder_router)
|
||||
app.include_router(skill_router)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from app.models.base import Base
|
||||
from app.models.user import User
|
||||
from app.models.folder import Folder
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.task import Task, TaskHistory
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
@@ -17,11 +18,14 @@ from app.models.brain import (
|
||||
brain_memory_sources,
|
||||
)
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.reminder import Reminder, ReminderStatus
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.log import Log, LogType, LogLevel
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"Folder",
|
||||
"Document",
|
||||
"DocumentChunk",
|
||||
"Task",
|
||||
@@ -45,6 +49,10 @@ __all__ = [
|
||||
"brain_memory_sources",
|
||||
"DailyTodo",
|
||||
"TodoSource",
|
||||
"Reminder",
|
||||
"ReminderStatus",
|
||||
"Goal",
|
||||
"GoalStatus",
|
||||
"Log",
|
||||
"LogType",
|
||||
"LogLevel",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -7,9 +7,10 @@ class Agent(BaseModel):
|
||||
__tablename__ = "agents"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
|
||||
role = Column(String(100), nullable=False) # master, schedule_planner, executor, librarian, analyst
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
selected_skill_ids = Column(JSON, default=list, nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
21
backend/app/models/goal.py
Normal file
21
backend/app/models/goal.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class GoalStatus(str, PyEnum):
|
||||
ACTIVE = "active"
|
||||
DONE = "done"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class Goal(BaseModel):
|
||||
__tablename__ = "goals"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
goal_date = Column(String(10), nullable=False, index=True)
|
||||
status = Column(Enum(GoalStatus), default=GoalStatus.ACTIVE, nullable=False)
|
||||
21
backend/app/models/reminder.py
Normal file
21
backend/app/models/reminder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class ReminderStatus(str, PyEnum):
|
||||
PENDING = "pending"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class Reminder(BaseModel):
|
||||
__tablename__ = "reminders"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
reminder_at = Column(DateTime, nullable=False, index=True)
|
||||
status = Column(Enum(ReminderStatus), default=ReminderStatus.PENDING, nullable=False)
|
||||
is_dismissed = Column(Boolean, default=False, nullable=False)
|
||||
@@ -9,11 +9,12 @@ class Skill(BaseModel):
|
||||
name = Column(String(100), nullable=False, unique=True, index=True)
|
||||
description = Column(Text, nullable=True) # 供 LLM 理解用途
|
||||
instructions = Column(Text, nullable=False) # Agent 执行时的指令模板
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/planner/executor/librarian/analyst
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/schedule_planner/executor/librarian/analyst
|
||||
tools = Column(JSON, default=list) # 引用的工具名称列表
|
||||
required_context = Column(JSON, default=list) # 需要的前置数据
|
||||
output_format = Column(Text, nullable=True) # 输出格式要求
|
||||
visibility = Column(String(20), default="private") # private/team/market
|
||||
is_builtin = Column(Boolean, default=False, nullable=False)
|
||||
team_id = Column(String(36), ForeignKey("users.id"), nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
owner_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
|
||||
@@ -6,6 +6,9 @@ from app.routers.forum import router as forum_router
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.settings import router as settings_router
|
||||
from app.routers.folder import router as folder_router
|
||||
from app.routers.skill import router as skill_router
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
||||
@@ -13,9 +14,9 @@ _agent_call_counts: dict[str, int] = {}
|
||||
_agent_current_tasks: dict[str, str | None] = {}
|
||||
_agent_statuses: dict[str, str] = {}
|
||||
|
||||
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
|
||||
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
|
||||
SUB_COMMANDERS_BY_ROLE = {
|
||||
"planner": ["planner_scope", "planner_steps"],
|
||||
"schedule_planner": ["schedule_analysis", "schedule_planning"],
|
||||
"executor": ["executor_tasks", "executor_forum"],
|
||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||
"analyst": ["analyst_progress", "analyst_insights"],
|
||||
@@ -88,10 +89,10 @@ async def get_agent_config(
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
defaults = {
|
||||
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
||||
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
|
||||
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
|
||||
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
||||
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
||||
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
||||
@@ -107,6 +108,7 @@ async def get_agent_config(
|
||||
system_prompt=prompt,
|
||||
enabled=True,
|
||||
is_active=True,
|
||||
selected_skill_ids=[],
|
||||
)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
@@ -116,6 +118,7 @@ async def get_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +144,19 @@ async def update_agent_config(
|
||||
if data.enabled is not None:
|
||||
agent.is_active = data.enabled
|
||||
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
||||
if data.selected_skill_ids is not None:
|
||||
if data.selected_skill_ids:
|
||||
result = await db.execute(
|
||||
select(Skill.id).where(
|
||||
Skill.id.in_(data.selected_skill_ids),
|
||||
Skill.owner_id == current_user.id,
|
||||
)
|
||||
)
|
||||
allowed_skill_ids = set(result.scalars().all())
|
||||
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
|
||||
if invalid_skill_ids:
|
||||
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
|
||||
agent.selected_skill_ids = data.selected_skill_ids
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
@@ -152,6 +168,7 @@ async def update_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
|
||||
from app.config import settings
|
||||
|
||||
@@ -58,6 +59,7 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
||||
await db.refresh(user)
|
||||
await ensure_builtin_skills(db)
|
||||
return user
|
||||
|
||||
|
||||
@@ -97,10 +99,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSessi
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
||||
await ensure_builtin_skills(db)
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
||||
await ensure_builtin_skills(db)
|
||||
return current_user
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
@@ -92,13 +93,16 @@ async def chat(
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
try:
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
# 更新对话消息计数
|
||||
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
||||
@@ -126,13 +130,17 @@ async def chat_stream(
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
async def stream_generator():
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
try:
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
conversation_id=data.conversation_id,
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
|
||||
92
backend/app/routers/goal.py
Normal file
92
backend/app/routers/goal.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.goal import GoalCreate, GoalListOut, GoalOut, GoalUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/goals", tags=["目标"])
|
||||
|
||||
|
||||
@router.get("", response_model=GoalListOut)
|
||||
async def list_goals(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str).isoformat()
|
||||
query = (
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id)
|
||||
.where(Goal.goal_date == target_date)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return GoalListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=GoalOut, status_code=201)
|
||||
async def create_goal(
|
||||
data: GoalCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
goal = Goal(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
goal_date=data.goal_date.isoformat(),
|
||||
status=data.status,
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.patch("/{goal_id}", response_model=GoalOut)
|
||||
async def update_goal(
|
||||
goal_id: str,
|
||||
data: GoalUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
payload = data.model_dump(exclude_none=True)
|
||||
if "goal_date" in payload:
|
||||
payload["goal_date"] = payload["goal_date"].isoformat()
|
||||
|
||||
for field, value in payload.items():
|
||||
setattr(goal, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.delete("/{goal_id}", status_code=204)
|
||||
async def delete_goal(
|
||||
goal_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
await db.delete(goal)
|
||||
await db.commit()
|
||||
90
backend/app/routers/reminder.py
Normal file
90
backend/app/routers/reminder.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.reminder import ReminderCreate, ReminderListOut, ReminderOut, ReminderUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/reminders", tags=["提醒"])
|
||||
|
||||
|
||||
@router.get("", response_model=ReminderListOut)
|
||||
async def list_reminders(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str)
|
||||
start = datetime.combine(target_date, datetime.min.time())
|
||||
end = datetime.combine(target_date, datetime.max.time())
|
||||
query = (
|
||||
select(Reminder)
|
||||
.where(Reminder.user_id == current_user.id)
|
||||
.where(Reminder.reminder_at >= start)
|
||||
.where(Reminder.reminder_at <= end)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return ReminderListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=ReminderOut, status_code=201)
|
||||
async def create_reminder(
|
||||
data: ReminderCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
reminder = Reminder(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
reminder_at=data.reminder_at,
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.patch("/{reminder_id}", response_model=ReminderOut)
|
||||
async def update_reminder(
|
||||
reminder_id: str,
|
||||
data: ReminderUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
for field, value in data.model_dump(exclude_none=True).items():
|
||||
setattr(reminder, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.delete("/{reminder_id}", status_code=204)
|
||||
async def delete_reminder(
|
||||
reminder_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
await db.delete(reminder)
|
||||
await db.commit()
|
||||
160
backend/app/routers/schedule_center.py
Normal file
160
backend/app/routers/schedule_center.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from calendar import monthrange
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.schedule_center import (
|
||||
ScheduleCenterDateOut,
|
||||
ScheduleCenterDaySummary,
|
||||
ScheduleCenterMonthOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/schedule-center", tags=["调度中心"])
|
||||
|
||||
|
||||
def _build_summary(
|
||||
target_date: str,
|
||||
todos: list[DailyTodo],
|
||||
tasks: list[Task],
|
||||
reminders: list[Reminder],
|
||||
goals: list[Goal],
|
||||
) -> ScheduleCenterDaySummary:
|
||||
return ScheduleCenterDaySummary(
|
||||
date=target_date,
|
||||
todo_total=len(todos),
|
||||
todo_completed=sum(1 for item in todos if item.is_completed),
|
||||
task_due_total=len(tasks),
|
||||
high_priority_total=sum(1 for item in tasks if item.priority in {TaskPriority.HIGH, TaskPriority.URGENT}),
|
||||
reminder_total=len(reminders),
|
||||
goal_total=len(goals),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/month", response_model=ScheduleCenterMonthOut)
|
||||
async def get_month_schedule(
|
||||
year: int = Query(..., ge=2000, le=2100),
|
||||
month: int = Query(..., ge=1, le=12),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
month_start = date(year, month, 1)
|
||||
days_in_month = monthrange(month_start.year, month_start.month)[1]
|
||||
start_key = month_start.isoformat()
|
||||
end_key = month_start.replace(day=days_in_month).isoformat()
|
||||
start_dt = datetime.combine(month_start, datetime.min.time())
|
||||
end_dt = datetime.combine(month_start.replace(day=days_in_month), datetime.max.time())
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo).where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date >= start_key, DailyTodo.todo_date <= end_key)
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task).where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal).where(Goal.user_id == current_user.id, Goal.goal_date >= start_key, Goal.goal_date <= end_key)
|
||||
)).scalars().all()
|
||||
|
||||
todo_map: dict[str, list[DailyTodo]] = {}
|
||||
for item in todos:
|
||||
todo_map.setdefault(item.todo_date, []).append(item)
|
||||
|
||||
task_map: dict[str, list[Task]] = {}
|
||||
for item in tasks:
|
||||
key = item.due_date.date().isoformat()
|
||||
task_map.setdefault(key, []).append(item)
|
||||
|
||||
reminder_map: dict[str, list[Reminder]] = {}
|
||||
for item in reminders:
|
||||
key = item.reminder_at.date().isoformat()
|
||||
reminder_map.setdefault(key, []).append(item)
|
||||
|
||||
goal_map: dict[str, list[Goal]] = {}
|
||||
for item in goals:
|
||||
goal_map.setdefault(item.goal_date, []).append(item)
|
||||
|
||||
days = []
|
||||
for day in range(1, days_in_month + 1):
|
||||
date_key = month_start.replace(day=day).isoformat()
|
||||
days.append(_build_summary(
|
||||
date_key,
|
||||
todo_map.get(date_key, []),
|
||||
task_map.get(date_key, []),
|
||||
reminder_map.get(date_key, []),
|
||||
goal_map.get(date_key, []),
|
||||
))
|
||||
|
||||
return ScheduleCenterMonthOut(month=f"{year:04d}-{month:02d}", days=days)
|
||||
|
||||
|
||||
@router.get("/date", response_model=ScheduleCenterDateOut)
|
||||
async def get_date_schedule(
|
||||
date_str: date = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date_str
|
||||
start_dt = datetime.combine(target_date, datetime.min.time())
|
||||
end_dt = datetime.combine(target_date, datetime.max.time())
|
||||
date_key = target_date.isoformat()
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)).scalars().all()
|
||||
|
||||
summary = _build_summary(date_key, todos, tasks, reminders, goals)
|
||||
return ScheduleCenterDateOut(
|
||||
date=date_key,
|
||||
todos=todos,
|
||||
tasks=tasks,
|
||||
reminders=reminders,
|
||||
goals=goals,
|
||||
summary=summary,
|
||||
generated_at=datetime.now(UTC),
|
||||
)
|
||||
@@ -1,11 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.skill import SkillCreate, SkillOut, SkillUpdate
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.skill_service import SkillService
|
||||
|
||||
router = APIRouter(prefix="/api/skills", tags=["Skill"])
|
||||
|
||||
@@ -37,13 +39,23 @@ async def create_skill(
|
||||
|
||||
@router.get("", response_model=list[SkillOut])
|
||||
async def list_skills(
|
||||
agent_type: str | None = Query(default=None),
|
||||
visibility: str | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Skill).where(Skill.owner_id == current_user.id).order_by(Skill.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id, agent_type=agent_type, visibility=visibility)
|
||||
|
||||
|
||||
@router.post("/bootstrap-builtin", response_model=list[SkillOut])
|
||||
async def bootstrap_builtin_skills(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await ensure_builtin_skills(db, preferred_owner_id=current_user.id)
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id)
|
||||
|
||||
|
||||
@router.get("/{skill_id}", response_model=SkillOut)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.user import User
|
||||
@@ -13,12 +15,28 @@ router = APIRouter(prefix="/api/tasks", tags=["看板"])
|
||||
@router.get("", response_model=list[TaskOut])
|
||||
async def list_tasks(
|
||||
status: TaskStatus | None = None,
|
||||
due_date: date | None = Query(default=None),
|
||||
date_from: date | None = Query(default=None),
|
||||
date_to: date | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Task).where(Task.user_id == current_user.id)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
if due_date:
|
||||
start = datetime.combine(due_date, datetime.min.time())
|
||||
end = datetime.combine(due_date, datetime.max.time())
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start, Task.due_date <= end)
|
||||
else:
|
||||
start = datetime.combine(date_from, datetime.min.time()) if date_from else None
|
||||
end = datetime.combine(date_to, datetime.max.time()) if date_to else None
|
||||
if start and end and start > end:
|
||||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||||
if start is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start)
|
||||
if end is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date <= end)
|
||||
query = query.order_by(desc(Task.created_at))
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
@@ -64,10 +82,10 @@ async def update_task(
|
||||
if field == "tags":
|
||||
setattr(task, field, json.dumps(value))
|
||||
elif field == "status" and value == TaskStatus.DONE:
|
||||
from datetime import UTC, datetime
|
||||
task.completed_at = datetime.now(UTC)
|
||||
setattr(task, field, value)
|
||||
else:
|
||||
elif field == "status":
|
||||
task.completed_at = None
|
||||
setattr(task, field, value)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from datetime import date
|
||||
from app.database import get_db
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
@@ -52,7 +53,7 @@ async def create_todo(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date=date.today().isoformat(),
|
||||
todo_date=(data.todo_date or date.today()).isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
@@ -74,14 +75,11 @@ async def update_todo(
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
|
||||
# 历史日期不允许修改
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可修改")
|
||||
|
||||
if data.title is not None:
|
||||
todo.title = data.title
|
||||
if data.todo_date is not None:
|
||||
todo.todo_date = data.todo_date.isoformat()
|
||||
if data.is_completed is not None:
|
||||
from datetime import UTC, datetime
|
||||
todo.is_completed = data.is_completed
|
||||
todo.completed_at = datetime.now(UTC) if data.is_completed else None
|
||||
|
||||
@@ -102,9 +100,6 @@ async def delete_todo(
|
||||
todo = result.scalar_one_or_none()
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可删除")
|
||||
|
||||
await db.delete(todo)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class AgentConfigUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
system_prompt: str | None = None
|
||||
enabled: bool | None = None
|
||||
selected_skill_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AgentConfigOut(BaseModel):
|
||||
@@ -51,5 +52,6 @@ class AgentConfigOut(BaseModel):
|
||||
system_prompt: str
|
||||
enabled: bool
|
||||
is_active: bool
|
||||
selected_skill_ids: list[str]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
35
backend/app/schemas/goal.py
Normal file
35
backend/app/schemas/goal.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.goal import GoalStatus
|
||||
|
||||
|
||||
class GoalCreate(BaseModel):
|
||||
title: str
|
||||
goal_date: date
|
||||
note: str | None = None
|
||||
status: GoalStatus = GoalStatus.ACTIVE
|
||||
|
||||
|
||||
class GoalUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
goal_date: date | None = None
|
||||
note: str | None = None
|
||||
status: GoalStatus | None = None
|
||||
|
||||
|
||||
class GoalOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
goal_date: str
|
||||
status: GoalStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class GoalListOut(BaseModel):
|
||||
items: list[GoalOut]
|
||||
40
backend/app/schemas/reminder.py
Normal file
40
backend/app/schemas/reminder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.reminder import ReminderStatus
|
||||
|
||||
|
||||
class ReminderCreate(BaseModel):
|
||||
title: str
|
||||
reminder_at: datetime
|
||||
note: str | None = None
|
||||
|
||||
|
||||
class ReminderUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
reminder_at: datetime | None = None
|
||||
note: str | None = None
|
||||
status: ReminderStatus | None = None
|
||||
is_dismissed: bool | None = None
|
||||
|
||||
|
||||
class ReminderOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
reminder_at: datetime
|
||||
status: ReminderStatus
|
||||
is_dismissed: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ReminderListOut(BaseModel):
|
||||
items: list[ReminderOut]
|
||||
|
||||
|
||||
class ReminderDateQuery(BaseModel):
|
||||
date: date
|
||||
33
backend/app/schemas/schedule_center.py
Normal file
33
backend/app/schemas/schedule_center.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas.goal import GoalOut
|
||||
from app.schemas.reminder import ReminderOut
|
||||
from app.schemas.task import TaskOut
|
||||
from app.schemas.todo import TodoOut
|
||||
|
||||
|
||||
class ScheduleCenterDaySummary(BaseModel):
|
||||
date: str
|
||||
todo_total: int
|
||||
todo_completed: int
|
||||
task_due_total: int
|
||||
high_priority_total: int
|
||||
reminder_total: int
|
||||
goal_total: int
|
||||
|
||||
|
||||
class ScheduleCenterMonthOut(BaseModel):
|
||||
month: str
|
||||
days: list[ScheduleCenterDaySummary]
|
||||
|
||||
|
||||
class ScheduleCenterDateOut(BaseModel):
|
||||
date: str
|
||||
todos: list[TodoOut]
|
||||
tasks: list[TaskOut]
|
||||
reminders: list[ReminderOut]
|
||||
goals: list[GoalOut]
|
||||
summary: ScheduleCenterDaySummary
|
||||
generated_at: datetime
|
||||
@@ -10,7 +10,8 @@ LLMType = Literal["chat", "vlm", "embedding", "rerank"]
|
||||
# 单个模型配置
|
||||
class LLMModelConfig(BaseModel):
|
||||
name: str = "" # 模型名称/别名,用于标识
|
||||
provider: LLMProviderType = "openai"
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
@@ -52,7 +53,8 @@ class SettingsOut(BaseModel):
|
||||
# 测试 LLM 连接请求
|
||||
class LLMTestIn(BaseModel):
|
||||
type: LLMType
|
||||
provider: LLMProviderType
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,7 +7,7 @@ class SkillCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
instructions: str
|
||||
agent_type: str # master/planner/executor/librarian/analyst
|
||||
agent_type: str # master/schedule_planner/executor/librarian/analyst
|
||||
tools: list[str] = []
|
||||
required_context: list[str] = []
|
||||
output_format: Optional[str] = None
|
||||
@@ -39,10 +40,11 @@ class SkillOut(BaseModel):
|
||||
required_context: list[str]
|
||||
output_format: Optional[str]
|
||||
visibility: str
|
||||
is_builtin: bool
|
||||
team_id: Optional[str]
|
||||
is_active: bool
|
||||
owner_id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.todo import TodoSource
|
||||
|
||||
|
||||
class TodoCreate(BaseModel):
|
||||
title: str
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
is_completed: bool | None = None
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoOut(BaseModel):
|
||||
|
||||
@@ -2,10 +2,87 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
BUILTIN_SKILLS = [
|
||||
{
|
||||
'name': '今日重点拆解',
|
||||
'description': '帮助日程规划师从上下文中提炼今天最值得推进的事项。',
|
||||
'instructions': '优先识别今天最关键的 1-3 个重点,说明原因,并给出可执行顺序。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '周计划编排',
|
||||
'description': '把本周目标整理成可落地的节奏与时间块。',
|
||||
'instructions': '将目标拆成周内节奏安排,明确先后顺序、时间块与缓冲。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '时间冲突分析',
|
||||
'description': '识别任务、日程与优先级之间的冲突。',
|
||||
'instructions': '分析冲突来源、影响和推荐取舍,必要时给出替代方案。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '任务执行 SOP',
|
||||
'description': '为执行角色提供标准执行步骤和结果回报格式。',
|
||||
'instructions': '执行前先确认目标与边界,执行中记录关键动作,执行后输出结果、风险与下一步。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['shell', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '外部交互推进',
|
||||
'description': '支持论坛、外部接口或内容发布类动作。',
|
||||
'instructions': '围绕外部交互任务,优先保证动作完整、结果清晰、反馈及时。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['api_calls', 'git'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '知识检索摘要',
|
||||
'description': '从知识中枢中提炼与当前问题最相关的信息。',
|
||||
'instructions': '检索后只保留当前决策需要的内容,输出摘要、来源与缺口。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['web_search', 'database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '图谱沉淀策略',
|
||||
'description': '帮助知识管理员把零散信息沉淀为结构化关系。',
|
||||
'instructions': '识别应沉淀的实体、关系与后续可检索维度。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '风险识别模板',
|
||||
'description': '帮助分析师快速识别当前推进中的风险点。',
|
||||
'instructions': '从进度、依赖、资源与外部信号中提炼风险,并按严重度排序。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '趋势洞察模板',
|
||||
'description': '把多源状态汇总为趋势与判断。',
|
||||
'instructions': '对比近期变化,输出趋势、证据、判断与建议动作。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'code_execution'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _is_bootstrap_enabled(settings) -> bool:
|
||||
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
|
||||
|
||||
@@ -58,3 +135,49 @@ async def ensure_admin_user(db: AsyncSession, settings) -> None:
|
||||
return
|
||||
raise
|
||||
await db.refresh(admin_user)
|
||||
|
||||
|
||||
async def ensure_builtin_skills(db: AsyncSession, preferred_owner_id: str | None = None) -> None:
|
||||
owner = None
|
||||
if preferred_owner_id:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.id == preferred_owner_id, User.is_active == True)
|
||||
)
|
||||
owner = owner_result.scalar_one_or_none()
|
||||
|
||||
if not owner:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.is_active == True).order_by(User.is_superuser.desc(), User.created_at.asc())
|
||||
)
|
||||
owner = owner_result.scalars().first()
|
||||
|
||||
if not owner:
|
||||
return
|
||||
|
||||
existing_result = await db.execute(select(Skill.name))
|
||||
existing_names = set(existing_result.scalars().all())
|
||||
|
||||
missing_skills = [
|
||||
Skill(
|
||||
owner_id=owner.id,
|
||||
name=item['name'],
|
||||
description=item['description'],
|
||||
instructions=item['instructions'],
|
||||
agent_type=item['agent_type'],
|
||||
tools=item['tools'],
|
||||
required_context=[],
|
||||
output_format=None,
|
||||
visibility=item['visibility'],
|
||||
is_builtin=True,
|
||||
team_id=None,
|
||||
is_active=True,
|
||||
)
|
||||
for item in BUILTIN_SKILLS
|
||||
if item['name'] not in existing_names
|
||||
]
|
||||
|
||||
if not missing_skills:
|
||||
return
|
||||
|
||||
db.add_all(missing_skills)
|
||||
await db.commit()
|
||||
|
||||
@@ -5,18 +5,17 @@ Jarvis Agent 服务层
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, AsyncGenerator
|
||||
import asyncio
|
||||
from openai import BadRequestError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
|
||||
from app.database import async_session
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
@@ -24,43 +23,35 @@ from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
||||
from app.agents.tools.time_reasoning import extract_reference_datetime
|
||||
from app.agents.state import initial_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||
error_text = str(error).lower()
|
||||
markers = [
|
||||
"invalid chat setting",
|
||||
"invalid params",
|
||||
"stream",
|
||||
"streaming",
|
||||
"unsupported",
|
||||
"bad_request_error",
|
||||
"http 400",
|
||||
"error code: 400",
|
||||
]
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
# 默认使用 OpenAI
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
if isinstance(error, BadRequestError):
|
||||
return (
|
||||
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
||||
and any(marker in error_text for marker in markers)
|
||||
)
|
||||
|
||||
return any(marker in error_text for marker in markers)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""对话 Agent 服务"""
|
||||
@@ -101,27 +92,18 @@ class AgentService:
|
||||
|
||||
llm_config = user.llm_config
|
||||
|
||||
# 如果指定了模型名称,查找对应的配置
|
||||
if model_name:
|
||||
for model_type in ["chat", "vlm"]:
|
||||
models = llm_config.get(model_type, [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
# 没找到,返回 None 让调用方知道配置不存在
|
||||
models = llm_config.get("chat", [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
return None
|
||||
|
||||
# 如果没指定模型名,返回默认启用的 chat 模型
|
||||
chat_models = llm_config.get("chat", [])
|
||||
for m in chat_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
vlm_models = llm_config.get("vlm", [])
|
||||
for m in vlm_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
return None
|
||||
|
||||
async def chat(
|
||||
@@ -134,11 +116,26 @@ class AgentService:
|
||||
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_stream)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if model_name and not user_llm_config:
|
||||
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
logger.info(
|
||||
"agent_chat_started",
|
||||
extra={
|
||||
"details": {
|
||||
"mode": "stream",
|
||||
"requested_model_name": model_name,
|
||||
"resolved_model_name": model_name_used,
|
||||
"message_length": len(message or ""),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
@@ -156,7 +153,6 @@ class AgentService:
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
@@ -168,7 +164,6 @@ class AgentService:
|
||||
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -193,156 +188,133 @@ class AgentService:
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model=model_name_used or "jarvis",
|
||||
attachments=None,
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
def _build_current_datetime_context() -> str:
|
||||
now_utc = datetime.now(UTC)
|
||||
return (
|
||||
"【当前时间】\n"
|
||||
f"- current_time_utc: {now_utc.isoformat()}\n"
|
||||
f"- current_date_utc: {now_utc.date().isoformat()}\n"
|
||||
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
||||
)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
async def run_agent():
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"current_sub_commander": None,
|
||||
"active_sub_commanders": [],
|
||||
"sub_commander_trace": [],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
current_datetime_context = _build_current_datetime_context()
|
||||
|
||||
# 使用 initial_state 构建状态
|
||||
state = initial_state(user_id, conversation_id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=full_message)],
|
||||
"memory_context": memory_ctx,
|
||||
"current_datetime_context": current_datetime_context,
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
})
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
kind = event.get("event")
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
try:
|
||||
async for event in graph.astream_events(state, version="v2"):
|
||||
kind = event.get("event")
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
|
||||
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"planner": ("planning", "Jarvis 正在拆解步骤"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map[event_name]
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
elif kind == "on_tool_start":
|
||||
tool_input = data.get("input")
|
||||
step = None
|
||||
if isinstance(tool_input, dict) and tool_input:
|
||||
step = f"调用工具 {event_name}"
|
||||
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
|
||||
elif kind == "on_tool_end":
|
||||
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
|
||||
elif kind == "on_chain_end" and event_name == "planner":
|
||||
output = data.get("output") or {}
|
||||
plan_steps = output.get("plan_steps") or []
|
||||
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
|
||||
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chat_model_end" and not collected:
|
||||
output = data.get("output")
|
||||
content = getattr(output, "content", "") if output else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected = content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
|
||||
except Exception as e:
|
||||
fallback = f"抱歉,发生错误: {str(e)}"
|
||||
collected = fallback
|
||||
yield {"type": "error", "error": str(e)}
|
||||
yield {"type": "chunk", "content": fallback}
|
||||
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
|
||||
elif kind == "on_tool_start":
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"Jarvis 正在调用工具 {event_name}",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=f"正在执行 {event_name}",
|
||||
)
|
||||
|
||||
elif kind == "on_tool_end":
|
||||
tool_result = data.get("output")
|
||||
step = f"已完成 {event_name}"
|
||||
if isinstance(tool_result, str) and len(tool_result) > 0:
|
||||
step = tool_result[:100]
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"工具 {event_name} 已完成",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
|
||||
elif kind == "on_chain_end" and event_name == "create_agent_graph":
|
||||
# 最终输出通常在这里
|
||||
output = data.get("output")
|
||||
if isinstance(output, dict) and "final_response" in output:
|
||||
final_resp = output["final_response"]
|
||||
# 如果还没流式输出完整,补全它
|
||||
if final_resp and not collected:
|
||||
collected = final_resp
|
||||
yield {"type": "chunk", "content": collected}
|
||||
|
||||
except Exception as e:
|
||||
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
||||
try:
|
||||
result_state = await graph.ainvoke(state)
|
||||
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
collected = str(fallback_content)
|
||||
yield {"type": "chunk", "content": collected}
|
||||
except Exception as fallback_error:
|
||||
logger.exception("llm_sync_fallback_failed")
|
||||
yield {"type": "error", "error": "模型服务暂不可用。"}
|
||||
else:
|
||||
logger.exception("agent_streaming_failed")
|
||||
yield {"type": "error", "error": str(e)}
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
|
||||
|
||||
# 最终更新数据库中的消息内容
|
||||
if collected:
|
||||
try:
|
||||
result2 = await self.db.execute(
|
||||
select(Message).where(Message.id == assistant_msg.id)
|
||||
)
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=collected[:500],
|
||||
raw_excerpt=collected[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
async with async_session() as session:
|
||||
result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id))
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("save_assistant_message_failed")
|
||||
|
||||
return conversation_id, assistant_msg.id, run_agent()
|
||||
|
||||
@@ -355,117 +327,44 @@ class AgentService:
|
||||
model_name: str | None = None,
|
||||
) -> tuple[str, str, str, str | None]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content, model_name_used)
|
||||
简单同步版对话
|
||||
"""
|
||||
# 获取或创建对话
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
else:
|
||||
conv = None
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
# 将文件上下文添加到消息
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="User message",
|
||||
content_summary=message[:500],
|
||||
raw_excerpt=message[:2000],
|
||||
metadata_={"role": "user"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
# 获取用户配置的 LLM
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
|
||||
}
|
||||
if not conversation_id:
|
||||
conv = Conversation(user_id=user_id, title=message[:50])
|
||||
self.db.add(conv)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(conv)
|
||||
conversation_id = conv.id
|
||||
|
||||
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
|
||||
self.db.add(user_msg)
|
||||
|
||||
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
||||
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
result_state = await graph.ainvoke(langgraph_state)
|
||||
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
|
||||
graph = get_agent_graph()
|
||||
state = initial_state(user_id, conversation_id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"memory_context": memory_ctx,
|
||||
"current_datetime_context": datetime.now(UTC).isoformat(),
|
||||
"user_llm_config": user_llm_config,
|
||||
})
|
||||
|
||||
result_state = await graph.ainvoke(state)
|
||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
except Exception as e:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
logger.exception("agent_chat_simple_failed")
|
||||
response_content = "抱歉,发生错误。"
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存助手消息
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
@@ -474,19 +373,5 @@ class AgentService:
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=response_content[:500],
|
||||
raw_excerpt=response_content[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content, model_name_used
|
||||
|
||||
@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Literal
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
@@ -16,8 +17,131 @@ from app.models.user import User
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
ToolStrategy = Literal["native", "json_fallback"]
|
||||
|
||||
|
||||
def _resolve_effective_base_url(config: dict | None) -> str:
|
||||
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
base_url = str((config or {}).get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
if provider in {"openai", "custom", "deepseek"}:
|
||||
return settings.OPENAI_BASE_URL
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_BASE_URL
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderCapabilities:
|
||||
provider: str
|
||||
supports_native_tools: bool
|
||||
preferred_tool_strategy: ToolStrategy
|
||||
|
||||
|
||||
def default_provider_capabilities() -> ProviderCapabilities:
|
||||
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
|
||||
|
||||
|
||||
def normalize_provider_name(config: dict | None) -> str:
|
||||
provider_raw = str((config or {}).get("provider") or "").strip().lower()
|
||||
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
model = str((config or {}).get("model") or "").strip().lower()
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
|
||||
# base_url-first inference (provider may be omitted in user config)
|
||||
if base_url:
|
||||
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
|
||||
return "ollama"
|
||||
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
|
||||
return "claude"
|
||||
if "api.deepseek.com" in base_url:
|
||||
return "deepseek"
|
||||
|
||||
# Many "openai-compatible" endpoints are configured as provider=openai.
|
||||
# We treat them as distinct providers so capability routing can stay conservative.
|
||||
if provider in {"openai", "custom"}:
|
||||
if any(key in model or key in base_url for key in {"minimax", "abab"}):
|
||||
return "minimax"
|
||||
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
|
||||
return "kimi"
|
||||
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
|
||||
return "qwen"
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
|
||||
provider = normalize_provider_name(config)
|
||||
|
||||
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
|
||||
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
|
||||
native_tool_providers = {"openai", "deepseek", "claude"}
|
||||
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
is_official_openai = (
|
||||
provider != "openai"
|
||||
or not base_url
|
||||
or "api.openai.com" in base_url
|
||||
or "openai.azure.com" in base_url
|
||||
)
|
||||
|
||||
if provider in native_tool_providers and is_official_openai:
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=True,
|
||||
preferred_tool_strategy="native",
|
||||
)
|
||||
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_from_config(config: dict | None):
|
||||
"""根据用户模型配置创建底层 LangChain LLM 实例"""
|
||||
if not config:
|
||||
return get_llm()
|
||||
|
||||
provider = normalize_provider_name(config)
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
setattr(llm, "_jarvis_user_llm_config", config)
|
||||
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
|
||||
return llm
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
|
||||
return _llm_instance
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,154 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
Jarvis 记忆系统 (基于 Mem0)
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
from app.config import settings as _settings
|
||||
|
||||
try:
|
||||
from mem0 import Memory
|
||||
|
||||
MEM0_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEM0_AVAILABLE = False
|
||||
Memory = None
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
embedding_models = user.llm_config.get("embedding", [])
|
||||
for model in embedding_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
||||
"api_key": model.get("api_key")
|
||||
or _settings.EMBEDDING_API_KEY
|
||||
or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 chat 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
chat_models = user.llm_config.get("chat", [])
|
||||
for model in chat_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
||||
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 客户端 - 按用户隔离"""
|
||||
|
||||
_instances: dict[str, Memory] = {}
|
||||
_persist_dir: str = "./data/mem0"
|
||||
|
||||
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
cache_key = user_id
|
||||
|
||||
if cache_key not in self._instances:
|
||||
self._instances[cache_key] = await self._init_memory(db, user_id)
|
||||
|
||||
return self._instances[cache_key]
|
||||
|
||||
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
if not MEM0_AVAILABLE:
|
||||
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
||||
|
||||
os.makedirs(self._persist_dir, exist_ok=True)
|
||||
|
||||
llm_config = {
|
||||
"model": _settings.OPENAI_MODEL,
|
||||
"base_url": _settings.OPENAI_BASE_URL,
|
||||
"api_key": _settings.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
embed_config = _settings.EMBEDDING_MODEL
|
||||
embed_base_url = _settings.EMBEDDING_BASE_URL
|
||||
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
||||
|
||||
if db and user_id:
|
||||
try:
|
||||
user_chat = await _get_user_chat_config(db, user_id)
|
||||
if user_chat:
|
||||
llm_config = user_chat
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
user_embed = await _get_user_embedding_config(db, user_id)
|
||||
if user_embed:
|
||||
embed_config = user_embed["model"]
|
||||
embed_base_url = user_embed["base_url"]
|
||||
embed_api_key = user_embed["api_key"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": f"jarvis_memory_{user_id}",
|
||||
"path": self._persist_dir,
|
||||
},
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": llm_config["model"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"base_url": llm_config["base_url"],
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": embed_config,
|
||||
"api_key": embed_api_key,
|
||||
"base_url": embed_base_url,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return Memory.from_config(config)
|
||||
|
||||
|
||||
_mem0_client = Mem0Client()
|
||||
|
||||
|
||||
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
return await _mem0_client.get_memory(db, user_id)
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
@@ -36,8 +167,7 @@ async def load_conversation_history(
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
@@ -47,14 +177,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
SUMMARIZE_THRESHOLD = 8
|
||||
MAX_HISTORY_TURNS = 10
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -72,17 +203,21 @@ async def generate_summary(
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
"""生成对话摘要"""
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
||||
llm = get_llm()
|
||||
response = await llm.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
||||
),
|
||||
HumanMessage(content=history_text),
|
||||
]
|
||||
)
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
@@ -92,8 +227,10 @@ async def save_summary(
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
) -> Any:
|
||||
"""保存对话摘要到数据库"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -109,8 +246,10 @@ async def save_summary(
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
) -> list[Any]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -119,31 +258,7 @@ async def get_summaries(
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
# ———— 长期记忆: 基于 Mem0 ————
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
@@ -151,55 +266,34 @@ async def extract_user_memories(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从对话中提取用户记忆并存储到 Mem0。
|
||||
Mem0 会自动处理:
|
||||
- 事实提取
|
||||
- 时间线追踪
|
||||
- 矛盾解决
|
||||
- 遗忘机制
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
)
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.add(
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
metadata={
|
||||
"conversation_id": conversation_id,
|
||||
"source": "jarvis_memory",
|
||||
},
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
return result.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 extract error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
@@ -207,41 +301,45 @@ async def recall_user_memories(
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
results = mem0.search(
|
||||
query=query,
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
return results.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
"""
|
||||
获取用户画像。
|
||||
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.history(user_id=user_id)
|
||||
return {
|
||||
"memories": result.get("results", []),
|
||||
"static": [],
|
||||
"dynamic": [],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Mem0 profile error: {e}")
|
||||
return {"memories": [], "static": [], "dynamic": []}
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -254,25 +352,22 @@ async def build_memory_context(
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
for m in memories:
|
||||
memory_text = m.get("memory", m.get("text", ""))
|
||||
if memory_text:
|
||||
lines.append(f" - {memory_text}")
|
||||
if lines:
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
|
||||
# 3. 知识大脑(长期项目记忆)
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
@@ -292,7 +387,7 @@ async def try_auto_summarize(
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
同时将对话内容存入 Mem0 进行记忆提取。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
@@ -306,8 +401,39 @@ async def try_auto_summarize(
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Auto summarize error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
||||
"""
|
||||
主动遗忘某条记忆。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.delete(memory_id, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 delete error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_memory(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
memory_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""
|
||||
更新某条记忆。Mem0 会自动处理矛盾检测。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.update(memory_id, content, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 update error: {e}")
|
||||
return False
|
||||
|
||||
@@ -99,46 +99,55 @@ async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession)
|
||||
|
||||
|
||||
async def test_llm_connection(
|
||||
provider: str,
|
||||
provider: str | None,
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str
|
||||
api_key: str,
|
||||
) -> dict:
|
||||
"""测试 LLM 连接"""
|
||||
try:
|
||||
# base_url-first: provider 可省略
|
||||
from app.services.llm_service import normalize_provider_name
|
||||
|
||||
effective_provider = normalize_provider_name({
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"base_url": base_url,
|
||||
})
|
||||
|
||||
# 根据不同 provider 创建临时 LLM 实例并测试
|
||||
if provider == "openai":
|
||||
if effective_provider in {"openai", "custom", "minimax", "kimi", "qwen"}:
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "claude":
|
||||
elif effective_provider == "claude":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
elif effective_provider == "ollama":
|
||||
from langchain_ollama import ChatOllama
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
elif effective_provider == "deepseek":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or "https://api.deepseek.com/v1",
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 provider: {provider}"}
|
||||
return {"success": False, "error": f"不支持的 endpoint/provider: {effective_provider}"}
|
||||
|
||||
# 简单测试调用
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -50,28 +50,22 @@ class SkillService:
|
||||
"""
|
||||
列出用户可访问的技能:自己的 + 市场的 + 团队的
|
||||
"""
|
||||
# 查询条件:自己的 或者 市场公开的 或者 团队的
|
||||
conditions = [
|
||||
access_scope = or_(
|
||||
Skill.owner_id == user_id,
|
||||
Skill.visibility == "market",
|
||||
Skill.team_id == user_id,
|
||||
]
|
||||
|
||||
# 如果提供了 agent_type 过滤
|
||||
if agent_type:
|
||||
conditions.append(Skill.agent_type == agent_type)
|
||||
|
||||
# 如果提供了 visibility 过滤
|
||||
if visibility:
|
||||
conditions.append(Skill.visibility == visibility)
|
||||
|
||||
query = select(Skill).where(
|
||||
and_(
|
||||
or_(*conditions),
|
||||
Skill.is_active == True
|
||||
)
|
||||
)
|
||||
|
||||
filters = [access_scope, Skill.is_active == True]
|
||||
|
||||
if agent_type:
|
||||
filters.append(Skill.agent_type == agent_type)
|
||||
|
||||
if visibility:
|
||||
filters.append(Skill.visibility == visibility)
|
||||
|
||||
query = select(Skill).where(and_(*filters))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from datetime import datetime, UTC
|
||||
from time import monotonic
|
||||
import platform
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import psutil
|
||||
@@ -7,21 +11,119 @@ except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fa
|
||||
|
||||
|
||||
class SystemService:
|
||||
_last_net_bytes_sent: int | None = None
|
||||
_last_net_bytes_recv: int | None = None
|
||||
_last_net_sample_at: float | None = None
|
||||
|
||||
def _get_network_rates(self) -> tuple[float, float]:
|
||||
counters = psutil.net_io_counters()
|
||||
now = monotonic()
|
||||
|
||||
if (
|
||||
self.__class__._last_net_sample_at is None
|
||||
or self.__class__._last_net_bytes_sent is None
|
||||
or self.__class__._last_net_bytes_recv is None
|
||||
):
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
return 0.0, 0.0
|
||||
|
||||
elapsed = max(now - self.__class__._last_net_sample_at, 1e-6)
|
||||
upload_bps = max(counters.bytes_sent - self.__class__._last_net_bytes_sent, 0) / elapsed
|
||||
download_bps = max(counters.bytes_recv - self.__class__._last_net_bytes_recv, 0) / elapsed
|
||||
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
|
||||
return round(upload_bps, 1), round(download_bps, 1)
|
||||
|
||||
def _get_gpu_status(self) -> dict:
|
||||
empty = {
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
}
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
'nvidia-smi',
|
||||
'--query-gpu=name,memory.total,memory.used,utilization.gpu',
|
||||
'--format=csv,noheader,nounits',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
timeout=2,
|
||||
check=False,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.SubprocessError, OSError):
|
||||
return empty
|
||||
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
return empty
|
||||
|
||||
first_line = result.stdout.strip().splitlines()[0]
|
||||
parts = [part.strip() for part in first_line.split(',')]
|
||||
if len(parts) < 4:
|
||||
return empty
|
||||
|
||||
def parse_number(value: str) -> float | None:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return {
|
||||
'gpu_name': parts[0] or None,
|
||||
'gpu_memory_total_mb': parse_number(parts[1]),
|
||||
'gpu_memory_used_mb': parse_number(parts[2]),
|
||||
'gpu_util_percent': parse_number(parts[3]),
|
||||
}
|
||||
|
||||
def get_status(self) -> dict:
|
||||
if psutil is None:
|
||||
return {
|
||||
'cpu_percent': 0.0,
|
||||
'memory_percent': 0.0,
|
||||
'disk_percent': 0.0,
|
||||
'disk_used_gb': 0.0,
|
||||
'disk_total_gb': 0.0,
|
||||
'network_upload_bps': 0.0,
|
||||
'network_download_bps': 0.0,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': 0.0,
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
upload_bps, download_bps = self._get_network_rates()
|
||||
gpu_status = self._get_gpu_status()
|
||||
boot_time = psutil.boot_time()
|
||||
now_ts = datetime.now(UTC).timestamp()
|
||||
return {
|
||||
'cpu_percent': round(cpu_percent, 1),
|
||||
'memory_percent': round(memory.percent, 1),
|
||||
'disk_percent': round(disk.percent, 1),
|
||||
'disk_used_gb': round(disk.used / (1024 ** 3), 1),
|
||||
'disk_total_gb': round(disk.total / (1024 ** 3), 1),
|
||||
'network_upload_bps': upload_bps,
|
||||
'network_download_bps': download_bps,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': round(max(now_ts - boot_time, 0.0), 1),
|
||||
**gpu_status,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
124
backend/app/services/web_search_service.py
Normal file
124
backend/app/services/web_search_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
source: str | None = None
|
||||
published_at: str | None = None
|
||||
|
||||
|
||||
class WebSearchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchConfigurationError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchRequestError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool | None = None,
|
||||
provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_limit: int | None = None,
|
||||
timeout_seconds: int | None = None,
|
||||
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
|
||||
auth_token: str | None = None,
|
||||
basic_user: str | None = None,
|
||||
basic_password: str | None = None,
|
||||
):
|
||||
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
|
||||
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
|
||||
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
|
||||
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
|
||||
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
|
||||
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
|
||||
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
|
||||
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
|
||||
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
|
||||
|
||||
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
|
||||
normalized_query = (query or '').strip()
|
||||
if not self.enabled or not self.base_url:
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
if self.provider != 'searxng':
|
||||
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
|
||||
if not normalized_query:
|
||||
raise WebSearchRequestError('搜索关键词不能为空')
|
||||
|
||||
parsed = urlparse(self.base_url)
|
||||
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
|
||||
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
|
||||
|
||||
params = {
|
||||
'q': normalized_query,
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
}
|
||||
headers = self._build_headers()
|
||||
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except httpx.HTTPError as exc:
|
||||
raise WebSearchRequestError('SearxNG 请求失败') from exc
|
||||
except ValueError as exc:
|
||||
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
|
||||
|
||||
raw_results = payload.get('results') if isinstance(payload, dict) else None
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
target_limit = max(1, min(limit or self.default_limit, 10))
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
title = str(item.get('title') or '').strip()
|
||||
url = str(item.get('url') or '').strip()
|
||||
snippet = str(item.get('content') or item.get('snippet') or '').strip()
|
||||
if not title or not url:
|
||||
continue
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet,
|
||||
source=str(item.get('engine') or item.get('source') or '').strip() or None,
|
||||
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
|
||||
)
|
||||
)
|
||||
if len(results) >= target_limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
if self.auth_type == 'bearer' and self.auth_token:
|
||||
return {'Authorization': f'Bearer {self.auth_token}'}
|
||||
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
|
||||
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
|
||||
request = httpx.Request('GET', self.base_url)
|
||||
credentials.auth_flow(request)
|
||||
return dict(request.headers)
|
||||
return {}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1 +0,0 @@
|
||||
bad
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -27,6 +27,9 @@ dependencies = [
|
||||
"llama-index-vector-stores-chroma>=0.3.0",
|
||||
"chromadb>=0.5.0",
|
||||
|
||||
# Memory
|
||||
"mem0ai>=1.0.0",
|
||||
|
||||
# 数据库
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
|
||||
@@ -1,9 +1,122 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.agents.graph import master_node
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.graph import (
|
||||
_choose_sub_commander,
|
||||
_parse_json_action,
|
||||
_route_agent_from_user_query,
|
||||
_run_sub_commander,
|
||||
master_node,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
from app.agents.state import AgentRole
|
||||
|
||||
|
||||
|
||||
|
||||
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
||||
return {
|
||||
'messages': [HumanMessage(content=message)],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'current_sub_commander': None,
|
||||
'active_sub_commanders': [],
|
||||
'sub_commander_trace': [],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'action_results': [],
|
||||
'created_entities': [],
|
||||
'tool_strategy_used': None,
|
||||
'provider_capabilities': None,
|
||||
'fallback_parse_error': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'schedule_context_summary': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'current_datetime_context': 'CURRENT_TIME: 2026-03-28T12:00:00+08:00',
|
||||
'current_datetime_reference': {'current_time_iso': '2026-03-28T12:00:00+08:00', 'current_date_iso': '2026-03-28', 'timezone': 'UTC'},
|
||||
'user_llm_config': user_llm_config,
|
||||
}
|
||||
|
||||
|
||||
class FakeFallbackLLM:
|
||||
def __init__(self, first_content: str, followup_content: str = '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'):
|
||||
self.first_content = first_content
|
||||
self.followup_content = followup_content
|
||||
self.calls = 0
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return AIMessage(content=self.first_content)
|
||||
return AIMessage(content=self.followup_content)
|
||||
|
||||
def bind_tools(self, tools):
|
||||
raise AssertionError('bind_tools should not be called in JSON fallback mode')
|
||||
|
||||
|
||||
class FakeNativeBoundLLM:
|
||||
async def ainvoke(self, messages):
|
||||
return AIMessage(
|
||||
content='',
|
||||
tool_calls=[
|
||||
{
|
||||
'id': 'call_1',
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FakeNativeLLM:
|
||||
def __init__(self):
|
||||
self.bound = FakeNativeBoundLLM()
|
||||
self.tool_binding_count = 0
|
||||
self.calls = 0
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(provider='openai', supports_native_tools=True, preferred_tool_strategy='native')
|
||||
|
||||
def bind_tools(self, tools):
|
||||
self.tool_binding_count += 1
|
||||
return self.bound
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
return AIMessage(content='已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。')
|
||||
|
||||
|
||||
class FakeTool:
|
||||
def __init__(self, name: str, result: str):
|
||||
self.name = name
|
||||
self.result = result
|
||||
self.invocations: list[dict] = []
|
||||
|
||||
def invoke(self, args: dict):
|
||||
self.invocations.append(args)
|
||||
return self.result
|
||||
|
||||
|
||||
class CapturingLLM:
|
||||
def __init__(self, content: str = '{"mode":"final","final_response":"好的。"}'):
|
||||
self.content = content
|
||||
self.messages = None
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(provider='ollama', supports_native_tools=False, preferred_tool_strategy='json_fallback')
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.messages = messages
|
||||
return AIMessage(content=self.content)
|
||||
|
||||
|
||||
class FailIfCalledLLM:
|
||||
async def ainvoke(self, messages):
|
||||
raise AssertionError('LLM should not be called for simple greetings')
|
||||
@@ -71,6 +184,68 @@ async def test_master_node_returns_stable_reply_for_identity_question(monkeypatc
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_identity_question_with_punctuation(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
state = {
|
||||
'messages': [HumanMessage(content='你是谁?')],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'user_llm_config': None,
|
||||
}
|
||||
|
||||
result = await master_node(state)
|
||||
|
||||
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_identity_question_with_particle(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
state = {
|
||||
'messages': [HumanMessage(content='你是谁啊')],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'user_llm_config': None,
|
||||
}
|
||||
|
||||
result = await master_node(state)
|
||||
|
||||
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_capability_question(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
@@ -100,3 +275,196 @@ async def test_master_node_returns_stable_reply_for_capability_question(monkeypa
|
||||
assert result['final_response'] == '主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_schedule_requests_to_schedule_planning():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '帮我安排一下这周计划') == 'schedule_planning'
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_focus_requests_to_schedule_analysis():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '基于最近对话帮我判断该聚焦什么') == 'schedule_analysis'
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_knowledge_requests_to_librarian():
|
||||
assert _route_agent_from_user_query('帮我搜索知识库里的项目资料') == AgentRole.LIBRARIAN
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_schedule_requests_to_schedule_planner():
|
||||
assert _route_agent_from_user_query('明天提醒我开会') == AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_explicit_month_day_milestone_to_schedule_planner():
|
||||
assert _route_agent_from_user_query('3月29日,对话系统交付节点') == AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_explicit_month_day_milestone_to_schedule_planning():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '3月29日,对话系统交付节点') == 'schedule_planning'
|
||||
|
||||
|
||||
|
||||
|
||||
def test_parse_json_action_extracts_tool_calls_from_fenced_json():
|
||||
parsed = _parse_json_action(
|
||||
'```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```',
|
||||
['create_reminder'],
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
'mode': 'tool_call',
|
||||
'tool_calls': [
|
||||
{
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
'reason': None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_parse_json_action_returns_none_for_invalid_or_unknown_payload():
|
||||
assert _parse_json_action('not json', ['create_reminder']) is None
|
||||
assert _parse_json_action('{"mode":"tool_call","tool_calls":[{"name":"unknown","arguments":{}}]}', ['create_reminder']) is None
|
||||
|
||||
|
||||
def test_parse_json_action_tolerates_prefix_and_suffix_text():
|
||||
parsed = _parse_json_action(
|
||||
'好的,下面是 JSON:\n```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```\n谢谢',
|
||||
['create_reminder'],
|
||||
)
|
||||
assert parsed is not None
|
||||
assert parsed['mode'] == 'tool_call'
|
||||
assert parsed['tool_calls'][0]['name'] == 'create_reminder'
|
||||
|
||||
|
||||
def test_parse_json_action_accepts_parameters_alias_for_tool_calls():
|
||||
parsed = _parse_json_action(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","parameters":{"title":"收被子","reminder_at":"2026-03-29T09:00:00+08:00"}}]}',
|
||||
['create_reminder'],
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
'mode': 'tool_call',
|
||||
'tool_calls': [
|
||||
{
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '收被子', 'reminder_at': '2026-03-29T09:00:00+08:00'},
|
||||
'reason': None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_run_sub_commander_uses_json_fallback_for_non_native_provider(monkeypatch):
|
||||
fake_llm = FakeFallbackLLM(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}'
|
||||
)
|
||||
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'schedule_planning',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'json_fallback'
|
||||
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
|
||||
assert result['tool_calls'][0]['name'] == 'create_reminder'
|
||||
assert result['created_entities'][0]['type'] == 'reminder'
|
||||
assert result['fallback_parse_error'] is None
|
||||
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'
|
||||
|
||||
|
||||
async def test_run_sub_commander_includes_current_datetime_context_in_system_messages(monkeypatch):
|
||||
fake_llm = CapturingLLM('{"mode":"final","final_response":"好的。"}')
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
state['current_datetime_context'] = 'CURRENT_TIME: 2026-03-28T12:00:00+08:00'
|
||||
|
||||
await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert fake_llm.messages is not None
|
||||
assert any(
|
||||
getattr(m, 'type', None) == 'system' and 'CURRENT_TIME:' in str(getattr(m, 'content', ''))
|
||||
for m in fake_llm.messages
|
||||
)
|
||||
|
||||
|
||||
async def test_run_sub_commander_uses_web_search_in_json_fallback(monkeypatch):
|
||||
fake_llm = FakeFallbackLLM(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"web_search","arguments":{"query":"Jarvis 最新模型更新","top_k":2}}]}',
|
||||
'我查了外部网页,下面是最新结果摘要。',
|
||||
)
|
||||
fake_tool = FakeTool('web_search', '成功搜索到 2 条网页结果')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'librarian_retrieval',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('帮我上网查一下 Jarvis 最新模型更新', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.LIBRARIAN
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.LIBRARIAN,
|
||||
'manager prompt',
|
||||
'帮我上网查一下 Jarvis 最新模型更新',
|
||||
use_tools=True,
|
||||
summary_target='knowledge_context',
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'json_fallback'
|
||||
assert fake_tool.invocations == [{'query': 'Jarvis 最新模型更新', 'top_k': 2}]
|
||||
assert result['tool_calls'][0]['name'] == 'web_search'
|
||||
assert result['last_tool_result'] == '[web_search] 成功搜索到 2 条网页结果'
|
||||
assert result['final_response'] == '我查了外部网页,下面是最新结果摘要。'
|
||||
|
||||
|
||||
fake_llm = FakeNativeLLM()
|
||||
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'schedule_planning',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'openai', 'model': 'gpt-4o'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'native'
|
||||
assert fake_llm.tool_binding_count == 1
|
||||
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
|
||||
assert result['created_entities'][0]['type'] == 'reminder'
|
||||
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'
|
||||
|
||||
49
backend/tests/backend/app/agents/test_search_tools.py
Normal file
49
backend/tests/backend/app/agents/test_search_tools.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.tools.search import web_search
|
||||
|
||||
|
||||
class FakeResult(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
def test_web_search_tool_formats_results(monkeypatch):
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
assert query == 'Jarvis 最新更新'
|
||||
assert limit == 2
|
||||
return [
|
||||
FakeResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis 最新更新', top_k=2)
|
||||
|
||||
assert '[1] Jarvis release notes' in result
|
||||
assert '链接: https://example.com/jarvis-release' in result
|
||||
assert '来源: duckduckgo' in result
|
||||
assert '时间: 2026-03-29' in result
|
||||
assert '摘要: Latest Jarvis changes.' in result
|
||||
|
||||
|
||||
def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
|
||||
from app.services.web_search_service import WebSearchConfigurationError
|
||||
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis')
|
||||
|
||||
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
||||
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tool_env(tmp_path):
|
||||
db_path = tmp_path / "test_task_tools.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# 只创建本测试需要的表,避免全量 metadata 引入未注册的外键表。
|
||||
await conn.run_sync(User.metadata.create_all, tables=[
|
||||
User.__table__,
|
||||
Task.__table__,
|
||||
DailyTodo.__table__,
|
||||
Reminder.__table__,
|
||||
Goal.__table__,
|
||||
])
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username="tool_user",
|
||||
email="tool@example.com",
|
||||
hashed_password=get_password_hash("secret123"),
|
||||
full_name="Tool Tester",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
try:
|
||||
yield {"session_factory": session_factory, "user_id": user.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_content_and_date_aliases_and_persists_task(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 0, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_task_accepts_content_and_date_aliases_and_sets_morning_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_schedule_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 9, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("priority_input", "expected"),
|
||||
[
|
||||
(1, TaskPriority.LOW),
|
||||
(2, TaskPriority.MEDIUM),
|
||||
(3, TaskPriority.HIGH),
|
||||
(4, TaskPriority.URGENT),
|
||||
("urgent", TaskPriority.URGENT),
|
||||
],
|
||||
)
|
||||
async def test_create_task_normalizes_legacy_and_string_priorities(tool_env, monkeypatch, priority_input, expected):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title=f"priority-{priority_input}", priority=priority_input)
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].priority == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_iso_datetime_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="timed task", due_date="2026-03-28T15:30:00Z")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.due_date == datetime(2026, 3, 28, 15, 30, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_missing_title_and_content(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func()
|
||||
|
||||
assert result == "创建任务失败: title 不能为空"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_invalid_priority(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="bad priority", priority="top")
|
||||
|
||||
assert "创建任务失败:" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_status_rejects_invalid_status(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
create_result = task_tools.create_task.func(title="status test")
|
||||
assert "任务创建成功" in create_result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_filters_by_normalized_status_and_formats_values(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
task_tools.create_task.func(title="todo task", priority="high")
|
||||
task_tools.create_task.func(title="done task", priority="low")
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
rows[1].status = TaskStatus.DONE
|
||||
await session.commit()
|
||||
|
||||
result = task_tools.get_tasks.func(status="done")
|
||||
|
||||
assert "done task" in result
|
||||
assert "todo task" not in result
|
||||
assert "状态:done" in result
|
||||
assert "优先级:low" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_accepts_datetime_description_and_at_aliases(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
description="提醒收被子",
|
||||
datetime="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Reminder))).scalar_one()
|
||||
|
||||
assert saved.title == "收被子"
|
||||
assert saved.note == "提醒收被子"
|
||||
assert saved.reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
datetime="2026-03-29T09:00:00+08:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
time="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
remind_at="2026-03-29T18:00:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 18, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_returns_failure_when_time_aliases_missing(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(title="收被子")
|
||||
|
||||
assert result == "创建提醒失败: reminder_at 不能为空"
|
||||
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.agents.tools.time_reasoning import (
|
||||
extract_reference_datetime,
|
||||
normalize_tool_time_arguments,
|
||||
resolve_time_expression_data,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_reference_datetime_from_current_time_context():
|
||||
context = '【当前时间】\n- current_time_utc: 2026-03-28T12:00:00+00:00\n- current_date_utc: 2026-03-28\n说明:解析相对时间时请以 current_time_utc 为准。'
|
||||
|
||||
result = extract_reference_datetime(context)
|
||||
|
||||
assert result == datetime(2026, 3, 28, 12, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_datetime():
|
||||
payload = resolve_time_expression_data(
|
||||
'明天早上9点',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['grain'] == 'datetime'
|
||||
assert payload['resolved_date'] == '2026-03-29'
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00'
|
||||
assert payload['assumed_time'] is False
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_date_window():
|
||||
payload = resolve_time_expression_data(
|
||||
'下周一下午',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_date'] == '2026-03-30'
|
||||
assert payload['resolved_datetime'] == '2026-03-30T15:00:00'
|
||||
assert payload['assumed_time'] is True
|
||||
assert 'assumed_time' in payload['reason']
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_reminder_time_aliases():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['reminder_at'] == '2026-03-29T09:00:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_date_only_tools():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_goal',
|
||||
{'title': '交付节点', 'goal_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['goal_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_preserves_explicit_datetime_offset():
|
||||
payload = resolve_time_expression_data(
|
||||
'2026-03-29T09:00:00+08:00',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00+08:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_keeps_create_task_date_without_explicit_time():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_task',
|
||||
{'title': '写周报', 'due_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['due_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_raises_for_invalid_time_text():
|
||||
try:
|
||||
normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天25点'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
except ValueError as exc:
|
||||
assert 'hour must be in 0..23' in str(exc)
|
||||
else:
|
||||
raise AssertionError('expected ValueError for invalid time text')
|
||||
23
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
23
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from app.agents.tools import forum as forum_tools
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("module", "label"),
|
||||
[
|
||||
(task_tools, "task"),
|
||||
(schedule_tools, "schedule"),
|
||||
(forum_tools, "forum"),
|
||||
],
|
||||
)
|
||||
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
||||
async def sample():
|
||||
return f"ok:{label}"
|
||||
|
||||
result = module._run_async(sample())
|
||||
|
||||
assert result == f"ok:{label}"
|
||||
@@ -9,7 +9,7 @@ from starlette.datastructures import UploadFile
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.brain import BrainEvent, BrainMemory
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.user import User
|
||||
from app.services import agent_service, memory_service
|
||||
@@ -32,6 +32,110 @@ class FakeStreamingGraph:
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingFinalResponseGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chain_end",
|
||||
"name": "master",
|
||||
"data": {"output": {"final_response": "这是最终回答。"}},
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingBadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStreamingBadRequestError2(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeOpenAIBadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStreamingOpenAIBadRequestGraph:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeOpenAIBadRequestError('invalid_request_error: tool arguments failed validation')
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "不应触发同步回退。"}
|
||||
|
||||
|
||||
class FakeStreamingFallbackGraph:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)')
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "这是回退后的同步回答。"}
|
||||
|
||||
|
||||
class FakeStreamingFallbackGraphGenericError:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}")
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "这是通用异常回退后的同步回答。"}
|
||||
|
||||
|
||||
class FakeStreamingDelegationThenFinalResponseGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="现在显示收到,3月28日的任务是完成对话系统。\n\n我将这部分转给schedule_planner,他会根据这个目标结合你当前的进度和资源,给出具体的安排建议。")},
|
||||
}
|
||||
yield {
|
||||
"event": "on_chain_end",
|
||||
"name": "schedule_planner",
|
||||
"data": {"output": {"final_response": "今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。"}},
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingDelegationThenModelEndGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="我将这部分转给schedule_planner。")},
|
||||
}
|
||||
yield {
|
||||
"event": "on_chat_model_end",
|
||||
"name": "schedule_planner",
|
||||
"data": {"output": SimpleNamespace(content="最终建议:先完成对话系统,再回归验证。")},
|
||||
}
|
||||
|
||||
|
||||
class CapturingStateGraph:
|
||||
def __init__(self, final_response: str = '已记录你的请求。'):
|
||||
self.final_response = final_response
|
||||
self.captured_state = None
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.captured_state = state
|
||||
return {"final_response": self.final_response}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_brain_ingestion.db'
|
||||
@@ -43,6 +147,7 @@ async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='brain-ingestion-tester',
|
||||
email='brain-ingestion@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Ingestion Tester',
|
||||
@@ -178,6 +283,360 @@ async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_in
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingFinalResponseGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'直接给我最终回答。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert ''.join(chunks) == '这是最终回答。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == '这是最终回答。'
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_prefers_chain_end_final_response_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenFinalResponseGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我安排今天先做什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。' in chunks
|
||||
assert chunks[-1] == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_prefers_model_end_final_content_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenModelEndGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我安排今天先做什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert '最终建议:先完成对话系统,再回归验证。' in chunks
|
||||
assert chunks[-1] == '最终建议:先完成对话系统,再回归验证。'
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
graph = FakeStreamingOpenAIBadRequestGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeOpenAIBadRequestError)
|
||||
|
||||
original_get_user_llm_config = AgentService._get_user_llm_config
|
||||
|
||||
async def fake_get_user_llm_config(self, user_id, model_name=None):
|
||||
return {
|
||||
'name': 'Official OpenAI',
|
||||
'provider': 'openai',
|
||||
'model': 'gpt-4o',
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'enabled': True,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(AgentService, '_get_user_llm_config', fake_get_user_llm_config)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'测试官方 OpenAI bad request 不应回退。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
errors = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
if event.get('type') == 'error':
|
||||
errors.append(event['error'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert graph.astream_calls == 1
|
||||
assert graph.ainvoke_calls == 0
|
||||
assert errors == ['模型服务暂不可用,请稍后再试。']
|
||||
assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。']
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
fallback_graph = FakeStreamingFallbackGraphGenericError()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: fallback_graph)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我制定一下明天的计划。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert fallback_graph.astream_calls == 1
|
||||
assert fallback_graph.ainvoke_calls == 1
|
||||
assert ''.join(chunks) == '这是通用异常回退后的同步回答。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].content_summary == '这是通用异常回退后的同步回答。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
|
||||
class PartialThenFailGraph:
|
||||
def __init__(self):
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version='v2'):
|
||||
yield {
|
||||
'event': 'on_chat_model_stream',
|
||||
'name': 'master',
|
||||
'data': {'chunk': SimpleNamespace(content='前半段')},
|
||||
}
|
||||
raise FakeStreamingBadRequestError('stream interrupted')
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {'final_response': '不应触发'}
|
||||
|
||||
graph = PartialThenFailGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeStreamingBadRequestError)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'测试部分流式输出失败。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
errors = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
if event.get('type') == 'error':
|
||||
errors.append(event['error'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
brain_event_result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(brain_event_result.scalars().all())
|
||||
|
||||
assert chunks == ['前半段']
|
||||
assert graph.ainvoke_calls == 0
|
||||
assert errors == ['stream interrupted']
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '前半段'
|
||||
assert events[1].content_summary == '前半段'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
graph = CapturingStateGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
service = AgentService(session)
|
||||
|
||||
await service.chat_simple(
|
||||
user.id,
|
||||
'3月29日,对话系统交付节点',
|
||||
)
|
||||
|
||||
assert graph.captured_state is not None
|
||||
current_context = graph.captured_state.get('current_datetime_context')
|
||||
assert isinstance(current_context, str)
|
||||
assert current_context
|
||||
assert '当前时间' in current_context
|
||||
assert '2026' in current_context
|
||||
|
||||
current_reference = graph.captured_state.get('current_datetime_reference')
|
||||
assert isinstance(current_reference, dict)
|
||||
assert 'current_time_iso' in current_reference
|
||||
assert 'current_date_iso' in current_reference
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_llm_config_defaults_to_enabled_chat_model_not_vlm(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
config = await service._get_user_llm_config(user.id)
|
||||
|
||||
assert config is not None
|
||||
assert config['name'] == 'Enabled Chat'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_llm_config_returns_none_when_only_vlm_is_enabled(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
config = await service._get_user_llm_config(user.id)
|
||||
|
||||
assert config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||||
await service.chat_simple(user.id, '测试聊天模型选择', model_name='Vision Only')
|
||||
|
||||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||||
message_result = await session.execute(select(Message))
|
||||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||||
|
||||
assert conversation_result.scalars().all() == []
|
||||
assert message_result.scalars().all() == []
|
||||
assert brain_event_result.scalars().all() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||||
await service.chat(user.id, '测试流式聊天模型选择', model_name='Vision Only')
|
||||
|
||||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||||
message_result = await session.execute(select(Message))
|
||||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||||
|
||||
assert conversation_result.scalars().all() == []
|
||||
assert message_result.scalars().all() == []
|
||||
assert brain_event_result.scalars().all() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_builtin_skills_creates_default_ability_skills(tmp_path):
|
||||
db_path = tmp_path / 'test_builtin_skills.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='bootstrap_user',
|
||||
email='bootstrap@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Bootstrap User',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_builtin_skills(session)
|
||||
await ensure_builtin_skills(session)
|
||||
result = await session.execute(select(Skill).order_by(Skill.agent_type, Skill.name))
|
||||
skills = result.scalars().all()
|
||||
|
||||
assert len(skills) >= 9
|
||||
assert any(skill.agent_type == 'schedule_planner' for skill in skills)
|
||||
assert any(skill.agent_type == 'executor' for skill in skills)
|
||||
assert any(skill.agent_type == 'librarian' for skill in skills)
|
||||
librarian_skill = next(skill for skill in skills if skill.name == '知识检索摘要')
|
||||
assert 'web_search' in (librarian_skill.tools or [])
|
||||
assert any(skill.agent_type == 'analyst' for skill in skills)
|
||||
assert len({skill.name for skill in skills}) == len(skills)
|
||||
|
||||
await engine.dispose()
|
||||
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchResult,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
'request failed',
|
||||
request=httpx.Request('GET', 'http://searx.example/search'),
|
||||
response=httpx.Response(self.status_code, request=httpx.Request('GET', 'http://searx.example/search')),
|
||||
)
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *, response=None, error=None, recorder=None, **kwargs):
|
||||
self._response = response
|
||||
self._error = error
|
||||
self._recorder = recorder if recorder is not None else []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, *, params=None, headers=None):
|
||||
self._recorder.append({'url': url, 'params': params, 'headers': headers})
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
return self._response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_normalized_results_from_searxng(monkeypatch):
|
||||
requests = []
|
||||
payload = {
|
||||
'results': [
|
||||
{
|
||||
'title': 'Jarvis release notes',
|
||||
'url': 'https://example.com/jarvis-release',
|
||||
'content': 'Latest Jarvis changes and release notes.',
|
||||
'engine': 'duckduckgo',
|
||||
'publishedDate': '2026-03-29',
|
||||
}
|
||||
]
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse(payload), recorder=requests, **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
default_limit=5,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
results = await service.search('Jarvis 最新版本', limit=3)
|
||||
|
||||
assert results == [
|
||||
WebSearchResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes and release notes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
assert requests == [
|
||||
{
|
||||
'url': 'http://searx.example/search',
|
||||
'params': {
|
||||
'q': 'Jarvis 最新版本',
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
},
|
||||
'headers': {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_empty_list_when_searxng_has_no_results(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse({'results': []}), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
results = await service.search('不存在的话题')
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_on_searxng_http_failure(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(error=httpx.TimeoutException('timed out'), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchRequestError, match='SearxNG 请求失败'):
|
||||
await service.search('Jarvis')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_when_not_configured():
|
||||
service = WebSearchService(
|
||||
enabled=False,
|
||||
provider='searxng',
|
||||
base_url='',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchConfigurationError, match='网页搜索未启用或未配置'):
|
||||
await service.search('Jarvis')
|
||||
150
backend/tests/backend/app/test_agent_router.py
Normal file
150
backend/tests/backend/app/test_agent_router.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_env(tmp_path):
|
||||
db_path = tmp_path / 'test_agent_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='agent_user',
|
||||
email='agent@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Agent Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
skill_a = Skill(
|
||||
name='Planner skill A',
|
||||
description='planner',
|
||||
instructions='plan a',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
skill_b = Skill(
|
||||
name='Planner skill B',
|
||||
description='planner',
|
||||
instructions='plan b',
|
||||
agent_type='schedule_planner',
|
||||
tools=['tasks'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
session.add_all([
|
||||
Agent(
|
||||
name='SCHEDULE PLANNER',
|
||||
role='schedule_planner',
|
||||
description='日程规划师',
|
||||
system_prompt='prompt',
|
||||
is_active=True,
|
||||
),
|
||||
skill_a,
|
||||
skill_b,
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(skill_a)
|
||||
await session.refresh(skill_b)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(agent_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_config_returns_default_empty_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['selected_skill_ids'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_persists_selected_skill_ids(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id'], ids['skill_b_id']]},
|
||||
)
|
||||
get_response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
first_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id']]},
|
||||
)
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'description': 'updated description'},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['description'] == 'updated description'
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': ['missing-skill']},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()['detail'] == '存在无效的技能绑定'
|
||||
67
backend/tests/backend/app/test_config.py
Normal file
67
backend/tests/backend/app/test_config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
from app import config as config_module
|
||||
from app.services.llm_service import default_provider_capabilities, normalize_provider_name, resolve_provider_capabilities
|
||||
|
||||
|
||||
def test_env_file_points_to_repo_root_env_file():
|
||||
assert config_module.ENV_FILE == Path(__file__).resolve().parents[4] / '.env'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_prefers_native_for_openai():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'gpt-4o', 'base_url': 'https://api.openai.com/v1'})
|
||||
|
||||
assert capabilities.provider == 'openai'
|
||||
assert capabilities.supports_native_tools is True
|
||||
assert capabilities.preferred_tool_strategy == 'native'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_openai_compatible_non_official_endpoint():
|
||||
capabilities = resolve_provider_capabilities(
|
||||
{'provider': 'openai', 'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}
|
||||
)
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_uses_global_openai_base_url_when_user_config_omits_it(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'OPENAI_BASE_URL', 'https://api.minimax.chat/v1')
|
||||
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'abab7.5-chat-preview'})
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_from_custom_config():
|
||||
assert normalize_provider_name({'provider': 'custom', 'model': 'MiniMax-M2.7-highspeed'}) == 'minimax'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_without_provider_when_base_url_matches():
|
||||
assert normalize_provider_name({'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}) == 'minimax'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_ollama():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_default_provider_capabilities_follows_global_settings(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
capabilities = default_provider_capabilities()
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_without_provider_uses_global_default(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
assert normalize_provider_name({'model': 'qwen2.5'}) == 'ollama'
|
||||
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import sys
|
||||
from datetime import UTC, date, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.task import router as task_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def schedule_env(tmp_path):
|
||||
db_path = tmp_path / 'test_schedule_center.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='schedule_user',
|
||||
email='schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Schedule Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_schedule_user',
|
||||
email='other-schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Schedule Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Legacy todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Done todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=True,
|
||||
completed_at=datetime(2026, 4, 10, 9, 30, tzinfo=UTC),
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=other_user.id,
|
||||
title='Other user todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='High priority task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 14, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='Urgent task next day',
|
||||
priority=TaskPriority.URGENT,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
due_date=datetime(2026, 4, 11, 10, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=other_user.id,
|
||||
title='Other user task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 15, 0, tzinfo=UTC),
|
||||
),
|
||||
Reminder(
|
||||
user_id=user.id,
|
||||
title='Doctor reminder',
|
||||
note='Bring reports',
|
||||
reminder_at=datetime(2026, 4, 10, 8, 0, tzinfo=UTC),
|
||||
),
|
||||
Goal(
|
||||
user_id=user.id,
|
||||
title='Launch calendar beta',
|
||||
note='Ship MVP',
|
||||
goal_date='2026-04-10',
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(todo_router)
|
||||
test_app.include_router(task_router)
|
||||
test_app.include_router(reminder_router)
|
||||
test_app.include_router(goal_router)
|
||||
test_app.include_router(schedule_center_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_todo_persists_explicit_todo_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post('/api/todos', json={'title': 'Plan sprint', 'todo_date': '2026-04-12'})
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Plan sprint'
|
||||
assert payload['todo_date'] == '2026-04-12'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_todo_allows_editing_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.patch(f'/api/todos/{todo_id}', json={'title': 'Updated title', 'todo_date': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Updated title'
|
||||
assert payload['todo_date'] == '2026-04-11'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_todo_allows_deleting_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.delete(f'/api/todos/{todo_id}')
|
||||
after_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 204
|
||||
assert all(item['id'] != todo_id for item in after_response.json()['items'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'due_date': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['title'] for item in payload] == ['High priority task']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date_range(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'date_from': '2026-04-10', 'date_to': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert {item['title'] for item in payload} == {'High priority task', 'Urgent task next day'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_date_returns_aggregated_resources(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/date', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['date'] == '2026-04-10'
|
||||
assert payload['summary'] == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert [item['title'] for item in payload['reminders']] == ['Doctor reminder']
|
||||
assert [item['title'] for item in payload['goals']] == ['Launch calendar beta']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_month_returns_day_summaries(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/month', params={'year': 2026, 'month': 4})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['month'] == '2026-04'
|
||||
day_10 = next(item for item in payload['days'] if item['date'] == '2026-04-10')
|
||||
day_11 = next(item for item in payload['days'] if item['date'] == '2026-04-11')
|
||||
assert day_10 == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert day_11['task_due_total'] == 1
|
||||
assert day_11['high_priority_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_reminder_with_naive_datetime_and_time_zone_appears_in_schedule_center(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
create_response = await client.post(
|
||||
'/api/reminders',
|
||||
json={'title': '收被子', 'note': '提醒收被子', 'reminder_at': '2026-03-29T09:00:00'},
|
||||
)
|
||||
detail_response = await client.get('/api/schedule-center/date', params={'date_str': '2026-03-29'})
|
||||
|
||||
assert create_response.status_code == 201
|
||||
assert detail_response.status_code == 200
|
||||
payload = detail_response.json()
|
||||
assert [item['title'] for item in payload['reminders']] == ['收被子']
|
||||
assert payload['summary']['reminder_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reminder_and_goal_crud(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
reminder_response = await client.post('/api/reminders', json={'title': 'Standup', 'note': 'Daily sync', 'reminder_at': '2026-04-12T09:00:00Z'})
|
||||
goal_response = await client.post('/api/goals', json={'title': 'Finish polish', 'note': 'UI cleanup', 'goal_date': '2026-04-12', 'status': 'active'})
|
||||
reminder_id = reminder_response.json()['id']
|
||||
goal_id = goal_response.json()['id']
|
||||
patch_reminder = await client.patch(f'/api/reminders/{reminder_id}', json={'status': 'done'})
|
||||
patch_goal = await client.patch(f'/api/goals/{goal_id}', json={'status': 'done'})
|
||||
reminders_list = await client.get('/api/reminders', params={'date_str': '2026-04-12'})
|
||||
goals_list = await client.get('/api/goals', params={'date_str': '2026-04-12'})
|
||||
delete_reminder = await client.delete(f'/api/reminders/{reminder_id}')
|
||||
delete_goal = await client.delete(f'/api/goals/{goal_id}')
|
||||
|
||||
assert reminder_response.status_code == 201
|
||||
assert goal_response.status_code == 201
|
||||
assert patch_reminder.json()['status'] == 'done'
|
||||
assert patch_goal.json()['status'] == 'done'
|
||||
assert [item['title'] for item in reminders_list.json()['items']] == ['Standup']
|
||||
assert [item['title'] for item in goals_list.json()['items']] == ['Finish polish']
|
||||
assert delete_reminder.status_code == 204
|
||||
assert delete_goal.status_code == 204
|
||||
190
backend/tests/backend/app/test_skill_router.py
Normal file
190
backend/tests/backend/app/test_skill_router.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.skill import router as skill_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def skill_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_skill_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
# Ensure app.database.init_db() runs against the test database.
|
||||
import app.database as database_module
|
||||
|
||||
monkeypatch.setattr(database_module, "engine", engine)
|
||||
monkeypatch.setattr(database_module, "async_session", session_factory)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='skill_user',
|
||||
email='skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Skill Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_skill_user',
|
||||
email='other-skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Skill Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
session.add_all([
|
||||
Skill(
|
||||
name='Planner skill',
|
||||
description='planner',
|
||||
instructions='plan',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Executor skill',
|
||||
description='executor',
|
||||
instructions='execute',
|
||||
agent_type='executor',
|
||||
tools=['shell'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Other user planner skill',
|
||||
description='other',
|
||||
instructions='other',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=other_user.id,
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(auth_router)
|
||||
test_app.include_router(skill_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, session_factory
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_filters_by_agent_type(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'agent_type': 'schedule_planner'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_migrates_planner_skills_to_schedule_planner(skill_env):
|
||||
app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("UPDATE skills SET agent_type = 'planner' WHERE name = 'Planner skill'"))
|
||||
await session.commit()
|
||||
|
||||
from app.database import init_db
|
||||
|
||||
await init_db()
|
||||
|
||||
async with session_factory() as session:
|
||||
migrated_response = await session.execute(text("SELECT agent_type FROM skills WHERE name = 'Planner skill'"))
|
||||
assert migrated_response.scalar_one() == 'schedule_planner'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_visibility_filter_still_respects_access_scope(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'visibility': 'private'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_bootstraps_builtin_market_skills_for_current_user(skill_env):
|
||||
test_app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("DELETE FROM skills"))
|
||||
await session.commit()
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
login_response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'skill_user', 'password': 'secret123'},
|
||||
headers={'content-type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
|
||||
assert login_response.status_code == 200
|
||||
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(select(Skill.name, Skill.is_builtin).order_by(Skill.name))
|
||||
skills = result.all()
|
||||
|
||||
names = {name for name, _is_builtin in skills}
|
||||
assert '今日重点拆解' in names
|
||||
assert '任务执行 SOP' in names
|
||||
assert any(is_builtin is True for _name, is_builtin in skills)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_without_agent_type_returns_current_user_skills(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
assert all(isinstance(item['created_at'], str) for item in payload)
|
||||
assert all(isinstance(item['updated_at'], str) for item in payload)
|
||||
assert all('is_builtin' in item for item in payload)
|
||||
assert all(item['is_builtin'] is False for item in payload)
|
||||
Reference in New Issue
Block a user