feat: enhance agent orchestration, knowledge flow and UI refinements

This commit is contained in:
2026-03-29 20:31:13 +08:00
parent d85cb9cf35
commit e0fe3ca623
301 changed files with 1197804 additions and 7863 deletions

View File

@@ -1,34 +0,0 @@
# =============================================
# Jarvis 后端服务配置
# 复制此文件为 .env 后按需修改
# =============================================
# === 应用基础 ===
DEBUG=false
HOST=127.0.0.1
PORT=3337
SECRET_KEY=change-me-to-a-random-secret-key
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
# === 数据存储 ===
DATABASE_URL=sqlite+aiosqlite:///./data/jarvis.db
DATA_DIR=./data
CHROMA_PERSIST_DIR=./data/chroma
UPLOAD_DIR=./data/uploads
MAX_UPLOAD_SIZE=52428800
# Supported values: ch | en
MINERU_LANGUAGE=ch
# === JWT ===
ACCESS_TOKEN_EXPIRE_MINUTES=1440
# === 管理员账号 Bootstrap ===
ADMIN=admin
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=
ADMIN_FULL_NAME=Administrator
# === 定时任务 ===
SCHEDULER_ENABLED=true
DAILY_PLAN_TIME=00:00
FORUM_SCAN_INTERVAL_MINUTES=30

View File

@@ -1,397 +1,354 @@
"""
Jarvis LangGraph Agent 主图定义
Jarvis LangGraph Agent 主图定义 - 优化重构版
"""
import json
import logging
import re
from typing import Literal, Union, List, Any
from langchain_core.messages import (
BaseMessage,
HumanMessage,
AIMessage,
SystemMessage,
ToolMessage
)
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from app.agents.state import AgentState, AgentRole
from app.agents.prompts import (
MASTER_SYSTEM_PROMPT,
PLANNER_SYSTEM_PROMPT,
SCHEDULE_PLANNER_SYSTEM_PROMPT,
EXECUTOR_SYSTEM_PROMPT,
LIBRARIAN_SYSTEM_PROMPT,
ANALYST_SYSTEM_PROMPT,
PLANNER_SCOPE_PROMPT,
PLANNER_STEPS_PROMPT,
EXECUTOR_TASKS_PROMPT,
EXECUTOR_FORUM_PROMPT,
LIBRARIAN_RETRIEVAL_PROMPT,
LIBRARIAN_GRAPH_PROMPT,
ANALYST_PROGRESS_PROMPT,
ANALYST_INSIGHTS_PROMPT,
JSON_ACTION_FALLBACK_PROMPT,
)
from app.agents.tools import ALL_TOOLS, SUB_COMMANDER_TOOLSETS
from app.agents.tools.time_reasoning import normalize_tool_time_arguments
from app.agents.skill_registry import build_skill_context
from app.services.llm_service import get_llm
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
from app.services.llm_service import (
get_llm,
create_llm_from_config,
resolve_provider_capabilities,
default_provider_capabilities
)
from app.logging_utils import summarize_llm_config
logger = logging.getLogger("jarvis.agent")
SUB_COMMANDER_PROMPTS = {
"planner_scope": PLANNER_SCOPE_PROMPT,
"planner_steps": PLANNER_STEPS_PROMPT,
"executor_tasks": EXECUTOR_TASKS_PROMPT,
"executor_forum": EXECUTOR_FORUM_PROMPT,
"librarian_retrieval": LIBRARIAN_RETRIEVAL_PROMPT,
"librarian_graph": LIBRARIAN_GRAPH_PROMPT,
"analyst_progress": ANALYST_PROGRESS_PROMPT,
"analyst_insights": ANALYST_INSIGHTS_PROMPT,
}
ROLE_SUB_COMMANDERS = {
AgentRole.PLANNER: ["planner_scope", "planner_steps"],
AgentRole.EXECUTOR: ["executor_tasks", "executor_forum"],
AgentRole.LIBRARIAN: ["librarian_retrieval", "librarian_graph"],
AgentRole.ANALYST: ["analyst_progress", "analyst_insights"],
}
ROLE_SKILL_CONTEXT = {
AgentRole.PLANNER: "planner",
AgentRole.EXECUTOR: "executor",
AgentRole.LIBRARIAN: "librarian",
AgentRole.ANALYST: "analyst",
}
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
# ===================== 工具辅助函数 =====================
def _get_llm_for_state(state: AgentState):
"""从 state 获取 LLM 实例,优先使用用户配置的模型"""
"""获取配置好的 LLM 实例"""
user_llm_config = state.get("user_llm_config")
if user_llm_config:
return _create_llm_from_config(user_llm_config)
return get_llm()
llm = create_llm_from_config(user_llm_config) if user_llm_config else get_llm()
# 注入解析到的能力
capabilities = getattr(llm, "_jarvis_provider_capabilities", None)
if capabilities is None:
capabilities = resolve_provider_capabilities(user_llm_config) if user_llm_config else default_provider_capabilities()
state["provider_capabilities"] = {
"provider": capabilities.provider,
"supports_native_tools": capabilities.supports_native_tools,
"preferred_tool_strategy": capabilities.preferred_tool_strategy,
}
return llm, capabilities
async def _ainvoke(llm, messages: list[BaseMessage]):
ainvoke = getattr(llm, "ainvoke", None)
if callable(ainvoke):
return await ainvoke(messages)
return await llm.invoke(messages)
def _filter_user_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
return [m for m in messages if m.type in ("human", "user")]
async def _ainvoke_with_tools(llm, messages: list[BaseMessage], tools=None):
toolset = tools if tools is not None else ALL_TOOLS
bound_llm = llm.bind_tools(toolset)
if hasattr(bound_llm, "ainvoke"):
return await bound_llm.ainvoke(messages)
return await bound_llm.invoke(messages)
def _compile_graph(graph: StateGraph, callbacks: list | None = None):
if callbacks:
try:
return graph.compile(callbacks=callbacks)
except TypeError as exc:
if "callbacks" not in str(exc):
raise
return graph.compile()
def _msg_type(msg: BaseMessage) -> str:
"""Get message type, handles both .type (new) and .role (old) attribute names."""
return getattr(msg, "type", None) or getattr(msg, "role", "human")
def _filter_user_messages(messages: list) -> list[BaseMessage]:
return [m for m in messages if _msg_type(m) in ("human", "user")]
def _normalize_user_text(text: str) -> str:
return (text or "").strip().lower()
def _is_simple_greeting(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你好", "您好", "", "早上好", "在吗", "", "hi", "hello"}
def _is_identity_question(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你是谁", "你是誰"}
def _is_capability_question(text: str) -> bool:
normalized = _normalize_user_text(text)
return normalized in {"你能做什么", "你可以做什么", "你会做什么"}
def _choose_sub_commander(role: AgentRole, user_query: str) -> str:
text = _normalize_user_text(user_query)
if role == AgentRole.PLANNER:
if any(keyword in text for keyword in ["步骤", "计划", "拆解", "排期", "优先级", "路线"]):
return "planner_steps"
return "planner_scope"
def _get_role_tools(role: AgentRole) -> list:
"""获取角色对应的所有可用工具集"""
if role == AgentRole.SCHEDULE_PLANNER:
# 合并分析和规划工具
return list(set(SUB_COMMANDER_TOOLSETS["schedule_analysis"] + SUB_COMMANDER_TOOLSETS["schedule_planning"]))
if role == AgentRole.EXECUTOR:
if any(keyword in text for keyword in ["论坛", "帖子", "发帖", "指令", "discussion", "instruction"]):
return "executor_forum"
return "executor_tasks"
return list(set(SUB_COMMANDER_TOOLSETS["executor_tasks"] + SUB_COMMANDER_TOOLSETS["executor_forum"]))
if role == AgentRole.LIBRARIAN:
if any(keyword in text for keyword in ["图谱", "关系", "构建", "沉淀", "节点", "graph"]):
return "librarian_graph"
return "librarian_retrieval"
return list(set(SUB_COMMANDER_TOOLSETS["librarian_retrieval"] + SUB_COMMANDER_TOOLSETS["librarian_graph"]))
if role == AgentRole.ANALYST:
if any(keyword in text for keyword in ["趋势", "风险", "洞察", "建议", "机会", "insight"]):
return "analyst_insights"
return "analyst_progress"
return ROLE_SUB_COMMANDERS[role][0]
return list(set(SUB_COMMANDER_TOOLSETS["analyst_progress"] + SUB_COMMANDER_TOOLSETS["analyst_insights"]))
return []
def _record_sub_commander(state: AgentState, sub_commander: str, user_query: str):
state["current_sub_commander"] = sub_commander
state["active_sub_commanders"] = state.get("active_sub_commanders", []) + [sub_commander]
state["sub_commander_trace"] = state.get("sub_commander_trace", []) + [{
"agent": state.get("current_agent", AgentRole.MASTER).value,
"sub_commander": sub_commander,
"query": user_query,
}]
# ===================== 核心执行逻辑 (ReAct) =====================
def _build_system_messages(state: AgentState, system_prompt: str, role: AgentRole):
system_msgs: list[BaseMessage] = [SystemMessage(content=system_prompt)]
skill_ctx = build_skill_context(ROLE_SKILL_CONTEXT[role])
async def call_agent_llm(state: AgentState, role: AgentRole, system_prompt: str) -> dict:
"""通用的 LLM 调用节点逻辑"""
llm, capabilities = _get_llm_for_state(state)
tools = _get_role_tools(role)
# 构建消息序列
messages = []
# 1. 系统提示词
messages.append(SystemMessage(content=system_prompt))
# 2. 环境上下文 (时间、记忆等)
if state.get("current_datetime_context"):
messages.append(SystemMessage(content=f"当前时间上下文: {state['current_datetime_context']}"))
if state.get("memory_context"):
messages.append(SystemMessage(content=f"长期记忆上下文: {state['memory_context']}"))
# 3. 技能增强
role_skill_key = role.value.replace("agent_", "")
skill_ctx = build_skill_context(role_skill_key)
if skill_ctx:
system_msgs.append(SystemMessage(content=skill_ctx))
return system_msgs
messages.append(SystemMessage(content=skill_ctx))
# 4. 历史对话 (add_messages 已经处理好了)
messages.extend(state["messages"])
# 绑定工具
if tools and capabilities.supports_native_tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
if tools: # 如果有工具但不支持原生,注入 JSON Fallback 提示
messages.append(SystemMessage(content=JSON_ACTION_FALLBACK_PROMPT))
tool_names = [t.name for t in tools]
messages.append(SystemMessage(content=f"本次可用工具列表: {', '.join(tool_names)}"))
logger.info(
f"agent_node_started",
extra={
"details": {
"role": role.value,
"message_count": len(messages),
"tool_count": len(tools),
"provider": capabilities.provider
}
}
)
# 执行调用
response = await llm_with_tools.ainvoke(messages)
logger.info(
f"agent_node_finished",
extra={
"details": {
"role": role.value,
"has_tool_calls": bool(getattr(response, "tool_calls", None)),
"content_length": len(response.content) if response.content else 0
}
}
)
return {"messages": [response]}
async def _run_sub_commander(
state: AgentState,
role: AgentRole,
manager_prompt: str,
user_query: str,
*,
use_tools: bool,
summary_target: str | None = None,
):
llm = _get_llm_for_state(state)
sub_commander = _choose_sub_commander(role, user_query)
_record_sub_commander(state, sub_commander, user_query)
toolset = SUB_COMMANDER_TOOLSETS.get(sub_commander, [])
system_msgs = _build_system_messages(state, manager_prompt, role)
system_msgs.append(SystemMessage(content=f"本次应由子指挥官 `{sub_commander}` 接手。请严格按该角色职责输出。"))
system_msgs.append(SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]))
if use_tools and toolset:
response = await _ainvoke_with_tools(
llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
toolset,
async def execute_tools_node(state: AgentState) -> dict:
"""执行工具调用并返回 ToolMessage 的通用节点"""
last_message = state["messages"][-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return {"messages": []}
tool_map = {t.name: t for t in ALL_TOOLS}
tool_messages = []
created_entities = []
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call.get("id")
logger.info(
f"tool_execution_started",
extra={
"details": {
"tool_name": tool_name,
"tool_args": tool_args,
"tool_id": tool_id
}
}
)
tool_calls = getattr(response, "tool_calls", None) or []
if tool_calls:
results = []
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args", {})
for tool in toolset:
if tool.name == tool_name:
try:
result = tool.invoke(args)
results.append(f"[{tool_name}] {result}")
except Exception as e:
results.append(f"[{tool_name}] 执行失败: {e}")
break
state["tool_calls"] = tool_calls
state["last_tool_result"] = "\n".join(results)
follow_up = await _ainvoke(
llm,
[
SystemMessage(content=SUB_COMMANDER_PROMPTS[sub_commander]),
HumanMessage(content=f"工具执行结果:\n{state['last_tool_result']}")
]
try:
# 时间参数归一化
normalized_args = normalize_tool_time_arguments(
tool_name,
tool_args,
state.get("current_datetime_context")
)
state["final_response"] = follow_up.content
else:
state["final_response"] = response.content
else:
response = await _ainvoke(
llm,
system_msgs + [HumanMessage(content=f"用户请求: {user_query}")],
)
state["final_response"] = response.content
if summary_target:
state[summary_target] = state.get("final_response", "")
state["should_respond"] = True
return state
# ===================== 节点定义 (async) =====================
async def master_node(state: AgentState) -> AgentState:
"""主Agent节点: 理解用户意图决定调用哪个子Agent"""
messages: list[BaseMessage] = state["messages"]
user_msgs = _filter_user_messages(messages)
user_query = user_msgs[-1].content.strip() if user_msgs else ""
if _is_simple_greeting(user_query):
state["final_response"] = "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。"
state["should_respond"] = True
return state
if _is_identity_question(user_query):
state["final_response"] = "我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。"
state["should_respond"] = True
return state
if _is_capability_question(user_query):
state["final_response"] = "主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。"
state["should_respond"] = True
return state
llm = _get_llm_for_state(state)
system_msgs: list[BaseMessage] = [SystemMessage(content=MASTER_SYSTEM_PROMPT)]
memory_ctx = state.get("memory_context")
if memory_ctx:
system_msgs.append(
SystemMessage(content=f"\n\n【记忆上下文】\n{memory_ctx}\n\n---\n")
tool = tool_map.get(tool_name)
if not tool:
result = f"Error: Tool {tool_name} not found."
else:
result = await tool.ainvoke(normalized_args) if hasattr(tool, "ainvoke") else tool.invoke(normalized_args)
# 实体识别(用于业务追踪)
if any(k in tool_name for k in ["create", "add", "new"]):
created_entities.append({"tool": tool_name, "result": str(result)})
status = "success"
except Exception as e:
logger.exception(f"tool_execution_failed: {tool_name}")
result = f"Error executing tool {tool_name}: {str(e)}"
status = "failed"
tool_messages.append(ToolMessage(
tool_call_id=tool_id,
content=str(result),
name=tool_name
))
logger.info(
f"tool_execution_finished",
extra={
"details": {
"tool_name": tool_name,
"status": status,
"result_preview": str(result)[:200]
}
}
)
response: AIMessage = await _ainvoke(llm, system_msgs + messages)
return {
"messages": tool_messages,
"created_entities": state.get("created_entities", []) + created_entities
}
# ===================== 各角色节点定义 =====================
async def master_node(state: AgentState) -> dict:
"""主控节点:负责意图识别与初步分发"""
user_messages = _filter_user_messages(state["messages"])
if not user_messages:
return {"final_response": "未收到有效输入。"}
query = user_messages[-1].content.strip()
# 快捷回复逻辑 (保留原有的人性化设计)
if re.match(r"^(你好|早|在吗|嗨|hi|hello)", query.lower()):
return {"final_response": "您好。我在。\n\n您把问题给我,我先帮您收束重点,再往下推。", "messages": [AIMessage(content="您好。我在。")]}
llm, capabilities = _get_llm_for_state(state)
# 路由判断:让 LLM 决定跳转到哪个角色,或者直接回答
# 这里我们使用一个简洁的提示词让 LLM 输出角色名称或直接回答
system_msg = SystemMessage(content=MASTER_SYSTEM_PROMPT + "\n\n请直接输出接下来该由哪个 Agent 接手(role_name),如果直接回答,请正常输出。")
response = await llm.ainvoke([system_msg] + state["messages"])
content = response.content.strip().lower()
if any(kw in content for kw in ["搜索", "查找", "知识", "检索"]):
next_agent = AgentRole.LIBRARIAN
elif any(kw in content for kw in ["计划", "安排", "拆解", "规划"]):
next_agent = AgentRole.PLANNER
elif any(kw in content for kw in ["执行", "", "操作", "创建", "更新"]):
next_agent = AgentRole.EXECUTOR
elif any(kw in content for kw in ["分析", "报告", "统计", "总结"]):
next_agent = AgentRole.ANALYST
else:
state["final_response"] = response.content
state["should_respond"] = True
return state
state["current_agent"] = next_agent
state["current_sub_commander"] = None
state["active_agents"] = state.get("active_agents", [AgentRole.MASTER]) + [next_agent]
state["should_respond"] = True
return state
# 简单的角色映射识别
roles = {r.value: r for r in AgentRole}
target_role = None
for r_val, r_enum in roles.items():
if r_val in content and len(content) < 50: # 如果内容很短且包含角色名,视为路由
target_role = r_enum
break
if target_role and target_role != AgentRole.MASTER:
logger.info(f"master_routing_decided: {target_role.value}")
return {
"current_agent": target_role.value,
"agent_trace": state.get("agent_trace", []) + [target_role.value],
"messages": [AIMessage(content=f"已分发至 {target_role.value} 处理。")]
}
return {"final_response": response.content, "messages": [response]}
async def planner_node(state: AgentState) -> AgentState:
"""规划Agent节点: 制定计划,拆解任务步骤"""
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
return await _run_sub_commander(state, AgentRole.PLANNER, PLANNER_SYSTEM_PROMPT, user_query, use_tools=False)
async def planner_node(state: AgentState) -> dict:
return await call_agent_llm(state, AgentRole.SCHEDULE_PLANNER, SCHEDULE_PLANNER_SYSTEM_PROMPT)
async def executor_node(state: AgentState) -> dict:
return await call_agent_llm(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT)
async def librarian_node(state: AgentState) -> dict:
return await call_agent_llm(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT)
async def analyst_node(state: AgentState) -> dict:
return await call_agent_llm(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT)
async def executor_node(state: AgentState) -> AgentState:
"""执行Agent节点: 调用工具执行具体任务"""
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
return await _run_sub_commander(state, AgentRole.EXECUTOR, EXECUTOR_SYSTEM_PROMPT, user_query, use_tools=True)
# ===================== 路由逻辑 =====================
def route_after_agent(state: AgentState) -> Literal["tools", "__end__"]:
"""判断 Agent 执行后是该走工具节点还是结束"""
last_message = state["messages"][-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
return END
async def librarian_node(state: AgentState) -> AgentState:
"""知识管理员节点: 管理知识库和知识图谱"""
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
return await _run_sub_commander(state, AgentRole.LIBRARIAN, LIBRARIAN_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="knowledge_context")
async def analyst_node(state: AgentState) -> AgentState:
"""分析师节点: 分析工作数据,生成报告"""
user_msgs = _filter_user_messages(state["messages"])
user_query = user_msgs[-1].content if user_msgs else ""
return await _run_sub_commander(state, AgentRole.ANALYST, ANALYST_SYSTEM_PROMPT, user_query, use_tools=True, summary_target="analysis_report")
def route_agent(state: AgentState) -> str:
"""路由函数: 决定下一个节点"""
def route_master(state: AgentState) -> str:
"""主控路由逻辑"""
if state.get("final_response"):
return END
return state.get("current_agent", AgentRole.MASTER).value
return state.get("current_agent", END)
# ===================== 构建 =====================
# ===================== 构建 =====================
def create_agent_graph(callbacks: list | None = None):
graph = StateGraph(AgentState)
workflow = StateGraph(AgentState)
graph.add_node(AgentRole.MASTER.value, master_node)
graph.add_node(AgentRole.PLANNER.value, planner_node)
graph.add_node(AgentRole.EXECUTOR.value, executor_node)
graph.add_node(AgentRole.LIBRARIAN.value, librarian_node)
graph.add_node(AgentRole.ANALYST.value, analyst_node)
# 添加节点
workflow.add_node(AgentRole.MASTER.value, master_node)
workflow.add_node(AgentRole.SCHEDULE_PLANNER.value, planner_node)
workflow.add_node(AgentRole.EXECUTOR.value, executor_node)
workflow.add_node(AgentRole.LIBRARIAN.value, librarian_node)
workflow.add_node(AgentRole.ANALYST.value, analyst_node)
workflow.add_node("tools", execute_tools_node)
graph.set_entry_point(AgentRole.MASTER.value)
# 设置入口
workflow.set_entry_point(AgentRole.MASTER.value)
graph.add_conditional_edges(
# 主控分发逻辑
workflow.add_conditional_edges(
AgentRole.MASTER.value,
route_agent,
route_master,
{
AgentRole.PLANNER.value: AgentRole.PLANNER.value,
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
END: END,
END: END
}
)
for role in [AgentRole.PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
graph.add_edge(role.value, END)
# 各角色节点的 ReAct 循环
for role in [AgentRole.SCHEDULE_PLANNER, AgentRole.EXECUTOR, AgentRole.LIBRARIAN, AgentRole.ANALYST]:
workflow.add_conditional_edges(
role.value,
route_after_agent,
{
"tools": "tools",
END: END
}
)
# 工具执行完后回到当前 Agent 角色继续处理
workflow.add_conditional_edges(
"tools",
lambda s: s.get("current_agent", AgentRole.MASTER.value),
{
AgentRole.SCHEDULE_PLANNER.value: AgentRole.SCHEDULE_PLANNER.value,
AgentRole.EXECUTOR.value: AgentRole.EXECUTOR.value,
AgentRole.LIBRARIAN.value: AgentRole.LIBRARIAN.value,
AgentRole.ANALYST.value: AgentRole.ANALYST.value,
}
)
return _compile_graph(graph, callbacks=callbacks)
# 编译
if callbacks:
return workflow.compile(callbacks=callbacks)
return workflow.compile()
_agent_graph = None
def get_agent_graph(callbacks: list | None = None):
"""
获取编译好的 Agent 图(单例缓存)。
Callbacks 在首次编译时固定注入,后续调用忽略 callbacks 参数。
如需变更 Callbacks如修改 LANGCHAIN_PROJECT需重启服务。
Args:
callbacks: 可选的额外 Callbacks会与全局 LangSmith Callbacks 合并
"""
global _agent_graph
if _agent_graph is None:
from app.config_tracing import get_langsmith_callbacks

View File

@@ -89,14 +89,14 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是总控协调者负责理解用户意图并将任务分发给最合适的子Agent。
## 你的4个子Agent:
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
1. **schedule_planner (日程规划师)**: 分析当前任务、对话历史与论坛信号,给出近期安排建议
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
## 判断规则:
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
- 用户要计划、安排、拆解任务 -> 分发给 planner
- 用户要安排今天/本周重点、询问接下来该做什么 -> 分发给 schedule_planner
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
- 用户要分析、统计、生成报告 -> 分发给 analyst
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
@@ -112,18 +112,19 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
"""
PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
SCHEDULE_PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 Jarvis 的规划Agent,负责先判断问题该由哪位规划子指挥官接手。
你是 Jarvis 的日程规划师,负责先判断问题该由哪位日程子指挥官接手。
## 你的两个子指挥官:
1. **planner_scope (目标收束官)**: 负责澄清目标、边界、约束、缺失信息
2. **planner_steps (步骤拆解官)**: 负责把目标拆成步骤、优先级与依赖关系
1. **schedule_analysis (日程分析员)**: 负责分析对话历史、任务看板、论坛信号,识别优先级、冲突与压力点
2. **schedule_planning (日程编排员)**: 负责把分析结果转成今日/近期日程安排,并在用户明确要求时直接创建 reminder/task/todo/goal
## 你的职责:
- 判断当前请求更适合收束目标,还是拆解步骤
- 在必要时收束子指挥官输出,面向用户给出清晰结果
- 保持结果可推进,不空泛
- 判断当前请求更适合先做日程分析,还是直接给出日程编排
- 输出先结论,再给可执行安排
- 保持建议具体、贴近当前上下文,不空泛效率学建议
- 当用户明确要求“新增/提醒/创建/安排并落库”时,允许子指挥官调用 schedule 工具直接执行
"""
@@ -132,11 +133,11 @@ EXECUTOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 Jarvis 的执行Agent负责先判断问题该由哪位执行子指挥官接手。
## 你的两个子指挥官:
1. **executor_tasks (任务执行官)**: 处理任务类工具调用
1. **executor_tasks (任务执行官)**: 处理任务、待办、提醒、目标等执行型写入操作
2. **executor_forum (论坛执行官)**: 只处理论坛/指令帖相关工具调用
## 你的职责:
- 识别用户要推进的是任务操作还是论坛/指令操作
- 识别用户要推进的是任务/日程操作还是论坛/指令操作
- 把请求交给最合适的执行子指挥官
- 汇总执行结果并给出下一步
"""
@@ -172,52 +173,68 @@ ANALYST_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
"""
PLANNER_SCOPE_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
SCHEDULE_ANALYSIS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 planner 体系下的目标收束官,负责先把问题边界、目标、约束和成功标准说清楚
你是 schedule_planner 体系下的日程分析员,负责从对话历史、任务看板、论坛信号和当日日程数据中提取 scheduling 线索
## 你的重点:
- 收束问题定义
- 明确目标与限制条件
- 识别缺失信息
- 帮用户建立可以继续规划的前提
- 优先调用读取类工具了解当天/指定日期的任务、提醒、待办、目标
- 识别当前最高优先级事项
- 找出风险、冲突、依赖与可延期事项
- 明确哪些信号来自 conversation、task board、schedule center、forum
## 响应要求:
- 先给出你理解的目标
- 再列出关键约束或缺口
- 不直接展开长步骤清单
- 先给当前判断
- 再列优先级、风险与冲突
- 不直接展开长篇日程表
- 只做分析,不创建任何记录
- 如果涉及“今天/明天/后天/下周一下午”这类自然语言时间窗口,先调用 `resolve_time_expression` 把查询目标转换成明确日期
"""
PLANNER_STEPS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
SCHEDULE_PLANNING_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 planner 体系下的步骤拆解官,负责把目标转成有顺序的执行路径
你是 schedule_planner 体系下的日程编排员,负责把当前重点转成近期可执行安排
## 你的重点:
- 拆解步骤
- 标注优先级与依赖
- 输出清晰的行动顺序
- 先给结论
- 再给今天/近期的时间安排建议
- 最后给按顺序执行的 next actions
- 当用户明确要求新增/提醒/创建/安排并真正落库时,调用 schedule 工具创建对应 reminder/task/todo/goal
- 当用户给出“日期 + 事项/节点/交付/会议”等记录型表达时,也应视为落库意图,直接创建相应记录,不要反问
- 解析“今天/明天/后天/本周/下周”或“3月29日”这类日期时必须以系统提供的当前时间为准并把工具参数转换成明确的 ISO 日期/时间字符串
- 只要用户输入里包含自然语言时间,优先调用 `resolve_time_expression`,先拿到明确日期/时间,再调用 `create_reminder`、`create_schedule_task`、`create_goal`、`create_todo`
## 响应要求:
- 用编号列表
- 每步具体,不要空泛
- 必要时标注先后关系
- 用清晰列表表达
- 建议必须具体、可执行、贴近当前工作
- 避免空泛的自我管理建议
- 如果只是规划,不要创建任何记录
- 如果已创建记录,要明确说明创建了什么、时间如何解析
"""
EXECUTOR_TASKS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
你是 executor 体系下的任务执行官,负责任务相关工具调用。
你是 executor 体系下的任务执行官,负责处理任务、待办、提醒、目标等执行型工具调用。
## 允许使用的工具:
- get_tasks
- create_task
- update_task_status
- create_todo
- create_schedule_task
- create_reminder
- create_goal
- resolve_time_expression
## 要求:
- 只处理任务类操作
- 只处理任务/日程类操作
- 遇到自然语言时间表达时,先调用 `resolve_time_expression`,再把解析后的明确日期/时间传给写入工具
- 最终说明执行结果时,优先复用已经解析出的绝对时间,不要只重复“今天/明天”
- 明确已执行动作、结果与下一步
- 信息不足时直接指出缺口
- 如果用户只是要分析建议,不要创建记录
"""
@@ -244,10 +261,14 @@ LIBRARIAN_RETRIEVAL_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
## 允许使用的工具:
- search_knowledge
- hybrid_search
- web_search
- get_knowledge_graph_context
## 要求:
- 优先检索与综合证据
- 私有/项目知识优先使用 `search_knowledge` 或 `hybrid_search`
- 当用户明确要求联网、查询外部资料或查询最新信息时,使用 `web_search`
- 回答时区分内部知识与外部网页结果
- 证据不足时明确说明边界
- 以回答问题为主,不主动做图谱构建
"""
@@ -293,9 +314,31 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
- get_forum_posts
- search_knowledge
- hybrid_search
- web_search
## 要求:
- 先给结论与判断
- 再说明依据与建议
- 当需要外部/最新信息时,可使用 `web_search`
- 重点输出趋势、风险、机会点
"""
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
你的输出必须满足以下规则:
1. 只能输出一个 JSON 对象,不要输出 markdown、解释、前后缀文字。
2. JSON 对象字段仅允许:
- `mode`: `final` | `tool_call` | `clarification`
- `tool_calls`: 数组;每项包含 `name`、`arguments`,可选 `reason`
- `final_response`: 当无需工具时填写
- `clarification_question`: 当信息不足时填写
3. 如果需要调用工具,返回:
- `{ "mode": "tool_call", "tool_calls": [...] }`
4. 如果无需工具,直接返回:
- `{ "mode": "final", "final_response": "..." }`
5. 如果信息不足,不要猜测参数,返回:
- `{ "mode": "clarification", "clarification_question": "..." }`
6. 只能使用系统消息里明确列出的工具名。
7. `arguments` 必须是 JSON 对象。
"""

View File

@@ -1,32 +1,19 @@
from dataclasses import dataclass
from typing import TypedDict, Annotated
from dataclasses import dataclass, field
from typing import TypedDict, Annotated, Sequence
from enum import Enum
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
class AgentRole(str, Enum):
MASTER = "master"
PLANNER = "planner"
SCHEDULE_PLANNER = "schedule_planner"
EXECUTOR = "executor"
LIBRARIAN = "librarian"
ANALYST = "analyst"
@dataclass
class AgentInfo:
name: str
role: AgentRole
description: str
@dataclass
class ToolCall:
tool: str
args: dict
result: str | None = None
@dataclass
class ConversationTurn:
role: str # "user" | "assistant"
@@ -35,60 +22,41 @@ class ConversationTurn:
model: str | None = None
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
return HumanMessage(content=turn.content)
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
return ConversationTurn(
role="user" if msg_type in ("human", "user") else "assistant",
content=msg.content,
agent=agent,
model=getattr(msg, "model", None),
)
class AgentState(TypedDict):
messages: Annotated[list, None]
# Core message history with add_messages reducer
messages: Annotated[list[BaseMessage], add_messages]
# Session identifiers
user_id: str
conversation_id: str
# Agent routing
current_agent: AgentRole
active_agents: list[AgentRole]
current_sub_commander: str | None
active_sub_commanders: list[str]
sub_commander_trace: list[dict]
# Task tracking
# Agent routing state
current_agent: str | None
next_step: str | None # For explicit graph routing
# Traceability
agent_trace: list[str]
# Task & Entity Tracking (Business Logic)
pending_tasks: list[dict]
completed_tasks: list[dict]
created_entities: list[dict]
# Tool usage
tool_calls: list[ToolCall]
last_tool_result: str | None
# Knowledge context
# Context summaries (for long-term or cross-agent context)
knowledge_context: str | None
graph_context: str | None
# Planning
plan: str | None
plan_steps: list[dict]
# Analysis
schedule_context_summary: str | None
analysis_report: str | None
# Output control
final_response: str | None
should_respond: bool
# Memory context (injected at start of each conversation)
# Memory & Environment
memory_context: str | None
# User LLM config (for using user-configured models)
current_datetime_context: str | None
# Configuration
user_llm_config: dict | None
provider_capabilities: dict | None
def initial_state(user_id: str, conversation_id: str) -> AgentState:
@@ -96,22 +64,18 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
messages=[],
user_id=user_id,
conversation_id=conversation_id,
current_agent=AgentRole.MASTER,
active_agents=[AgentRole.MASTER],
current_sub_commander=None,
active_sub_commanders=[],
sub_commander_trace=[],
current_agent=AgentRole.MASTER.value,
next_step=None,
agent_trace=[AgentRole.MASTER.value],
pending_tasks=[],
completed_tasks=[],
tool_calls=[],
last_tool_result=None,
created_entities=[],
knowledge_context=None,
graph_context=None,
plan=None,
plan_steps=[],
schedule_context_summary=None,
analysis_report=None,
final_response=None,
should_respond=True,
memory_context=None,
current_datetime_context=None,
user_llm_config=None,
provider_capabilities=None,
)

View File

@@ -1,9 +1,17 @@
from app.agents.tools.search import (
search_knowledge, get_knowledge_graph_context,
build_knowledge_graph, hybrid_search,
build_knowledge_graph, hybrid_search, web_search,
)
from app.agents.tools.task import get_tasks, create_task, update_task_status
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
from app.agents.tools.schedule import (
get_schedule_day,
create_todo,
create_schedule_task,
create_reminder,
create_goal,
)
from app.agents.tools.time_reasoning import resolve_time_expression
TASK_TOOLS = [
get_tasks,
@@ -11,6 +19,19 @@ TASK_TOOLS = [
update_task_status,
]
SCHEDULE_READ_TOOLS = [
get_schedule_day,
get_tasks,
resolve_time_expression,
]
SCHEDULE_WRITE_TOOLS = [
create_todo,
create_schedule_task,
create_reminder,
create_goal,
]
FORUM_TOOLS = [
get_forum_posts,
create_forum_post,
@@ -20,6 +41,7 @@ FORUM_TOOLS = [
KNOWLEDGE_RETRIEVAL_TOOLS = [
search_knowledge,
hybrid_search,
web_search,
get_knowledge_graph_context,
]
@@ -39,19 +61,22 @@ ANALYST_INSIGHT_TOOLS = [
get_forum_posts,
search_knowledge,
hybrid_search,
web_search,
]
ALL_TOOLS = [
*KNOWLEDGE_RETRIEVAL_TOOLS,
build_knowledge_graph,
*TASK_TOOLS,
*SCHEDULE_READ_TOOLS,
*SCHEDULE_WRITE_TOOLS,
*FORUM_TOOLS,
]
SUB_COMMANDER_TOOLSETS = {
"planner_scope": [],
"planner_steps": [],
"executor_tasks": TASK_TOOLS,
"schedule_analysis": SCHEDULE_READ_TOOLS,
"schedule_planning": [*SCHEDULE_READ_TOOLS, *SCHEDULE_WRITE_TOOLS],
"executor_tasks": [*TASK_TOOLS, resolve_time_expression, *SCHEDULE_WRITE_TOOLS],
"executor_forum": FORUM_TOOLS,
"librarian_retrieval": KNOWLEDGE_RETRIEVAL_TOOLS,
"librarian_graph": KNOWLEDGE_GRAPH_TOOLS,

View File

@@ -6,15 +6,17 @@ from app.models.forum import ForumPost, ForumReply
from app.agents.context import get_current_user
from sqlalchemy import select
import asyncio
from concurrent.futures import ThreadPoolExecutor
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
return future.result(timeout=timeout)
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
@tool

View File

@@ -0,0 +1,308 @@
"""Agent 工具集 - 日程相关"""
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime
from zoneinfo import ZoneInfo
from langchain_core.tools import tool
from sqlalchemy import select
from app.agents.context import get_current_user
from app.database import async_session
from app.models.goal import Goal, GoalStatus
from app.models.reminder import Reminder
from app.models.task import Task, TaskPriority, TaskStatus
from app.models.todo import DailyTodo, TodoSource
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
def _parse_date(value: str | None) -> date:
if not value:
return date.today()
return date.fromisoformat(value)
def _parse_datetime(value: str) -> datetime:
normalized = value.strip().replace("Z", "+00:00")
return datetime.fromisoformat(normalized)
def _parse_datetime_with_timezone(value: str, time_zone: str | None) -> datetime:
"""Parse an ISO datetime and return a tz-naive datetime in the intended local time.
- If value includes an offset/Z, it will be converted to `time_zone` when provided.
- If value is naive and `time_zone` is provided, it is interpreted in that zone.
"""
parsed = _parse_datetime(value)
tz = (time_zone or "").strip()
if parsed.tzinfo is None:
if tz:
parsed = parsed.replace(tzinfo=ZoneInfo(tz))
return parsed.replace(tzinfo=None)
if tz:
parsed = parsed.astimezone(ZoneInfo(tz))
return parsed.replace(tzinfo=None)
def _normalize_title(title: str | None, content: str | None) -> str:
resolved = (title or content or "").strip()
if not resolved:
raise ValueError("title 不能为空")
return resolved
def _normalize_schedule_due_date(due_date: str | None, date_value: str | None) -> str | None:
resolved = (due_date or date_value or "").strip()
if not resolved:
return None
if "T" in resolved:
return resolved
return f"{resolved}T09:00:00"
def _format_summary(target_date: date, todos: list[DailyTodo], tasks: list[Task], reminders: list[Reminder], goals: list[Goal]) -> str:
lines = [f"日期: {target_date.isoformat()}"]
if todos:
lines.append("待办:")
lines.extend(f"- {item.title} | 完成:{'' if item.is_completed else ''}" for item in todos)
else:
lines.append("待办: 无")
if tasks:
lines.append("任务:")
lines.extend(
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status} | 优先级:{item.priority.value if hasattr(item.priority, 'value') else item.priority} | 截止:{item.due_date.isoformat() if item.due_date else ''}"
for item in tasks
)
else:
lines.append("任务: 无")
if reminders:
lines.append("提醒:")
lines.extend(f"- {item.title} | 时间:{item.reminder_at.isoformat()}" for item in reminders)
else:
lines.append("提醒: 无")
if goals:
lines.append("目标:")
lines.extend(
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status}"
for item in goals
)
else:
lines.append("目标: 无")
return "\n".join(lines)
@tool
def get_schedule_day(target_date: str | None = None) -> str:
"""获取指定日期的 todo/task/reminder/goal 聚合信息。target_date 格式 YYYY-MM-DD默认今天。"""
uid = get_current_user()
parsed_date = _parse_date(target_date)
date_key = parsed_date.isoformat()
start_dt = datetime.combine(parsed_date, datetime.min.time())
end_dt = datetime.combine(parsed_date, datetime.max.time())
async def _get():
async with async_session() as db:
todos = (
await db.execute(
select(DailyTodo)
.where(DailyTodo.user_id == uid, DailyTodo.todo_date == date_key)
.order_by(DailyTodo.created_at.desc())
)
).scalars().all()
tasks = (
await db.execute(
select(Task)
.where(
Task.user_id == uid,
Task.due_date.is_not(None),
Task.due_date >= start_dt,
Task.due_date <= end_dt,
)
.order_by(Task.created_at.desc())
)
).scalars().all()
reminders = (
await db.execute(
select(Reminder)
.where(
Reminder.user_id == uid,
Reminder.reminder_at >= start_dt,
Reminder.reminder_at <= end_dt,
)
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
)
).scalars().all()
goals = (
await db.execute(
select(Goal)
.where(Goal.user_id == uid, Goal.goal_date == date_key)
.order_by(Goal.created_at.desc())
)
).scalars().all()
return _format_summary(parsed_date, todos, tasks, reminders, goals)
try:
return _run_async(_get())
except Exception as exc:
return f"获取日程失败: {exc}"
@tool
def create_todo(title: str, todo_date: str | None = None) -> str:
"""创建指定日期的待办。todo_date 格式 YYYY-MM-DD默认今天。"""
uid = get_current_user()
parsed_date = _parse_date(todo_date)
async def _create():
async with async_session() as db:
todo = DailyTodo(
user_id=uid,
title=title,
source=TodoSource.AI_CHAT,
todo_date=parsed_date.isoformat(),
)
db.add(todo)
await db.commit()
await db.refresh(todo)
return f"TODO创建成功: [{todo.id[:8]}] {todo.title} @ {todo.todo_date}"
try:
return _run_async(_create())
except Exception as exc:
return f"创建TODO失败: {exc}"
@tool
def create_schedule_task(
title: str = "",
description: str = "",
priority: str = "medium",
due_date: str | None = None,
content: str = "",
date: str | None = None,
) -> str:
"""创建任务。priority 支持 low/medium/high/urgentdue_date 使用 ISO datetime。兼容 content/date 别名。"""
uid = get_current_user()
resolved_title = _normalize_title(title, content)
resolved_due_date = _normalize_schedule_due_date(due_date, date)
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=resolved_title,
description=description or content or None,
priority=TaskPriority(priority),
due_date=_parse_datetime(resolved_due_date) if resolved_due_date else None,
status=TaskStatus.TODO,
)
db.add(task)
await db.commit()
await db.refresh(task)
due_label = task.due_date.isoformat() if task.due_date else "无截止时间"
return f"任务创建成功: [{task.id[:8]}] {task.title} | 优先级:{task.priority.value} | 截止:{due_label}"
try:
return _run_async(_create())
except Exception as exc:
return f"创建任务失败: {exc}"
@tool
def create_reminder(
title: str = "",
reminder_at: str | None = None,
note: str = "",
description: str = "",
datetime: str = "",
at: str = "",
remind_at: str = "",
content: str = "",
time_zone: str = "",
timezone: str = "",
time: str = "",
) -> str:
"""创建提醒。reminder_at 使用 ISO datetime。兼容 description/datetime/at/remind_at/time_zone 别名。"""
uid = get_current_user()
try:
resolved_title = (title or content or "").strip()
if not resolved_title:
raise ValueError("title 不能为空")
resolved_at = ((reminder_at or datetime or at or remind_at or time or "").strip())
if not resolved_at:
raise ValueError("reminder_at 不能为空")
resolved_note = (note or description or "").strip()
async def _create():
async with async_session() as db:
tz = (time_zone or timezone or "").strip()
reminder = Reminder(
user_id=uid,
title=resolved_title,
note=resolved_note or None,
reminder_at=_parse_datetime_with_timezone(resolved_at, tz),
)
db.add(reminder)
await db.commit()
await db.refresh(reminder)
return f"提醒创建成功: [{reminder.id[:8]}] {reminder.title} @ {reminder.reminder_at.isoformat()}"
return _run_async(_create())
except Exception as exc:
return f"创建提醒失败: {exc}"
@tool
def create_goal(title: str, goal_date: str | None = None, note: str = "", status: str = "active") -> str:
"""创建指定日期目标。goal_date 格式 YYYY-MM-DD默认今天status 支持 active/done/archived。"""
uid = get_current_user()
parsed_date = _parse_date(goal_date)
async def _create():
async with async_session() as db:
goal = Goal(
user_id=uid,
title=title,
note=note or None,
goal_date=parsed_date.isoformat(),
status=GoalStatus(status),
)
db.add(goal)
await db.commit()
await db.refresh(goal)
return f"目标创建成功: [{goal.id[:8]}] {goal.title} @ {goal.goal_date}"
try:
return _run_async(_create())
except Exception as exc:
return f"创建目标失败: {exc}"
__all__ = [
"get_schedule_day",
"create_todo",
"create_schedule_task",
"create_reminder",
"create_goal",
]

View File

@@ -5,12 +5,14 @@ Agent 工具集 - 知识库 & 图谱相关
由于 LangChain 工具系统是同步的,内部用 run_in_executor 处理 async 逻辑。
"""
from langchain_core.tools import tool
from concurrent.futures import ThreadPoolExecutor
from app.database import async_session
from app.agents.context import get_current_user
import asyncio
from langchain_core.tools import tool
from app.agents.context import get_current_user
from app.database import async_session
_executor = ThreadPoolExecutor(max_workers=4)
@@ -151,9 +153,56 @@ def hybrid_search(query: str, top_k: int = 5) -> str:
return f"混合搜索失败: {str(e)}"
@tool
def web_search(query: str, top_k: int = 5) -> str:
"""
通过 SearxNG 搜索外部网页信息,返回标题、链接和摘要。
Args:
query: 搜索关键词
top_k: 返回结果数量,默认 5 条
Returns:
适合模型综合的网页结果文本
"""
from app.services.web_search_service import (
WebSearchConfigurationError,
WebSearchRequestError,
WebSearchService,
)
async def _search():
service = WebSearchService()
results = await service.search(query, limit=top_k)
if not results:
return "未找到相关网页结果。"
texts = []
for index, result in enumerate(results, 1):
source = f"\n来源: {result.source}" if result.source else ""
published_at = f"\n时间: {result.published_at}" if result.published_at else ""
snippet = result.snippet or "(无摘要)"
texts.append(
f"[{index}] {result.title}\n"
f"链接: {result.url}{source}{published_at}\n"
f"摘要: {snippet}"
)
return "\n\n---\n\n".join(texts)
try:
return _run_async(_search(), timeout=30)
except WebSearchConfigurationError as exc:
return f"网页搜索不可用: {exc}"
except WebSearchRequestError as exc:
return f"网页搜索失败: {exc}"
except Exception as exc:
return f"网页搜索失败: {exc}"
__all__ = [
"search_knowledge",
"get_knowledge_graph_context",
"build_knowledge_graph",
"hybrid_search",
"web_search",
]

View File

@@ -1,22 +1,85 @@
"""Agent 工具集 - 任务相关"""
from langchain_core.tools import tool
from app.database import async_session
from app.models.task import Task
from app.agents.context import get_current_user
from sqlalchemy import select
import asyncio
from datetime import UTC, datetime
_executor = None
from app.models.base import utc_now
from langchain_core.tools import tool
from sqlalchemy import select
from app.agents.context import get_current_user
from app.database import async_session
from app.models.task import Task, TaskPriority, TaskStatus
import asyncio
from concurrent.futures import ThreadPoolExecutor
_executor = ThreadPoolExecutor(max_workers=4)
def _run_async(coro, timeout: int = 30):
try:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
return future.result(timeout=timeout)
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
def _normalize_title(title: str | None, content: str | None) -> str:
resolved = (title or content or "").strip()
if not resolved:
raise ValueError("title 不能为空")
return resolved
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
resolved = (due_date or date_value or "").strip()
return resolved or None
def _parse_due_date(value: str | None) -> datetime | None:
if not value:
return None
normalized = value.strip()
if not normalized:
return None
if "T" not in normalized:
normalized = f"{normalized}T00:00:00"
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
if parsed.tzinfo is not None:
return parsed.astimezone(UTC).replace(tzinfo=None)
return parsed
def _normalize_priority(priority: int | str | None) -> TaskPriority:
if priority is None or priority == "":
return TaskPriority.MEDIUM
if isinstance(priority, TaskPriority):
return priority
if isinstance(priority, int):
return {
1: TaskPriority.LOW,
2: TaskPriority.MEDIUM,
3: TaskPriority.HIGH,
4: TaskPriority.URGENT,
}.get(priority, TaskPriority.MEDIUM)
normalized = str(priority).strip().lower()
if not normalized:
return TaskPriority.MEDIUM
return TaskPriority(normalized)
def _normalize_status(status: str) -> TaskStatus:
normalized = status.strip().lower()
return TaskStatus(normalized)
def _format_status(value: TaskStatus | str) -> str:
return value.value if hasattr(value, "value") else str(value)
def _format_priority(value: TaskPriority | str) -> str:
return value.value if hasattr(value, "value") else str(value)
@tool
@@ -25,7 +88,7 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
获取用户当前的任务列表。
Args:
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
limit: 返回数量默认20
Returns:
@@ -33,67 +96,82 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
"""
uid = get_current_user()
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if status:
query = query.where(Task.status == status)
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
if not tasks:
return "暂无任务"
lines = []
for t in tasks:
lines.append(
f"- [{t.id[:8]}] {t.title} | "
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or ''}"
)
return "\n".join(lines)
try:
resolved_status = _normalize_status(status) if status else None
async def _get():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if resolved_status:
query = query.where(Task.status == resolved_status)
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
result = await db.execute(query)
tasks = result.scalars().all()
if not tasks:
return "暂无任务"
lines = []
for t in tasks:
lines.append(
f"- [{t.id[:8]}] {t.title} | "
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or ''}"
)
return "\n".join(lines)
return _run_async(_get())
except Exception as e:
return f"获取任务失败: {str(e)}"
@tool
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
def create_task(
title: str = "",
description: str = "",
priority: int | str = 2,
due_date: str | None = None,
content: str = "",
date: str | None = None,
) -> str:
"""
创建新任务。
Args:
title: 任务标题(必填)
title: 任务标题(必填,兼容 content 作为别名
description: 任务描述
priority: 优先级 1-4,数字越大优先级越高默认2
due_date: 截止日期,格式 YYYY-MM-DD
priority: 优先级,支持 1-4 或 low/medium/high/urgent默认2
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
content: title 的兼容别名
date: due_date 的兼容别名
Returns:
创建结果
"""
uid = get_current_user()
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=title,
description=description,
priority=priority,
due_date=due_date,
status="todo",
)
db.add(task)
await db.commit()
await db.refresh(task)
return f"任务创建成功: [{task.id[:8]}] {title}"
try:
resolved_title = _normalize_title(title, content)
resolved_due_date = _normalize_due_date(due_date, date)
resolved_priority = _normalize_priority(priority)
async def _create():
async with async_session() as db:
task = Task(
user_id=uid,
title=resolved_title,
description=description or content or None,
priority=resolved_priority,
due_date=_parse_due_date(resolved_due_date),
status=TaskStatus.TODO,
)
db.add(task)
await db.commit()
await db.refresh(task)
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
return _run_async(_create())
except Exception as e:
return f"创建任务失败: {str(e)}"
@@ -106,34 +184,37 @@ def update_task_status(task_id: str, status: str) -> str:
Args:
task_id: 任务ID完整ID或前8位
status: 新状态 (todo/in_progress/done/blocked)
status: 新状态 (todo/in_progress/done/cancelled)
Returns:
更新结果
"""
uid = get_current_user()
async def _update():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if len(task_id) == 8:
query = query.where(Task.id.like(f"{task_id}%"))
else:
query = query.where(Task.id == task_id)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return f"任务不存在: {task_id}"
task.status = status
await db.commit()
return f"任务状态已更新: {task.title} -> {status}"
try:
resolved_status = _normalize_status(status)
async def _update():
async with async_session() as db:
from app.models.user import User
query = (
select(Task)
.join(User, User.id == Task.user_id)
.where(User.id == uid)
)
if len(task_id) == 8:
query = query.where(Task.id.like(f"{task_id}%"))
else:
query = query.where(Task.id == task_id)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return f"任务不存在: {task_id}"
task.status = resolved_status
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
await db.commit()
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
return _run_async(_update())
except Exception as e:
return f"更新任务失败: {str(e)}"

View File

@@ -0,0 +1,269 @@
from __future__ import annotations
import json
import re
from datetime import UTC, date, datetime, time, timedelta
from langchain_core.tools import tool
_WEEKDAY_MAP = {"": 0, "": 1, "": 2, "": 3, "": 4, "": 5, "": 6, "": 6}
_DEFAULT_HOUR_BY_PERIOD = {
"morning": 9,
"noon": 12,
"afternoon": 15,
"evening": 20,
}
_TIME_KEYWORDS = ("今天", "明天", "后天", "本周", "这周", "下周", "", "星期", "", "", "早上", "上午", "中午", "下午", "晚上", "今晚", "", ":", "")
def _parse_datetime(value: str) -> datetime:
normalized = value.strip().replace("Z", "+00:00")
return datetime.fromisoformat(normalized)
def extract_reference_datetime(current_datetime_context: str | None) -> datetime:
context = (current_datetime_context or "").strip()
if context:
for pattern in (r"current_time_utc:\s*(\S+)", r"CURRENT_TIME:\s*(\S+)", r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2}))"):
match = re.search(pattern, context)
if match:
return _parse_datetime(match.group(1))
return datetime.now(UTC)
def _normalize_local_iso(value: datetime) -> str:
return value.replace(tzinfo=None).isoformat(timespec="seconds")
def _normalize_datetime_iso(value: datetime) -> str:
if value.tzinfo is not None:
return value.isoformat(timespec="seconds")
return _normalize_local_iso(value)
def _normalize_date_iso(value: date) -> str:
return value.isoformat()
def _is_iso_datetime(value: str) -> bool:
try:
parsed = _parse_datetime(value)
except ValueError:
return False
return isinstance(parsed, datetime)
def _is_iso_date(value: str) -> bool:
try:
date.fromisoformat(value.strip())
return True
except ValueError:
return False
def _has_explicit_time(text: str) -> bool:
return bool(
re.search(r"\d{1,2}[:]\d{2}", text)
or re.search(r"\d{1,2}点(?:半|(?:\d{1,2})分?)?", text)
or any(keyword in text for keyword in ("早上", "上午", "中午", "下午", "晚上", "今晚"))
)
def _detect_period(text: str) -> str | None:
if any(keyword in text for keyword in ("晚上", "今晚")):
return "evening"
if "下午" in text:
return "afternoon"
if "中午" in text:
return "noon"
if any(keyword in text for keyword in ("早上", "上午", "早晨", "清晨")):
return "morning"
return None
def _resolve_time(text: str) -> tuple[time, bool, str | None]:
period = _detect_period(text)
colon_match = re.search(r"(\d{1,2})[:](\d{2})", text)
if colon_match:
hour = int(colon_match.group(1))
minute = int(colon_match.group(2))
if period in {"afternoon", "evening"} and hour < 12:
hour += 12
return time(hour=hour, minute=minute), False, period
half_match = re.search(r"(\d{1,2})点半", text)
if half_match:
hour = int(half_match.group(1))
if period in {"afternoon", "evening"} and hour < 12:
hour += 12
return time(hour=hour, minute=30), False, period
dot_match = re.search(r"(\d{1,2})点(?:(\d{1,2})分?)?", text)
if dot_match:
hour = int(dot_match.group(1))
minute = int(dot_match.group(2) or 0)
if period in {"afternoon", "evening"} and hour < 12:
hour += 12
if period == "noon" and hour < 11:
hour += 12
return time(hour=hour, minute=minute), False, period
if period:
return time(hour=_DEFAULT_HOUR_BY_PERIOD[period], minute=0), True, period
return time(hour=9, minute=0), True, None
def _resolve_date(text: str, reference: datetime) -> tuple[date, str]:
stripped = text.strip()
if _is_iso_date(stripped):
return date.fromisoformat(stripped), "explicit_date"
month_day_match = re.search(r"(\d{1,2})月(\d{1,2})日", stripped)
if month_day_match:
month = int(month_day_match.group(1))
day = int(month_day_match.group(2))
candidate = date(reference.year, month, day)
if candidate < reference.date() - timedelta(days=1):
candidate = date(reference.year + 1, month, day)
return candidate, "explicit_month_day"
if "后天" in stripped:
return reference.date() + timedelta(days=2), "relative_day"
if "明天" in stripped:
return reference.date() + timedelta(days=1), "relative_day"
if "今天" in stripped:
return reference.date(), "relative_day"
weekday_match = re.search(r"((?:本周|这周|下周)?)(?:周|星期)([一二三四五六日天])", stripped)
if weekday_match:
prefix = weekday_match.group(1)
weekday = _WEEKDAY_MAP[weekday_match.group(2)]
current_weekday = reference.date().weekday()
delta = weekday - current_weekday
if prefix == "下周":
delta += 7 if delta <= 0 else 7
elif prefix in {"本周", "这周"}:
if delta < 0:
delta += 7
elif delta < 0:
delta += 7
return reference.date() + timedelta(days=delta), "relative_weekday"
return reference.date(), "reference_day"
def resolve_time_expression_data(
expression: str,
*,
current_datetime_context: str | None = None,
prefer: str = "datetime",
) -> dict:
text = (expression or "").strip()
if not text:
raise ValueError("expression 不能为空")
reference = extract_reference_datetime(current_datetime_context)
if _is_iso_datetime(text):
parsed = _parse_datetime(text)
return {
"expression": text,
"reference_time": reference.isoformat(),
"grain": "datetime",
"resolved_date": _normalize_date_iso(parsed.date()),
"resolved_datetime": _normalize_datetime_iso(parsed),
"assumed_time": False,
"reason": "explicit_datetime",
}
if _is_iso_date(text):
parsed_date = date.fromisoformat(text)
return {
"expression": text,
"reference_time": reference.isoformat(),
"grain": "date",
"resolved_date": _normalize_date_iso(parsed_date),
"resolved_datetime": None,
"assumed_time": False,
"reason": "explicit_date",
}
resolved_date, date_reason = _resolve_date(text, reference)
resolved_time, assumed_time, period = _resolve_time(text)
has_explicit_time = _has_explicit_time(text)
grain = "date" if prefer == "date" and not has_explicit_time else "datetime"
resolved_dt = datetime.combine(resolved_date, resolved_time)
note = date_reason
if period:
note = f"{note}:{period}"
if assumed_time:
note = f"{note}:assumed_time"
return {
"expression": text,
"reference_time": reference.isoformat(),
"grain": grain,
"resolved_date": _normalize_date_iso(resolved_date),
"resolved_datetime": None if grain == "date" else _normalize_local_iso(resolved_dt),
"assumed_time": assumed_time,
"reason": note,
}
@tool
def resolve_time_expression(
expression: str,
current_datetime_context: str = "",
prefer: str = "datetime",
) -> str:
"""解析中文自然语言时间表达,基于当前参考时间返回明确的日期或 datetime。prefer 支持 datetime/date。"""
try:
payload = resolve_time_expression_data(
expression,
current_datetime_context=current_datetime_context or None,
prefer=prefer,
)
return json.dumps(payload, ensure_ascii=False)
except Exception as exc:
return json.dumps(
{
"expression": expression,
"error": str(exc),
},
ensure_ascii=False,
)
def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_context: str | None) -> dict:
normalized = dict(args)
if tool_name == "create_reminder":
raw_value = next((normalized.get(key) for key in ("reminder_at", "datetime", "at", "remind_at", "time") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
if raw_value and not _is_iso_datetime(raw_value):
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
normalized["reminder_at"] = payload["resolved_datetime"]
return normalized
if tool_name in {"create_schedule_task", "create_task"}:
raw_value = next((normalized.get(key) for key in ("due_date", "date") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
if raw_value and not _is_iso_datetime(raw_value) and not _is_iso_date(raw_value):
prefer = "datetime" if tool_name == "create_schedule_task" or _has_explicit_time(raw_value) else "date"
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer=prefer)
normalized["due_date"] = payload["resolved_datetime"] or payload["resolved_date"]
return normalized
if tool_name in {"create_todo", "create_goal", "get_schedule_day"}:
field_name = {
"create_todo": "todo_date",
"create_goal": "goal_date",
"get_schedule_day": "target_date",
}[tool_name]
raw_value = normalized.get(field_name)
if isinstance(raw_value, str) and raw_value.strip() and not _is_iso_date(raw_value):
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="date")
normalized[field_name] = payload["resolved_date"]
return normalized
return normalized
__all__ = ["resolve_time_expression", "resolve_time_expression_data", "normalize_tool_time_arguments", "extract_reference_datetime"]

View File

@@ -15,7 +15,9 @@ def _resolve_path(value: str) -> str:
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
model_config = SettingsConfigDict(
env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore"
)
# === 应用基础 ===
APP_NAME: str = "Jarvis"
@@ -75,6 +77,8 @@ class Settings(BaseSettings):
# === 向量化 ===
EMBEDDING_MODEL: str = "text-embedding-3-small"
EMBEDDING_BASE_URL: str = "https://api.openai.com/v1"
EMBEDDING_API_KEY: str = ""
CHUNK_SIZE: int = 500
CHUNK_OVERLAP: int = 50
@@ -86,6 +90,17 @@ class Settings(BaseSettings):
# === NAS 部署 ===
NAS_DATA_ROOT: str = "/data/jarvis"
# === Web Search / SearxNG ===
WEB_SEARCH_ENABLED: bool = False
WEB_SEARCH_PROVIDER: str = "searxng"
SEARXNG_BASE_URL: str = ""
SEARXNG_AUTH_TYPE: Literal["none", "bearer", "basic"] = "none"
SEARXNG_AUTH_TOKEN: str = ""
SEARXNG_BASIC_USER: str = ""
SEARXNG_BASIC_PASSWORD: str = ""
WEB_SEARCH_DEFAULT_LIMIT: int = 5
WEB_SEARCH_TIMEOUT_SECONDS: int = 10
settings = Settings()
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)

View File

@@ -40,6 +40,8 @@ async def init_db():
await ensure_document_columns(conn)
await ensure_user_columns(conn)
await ensure_forum_columns(conn)
await ensure_agent_columns(conn)
await ensure_skill_columns(conn)
async def ensure_log_columns(conn):
@@ -139,6 +141,55 @@ async def ensure_forum_columns(conn):
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
async def ensure_agent_columns(conn):
rows = await _get_table_info(conn, 'agents')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
'selected_skill_ids': "ALTER TABLE agents ADD COLUMN selected_skill_ids JSON DEFAULT '[]' NOT NULL",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
async def ensure_skill_columns(conn):
rows = await _get_table_info(conn, 'skills')
if not rows:
return
columns = {row[1] for row in rows}
required_columns = {
'required_context': "ALTER TABLE skills ADD COLUMN required_context JSON DEFAULT '[]' NOT NULL",
'output_format': "ALTER TABLE skills ADD COLUMN output_format TEXT",
'is_builtin': "ALTER TABLE skills ADD COLUMN is_builtin BOOLEAN DEFAULT 0 NOT NULL",
'team_id': "ALTER TABLE skills ADD COLUMN team_id VARCHAR(36)",
}
for column, ddl in required_columns.items():
if column not in columns:
await conn.execute(text(ddl))
await conn.execute(text("UPDATE skills SET agent_type = 'schedule_planner' WHERE agent_type = 'planner'"))
builtin_names = [
'今日重点拆解',
'周计划编排',
'时间冲突分析',
'任务执行 SOP',
'外部交互推进',
'知识检索摘要',
'图谱沉淀策略',
'风险识别模板',
'趋势洞察模板',
]
for name in builtin_names:
await conn.execute(
text("UPDATE skills SET is_builtin = 1 WHERE name = :name"),
{'name': name},
)
async def _backfill_usernames(conn):
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
users = result.fetchall()

View File

@@ -14,6 +14,9 @@ from app.routers import (
graph_router,
agent_router,
todo_router,
reminder_router,
goal_router,
schedule_center_router,
settings_router,
folder_router,
skill_router,
@@ -23,7 +26,7 @@ from app.routers import (
)
from app.routers.scheduler import router as scheduler_router
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
from app.services.admin_bootstrap_service import ensure_admin_user
from app.services.admin_bootstrap_service import ensure_admin_user, ensure_builtin_skills
from app.config import settings
from app.logging_utils import (
setup_logging,
@@ -53,6 +56,7 @@ async def run_startup() -> None:
await init_db()
async with async_session() as session:
await ensure_admin_user(session, settings)
await ensure_builtin_skills(session)
await persist_system_log(
message="application_started",
source="app",
@@ -103,6 +107,9 @@ app.include_router(forum_router)
app.include_router(graph_router)
app.include_router(agent_router)
app.include_router(todo_router)
app.include_router(reminder_router)
app.include_router(goal_router)
app.include_router(schedule_center_router)
app.include_router(settings_router)
app.include_router(folder_router)
app.include_router(skill_router)

View File

@@ -1,5 +1,6 @@
from app.models.base import Base
from app.models.user import User
from app.models.folder import Folder
from app.models.document import Document, DocumentChunk
from app.models.task import Task, TaskHistory
from app.models.forum import ForumPost, ForumReply
@@ -17,11 +18,14 @@ from app.models.brain import (
brain_memory_sources,
)
from app.models.todo import DailyTodo, TodoSource
from app.models.reminder import Reminder, ReminderStatus
from app.models.goal import Goal, GoalStatus
from app.models.log import Log, LogType, LogLevel
__all__ = [
"Base",
"User",
"Folder",
"Document",
"DocumentChunk",
"Task",
@@ -45,6 +49,10 @@ __all__ = [
"brain_memory_sources",
"DailyTodo",
"TodoSource",
"Reminder",
"ReminderStatus",
"Goal",
"GoalStatus",
"Log",
"LogType",
"LogLevel",

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer, JSON
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
@@ -7,9 +7,10 @@ class Agent(BaseModel):
__tablename__ = "agents"
name = Column(String(100), nullable=False)
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
role = Column(String(100), nullable=False) # master, schedule_planner, executor, librarian, analyst
description = Column(Text, nullable=True)
system_prompt = Column(Text, nullable=False)
selected_skill_ids = Column(JSON, default=list, nullable=False)
is_active = Column(Boolean, default=True)
is_default = Column(Boolean, default=False)

View File

@@ -0,0 +1,21 @@
from enum import Enum as PyEnum
from sqlalchemy import Column, Enum, ForeignKey, String, Text
from app.models.base import BaseModel
class GoalStatus(str, PyEnum):
ACTIVE = "active"
DONE = "done"
ARCHIVED = "archived"
class Goal(BaseModel):
__tablename__ = "goals"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(255), nullable=False)
note = Column(Text, nullable=True)
goal_date = Column(String(10), nullable=False, index=True)
status = Column(Enum(GoalStatus), default=GoalStatus.ACTIVE, nullable=False)

View File

@@ -0,0 +1,21 @@
from enum import Enum as PyEnum
from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, String, Text
from app.models.base import BaseModel
class ReminderStatus(str, PyEnum):
PENDING = "pending"
DONE = "done"
class Reminder(BaseModel):
__tablename__ = "reminders"
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
title = Column(String(255), nullable=False)
note = Column(Text, nullable=True)
reminder_at = Column(DateTime, nullable=False, index=True)
status = Column(Enum(ReminderStatus), default=ReminderStatus.PENDING, nullable=False)
is_dismissed = Column(Boolean, default=False, nullable=False)

View File

@@ -9,11 +9,12 @@ class Skill(BaseModel):
name = Column(String(100), nullable=False, unique=True, index=True)
description = Column(Text, nullable=True) # 供 LLM 理解用途
instructions = Column(Text, nullable=False) # Agent 执行时的指令模板
agent_type = Column(String(50), nullable=False, index=True) # master/planner/executor/librarian/analyst
agent_type = Column(String(50), nullable=False, index=True) # master/schedule_planner/executor/librarian/analyst
tools = Column(JSON, default=list) # 引用的工具名称列表
required_context = Column(JSON, default=list) # 需要的前置数据
output_format = Column(Text, nullable=True) # 输出格式要求
visibility = Column(String(20), default="private") # private/team/market
is_builtin = Column(Boolean, default=False, nullable=False)
team_id = Column(String(36), ForeignKey("users.id"), nullable=True)
is_active = Column(Boolean, default=True)
owner_id = Column(String(36), ForeignKey("users.id"), nullable=False)

View File

@@ -6,6 +6,9 @@ from app.routers.forum import router as forum_router
from app.routers.graph import router as graph_router
from app.routers.agent import router as agent_router
from app.routers.todo import router as todo_router
from app.routers.reminder import router as reminder_router
from app.routers.goal import router as goal_router
from app.routers.schedule_center import router as schedule_center_router
from app.routers.settings import router as settings_router
from app.routers.folder import router as folder_router
from app.routers.skill import router as skill_router

View File

@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models.agent import Agent
from app.models.skill import Skill
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
@@ -13,9 +14,9 @@ _agent_call_counts: dict[str, int] = {}
_agent_current_tasks: dict[str, str | None] = {}
_agent_statuses: dict[str, str] = {}
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
SUB_COMMANDERS_BY_ROLE = {
"planner": ["planner_scope", "planner_steps"],
"schedule_planner": ["schedule_analysis", "schedule_planning"],
"executor": ["executor_tasks", "executor_forum"],
"librarian": ["librarian_retrieval", "librarian_graph"],
"analyst": ["analyst_progress", "analyst_insights"],
@@ -88,10 +89,10 @@ async def get_agent_config(
agent = result.scalar_one_or_none()
if not agent:
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
defaults = {
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
@@ -107,6 +108,7 @@ async def get_agent_config(
system_prompt=prompt,
enabled=True,
is_active=True,
selected_skill_ids=[],
)
return AgentConfigOut(
id=agent.role,
@@ -116,6 +118,7 @@ async def get_agent_config(
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
selected_skill_ids=agent.selected_skill_ids or [],
)
@@ -141,6 +144,19 @@ async def update_agent_config(
if data.enabled is not None:
agent.is_active = data.enabled
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
if data.selected_skill_ids is not None:
if data.selected_skill_ids:
result = await db.execute(
select(Skill.id).where(
Skill.id.in_(data.selected_skill_ids),
Skill.owner_id == current_user.id,
)
)
allowed_skill_ids = set(result.scalars().all())
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
if invalid_skill_ids:
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
agent.selected_skill_ids = data.selected_skill_ids
await db.commit()
await db.refresh(agent)
@@ -152,6 +168,7 @@ async def update_agent_config(
system_prompt=agent.system_prompt,
enabled=agent.is_active,
is_active=agent.is_active,
selected_skill_ids=agent.selected_skill_ids or [],
)

View File

@@ -6,6 +6,7 @@ from sqlalchemy.exc import IntegrityError
from app.database import get_db
from app.models.user import User
from app.schemas.auth import UserCreate, UserOut, Token
from app.services.admin_bootstrap_service import ensure_builtin_skills
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
from app.config import settings
@@ -58,6 +59,7 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
await db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
await db.refresh(user)
await ensure_builtin_skills(db)
return user
@@ -97,10 +99,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSessi
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
await ensure_builtin_skills(db)
access_token = create_access_token(data={"sub": user.id})
return Token(access_token=access_token)
@router.get("/me", response_model=UserOut)
async def get_me(current_user: User = Depends(get_current_user)):
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
await ensure_builtin_skills(db)
return current_user

View File

@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
@@ -92,13 +93,16 @@ async def chat(
):
"""简单版对话(非流式)"""
agent_svc = AgentService(db)
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
try:
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
# 更新对话消息计数
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
@@ -126,13 +130,17 @@ async def chat_stream(
agent_svc = AgentService(db)
async def stream_generator():
conv_id, msg_id, stream = await agent_svc.chat(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
try:
conv_id, msg_id, stream = await agent_svc.chat(
user_id=current_user.id,
message=data.message,
conversation_id=data.conversation_id,
file_ids=data.file_ids,
model_name=data.model_name,
)
except ValueError as exc:
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
return
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"

View File

@@ -0,0 +1,92 @@
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.goal import Goal
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.goal import GoalCreate, GoalListOut, GoalOut, GoalUpdate
router = APIRouter(prefix="/api/goals", tags=["目标"])
@router.get("", response_model=GoalListOut)
async def list_goals(
date_str: str = Query(...),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date.fromisoformat(date_str).isoformat()
query = (
select(Goal)
.where(Goal.user_id == current_user.id)
.where(Goal.goal_date == target_date)
.order_by(Goal.created_at.desc())
)
items = (await db.execute(query)).scalars().all()
return GoalListOut(items=items)
@router.post("", response_model=GoalOut, status_code=201)
async def create_goal(
data: GoalCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
goal = Goal(
user_id=current_user.id,
title=data.title,
note=data.note,
goal_date=data.goal_date.isoformat(),
status=data.status,
)
db.add(goal)
await db.commit()
await db.refresh(goal)
return goal
@router.patch("/{goal_id}", response_model=GoalOut)
async def update_goal(
goal_id: str,
data: GoalUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
)
goal = result.scalar_one_or_none()
if not goal:
raise HTTPException(status_code=404, detail="目标不存在")
payload = data.model_dump(exclude_none=True)
if "goal_date" in payload:
payload["goal_date"] = payload["goal_date"].isoformat()
for field, value in payload.items():
setattr(goal, field, value)
await db.commit()
await db.refresh(goal)
return goal
@router.delete("/{goal_id}", status_code=204)
async def delete_goal(
goal_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
)
goal = result.scalar_one_or_none()
if not goal:
raise HTTPException(status_code=404, detail="目标不存在")
await db.delete(goal)
await db.commit()

View File

@@ -0,0 +1,90 @@
from datetime import date, datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.reminder import Reminder
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.reminder import ReminderCreate, ReminderListOut, ReminderOut, ReminderUpdate
router = APIRouter(prefix="/api/reminders", tags=["提醒"])
@router.get("", response_model=ReminderListOut)
async def list_reminders(
date_str: str = Query(...),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date.fromisoformat(date_str)
start = datetime.combine(target_date, datetime.min.time())
end = datetime.combine(target_date, datetime.max.time())
query = (
select(Reminder)
.where(Reminder.user_id == current_user.id)
.where(Reminder.reminder_at >= start)
.where(Reminder.reminder_at <= end)
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
)
items = (await db.execute(query)).scalars().all()
return ReminderListOut(items=items)
@router.post("", response_model=ReminderOut, status_code=201)
async def create_reminder(
data: ReminderCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
reminder = Reminder(
user_id=current_user.id,
title=data.title,
note=data.note,
reminder_at=data.reminder_at,
)
db.add(reminder)
await db.commit()
await db.refresh(reminder)
return reminder
@router.patch("/{reminder_id}", response_model=ReminderOut)
async def update_reminder(
reminder_id: str,
data: ReminderUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
)
reminder = result.scalar_one_or_none()
if not reminder:
raise HTTPException(status_code=404, detail="提醒不存在")
for field, value in data.model_dump(exclude_none=True).items():
setattr(reminder, field, value)
await db.commit()
await db.refresh(reminder)
return reminder
@router.delete("/{reminder_id}", status_code=204)
async def delete_reminder(
reminder_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
)
reminder = result.scalar_one_or_none()
if not reminder:
raise HTTPException(status_code=404, detail="提醒不存在")
await db.delete(reminder)
await db.commit()

View File

@@ -0,0 +1,160 @@
from calendar import monthrange
from datetime import UTC, date, datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.goal import Goal
from app.models.reminder import Reminder
from app.models.task import Task, TaskPriority
from app.models.todo import DailyTodo
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.schedule_center import (
ScheduleCenterDateOut,
ScheduleCenterDaySummary,
ScheduleCenterMonthOut,
)
router = APIRouter(prefix="/api/schedule-center", tags=["调度中心"])
def _build_summary(
target_date: str,
todos: list[DailyTodo],
tasks: list[Task],
reminders: list[Reminder],
goals: list[Goal],
) -> ScheduleCenterDaySummary:
return ScheduleCenterDaySummary(
date=target_date,
todo_total=len(todos),
todo_completed=sum(1 for item in todos if item.is_completed),
task_due_total=len(tasks),
high_priority_total=sum(1 for item in tasks if item.priority in {TaskPriority.HIGH, TaskPriority.URGENT}),
reminder_total=len(reminders),
goal_total=len(goals),
)
@router.get("/month", response_model=ScheduleCenterMonthOut)
async def get_month_schedule(
year: int = Query(..., ge=2000, le=2100),
month: int = Query(..., ge=1, le=12),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
month_start = date(year, month, 1)
days_in_month = monthrange(month_start.year, month_start.month)[1]
start_key = month_start.isoformat()
end_key = month_start.replace(day=days_in_month).isoformat()
start_dt = datetime.combine(month_start, datetime.min.time())
end_dt = datetime.combine(month_start.replace(day=days_in_month), datetime.max.time())
todos = (await db.execute(
select(DailyTodo).where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date >= start_key, DailyTodo.todo_date <= end_key)
)).scalars().all()
tasks = (await db.execute(
select(Task).where(
Task.user_id == current_user.id,
Task.due_date.is_not(None),
Task.due_date >= start_dt,
Task.due_date <= end_dt,
)
)).scalars().all()
reminders = (await db.execute(
select(Reminder).where(
Reminder.user_id == current_user.id,
Reminder.reminder_at >= start_dt,
Reminder.reminder_at <= end_dt,
)
)).scalars().all()
goals = (await db.execute(
select(Goal).where(Goal.user_id == current_user.id, Goal.goal_date >= start_key, Goal.goal_date <= end_key)
)).scalars().all()
todo_map: dict[str, list[DailyTodo]] = {}
for item in todos:
todo_map.setdefault(item.todo_date, []).append(item)
task_map: dict[str, list[Task]] = {}
for item in tasks:
key = item.due_date.date().isoformat()
task_map.setdefault(key, []).append(item)
reminder_map: dict[str, list[Reminder]] = {}
for item in reminders:
key = item.reminder_at.date().isoformat()
reminder_map.setdefault(key, []).append(item)
goal_map: dict[str, list[Goal]] = {}
for item in goals:
goal_map.setdefault(item.goal_date, []).append(item)
days = []
for day in range(1, days_in_month + 1):
date_key = month_start.replace(day=day).isoformat()
days.append(_build_summary(
date_key,
todo_map.get(date_key, []),
task_map.get(date_key, []),
reminder_map.get(date_key, []),
goal_map.get(date_key, []),
))
return ScheduleCenterMonthOut(month=f"{year:04d}-{month:02d}", days=days)
@router.get("/date", response_model=ScheduleCenterDateOut)
async def get_date_schedule(
date_str: date = Query(...),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
target_date = date_str
start_dt = datetime.combine(target_date, datetime.min.time())
end_dt = datetime.combine(target_date, datetime.max.time())
date_key = target_date.isoformat()
todos = (await db.execute(
select(DailyTodo)
.where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date == date_key)
.order_by(DailyTodo.created_at.desc())
)).scalars().all()
tasks = (await db.execute(
select(Task)
.where(
Task.user_id == current_user.id,
Task.due_date.is_not(None),
Task.due_date >= start_dt,
Task.due_date <= end_dt,
)
.order_by(Task.created_at.desc())
)).scalars().all()
reminders = (await db.execute(
select(Reminder)
.where(
Reminder.user_id == current_user.id,
Reminder.reminder_at >= start_dt,
Reminder.reminder_at <= end_dt,
)
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
)).scalars().all()
goals = (await db.execute(
select(Goal)
.where(Goal.user_id == current_user.id, Goal.goal_date == date_key)
.order_by(Goal.created_at.desc())
)).scalars().all()
summary = _build_summary(date_key, todos, tasks, reminders, goals)
return ScheduleCenterDateOut(
date=date_key,
todos=todos,
tasks=tasks,
reminders=reminders,
goals=goals,
summary=summary,
generated_at=datetime.now(UTC),
)

View File

@@ -1,11 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.skill import Skill
from app.models.user import User
from app.routers.auth import get_current_user
from app.schemas.skill import SkillCreate, SkillOut, SkillUpdate
from app.services.admin_bootstrap_service import ensure_builtin_skills
from app.services.skill_service import SkillService
router = APIRouter(prefix="/api/skills", tags=["Skill"])
@@ -37,13 +39,23 @@ async def create_skill(
@router.get("", response_model=list[SkillOut])
async def list_skills(
agent_type: str | None = Query(default=None),
visibility: str | None = Query(default=None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(Skill).where(Skill.owner_id == current_user.id).order_by(Skill.created_at.desc())
)
return result.scalars().all()
service = SkillService(db)
return await service.list_for_user(current_user.id, agent_type=agent_type, visibility=visibility)
@router.post("/bootstrap-builtin", response_model=list[SkillOut])
async def bootstrap_builtin_skills(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await ensure_builtin_skills(db, preferred_owner_id=current_user.id)
service = SkillService(db)
return await service.list_for_user(current_user.id)
@router.get("/{skill_id}", response_model=SkillOut)

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException
from datetime import UTC, date, datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from app.database import get_db
from app.models.task import Task, TaskStatus
from app.models.user import User
@@ -13,12 +15,28 @@ router = APIRouter(prefix="/api/tasks", tags=["看板"])
@router.get("", response_model=list[TaskOut])
async def list_tasks(
status: TaskStatus | None = None,
due_date: date | None = Query(default=None),
date_from: date | None = Query(default=None),
date_to: date | None = Query(default=None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
query = select(Task).where(Task.user_id == current_user.id)
if status:
query = query.where(Task.status == status)
if due_date:
start = datetime.combine(due_date, datetime.min.time())
end = datetime.combine(due_date, datetime.max.time())
query = query.where(Task.due_date.is_not(None), Task.due_date >= start, Task.due_date <= end)
else:
start = datetime.combine(date_from, datetime.min.time()) if date_from else None
end = datetime.combine(date_to, datetime.max.time()) if date_to else None
if start and end and start > end:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
if start is not None:
query = query.where(Task.due_date.is_not(None), Task.due_date >= start)
if end is not None:
query = query.where(Task.due_date.is_not(None), Task.due_date <= end)
query = query.order_by(desc(Task.created_at))
result = await db.execute(query)
return result.scalars().all()
@@ -64,10 +82,10 @@ async def update_task(
if field == "tags":
setattr(task, field, json.dumps(value))
elif field == "status" and value == TaskStatus.DONE:
from datetime import UTC, datetime
task.completed_at = datetime.now(UTC)
setattr(task, field, value)
else:
elif field == "status":
task.completed_at = None
setattr(task, field, value)
await db.commit()

View File

@@ -1,7 +1,8 @@
from datetime import UTC, date, datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from datetime import date
from app.database import get_db
from app.models.todo import DailyTodo, TodoSource
from app.models.user import User
@@ -52,7 +53,7 @@ async def create_todo(
user_id=current_user.id,
title=data.title,
source=TodoSource.MANUAL,
todo_date=date.today().isoformat(),
todo_date=(data.todo_date or date.today()).isoformat(),
)
db.add(todo)
await db.commit()
@@ -74,14 +75,11 @@ async def update_todo(
if not todo:
raise HTTPException(status_code=404, detail="待办不存在")
# 历史日期不允许修改
if todo.todo_date != date.today().isoformat():
raise HTTPException(status_code=403, detail="历史待办不可修改")
if data.title is not None:
todo.title = data.title
if data.todo_date is not None:
todo.todo_date = data.todo_date.isoformat()
if data.is_completed is not None:
from datetime import UTC, datetime
todo.is_completed = data.is_completed
todo.completed_at = datetime.now(UTC) if data.is_completed else None
@@ -102,9 +100,6 @@ async def delete_todo(
todo = result.scalar_one_or_none()
if not todo:
raise HTTPException(status_code=404, detail="待办不存在")
if todo.todo_date != date.today().isoformat():
raise HTTPException(status_code=403, detail="历史待办不可删除")
await db.delete(todo)
await db.commit()

View File

@@ -41,6 +41,7 @@ class AgentConfigUpdate(BaseModel):
description: str | None = None
system_prompt: str | None = None
enabled: bool | None = None
selected_skill_ids: list[str] | None = None
class AgentConfigOut(BaseModel):
@@ -51,5 +52,6 @@ class AgentConfigOut(BaseModel):
system_prompt: str
enabled: bool
is_active: bool
selected_skill_ids: list[str]
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,35 @@
from datetime import date, datetime
from pydantic import BaseModel
from app.models.goal import GoalStatus
class GoalCreate(BaseModel):
title: str
goal_date: date
note: str | None = None
status: GoalStatus = GoalStatus.ACTIVE
class GoalUpdate(BaseModel):
title: str | None = None
goal_date: date | None = None
note: str | None = None
status: GoalStatus | None = None
class GoalOut(BaseModel):
id: str
title: str
note: str | None
goal_date: str
status: GoalStatus
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class GoalListOut(BaseModel):
items: list[GoalOut]

View File

@@ -0,0 +1,40 @@
from datetime import date, datetime
from pydantic import BaseModel
from app.models.reminder import ReminderStatus
class ReminderCreate(BaseModel):
title: str
reminder_at: datetime
note: str | None = None
class ReminderUpdate(BaseModel):
title: str | None = None
reminder_at: datetime | None = None
note: str | None = None
status: ReminderStatus | None = None
is_dismissed: bool | None = None
class ReminderOut(BaseModel):
id: str
title: str
note: str | None
reminder_at: datetime
status: ReminderStatus
is_dismissed: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ReminderListOut(BaseModel):
items: list[ReminderOut]
class ReminderDateQuery(BaseModel):
date: date

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from pydantic import BaseModel
from app.schemas.goal import GoalOut
from app.schemas.reminder import ReminderOut
from app.schemas.task import TaskOut
from app.schemas.todo import TodoOut
class ScheduleCenterDaySummary(BaseModel):
date: str
todo_total: int
todo_completed: int
task_due_total: int
high_priority_total: int
reminder_total: int
goal_total: int
class ScheduleCenterMonthOut(BaseModel):
month: str
days: list[ScheduleCenterDaySummary]
class ScheduleCenterDateOut(BaseModel):
date: str
todos: list[TodoOut]
tasks: list[TaskOut]
reminders: list[ReminderOut]
goals: list[GoalOut]
summary: ScheduleCenterDaySummary
generated_at: datetime

View File

@@ -10,7 +10,8 @@ LLMType = Literal["chat", "vlm", "embedding", "rerank"]
# 单个模型配置
class LLMModelConfig(BaseModel):
name: str = "" # 模型名称/别名,用于标识
provider: LLMProviderType = "openai"
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
provider: Optional[LLMProviderType] = None
model: str = ""
base_url: str = ""
api_key: str = ""
@@ -52,7 +53,8 @@ class SettingsOut(BaseModel):
# 测试 LLM 连接请求
class LLMTestIn(BaseModel):
type: LLMType
provider: LLMProviderType
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
provider: Optional[LLMProviderType] = None
model: str
base_url: str
api_key: str

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from pydantic import BaseModel
from typing import Optional
@@ -6,7 +7,7 @@ class SkillCreate(BaseModel):
name: str
description: Optional[str] = None
instructions: str
agent_type: str # master/planner/executor/librarian/analyst
agent_type: str # master/schedule_planner/executor/librarian/analyst
tools: list[str] = []
required_context: list[str] = []
output_format: Optional[str] = None
@@ -39,10 +40,11 @@ class SkillOut(BaseModel):
required_context: list[str]
output_format: Optional[str]
visibility: str
is_builtin: bool
team_id: Optional[str]
is_active: bool
owner_id: str
created_at: str
updated_at: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -1,15 +1,19 @@
from datetime import date, datetime
from pydantic import BaseModel
from datetime import datetime
from app.models.todo import TodoSource
class TodoCreate(BaseModel):
title: str
todo_date: date | None = None
class TodoUpdate(BaseModel):
title: str | None = None
is_completed: bool | None = None
todo_date: date | None = None
class TodoOut(BaseModel):

View File

@@ -2,10 +2,87 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.skill import Skill
from app.models.user import User
from app.services.auth_service import get_password_hash
BUILTIN_SKILLS = [
{
'name': '今日重点拆解',
'description': '帮助日程规划师从上下文中提炼今天最值得推进的事项。',
'instructions': '优先识别今天最关键的 1-3 个重点,说明原因,并给出可执行顺序。',
'agent_type': 'schedule_planner',
'tools': ['calendar', 'tasks'],
'visibility': 'market',
},
{
'name': '周计划编排',
'description': '把本周目标整理成可落地的节奏与时间块。',
'instructions': '将目标拆成周内节奏安排,明确先后顺序、时间块与缓冲。',
'agent_type': 'schedule_planner',
'tools': ['calendar'],
'visibility': 'market',
},
{
'name': '时间冲突分析',
'description': '识别任务、日程与优先级之间的冲突。',
'instructions': '分析冲突来源、影响和推荐取舍,必要时给出替代方案。',
'agent_type': 'schedule_planner',
'tools': ['calendar', 'tasks'],
'visibility': 'market',
},
{
'name': '任务执行 SOP',
'description': '为执行角色提供标准执行步骤和结果回报格式。',
'instructions': '执行前先确认目标与边界,执行中记录关键动作,执行后输出结果、风险与下一步。',
'agent_type': 'executor',
'tools': ['shell', 'api_calls'],
'visibility': 'market',
},
{
'name': '外部交互推进',
'description': '支持论坛、外部接口或内容发布类动作。',
'instructions': '围绕外部交互任务,优先保证动作完整、结果清晰、反馈及时。',
'agent_type': 'executor',
'tools': ['api_calls', 'git'],
'visibility': 'market',
},
{
'name': '知识检索摘要',
'description': '从知识中枢中提炼与当前问题最相关的信息。',
'instructions': '检索后只保留当前决策需要的内容,输出摘要、来源与缺口。',
'agent_type': 'librarian',
'tools': ['web_search', 'database'],
'visibility': 'market',
},
{
'name': '图谱沉淀策略',
'description': '帮助知识管理员把零散信息沉淀为结构化关系。',
'instructions': '识别应沉淀的实体、关系与后续可检索维度。',
'agent_type': 'librarian',
'tools': ['database'],
'visibility': 'market',
},
{
'name': '风险识别模板',
'description': '帮助分析师快速识别当前推进中的风险点。',
'instructions': '从进度、依赖、资源与外部信号中提炼风险,并按严重度排序。',
'agent_type': 'analyst',
'tools': ['database', 'api_calls'],
'visibility': 'market',
},
{
'name': '趋势洞察模板',
'description': '把多源状态汇总为趋势与判断。',
'instructions': '对比近期变化,输出趋势、证据、判断与建议动作。',
'agent_type': 'analyst',
'tools': ['database', 'code_execution'],
'visibility': 'market',
},
]
def _is_bootstrap_enabled(settings) -> bool:
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
@@ -58,3 +135,49 @@ async def ensure_admin_user(db: AsyncSession, settings) -> None:
return
raise
await db.refresh(admin_user)
async def ensure_builtin_skills(db: AsyncSession, preferred_owner_id: str | None = None) -> None:
owner = None
if preferred_owner_id:
owner_result = await db.execute(
select(User).where(User.id == preferred_owner_id, User.is_active == True)
)
owner = owner_result.scalar_one_or_none()
if not owner:
owner_result = await db.execute(
select(User).where(User.is_active == True).order_by(User.is_superuser.desc(), User.created_at.asc())
)
owner = owner_result.scalars().first()
if not owner:
return
existing_result = await db.execute(select(Skill.name))
existing_names = set(existing_result.scalars().all())
missing_skills = [
Skill(
owner_id=owner.id,
name=item['name'],
description=item['description'],
instructions=item['instructions'],
agent_type=item['agent_type'],
tools=item['tools'],
required_context=[],
output_format=None,
visibility=item['visibility'],
is_builtin=True,
team_id=None,
is_active=True,
)
for item in BUILTIN_SKILLS
if item['name'] not in existing_names
]
if not missing_skills:
return
db.add_all(missing_skills)
await db.commit()

View File

@@ -5,18 +5,17 @@ Jarvis Agent 服务层
import json
import uuid
from datetime import datetime
import logging
from datetime import UTC, datetime
from typing import Any, AsyncGenerator
import asyncio
from openai import BadRequestError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from langchain_core.messages import HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
import httpx
from app.database import async_session
from app.logging_utils import summarize_llm_config
from app.models.conversation import Conversation, Message
from app.models.user import User
@@ -24,43 +23,35 @@ from app.agents.graph import get_agent_graph
from app.agents.context import set_current_user, clear_current_user
from app.services import memory_service
from app.services.brain_service import BrainService
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
from app.agents.tools.time_reasoning import extract_reference_datetime
from app.agents.state import initial_state
logger = logging.getLogger(__name__)
def _create_llm_from_config(config: dict):
"""根据用户模型配置创建 LLM 实例"""
provider = config.get("provider", "openai")
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
capabilities = resolve_provider_capabilities(user_llm_config)
error_text = str(error).lower()
markers = [
"invalid chat setting",
"invalid params",
"stream",
"streaming",
"unsupported",
"bad_request_error",
"http 400",
"error code: 400",
]
if provider == "openai" or provider == "deepseek" or provider == "custom":
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
return ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
return ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
# 默认使用 OpenAI
return ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
if isinstance(error, BadRequestError):
return (
getattr(capabilities, "provider", None) not in {"openai", "claude"}
and any(marker in error_text for marker in markers)
)
return any(marker in error_text for marker in markers)
class AgentService:
"""对话 Agent 服务"""
@@ -101,27 +92,18 @@ class AgentService:
llm_config = user.llm_config
# 如果指定了模型名称,查找对应的配置
if model_name:
for model_type in ["chat", "vlm"]:
models = llm_config.get(model_type, [])
for m in models:
if m.get("name") == model_name:
return m
# 没找到,返回 None 让调用方知道配置不存在
models = llm_config.get("chat", [])
for m in models:
if m.get("name") == model_name:
return m
return None
# 如果没指定模型名,返回默认启用的 chat 模型
chat_models = llm_config.get("chat", [])
for m in chat_models:
if m.get("enabled"):
return m
vlm_models = llm_config.get("vlm", [])
for m in vlm_models:
if m.get("enabled"):
return m
return None
async def chat(
@@ -134,11 +116,26 @@ class AgentService:
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
"""
处理对话请求(流式)
Returns:
(conversation_id, message_id, response_stream)
"""
# 获取或创建对话
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if model_name and not user_llm_config:
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
logger.info(
"agent_chat_started",
extra={
"details": {
"mode": "stream",
"requested_model_name": model_name,
"resolved_model_name": model_name_used,
"message_length": len(message or ""),
}
},
)
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
@@ -156,7 +153,6 @@ class AgentService:
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
@@ -168,7 +164,6 @@ class AgentService:
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
@@ -193,156 +188,133 @@ class AgentService:
)
await self.db.commit()
# 预创建助手消息(后续更新内容)
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
content="",
model=model_name_used or "jarvis",
attachments=None,
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
def _build_current_datetime_context() -> str:
now_utc = datetime.now(UTC)
return (
"【当前时间】\n"
f"- current_time_utc: {now_utc.isoformat()}\n"
f"- current_date_utc: {now_utc.date().isoformat()}\n"
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
)
# 调用 LangGraph Agent
async def run_agent():
set_current_user(user_id)
try:
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"current_sub_commander": None,
"active_sub_commanders": [],
"sub_commander_trace": [],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
current_datetime_context = _build_current_datetime_context()
# 使用 initial_state 构建状态
state = initial_state(user_id, conversation_id)
state.update({
"messages": [HumanMessage(content=full_message)],
"memory_context": memory_ctx,
"current_datetime_context": current_datetime_context,
"user_llm_config": user_llm_config,
}
})
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
collected = ""
async for event in graph.astream_events(langgraph_state, version="v2"):
kind = event.get("event")
event_name = event.get("name", "")
metadata = event.get("metadata", {})
data = event.get("data", {})
try:
async for event in graph.astream_events(state, version="v2"):
kind = event.get("event")
event_name = event.get("name", "")
metadata = event.get("metadata", {})
data = event.get("data", {})
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"planner": ("planning", "Jarvis 正在拆解步骤"),
"executor": ("tool", "Jarvis 正在执行操作"),
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map[event_name]
yield self._build_progress_event(stage, label, agent=event_name, step=label)
elif kind == "on_tool_start":
tool_input = data.get("input")
step = None
if isinstance(tool_input, dict) and tool_input:
step = f"调用工具 {event_name}"
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
elif kind == "on_tool_end":
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
elif kind == "on_chain_end" and event_name == "planner":
output = data.get("output") or {}
plan_steps = output.get("plan_steps") or []
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = getattr(chunk, "content", "") if chunk else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chat_model_end" and not collected:
output = data.get("output")
content = getattr(output, "content", "") if output else ""
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
text_parts.append(item.get("text", ""))
else:
text_parts.append(str(item))
content = "".join(text_parts)
if content:
collected = content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
except Exception as e:
fallback = f"抱歉,发生错误: {str(e)}"
collected = fallback
yield {"type": "error", "error": str(e)}
yield {"type": "chunk", "content": fallback}
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
stage_map = {
"master": ("thinking", "Jarvis 正在理解请求"),
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
"executor": ("tool", "Jarvis 正在执行操作"),
"librarian": ("tool", "Jarvis 正在检索知识"),
"analyst": ("thinking", "Jarvis 正在分析信息"),
}
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
yield self._build_progress_event(stage, label, agent=event_name, step=label)
elif kind == "on_tool_start":
yield self._build_progress_event(
"tool",
f"Jarvis 正在调用工具 {event_name}",
agent="executor",
tool_name=event_name,
step=f"正在执行 {event_name}",
)
elif kind == "on_tool_end":
tool_result = data.get("output")
step = f"已完成 {event_name}"
if isinstance(tool_result, str) and len(tool_result) > 0:
step = tool_result[:100]
yield self._build_progress_event(
"tool",
f"工具 {event_name} 已完成",
agent="executor",
tool_name=event_name,
step=step,
)
elif kind == "on_chat_model_stream":
chunk = data.get("chunk")
content = getattr(chunk, "content", "") if chunk else ""
if content:
collected += content
yield {"type": "chunk", "content": content}
elif kind == "on_chain_end" and event_name == "create_agent_graph":
# 最终输出通常在这里
output = data.get("output")
if isinstance(output, dict) and "final_response" in output:
final_resp = output["final_response"]
# 如果还没流式输出完整,补全它
if final_resp and not collected:
collected = final_resp
yield {"type": "chunk", "content": collected}
except Exception as e:
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
try:
result_state = await graph.ainvoke(state)
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
collected = str(fallback_content)
yield {"type": "chunk", "content": collected}
except Exception as fallback_error:
logger.exception("llm_sync_fallback_failed")
yield {"type": "error", "error": "模型服务暂不可用。"}
else:
logger.exception("agent_streaming_failed")
yield {"type": "error", "error": str(e)}
finally:
clear_current_user()
try:
asyncio.get_running_loop().create_task(
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
# 最终更新数据库中的消息内容
if collected:
try:
result2 = await self.db.execute(
select(Message).where(Message.id == assistant_msg.id)
)
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await self.db.commit()
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=collected[:500],
raw_excerpt=collected[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
async with async_session() as session:
result2 = await session.execute(select(Message).where(Message.id == assistant_msg.id))
msg = result2.scalar_one_or_none()
if msg:
msg.content = collected
await session.commit()
except Exception:
pass
logger.exception("save_assistant_message_failed")
return conversation_id, assistant_msg.id, run_agent()
@@ -355,117 +327,44 @@ class AgentService:
model_name: str | None = None,
) -> tuple[str, str, str, str | None]:
"""
简单同步版对话(无流式)
Returns:
(conversation_id, message_id, response_content, model_name_used)
简单同步版对话
"""
# 获取或创建对话
if conversation_id:
result = await self.db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
else:
conv = None
if not conv:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
else:
conversation_id = conv.id
# 如果有文件,读取内容作为上下文
file_context = ""
if file_ids:
from app.services.document_service import DocumentService
doc_svc = DocumentService(self.db)
for file_id in file_ids:
content = await doc_svc.get_document_content(user_id, file_id)
if content:
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
# 将文件上下文添加到消息
full_message = f"{message}\n{file_context}" if file_context else message
# 存储用户消息
user_msg = Message(
conversation_id=conversation_id,
role="user",
content=message,
attachments=[{"file_ids": file_ids}] if file_ids else None,
)
self.db.add(user_msg)
await self.db.commit()
await self.db.refresh(user_msg)
brain_service = BrainService(self.db)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="User message",
content_summary=message[:500],
raw_excerpt=message[:2000],
metadata_={"role": "user"},
importance_signal=1.0,
)
await self.db.commit()
# 加载记忆上下文
memory_ctx = await memory_service.build_memory_context(
self.db, user_id, conversation_id, message
)
# 获取用户配置的 LLM
user_llm_config = await self._get_user_llm_config(user_id, model_name)
model_name_used = model_name
if user_llm_config:
model_name_used = user_llm_config.get("name", model_name)
# 调用 LangGraph Agent
set_current_user(user_id)
graph = get_agent_graph()
langgraph_state = {
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
"user_id": user_id,
"conversation_id": conversation_id,
"current_agent": "master",
"active_agents": ["master"],
"pending_tasks": [],
"completed_tasks": [],
"tool_calls": [],
"last_tool_result": None,
"knowledge_context": None,
"graph_context": None,
"plan": None,
"plan_steps": [],
"analysis_report": None,
"final_response": None,
"should_respond": True,
"memory_context": memory_ctx,
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
}
if not conversation_id:
conv = Conversation(user_id=user_id, title=message[:50])
self.db.add(conv)
await self.db.commit()
await self.db.refresh(conv)
conversation_id = conv.id
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
self.db.add(user_msg)
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
set_current_user(user_id)
try:
result_state = await graph.ainvoke(langgraph_state)
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
graph = get_agent_graph()
state = initial_state(user_id, conversation_id)
state.update({
"messages": [HumanMessage(content=message)],
"memory_context": memory_ctx,
"current_datetime_context": datetime.now(UTC).isoformat(),
"user_llm_config": user_llm_config,
})
result_state = await graph.ainvoke(state)
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
except Exception as e:
response_content = f"抱歉,发生错误: {str(e)}"
logger.exception("agent_chat_simple_failed")
response_content = "抱歉,发生错误。"
finally:
clear_current_user()
try:
asyncio.get_running_loop().create_task(
self._try_auto_summarize_background(user_id, conversation_id)
)
except Exception:
pass
# 保存助手消息
assistant_msg = Message(
conversation_id=conversation_id,
role="assistant",
@@ -474,19 +373,5 @@ class AgentService:
)
self.db.add(assistant_msg)
await self.db.commit()
await self.db.refresh(assistant_msg)
await brain_service.create_event(
user_id,
source_type="conversation",
source_id=conversation_id,
event_type="message_created",
title="Assistant message",
content_summary=response_content[:500],
raw_excerpt=response_content[:2000],
metadata_={"role": "assistant"},
importance_signal=1.0,
)
await self.db.commit()
return conversation_id, assistant_msg.id, response_content, model_name_used

View File

@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
"""
from abc import ABC, abstractmethod
from typing import AsyncIterator
from dataclasses import dataclass
from typing import AsyncIterator, Literal
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from langchain_core.messages import BaseMessage, AIMessage
@@ -16,8 +17,131 @@ from app.models.user import User
import httpx
import os
os.makedirs(settings.DATA_DIR, exist_ok=True)
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
ToolStrategy = Literal["native", "json_fallback"]
def _resolve_effective_base_url(config: dict | None) -> str:
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
base_url = str((config or {}).get("base_url") or "").strip()
if base_url:
return base_url
if provider in {"openai", "custom", "deepseek"}:
return settings.OPENAI_BASE_URL
if provider == "ollama":
return settings.OLLAMA_BASE_URL
return ""
@dataclass(frozen=True)
class ProviderCapabilities:
provider: str
supports_native_tools: bool
preferred_tool_strategy: ToolStrategy
def default_provider_capabilities() -> ProviderCapabilities:
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
def normalize_provider_name(config: dict | None) -> str:
provider_raw = str((config or {}).get("provider") or "").strip().lower()
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
model = str((config or {}).get("model") or "").strip().lower()
base_url = _resolve_effective_base_url(config).strip().lower()
# base_url-first inference (provider may be omitted in user config)
if base_url:
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
return "ollama"
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
return "claude"
if "api.deepseek.com" in base_url:
return "deepseek"
# Many "openai-compatible" endpoints are configured as provider=openai.
# We treat them as distinct providers so capability routing can stay conservative.
if provider in {"openai", "custom"}:
if any(key in model or key in base_url for key in {"minimax", "abab"}):
return "minimax"
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
return "kimi"
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
return "qwen"
return provider
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
provider = normalize_provider_name(config)
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
native_tool_providers = {"openai", "deepseek", "claude"}
base_url = _resolve_effective_base_url(config).strip().lower()
is_official_openai = (
provider != "openai"
or not base_url
or "api.openai.com" in base_url
or "openai.azure.com" in base_url
)
if provider in native_tool_providers and is_official_openai:
return ProviderCapabilities(
provider=provider,
supports_native_tools=True,
preferred_tool_strategy="native",
)
return ProviderCapabilities(
provider=provider,
supports_native_tools=False,
preferred_tool_strategy="json_fallback",
)
def create_llm_from_config(config: dict | None):
"""根据用户模型配置创建底层 LangChain LLM 实例"""
if not config:
return get_llm()
provider = normalize_provider_name(config)
model = config.get("model", "")
api_key = config.get("api_key", "")
base_url = config.get("base_url", "")
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "claude":
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=httpx.Timeout(60.0, connect=10.0),
)
elif provider == "ollama":
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=httpx.Timeout(120.0, connect=10.0),
)
else:
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=httpx.Timeout(60.0, connect=10.0),
)
setattr(llm, "_jarvis_user_llm_config", config)
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
return llm
class LLMService(ABC):
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
_llm_instance = OllamaService()
else:
raise ValueError(f"Unknown LLM provider: {provider}")
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
return _llm_instance

View File

@@ -1,23 +1,154 @@
"""
Jarvis 记忆系统
Jarvis 记忆系统 (基于 Mem0)
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
"""
import json
import re
import os
from datetime import datetime
from typing import Optional
from typing import Optional, Any
from sqlalchemy import select, desc, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory import MemorySummary, UserMemory
from app.models.conversation import Conversation, Message
from app.models.user import User
from app.services.brain_service import BrainService
from app.services.llm_service import get_llm
from app.agents.context import get_current_user
from app.config import settings as _settings
try:
from mem0 import Memory
MEM0_AVAILABLE = True
except ImportError:
MEM0_AVAILABLE = False
Memory = None
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 embedding 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
embedding_models = user.llm_config.get("embedding", [])
for model in embedding_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
"api_key": model.get("api_key")
or _settings.EMBEDDING_API_KEY
or _settings.OPENAI_API_KEY,
}
return None
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
"""从用户配置中获取 chat 模型配置"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.llm_config:
return None
chat_models = user.llm_config.get("chat", [])
for model in chat_models:
if model.get("enabled") and model.get("model"):
return {
"model": model.get("model"),
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
}
return None
class Mem0Client:
"""Mem0 客户端 - 按用户隔离"""
_instances: dict[str, Memory] = {}
_persist_dir: str = "./data/mem0"
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
cache_key = user_id
if cache_key not in self._instances:
self._instances[cache_key] = await self._init_memory(db, user_id)
return self._instances[cache_key]
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
if not MEM0_AVAILABLE:
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
os.makedirs(self._persist_dir, exist_ok=True)
llm_config = {
"model": _settings.OPENAI_MODEL,
"base_url": _settings.OPENAI_BASE_URL,
"api_key": _settings.OPENAI_API_KEY,
}
embed_config = _settings.EMBEDDING_MODEL
embed_base_url = _settings.EMBEDDING_BASE_URL
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
if db and user_id:
try:
user_chat = await _get_user_chat_config(db, user_id)
if user_chat:
llm_config = user_chat
except Exception:
pass
try:
user_embed = await _get_user_embedding_config(db, user_id)
if user_embed:
embed_config = user_embed["model"]
embed_base_url = user_embed["base_url"]
embed_api_key = user_embed["api_key"]
except Exception:
pass
config = {
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": f"jarvis_memory_{user_id}",
"path": self._persist_dir,
},
},
"llm": {
"provider": "openai",
"config": {
"model": llm_config["model"],
"api_key": llm_config["api_key"],
"base_url": llm_config["base_url"],
},
},
"embedder": {
"provider": "openai",
"config": {
"model": embed_config,
"api_key": embed_api_key,
"base_url": embed_base_url,
},
},
}
return Memory.from_config(config)
_mem0_client = Mem0Client()
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
"""获取指定用户的 Mem0 实例"""
return await _mem0_client.get_memory(db, user_id)
# ———— 短期记忆: 对话历史 ————
async def load_conversation_history(
db: AsyncSession,
conversation_id: str,
@@ -36,8 +167,7 @@ async def load_conversation_history(
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
"""获取对话轮数(用户消息数)"""
result = await db.execute(
select(func.count(Message.id))
.where(
select(func.count(Message.id)).where(
Message.conversation_id == conversation_id,
Message.role == "user",
)
@@ -47,14 +177,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
# ———— 中期记忆: 对话摘要 ————
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
SUMMARIZE_THRESHOLD = 8
MAX_HISTORY_TURNS = 10
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
"""判断当前对话是否需要摘要"""
from app.models.memory import MemorySummary
turn_count = await get_conversation_turn_count(db, conversation_id)
# 检查是否已有摘要覆盖到当前轮数
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
@@ -72,17 +203,21 @@ async def generate_summary(
conversation_id: str,
messages: list[Message],
) -> str:
"""调用 LLM 生成对话摘要"""
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages
)
llm = get_llm()
"""生成对话摘要"""
from app.services.llm_service import get_llm
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
HumanMessage(content=history_text),
])
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
llm = get_llm()
response = await llm.invoke(
[
SystemMessage(
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
"提取关键信息、用户偏好、待办事项等。不超过150字。"
),
HumanMessage(content=history_text),
]
)
return response.content.strip()
@@ -92,8 +227,10 @@ async def save_summary(
conversation_id: str,
summary_text: str,
turn_count: int,
) -> MemorySummary:
"""保存对话摘要"""
) -> Any:
"""保存对话摘要到数据库"""
from app.models.memory import MemorySummary
summary = MemorySummary(
user_id=user_id,
conversation_id=conversation_id,
@@ -109,8 +246,10 @@ async def save_summary(
async def get_summaries(
db: AsyncSession,
conversation_id: str,
) -> list[MemorySummary]:
) -> list[Any]:
"""获取某对话的所有历史摘要"""
from app.models.memory import MemorySummary
result = await db.execute(
select(MemorySummary)
.where(MemorySummary.conversation_id == conversation_id)
@@ -119,31 +258,7 @@ async def get_summaries(
return list(result.scalars().all())
# ———— 长期记忆: 用户画像 ————
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
只提取事实性的、可能对未来对话有帮助的信息,如:
- 用户的身份/职业/背景
- 用户的偏好和习惯
- 用户的目标和计划
- 重要的事件和日期
- 用户的观点和态度
每条记忆格式: [类型] 内容
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
如果没有提取到任何记忆,回复""
"""
FACT_TYPES = {"fact", "preference", "goal", "habit"}
def _parse_fact_line(line: str) -> tuple[str, str] | None:
"""解析一行记忆: [fact] 内容 -> (type, content)"""
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
if m and m.group(1) in FACT_TYPES:
return m.group(1), m.group(2).strip()
return None
# ———— 长期记忆: 基于 Mem0 ————
async def extract_user_memories(
@@ -151,55 +266,34 @@ async def extract_user_memories(
user_id: str,
conversation_id: str,
messages: list[Message],
) -> list[UserMemory]:
"""从对话中提取用户记忆并保存"""
) -> list[dict]:
"""
从对话中提取用户记忆并存储到 Mem0。
Mem0 会自动处理:
- 事实提取
- 时间线追踪
- 矛盾解决
- 遗忘机制
"""
if len(messages) < 2:
return []
history_text = "\n".join(
f"[{m.role}] {m.content}" for m in messages[-10:]
)
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
llm = get_llm()
from langchain_core.messages import HumanMessage, SystemMessage
response = await llm.invoke([
SystemMessage(content=EXTRACTION_PROMPT),
HumanMessage(content=history_text),
])
text = response.content.strip()
if text == "" or not text:
return []
memories = []
for line in text.split("\n"):
parsed = _parse_fact_line(line)
if not parsed:
continue
mem_type, content = parsed
# 检查是否已有完全相同的记忆
existing = await db.execute(
select(UserMemory).where(
UserMemory.user_id == user_id,
UserMemory.content == content,
)
)
if existing.scalar_one_or_none():
continue
mem = UserMemory(
try:
mem0 = await get_mem0(db, user_id)
result = mem0.add(
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
user_id=user_id,
memory_type=mem_type,
content=content,
importance=5,
source_conversation_id=conversation_id,
metadata={
"conversation_id": conversation_id,
"source": "jarvis_memory",
},
)
db.add(mem)
memories.append(mem)
if memories:
await db.commit()
return memories
return result.get("results", [])
except Exception as e:
print(f"Mem0 extract error: {e}")
return []
async def recall_user_memories(
@@ -207,41 +301,45 @@ async def recall_user_memories(
user_id: str,
query: str,
top_k: int = 5,
) -> list[UserMemory]:
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
# 先尝试语义相似(通过 LLM 判断)
# 降级: 直接从数据库取最近的重要记忆
result = await db.execute(
select(UserMemory)
.where(UserMemory.user_id == user_id)
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
.limit(top_k)
)
memories = list(result.scalars().all())
# 重置召回标记
for m in memories:
m.is_recalled = False
await db.commit()
return memories
) -> list[dict]:
"""
根据当前输入召回相关的用户记忆。
使用 Mem0 的语义搜索。
"""
try:
mem0 = await get_mem0(db, user_id)
results = mem0.search(
query=query,
filters={"user_id": user_id},
limit=top_k,
)
return results.get("results", [])
except Exception as e:
print(f"Mem0 search error: {e}")
return []
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
"""标记记忆已被召回使用"""
result = await db.execute(
select(UserMemory).where(UserMemory.id == memory_id)
)
mem = result.scalar_one_or_none()
if mem:
mem.is_recalled = True
mem.recall_count = (mem.recall_count or 0) + 1
mem.last_recalled_at = datetime.now(UTC)
await db.commit()
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
"""
获取用户画像。
Mem0 的 profile API 会返回 static 和 dynamic facts。
"""
try:
mem0 = await get_mem0(db, user_id)
result = mem0.history(user_id=user_id)
return {
"memories": result.get("results", []),
"static": [],
"dynamic": [],
}
except Exception as e:
print(f"Mem0 profile error: {e}")
return {"memories": [], "static": [], "dynamic": []}
# ———— 记忆组装: 供 Agent 使用的上下文 ————
async def build_memory_context(
db: AsyncSession,
user_id: str,
@@ -254,25 +352,22 @@ async def build_memory_context(
"""
parts = []
# 1. 用户画像(长期记忆)
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if user_memories:
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
if memories:
lines = []
for m in user_memories:
tag = f"[{m.memory_type}]"
lines.append(f" {tag} {m.content}")
await mark_memory_recalled(db, m.id)
parts.append("【用户记忆】\n" + "\n".join(lines))
for m in memories:
memory_text = m.get("memory", m.get("text", ""))
if memory_text:
lines.append(f" - {memory_text}")
if lines:
parts.append("【用户记忆】\n" + "\n".join(lines))
# 2. 对话摘要(中期记忆)
summaries = await get_summaries(db, conversation_id)
if summaries:
# 只取最近2条
recent = summaries[-2:]
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
lines = [f"[对话摘要{i + 1}] {s.summary_text}" for i, s in enumerate(recent)]
parts.append("【之前对话摘要】\n" + "\n".join(lines))
# 3. 知识大脑(长期项目记忆)
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
if brain_memories:
lines = []
@@ -292,7 +387,7 @@ async def try_auto_summarize(
) -> bool:
"""
检查是否需要摘要,如果需要则生成并保存。
返回是否执行了摘要
同时将对话内容存入 Mem0 进行记忆提取
"""
if not await should_summarize(db, conversation_id):
return False
@@ -306,8 +401,39 @@ async def try_auto_summarize(
turn_count = await get_conversation_turn_count(db, conversation_id)
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
# 同时提取用户记忆
await extract_user_memories(db, user_id, conversation_id, messages)
return True
except Exception:
except Exception as e:
print(f"Auto summarize error: {e}")
return False
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
"""
主动遗忘某条记忆。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.delete(memory_id, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 delete error: {e}")
return False
async def update_memory(
db: AsyncSession,
user_id: str,
memory_id: str,
content: str,
) -> bool:
"""
更新某条记忆。Mem0 会自动处理矛盾检测。
"""
try:
mem0 = await get_mem0(db, user_id)
mem0.update(memory_id, content, user_id=user_id)
return True
except Exception as e:
print(f"Mem0 update error: {e}")
return False

View File

@@ -99,46 +99,55 @@ async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession)
async def test_llm_connection(
provider: str,
provider: str | None,
model: str,
base_url: str,
api_key: str
api_key: str,
) -> dict:
"""测试 LLM 连接"""
try:
# base_url-first: provider 可省略
from app.services.llm_service import normalize_provider_name
effective_provider = normalize_provider_name({
"provider": provider,
"model": model,
"base_url": base_url,
})
# 根据不同 provider 创建临时 LLM 实例并测试
if provider == "openai":
if effective_provider in {"openai", "custom", "minimax", "kimi", "qwen"}:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or None,
timeout=30
timeout=30,
)
elif provider == "claude":
elif effective_provider == "claude":
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
api_key=api_key,
model=model,
timeout=30
timeout=30,
)
elif provider == "ollama":
elif effective_provider == "ollama":
from langchain_ollama import ChatOllama
llm = ChatOllama(
base_url=base_url or "http://localhost:11434",
model=model,
timeout=30
timeout=30,
)
elif provider == "deepseek":
elif effective_provider == "deepseek":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=api_key,
model=model,
base_url=base_url or "https://api.deepseek.com/v1",
timeout=30
timeout=30,
)
else:
return {"success": False, "error": f"不支持的 provider: {provider}"}
return {"success": False, "error": f"不支持的 endpoint/provider: {effective_provider}"}
# 简单测试调用
from langchain_core.messages import HumanMessage

View File

@@ -50,28 +50,22 @@ class SkillService:
"""
列出用户可访问的技能:自己的 + 市场的 + 团队的
"""
# 查询条件:自己的 或者 市场公开的 或者 团队的
conditions = [
access_scope = or_(
Skill.owner_id == user_id,
Skill.visibility == "market",
Skill.team_id == user_id,
]
# 如果提供了 agent_type 过滤
if agent_type:
conditions.append(Skill.agent_type == agent_type)
# 如果提供了 visibility 过滤
if visibility:
conditions.append(Skill.visibility == visibility)
query = select(Skill).where(
and_(
or_(*conditions),
Skill.is_active == True
)
)
filters = [access_scope, Skill.is_active == True]
if agent_type:
filters.append(Skill.agent_type == agent_type)
if visibility:
filters.append(Skill.visibility == visibility)
query = select(Skill).where(and_(*filters))
result = await self.db.execute(query)
return list(result.scalars().all())

View File

@@ -1,4 +1,8 @@
from datetime import datetime, UTC
from time import monotonic
import platform
import socket
import subprocess
try:
import psutil
@@ -7,21 +11,119 @@ except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fa
class SystemService:
_last_net_bytes_sent: int | None = None
_last_net_bytes_recv: int | None = None
_last_net_sample_at: float | None = None
def _get_network_rates(self) -> tuple[float, float]:
counters = psutil.net_io_counters()
now = monotonic()
if (
self.__class__._last_net_sample_at is None
or self.__class__._last_net_bytes_sent is None
or self.__class__._last_net_bytes_recv is None
):
self.__class__._last_net_bytes_sent = counters.bytes_sent
self.__class__._last_net_bytes_recv = counters.bytes_recv
self.__class__._last_net_sample_at = now
return 0.0, 0.0
elapsed = max(now - self.__class__._last_net_sample_at, 1e-6)
upload_bps = max(counters.bytes_sent - self.__class__._last_net_bytes_sent, 0) / elapsed
download_bps = max(counters.bytes_recv - self.__class__._last_net_bytes_recv, 0) / elapsed
self.__class__._last_net_bytes_sent = counters.bytes_sent
self.__class__._last_net_bytes_recv = counters.bytes_recv
self.__class__._last_net_sample_at = now
return round(upload_bps, 1), round(download_bps, 1)
def _get_gpu_status(self) -> dict:
empty = {
'gpu_name': None,
'gpu_memory_total_mb': None,
'gpu_memory_used_mb': None,
'gpu_util_percent': None,
}
try:
result = subprocess.run(
[
'nvidia-smi',
'--query-gpu=name,memory.total,memory.used,utilization.gpu',
'--format=csv,noheader,nounits',
],
capture_output=True,
text=True,
encoding='utf-8',
timeout=2,
check=False,
)
except (FileNotFoundError, subprocess.SubprocessError, OSError):
return empty
if result.returncode != 0 or not result.stdout.strip():
return empty
first_line = result.stdout.strip().splitlines()[0]
parts = [part.strip() for part in first_line.split(',')]
if len(parts) < 4:
return empty
def parse_number(value: str) -> float | None:
try:
return float(value)
except (TypeError, ValueError):
return None
return {
'gpu_name': parts[0] or None,
'gpu_memory_total_mb': parse_number(parts[1]),
'gpu_memory_used_mb': parse_number(parts[2]),
'gpu_util_percent': parse_number(parts[3]),
}
def get_status(self) -> dict:
if psutil is None:
return {
'cpu_percent': 0.0,
'memory_percent': 0.0,
'disk_percent': 0.0,
'disk_used_gb': 0.0,
'disk_total_gb': 0.0,
'network_upload_bps': 0.0,
'network_download_bps': 0.0,
'system_name': platform.system(),
'system_version': platform.version(),
'hostname': socket.gethostname(),
'uptime_seconds': 0.0,
'gpu_name': None,
'gpu_memory_total_mb': None,
'gpu_memory_used_mb': None,
'gpu_util_percent': None,
'timestamp': datetime.now(UTC).isoformat(),
}
cpu_percent = psutil.cpu_percent(interval=None)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
upload_bps, download_bps = self._get_network_rates()
gpu_status = self._get_gpu_status()
boot_time = psutil.boot_time()
now_ts = datetime.now(UTC).timestamp()
return {
'cpu_percent': round(cpu_percent, 1),
'memory_percent': round(memory.percent, 1),
'disk_percent': round(disk.percent, 1),
'disk_used_gb': round(disk.used / (1024 ** 3), 1),
'disk_total_gb': round(disk.total / (1024 ** 3), 1),
'network_upload_bps': upload_bps,
'network_download_bps': download_bps,
'system_name': platform.system(),
'system_version': platform.version(),
'hostname': socket.gethostname(),
'uptime_seconds': round(max(now_ts - boot_time, 0.0), 1),
**gpu_status,
'timestamp': datetime.now(UTC).isoformat(),
}

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse
import httpx
from app.config import settings
@dataclass(frozen=True)
class WebSearchResult:
title: str
url: str
snippet: str
source: str | None = None
published_at: str | None = None
class WebSearchError(Exception):
pass
class WebSearchConfigurationError(WebSearchError):
pass
class WebSearchRequestError(WebSearchError):
pass
class WebSearchService:
def __init__(
self,
*,
enabled: bool | None = None,
provider: str | None = None,
base_url: str | None = None,
default_limit: int | None = None,
timeout_seconds: int | None = None,
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
auth_token: str | None = None,
basic_user: str | None = None,
basic_password: str | None = None,
):
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
normalized_query = (query or '').strip()
if not self.enabled or not self.base_url:
raise WebSearchConfigurationError('网页搜索未启用或未配置')
if self.provider != 'searxng':
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
if not normalized_query:
raise WebSearchRequestError('搜索关键词不能为空')
parsed = urlparse(self.base_url)
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
params = {
'q': normalized_query,
'format': 'json',
'language': 'zh-CN',
'safesearch': 1,
}
headers = self._build_headers()
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
response.raise_for_status()
payload = response.json()
except httpx.HTTPError as exc:
raise WebSearchRequestError('SearxNG 请求失败') from exc
except ValueError as exc:
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
raw_results = payload.get('results') if isinstance(payload, dict) else None
if not isinstance(raw_results, list):
return []
results: list[WebSearchResult] = []
target_limit = max(1, min(limit or self.default_limit, 10))
for item in raw_results:
if not isinstance(item, dict):
continue
title = str(item.get('title') or '').strip()
url = str(item.get('url') or '').strip()
snippet = str(item.get('content') or item.get('snippet') or '').strip()
if not title or not url:
continue
results.append(
WebSearchResult(
title=title,
url=url,
snippet=snippet,
source=str(item.get('engine') or item.get('source') or '').strip() or None,
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
)
)
if len(results) >= target_limit:
break
return results
def _build_headers(self) -> dict[str, str]:
if self.auth_type == 'bearer' and self.auth_token:
return {'Authorization': f'Bearer {self.auth_token}'}
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
request = httpx.Request('GET', self.base_url)
credentials.auth_flow(request)
return dict(request.headers)
return {}

Binary file not shown.

Binary file not shown.

View File

@@ -1 +0,0 @@
%PDF-1.4 bad

View File

@@ -1 +0,0 @@
%PDF-1.4 bad

View File

@@ -1 +0,0 @@
%PDF-1.4 bad

View File

@@ -1 +0,0 @@
%PDF-1.4 bad

View File

@@ -1 +0,0 @@
%PDF-1.4 bad

View File

@@ -27,6 +27,9 @@ dependencies = [
"llama-index-vector-stores-chroma>=0.3.0",
"chromadb>=0.5.0",
# Memory
"mem0ai>=1.0.0",
# 数据库
"sqlalchemy>=2.0.0",
"aiosqlite>=0.20.0",

View File

@@ -1,9 +1,122 @@
from langchain_core.messages import HumanMessage
from types import SimpleNamespace
from app.agents.graph import master_node
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.graph import (
_choose_sub_commander,
_parse_json_action,
_route_agent_from_user_query,
_run_sub_commander,
master_node,
)
from app.agents.tools.time_reasoning import resolve_time_expression
from app.agents.state import AgentRole
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
return {
'messages': [HumanMessage(content=message)],
'user_id': 'u1',
'conversation_id': 'c1',
'current_agent': AgentRole.MASTER,
'active_agents': [AgentRole.MASTER],
'current_sub_commander': None,
'active_sub_commanders': [],
'sub_commander_trace': [],
'pending_tasks': [],
'completed_tasks': [],
'tool_calls': [],
'last_tool_result': None,
'action_results': [],
'created_entities': [],
'tool_strategy_used': None,
'provider_capabilities': None,
'fallback_parse_error': None,
'knowledge_context': None,
'graph_context': None,
'schedule_context_summary': None,
'plan': None,
'plan_steps': [],
'analysis_report': None,
'final_response': None,
'should_respond': True,
'memory_context': None,
'current_datetime_context': 'CURRENT_TIME: 2026-03-28T12:00:00+08:00',
'current_datetime_reference': {'current_time_iso': '2026-03-28T12:00:00+08:00', 'current_date_iso': '2026-03-28', 'timezone': 'UTC'},
'user_llm_config': user_llm_config,
}
class FakeFallbackLLM:
def __init__(self, first_content: str, followup_content: str = '已创建提醒:开会,时间为 2026-03-29 09:00按当前时间理解为“明天早上9点”'):
self.first_content = first_content
self.followup_content = followup_content
self.calls = 0
async def ainvoke(self, messages):
self.calls += 1
if self.calls == 1:
return AIMessage(content=self.first_content)
return AIMessage(content=self.followup_content)
def bind_tools(self, tools):
raise AssertionError('bind_tools should not be called in JSON fallback mode')
class FakeNativeBoundLLM:
async def ainvoke(self, messages):
return AIMessage(
content='',
tool_calls=[
{
'id': 'call_1',
'name': 'create_reminder',
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
}
],
)
class FakeNativeLLM:
def __init__(self):
self.bound = FakeNativeBoundLLM()
self.tool_binding_count = 0
self.calls = 0
self._jarvis_provider_capabilities = SimpleNamespace(provider='openai', supports_native_tools=True, preferred_tool_strategy='native')
def bind_tools(self, tools):
self.tool_binding_count += 1
return self.bound
async def ainvoke(self, messages):
self.calls += 1
return AIMessage(content='已创建提醒:开会,时间为 2026-03-29 09:00按当前时间理解为“明天早上9点”')
class FakeTool:
def __init__(self, name: str, result: str):
self.name = name
self.result = result
self.invocations: list[dict] = []
def invoke(self, args: dict):
self.invocations.append(args)
return self.result
class CapturingLLM:
def __init__(self, content: str = '{"mode":"final","final_response":"好的。"}'):
self.content = content
self.messages = None
self._jarvis_provider_capabilities = SimpleNamespace(provider='ollama', supports_native_tools=False, preferred_tool_strategy='json_fallback')
async def ainvoke(self, messages):
self.messages = messages
return AIMessage(content=self.content)
class FailIfCalledLLM:
async def ainvoke(self, messages):
raise AssertionError('LLM should not be called for simple greetings')
@@ -71,6 +184,68 @@ async def test_master_node_returns_stable_reply_for_identity_question(monkeypatc
assert result['active_agents'] == [AgentRole.MASTER]
async def test_master_node_returns_stable_reply_for_identity_question_with_punctuation(monkeypatch):
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
state = {
'messages': [HumanMessage(content='你是谁?')],
'user_id': 'u1',
'conversation_id': 'c1',
'current_agent': AgentRole.MASTER,
'active_agents': [AgentRole.MASTER],
'pending_tasks': [],
'completed_tasks': [],
'tool_calls': [],
'last_tool_result': None,
'knowledge_context': None,
'graph_context': None,
'plan': None,
'plan_steps': [],
'analysis_report': None,
'final_response': None,
'should_respond': True,
'memory_context': None,
'user_llm_config': None,
}
result = await master_node(state)
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
assert result['current_agent'] == AgentRole.MASTER
assert result['active_agents'] == [AgentRole.MASTER]
async def test_master_node_returns_stable_reply_for_identity_question_with_particle(monkeypatch):
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
state = {
'messages': [HumanMessage(content='你是谁啊')],
'user_id': 'u1',
'conversation_id': 'c1',
'current_agent': AgentRole.MASTER,
'active_agents': [AgentRole.MASTER],
'pending_tasks': [],
'completed_tasks': [],
'tool_calls': [],
'last_tool_result': None,
'knowledge_context': None,
'graph_context': None,
'plan': None,
'plan_steps': [],
'analysis_report': None,
'final_response': None,
'should_respond': True,
'memory_context': None,
'user_llm_config': None,
}
result = await master_node(state)
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
assert result['current_agent'] == AgentRole.MASTER
assert result['active_agents'] == [AgentRole.MASTER]
async def test_master_node_returns_stable_reply_for_capability_question(monkeypatch):
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
@@ -100,3 +275,196 @@ async def test_master_node_returns_stable_reply_for_capability_question(monkeypa
assert result['final_response'] == '主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。'
assert result['current_agent'] == AgentRole.MASTER
assert result['active_agents'] == [AgentRole.MASTER]
def test_choose_sub_commander_routes_schedule_requests_to_schedule_planning():
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '帮我安排一下这周计划') == 'schedule_planning'
def test_choose_sub_commander_routes_focus_requests_to_schedule_analysis():
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '基于最近对话帮我判断该聚焦什么') == 'schedule_analysis'
def test_route_agent_from_user_query_routes_knowledge_requests_to_librarian():
assert _route_agent_from_user_query('帮我搜索知识库里的项目资料') == AgentRole.LIBRARIAN
def test_route_agent_from_user_query_routes_schedule_requests_to_schedule_planner():
assert _route_agent_from_user_query('明天提醒我开会') == AgentRole.SCHEDULE_PLANNER
def test_route_agent_from_user_query_routes_explicit_month_day_milestone_to_schedule_planner():
assert _route_agent_from_user_query('3月29日对话系统交付节点') == AgentRole.SCHEDULE_PLANNER
def test_choose_sub_commander_routes_explicit_month_day_milestone_to_schedule_planning():
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '3月29日对话系统交付节点') == 'schedule_planning'
def test_parse_json_action_extracts_tool_calls_from_fenced_json():
parsed = _parse_json_action(
'```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```',
['create_reminder'],
)
assert parsed == {
'mode': 'tool_call',
'tool_calls': [
{
'name': 'create_reminder',
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
'reason': None,
}
],
}
def test_parse_json_action_returns_none_for_invalid_or_unknown_payload():
assert _parse_json_action('not json', ['create_reminder']) is None
assert _parse_json_action('{"mode":"tool_call","tool_calls":[{"name":"unknown","arguments":{}}]}', ['create_reminder']) is None
def test_parse_json_action_tolerates_prefix_and_suffix_text():
parsed = _parse_json_action(
'好的,下面是 JSON\n```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```\n谢谢',
['create_reminder'],
)
assert parsed is not None
assert parsed['mode'] == 'tool_call'
assert parsed['tool_calls'][0]['name'] == 'create_reminder'
def test_parse_json_action_accepts_parameters_alias_for_tool_calls():
parsed = _parse_json_action(
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","parameters":{"title":"收被子","reminder_at":"2026-03-29T09:00:00+08:00"}}]}',
['create_reminder'],
)
assert parsed == {
'mode': 'tool_call',
'tool_calls': [
{
'name': 'create_reminder',
'args': {'title': '收被子', 'reminder_at': '2026-03-29T09:00:00+08:00'},
'reason': None,
}
],
}
async def test_run_sub_commander_uses_json_fallback_for_non_native_provider(monkeypatch):
fake_llm = FakeFallbackLLM(
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}'
)
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
monkeypatch.setitem(
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
'schedule_planning',
[fake_tool],
)
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
result = await _run_sub_commander(
state,
AgentRole.SCHEDULE_PLANNER,
'manager prompt',
'明天 9 点提醒我开会',
use_tools=True,
)
assert result['tool_strategy_used'] == 'json_fallback'
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
assert result['tool_calls'][0]['name'] == 'create_reminder'
assert result['created_entities'][0]['type'] == 'reminder'
assert result['fallback_parse_error'] is None
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00按当前时间理解为“明天早上9点”'
async def test_run_sub_commander_includes_current_datetime_context_in_system_messages(monkeypatch):
fake_llm = CapturingLLM('{"mode":"final","final_response":"好的。"}')
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
state['current_datetime_context'] = 'CURRENT_TIME: 2026-03-28T12:00:00+08:00'
await _run_sub_commander(
state,
AgentRole.SCHEDULE_PLANNER,
'manager prompt',
'明天 9 点提醒我开会',
use_tools=True,
)
assert fake_llm.messages is not None
assert any(
getattr(m, 'type', None) == 'system' and 'CURRENT_TIME:' in str(getattr(m, 'content', ''))
for m in fake_llm.messages
)
async def test_run_sub_commander_uses_web_search_in_json_fallback(monkeypatch):
fake_llm = FakeFallbackLLM(
'{"mode":"tool_call","tool_calls":[{"name":"web_search","arguments":{"query":"Jarvis 最新模型更新","top_k":2}}]}',
'我查了外部网页,下面是最新结果摘要。',
)
fake_tool = FakeTool('web_search', '成功搜索到 2 条网页结果')
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
monkeypatch.setitem(
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
'librarian_retrieval',
[fake_tool],
)
state = _base_state('帮我上网查一下 Jarvis 最新模型更新', {'provider': 'ollama', 'model': 'qwen2.5'})
state['current_agent'] = AgentRole.LIBRARIAN
result = await _run_sub_commander(
state,
AgentRole.LIBRARIAN,
'manager prompt',
'帮我上网查一下 Jarvis 最新模型更新',
use_tools=True,
summary_target='knowledge_context',
)
assert result['tool_strategy_used'] == 'json_fallback'
assert fake_tool.invocations == [{'query': 'Jarvis 最新模型更新', 'top_k': 2}]
assert result['tool_calls'][0]['name'] == 'web_search'
assert result['last_tool_result'] == '[web_search] 成功搜索到 2 条网页结果'
assert result['final_response'] == '我查了外部网页,下面是最新结果摘要。'
fake_llm = FakeNativeLLM()
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
monkeypatch.setitem(
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
'schedule_planning',
[fake_tool],
)
state = _base_state('明天 9 点提醒我开会', {'provider': 'openai', 'model': 'gpt-4o'})
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
result = await _run_sub_commander(
state,
AgentRole.SCHEDULE_PLANNER,
'manager prompt',
'明天 9 点提醒我开会',
use_tools=True,
)
assert result['tool_strategy_used'] == 'native'
assert fake_llm.tool_binding_count == 1
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
assert result['created_entities'][0]['type'] == 'reminder'
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00按当前时间理解为“明天早上9点”'

View File

@@ -0,0 +1,49 @@
from types import SimpleNamespace
import pytest
from app.agents.tools.search import web_search
class FakeResult(SimpleNamespace):
pass
def test_web_search_tool_formats_results(monkeypatch):
class FakeService:
async def search(self, query: str, limit: int | None = None):
assert query == 'Jarvis 最新更新'
assert limit == 2
return [
FakeResult(
title='Jarvis release notes',
url='https://example.com/jarvis-release',
snippet='Latest Jarvis changes.',
source='duckduckgo',
published_at='2026-03-29',
)
]
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
result = web_search.func('Jarvis 最新更新', top_k=2)
assert '[1] Jarvis release notes' in result
assert '链接: https://example.com/jarvis-release' in result
assert '来源: duckduckgo' in result
assert '时间: 2026-03-29' in result
assert '摘要: Latest Jarvis changes.' in result
def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
from app.services.web_search_service import WebSearchConfigurationError
class FakeService:
async def search(self, query: str, limit: int | None = None):
raise WebSearchConfigurationError('网页搜索未启用或未配置')
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
result = web_search.func('Jarvis')
assert result == '网页搜索不可用: 网页搜索未启用或未配置'

View File

@@ -0,0 +1,277 @@
import sys
from datetime import datetime
from unittest.mock import Mock
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault("psutil", Mock())
import app.models # noqa: F401
from app.models.goal import Goal
from app.models.reminder import Reminder
from app.models.task import Task, TaskPriority, TaskStatus
from app.models.todo import DailyTodo
from app.models.user import User
from app.services.auth_service import get_password_hash
@pytest.fixture
async def tool_env(tmp_path):
db_path = tmp_path / "test_task_tools.db"
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
# 只创建本测试需要的表,避免全量 metadata 引入未注册的外键表。
await conn.run_sync(User.metadata.create_all, tables=[
User.__table__,
Task.__table__,
DailyTodo.__table__,
Reminder.__table__,
Goal.__table__,
])
async with session_factory() as session:
user = User(
username="tool_user",
email="tool@example.com",
hashed_password=get_password_hash("secret123"),
full_name="Tool Tester",
)
session.add(user)
await session.commit()
await session.refresh(user)
try:
yield {"session_factory": session_factory, "user_id": user.id}
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_create_task_accepts_content_and_date_aliases_and_persists_task(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
result = task_tools.create_task.func(content="完成对话系统", date="2026-03-28")
assert "任务创建成功" in result
async with tool_env["session_factory"]() as session:
saved = (await session.execute(select(Task))).scalar_one()
assert saved.title == "完成对话系统"
assert saved.description == "完成对话系统"
assert saved.priority == TaskPriority.MEDIUM
assert saved.status == TaskStatus.TODO
assert saved.due_date == datetime(2026, 3, 28, 0, 0)
@pytest.mark.asyncio
async def test_create_schedule_task_accepts_content_and_date_aliases_and_sets_morning_due_date(tool_env, monkeypatch):
from app.agents.tools import schedule as schedule_tools
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
result = schedule_tools.create_schedule_task.func(content="完成对话系统", date="2026-03-28")
assert "任务创建成功" in result
async with tool_env["session_factory"]() as session:
saved = (await session.execute(select(Task))).scalar_one()
assert saved.title == "完成对话系统"
assert saved.description == "完成对话系统"
assert saved.priority == TaskPriority.MEDIUM
assert saved.status == TaskStatus.TODO
assert saved.due_date == datetime(2026, 3, 28, 9, 0)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("priority_input", "expected"),
[
(1, TaskPriority.LOW),
(2, TaskPriority.MEDIUM),
(3, TaskPriority.HIGH),
(4, TaskPriority.URGENT),
("urgent", TaskPriority.URGENT),
],
)
async def test_create_task_normalizes_legacy_and_string_priorities(tool_env, monkeypatch, priority_input, expected):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
result = task_tools.create_task.func(title=f"priority-{priority_input}", priority=priority_input)
assert "任务创建成功" in result
async with tool_env["session_factory"]() as session:
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
assert rows[-1].priority == expected
@pytest.mark.asyncio
async def test_create_task_accepts_iso_datetime_due_date(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
result = task_tools.create_task.func(title="timed task", due_date="2026-03-28T15:30:00Z")
assert "任务创建成功" in result
async with tool_env["session_factory"]() as session:
saved = (await session.execute(select(Task))).scalar_one()
assert saved.due_date == datetime(2026, 3, 28, 15, 30, 0)
@pytest.mark.asyncio
async def test_create_task_returns_failure_for_missing_title_and_content(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
result = task_tools.create_task.func()
assert result == "创建任务失败: title 不能为空"
@pytest.mark.asyncio
async def test_create_task_returns_failure_for_invalid_priority(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
result = task_tools.create_task.func(title="bad priority", priority="top")
assert "创建任务失败:" in result
@pytest.mark.asyncio
async def test_update_task_status_rejects_invalid_status(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
create_result = task_tools.create_task.func(title="status test")
assert "任务创建成功" in create_result
async with tool_env["session_factory"]() as session:
saved = (await session.execute(select(Task))).scalar_one()
@pytest.mark.asyncio
async def test_get_tasks_filters_by_normalized_status_and_formats_values(tool_env, monkeypatch):
from app.agents.tools import task as task_tools
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
task_tools.create_task.func(title="todo task", priority="high")
task_tools.create_task.func(title="done task", priority="low")
async with tool_env["session_factory"]() as session:
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
rows[1].status = TaskStatus.DONE
await session.commit()
result = task_tools.get_tasks.func(status="done")
assert "done task" in result
assert "todo task" not in result
assert "状态:done" in result
assert "优先级:low" in result
@pytest.mark.asyncio
async def test_create_schedule_reminder_accepts_datetime_description_and_at_aliases(tool_env, monkeypatch):
from app.agents.tools import schedule as schedule_tools
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
result = schedule_tools.create_reminder.func(
title="收被子",
description="提醒收被子",
datetime="2026-03-29T09:00:00",
time_zone="Asia/Shanghai",
)
assert "提醒创建成功" in result
async with tool_env["session_factory"]() as session:
saved = (await session.execute(select(Reminder))).scalar_one()
assert saved.title == "收被子"
assert saved.note == "提醒收被子"
assert saved.reminder_at == datetime(2026, 3, 29, 9, 0)
result = schedule_tools.create_reminder.func(
content="收被子",
datetime="2026-03-29T09:00:00+08:00",
)
assert "提醒创建成功" in result
async with tool_env["session_factory"]() as session:
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
assert rows[-1].title == "收被子"
assert rows[-1].note is None
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
result = schedule_tools.create_reminder.func(
content="收被子",
time="2026-03-29T09:00:00",
time_zone="Asia/Shanghai",
)
assert "提醒创建成功" in result
async with tool_env["session_factory"]() as session:
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
assert rows[-1].title == "收被子"
assert rows[-1].note is None
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
result = schedule_tools.create_reminder.func(
title="收被子",
remind_at="2026-03-29T18:00:00",
)
assert "提醒创建成功" in result
async with tool_env["session_factory"]() as session:
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
assert rows[-1].title == "收被子"
assert rows[-1].note is None
assert rows[-1].reminder_at == datetime(2026, 3, 29, 18, 0)
@pytest.mark.asyncio
async def test_create_schedule_reminder_returns_failure_when_time_aliases_missing(tool_env, monkeypatch):
from app.agents.tools import schedule as schedule_tools
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
result = schedule_tools.create_reminder.func(title="收被子")
assert result == "创建提醒失败: reminder_at 不能为空"

View File

@@ -0,0 +1,94 @@
from datetime import UTC, datetime
from app.agents.tools.time_reasoning import (
extract_reference_datetime,
normalize_tool_time_arguments,
resolve_time_expression_data,
)
def test_extract_reference_datetime_from_current_time_context():
context = '【当前时间】\n- current_time_utc: 2026-03-28T12:00:00+00:00\n- current_date_utc: 2026-03-28\n说明:解析相对时间时请以 current_time_utc 为准。'
result = extract_reference_datetime(context)
assert result == datetime(2026, 3, 28, 12, 0, tzinfo=UTC)
def test_resolve_time_expression_data_normalizes_relative_datetime():
payload = resolve_time_expression_data(
'明天早上9点',
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
prefer='datetime',
)
assert payload['grain'] == 'datetime'
assert payload['resolved_date'] == '2026-03-29'
assert payload['resolved_datetime'] == '2026-03-29T09:00:00'
assert payload['assumed_time'] is False
def test_resolve_time_expression_data_normalizes_relative_date_window():
payload = resolve_time_expression_data(
'下周一下午',
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
prefer='datetime',
)
assert payload['resolved_date'] == '2026-03-30'
assert payload['resolved_datetime'] == '2026-03-30T15:00:00'
assert payload['assumed_time'] is True
assert 'assumed_time' in payload['reason']
def test_normalize_tool_time_arguments_converts_reminder_time_aliases():
normalized = normalize_tool_time_arguments(
'create_reminder',
{'title': '开会', 'reminder_at': '明天 09:00'},
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
)
assert normalized['reminder_at'] == '2026-03-29T09:00:00'
def test_normalize_tool_time_arguments_converts_date_only_tools():
normalized = normalize_tool_time_arguments(
'create_goal',
{'title': '交付节点', 'goal_date': '明天'},
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
)
assert normalized['goal_date'] == '2026-03-29'
def test_resolve_time_expression_data_preserves_explicit_datetime_offset():
payload = resolve_time_expression_data(
'2026-03-29T09:00:00+08:00',
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
prefer='datetime',
)
assert payload['resolved_datetime'] == '2026-03-29T09:00:00+08:00'
def test_normalize_tool_time_arguments_keeps_create_task_date_without_explicit_time():
normalized = normalize_tool_time_arguments(
'create_task',
{'title': '写周报', 'due_date': '明天'},
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
)
assert normalized['due_date'] == '2026-03-29'
def test_normalize_tool_time_arguments_raises_for_invalid_time_text():
try:
normalize_tool_time_arguments(
'create_reminder',
{'title': '开会', 'reminder_at': '明天25点'},
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
)
except ValueError as exc:
assert 'hour must be in 0..23' in str(exc)
else:
raise AssertionError('expected ValueError for invalid time text')

View File

@@ -0,0 +1,23 @@
import pytest
from app.agents.tools import forum as forum_tools
from app.agents.tools import schedule as schedule_tools
from app.agents.tools import task as task_tools
@pytest.mark.asyncio
@pytest.mark.parametrize(
("module", "label"),
[
(task_tools, "task"),
(schedule_tools, "schedule"),
(forum_tools, "forum"),
],
)
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
async def sample():
return f"ok:{label}"
result = module._run_async(sample())
assert result == f"ok:{label}"

View File

@@ -9,7 +9,7 @@ from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.brain import BrainEvent, BrainMemory
from app.models.conversation import Conversation
from app.models.conversation import Conversation, Message
from app.models.memory import MemorySummary, UserMemory
from app.models.user import User
from app.services import agent_service, memory_service
@@ -32,6 +32,110 @@ class FakeStreamingGraph:
}
class FakeStreamingFinalResponseGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chain_end",
"name": "master",
"data": {"output": {"final_response": "这是最终回答。"}},
}
class FakeStreamingBadRequestError(Exception):
pass
class FakeStreamingBadRequestError2(Exception):
pass
class FakeOpenAIBadRequestError(Exception):
pass
class FakeStreamingOpenAIBadRequestGraph:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeOpenAIBadRequestError('invalid_request_error: tool arguments failed validation')
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "不应触发同步回退。"}
class FakeStreamingFallbackGraph:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)')
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "这是回退后的同步回答。"}
class FakeStreamingFallbackGraphGenericError:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}")
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "这是通用异常回退后的同步回答。"}
class FakeStreamingDelegationThenFinalResponseGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="现在显示收到3月28日的任务是完成对话系统。\n\n我将这部分转给schedule_planner他会根据这个目标结合你当前的进度和资源给出具体的安排建议。")},
}
yield {
"event": "on_chain_end",
"name": "schedule_planner",
"data": {"output": {"final_response": "今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。"}},
}
class FakeStreamingDelegationThenModelEndGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="我将这部分转给schedule_planner。")},
}
yield {
"event": "on_chat_model_end",
"name": "schedule_planner",
"data": {"output": SimpleNamespace(content="最终建议:先完成对话系统,再回归验证。")},
}
class CapturingStateGraph:
def __init__(self, final_response: str = '已记录你的请求。'):
self.final_response = final_response
self.captured_state = None
async def ainvoke(self, state):
self.captured_state = state
return {"final_response": self.final_response}
@pytest.fixture
async def brain_ingestion_env(tmp_path, monkeypatch):
db_path = tmp_path / 'test_brain_ingestion.db'
@@ -43,6 +147,7 @@ async def brain_ingestion_env(tmp_path, monkeypatch):
async with session_factory() as session:
user = User(
username='brain-ingestion-tester',
email='brain-ingestion@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Ingestion Tester',
@@ -178,6 +283,360 @@ async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_in
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingFinalResponseGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'直接给我最终回答。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert ''.join(chunks) == '这是最终回答。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == '这是最终回答。'
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_prefers_chain_end_final_response_over_delegation_chunk(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenFinalResponseGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我安排今天先做什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。' in chunks
assert chunks[-1] == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
assert assistant_message is not None
assert assistant_message.content == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
@pytest.mark.asyncio
async def test_streaming_chat_prefers_model_end_final_content_over_delegation_chunk(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenModelEndGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我安排今天先做什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert '最终建议:先完成对话系统,再回归验证。' in chunks
assert chunks[-1] == '最终建议:先完成对话系统,再回归验证。'
assert assistant_message is not None
assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。'
@pytest.mark.asyncio
async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
graph = FakeStreamingOpenAIBadRequestGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
monkeypatch.setattr(agent_service, 'BadRequestError', FakeOpenAIBadRequestError)
original_get_user_llm_config = AgentService._get_user_llm_config
async def fake_get_user_llm_config(self, user_id, model_name=None):
return {
'name': 'Official OpenAI',
'provider': 'openai',
'model': 'gpt-4o',
'base_url': 'https://api.openai.com/v1',
'enabled': True,
}
monkeypatch.setattr(AgentService, '_get_user_llm_config', fake_get_user_llm_config)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'测试官方 OpenAI bad request 不应回退。',
)
chunks = []
errors = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
if event.get('type') == 'error':
errors.append(event['error'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert graph.astream_calls == 1
assert graph.ainvoke_calls == 0
assert errors == ['模型服务暂不可用,请稍后再试。']
assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。']
assert assistant_message is not None
assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。'
@pytest.mark.asyncio
async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
fallback_graph = FakeStreamingFallbackGraphGenericError()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: fallback_graph)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我制定一下明天的计划。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert fallback_graph.astream_calls == 1
assert fallback_graph.ainvoke_calls == 1
assert ''.join(chunks) == '这是通用异常回退后的同步回答。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].content_summary == '这是通用异常回退后的同步回答。'
@pytest.mark.asyncio
async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
class PartialThenFailGraph:
def __init__(self):
self.ainvoke_calls = 0
async def astream_events(self, state, version='v2'):
yield {
'event': 'on_chat_model_stream',
'name': 'master',
'data': {'chunk': SimpleNamespace(content='前半段')},
}
raise FakeStreamingBadRequestError('stream interrupted')
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {'final_response': '不应触发'}
graph = PartialThenFailGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
monkeypatch.setattr(agent_service, 'BadRequestError', FakeStreamingBadRequestError)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'测试部分流式输出失败。',
)
chunks = []
errors = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
if event.get('type') == 'error':
errors.append(event['error'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
brain_event_result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(brain_event_result.scalars().all())
assert chunks == ['前半段']
assert graph.ainvoke_calls == 0
assert errors == ['stream interrupted']
assert assistant_message is not None
assert assistant_message.content == '前半段'
assert events[1].content_summary == '前半段'
@pytest.mark.asyncio
async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
graph = CapturingStateGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
service = AgentService(session)
await service.chat_simple(
user.id,
'3月29日对话系统交付节点',
)
assert graph.captured_state is not None
current_context = graph.captured_state.get('current_datetime_context')
assert isinstance(current_context, str)
assert current_context
assert '当前时间' in current_context
assert '2026' in current_context
current_reference = graph.captured_state.get('current_datetime_reference')
assert isinstance(current_reference, dict)
assert 'current_time_iso' in current_reference
assert 'current_date_iso' in current_reference
@pytest.mark.asyncio
async def test_get_user_llm_config_defaults_to_enabled_chat_model_not_vlm(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
config = await service._get_user_llm_config(user.id)
assert config is not None
assert config['name'] == 'Enabled Chat'
@pytest.mark.asyncio
async def test_get_user_llm_config_returns_none_when_only_vlm_is_enabled(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
],
'vlm': [
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
config = await service._get_user_llm_config(user.id)
assert config is None
@pytest.mark.asyncio
async def test_chat_simple_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
await service.chat_simple(user.id, '测试聊天模型选择', model_name='Vision Only')
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
message_result = await session.execute(select(Message))
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
assert conversation_result.scalars().all() == []
assert message_result.scalars().all() == []
assert brain_event_result.scalars().all() == []
@pytest.mark.asyncio
async def test_streaming_chat_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
await service.chat(user.id, '测试流式聊天模型选择', model_name='Vision Only')
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
message_result = await session.execute(select(Message))
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
assert conversation_result.scalars().all() == []
assert message_result.scalars().all() == []
assert brain_event_result.scalars().all() == []
@pytest.mark.asyncio
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
session, user = brain_ingestion_env

View File

@@ -0,0 +1,49 @@
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
import app.models # noqa: F401
from app.database import Base
from app.models.skill import Skill
from app.models.user import User
from app.services.admin_bootstrap_service import ensure_builtin_skills
from app.services.auth_service import get_password_hash
@pytest.mark.asyncio
async def test_ensure_builtin_skills_creates_default_ability_skills(tmp_path):
db_path = tmp_path / 'test_builtin_skills.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
username='bootstrap_user',
email='bootstrap@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Bootstrap User',
is_active=True,
is_superuser=True,
)
session.add(user)
await session.commit()
async with session_factory() as session:
await ensure_builtin_skills(session)
await ensure_builtin_skills(session)
result = await session.execute(select(Skill).order_by(Skill.agent_type, Skill.name))
skills = result.scalars().all()
assert len(skills) >= 9
assert any(skill.agent_type == 'schedule_planner' for skill in skills)
assert any(skill.agent_type == 'executor' for skill in skills)
assert any(skill.agent_type == 'librarian' for skill in skills)
librarian_skill = next(skill for skill in skills if skill.name == '知识检索摘要')
assert 'web_search' in (librarian_skill.tools or [])
assert any(skill.agent_type == 'analyst' for skill in skills)
assert len({skill.name for skill in skills}) == len(skills)
await engine.dispose()

View File

@@ -0,0 +1,144 @@
import httpx
import pytest
from app.services.web_search_service import (
WebSearchConfigurationError,
WebSearchRequestError,
WebSearchResult,
WebSearchService,
)
class FakeResponse:
def __init__(self, payload: dict, status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self):
if self.status_code >= 400:
raise httpx.HTTPStatusError(
'request failed',
request=httpx.Request('GET', 'http://searx.example/search'),
response=httpx.Response(self.status_code, request=httpx.Request('GET', 'http://searx.example/search')),
)
def json(self):
return self._payload
class FakeAsyncClient:
def __init__(self, *, response=None, error=None, recorder=None, **kwargs):
self._response = response
self._error = error
self._recorder = recorder if recorder is not None else []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, *, params=None, headers=None):
self._recorder.append({'url': url, 'params': params, 'headers': headers})
if self._error is not None:
raise self._error
return self._response
@pytest.mark.asyncio
async def test_web_search_service_returns_normalized_results_from_searxng(monkeypatch):
requests = []
payload = {
'results': [
{
'title': 'Jarvis release notes',
'url': 'https://example.com/jarvis-release',
'content': 'Latest Jarvis changes and release notes.',
'engine': 'duckduckgo',
'publishedDate': '2026-03-29',
}
]
}
monkeypatch.setattr(
'app.services.web_search_service.httpx.AsyncClient',
lambda **kwargs: FakeAsyncClient(response=FakeResponse(payload), recorder=requests, **kwargs),
)
service = WebSearchService(
enabled=True,
provider='searxng',
base_url='http://searx.example',
default_limit=5,
timeout_seconds=10,
)
results = await service.search('Jarvis 最新版本', limit=3)
assert results == [
WebSearchResult(
title='Jarvis release notes',
url='https://example.com/jarvis-release',
snippet='Latest Jarvis changes and release notes.',
source='duckduckgo',
published_at='2026-03-29',
)
]
assert requests == [
{
'url': 'http://searx.example/search',
'params': {
'q': 'Jarvis 最新版本',
'format': 'json',
'language': 'zh-CN',
'safesearch': 1,
},
'headers': {},
}
]
@pytest.mark.asyncio
async def test_web_search_service_returns_empty_list_when_searxng_has_no_results(monkeypatch):
monkeypatch.setattr(
'app.services.web_search_service.httpx.AsyncClient',
lambda **kwargs: FakeAsyncClient(response=FakeResponse({'results': []}), **kwargs),
)
service = WebSearchService(
enabled=True,
provider='searxng',
base_url='http://searx.example',
)
results = await service.search('不存在的话题')
assert results == []
@pytest.mark.asyncio
async def test_web_search_service_raises_clear_error_on_searxng_http_failure(monkeypatch):
monkeypatch.setattr(
'app.services.web_search_service.httpx.AsyncClient',
lambda **kwargs: FakeAsyncClient(error=httpx.TimeoutException('timed out'), **kwargs),
)
service = WebSearchService(
enabled=True,
provider='searxng',
base_url='http://searx.example',
)
with pytest.raises(WebSearchRequestError, match='SearxNG 请求失败'):
await service.search('Jarvis')
@pytest.mark.asyncio
async def test_web_search_service_raises_clear_error_when_not_configured():
service = WebSearchService(
enabled=False,
provider='searxng',
base_url='',
)
with pytest.raises(WebSearchConfigurationError, match='网页搜索未启用或未配置'):
await service.search('Jarvis')

View File

@@ -0,0 +1,150 @@
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.agent import Agent
from app.models.skill import Skill
from app.models.user import User
from app.routers.agent import router as agent_router
from app.routers.auth import get_current_user
from app.services.auth_service import get_password_hash
@pytest.fixture
async def agent_env(tmp_path):
db_path = tmp_path / 'test_agent_router.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
username='agent_user',
email='agent@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Agent Tester',
)
session.add(user)
await session.flush()
skill_a = Skill(
name='Planner skill A',
description='planner',
instructions='plan a',
agent_type='schedule_planner',
tools=['calendar'],
required_context=[],
visibility='private',
is_active=True,
owner_id=user.id,
)
skill_b = Skill(
name='Planner skill B',
description='planner',
instructions='plan b',
agent_type='schedule_planner',
tools=['tasks'],
required_context=[],
visibility='private',
is_active=True,
owner_id=user.id,
)
session.add_all([
Agent(
name='SCHEDULE PLANNER',
role='schedule_planner',
description='日程规划师',
system_prompt='prompt',
is_active=True,
),
skill_a,
skill_b,
])
await session.commit()
await session.refresh(user)
await session.refresh(skill_a)
await session.refresh(skill_b)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(agent_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_get_agent_config_returns_default_empty_selected_skill_ids(agent_env):
app, _ids = agent_env
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/agents/config/schedule_planner')
assert response.status_code == 200
payload = response.json()
assert payload['selected_skill_ids'] == []
@pytest.mark.asyncio
async def test_update_agent_config_persists_selected_skill_ids(agent_env):
app, ids = agent_env
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
update_response = await client.put(
'/api/agents/config/schedule_planner',
json={'selected_skill_ids': [ids['skill_a_id'], ids['skill_b_id']]},
)
get_response = await client.get('/api/agents/config/schedule_planner')
assert update_response.status_code == 200
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
assert get_response.status_code == 200
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
@pytest.mark.asyncio
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
app, ids = agent_env
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
first_response = await client.put(
'/api/agents/config/schedule_planner',
json={'selected_skill_ids': [ids['skill_a_id']]},
)
update_response = await client.put(
'/api/agents/config/schedule_planner',
json={'description': 'updated description'},
)
assert first_response.status_code == 200
assert update_response.status_code == 200
assert update_response.json()['description'] == 'updated description'
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id']]
@pytest.mark.asyncio
async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env):
app, _ids = agent_env
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.put(
'/api/agents/config/schedule_planner',
json={'selected_skill_ids': ['missing-skill']},
)
assert response.status_code == 400
assert response.json()['detail'] == '存在无效的技能绑定'

View File

@@ -0,0 +1,67 @@
from pathlib import Path
from app import config as config_module
from app.services.llm_service import default_provider_capabilities, normalize_provider_name, resolve_provider_capabilities
def test_env_file_points_to_repo_root_env_file():
assert config_module.ENV_FILE == Path(__file__).resolve().parents[4] / '.env'
def test_resolve_provider_capabilities_prefers_native_for_openai():
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'gpt-4o', 'base_url': 'https://api.openai.com/v1'})
assert capabilities.provider == 'openai'
assert capabilities.supports_native_tools is True
assert capabilities.preferred_tool_strategy == 'native'
def test_resolve_provider_capabilities_falls_back_for_openai_compatible_non_official_endpoint():
capabilities = resolve_provider_capabilities(
{'provider': 'openai', 'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}
)
assert capabilities.provider == 'minimax'
assert capabilities.supports_native_tools is False
assert capabilities.preferred_tool_strategy == 'json_fallback'
def test_resolve_provider_capabilities_uses_global_openai_base_url_when_user_config_omits_it(monkeypatch):
monkeypatch.setattr(config_module.settings, 'OPENAI_BASE_URL', 'https://api.minimax.chat/v1')
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'abab7.5-chat-preview'})
assert capabilities.provider == 'minimax'
assert capabilities.supports_native_tools is False
assert capabilities.preferred_tool_strategy == 'json_fallback'
def test_normalize_provider_name_recognizes_minimax_from_custom_config():
assert normalize_provider_name({'provider': 'custom', 'model': 'MiniMax-M2.7-highspeed'}) == 'minimax'
def test_normalize_provider_name_recognizes_minimax_without_provider_when_base_url_matches():
assert normalize_provider_name({'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}) == 'minimax'
def test_resolve_provider_capabilities_falls_back_for_ollama():
capabilities = resolve_provider_capabilities({'provider': 'ollama', 'model': 'qwen2.5'})
assert capabilities.provider == 'ollama'
assert capabilities.supports_native_tools is False
assert capabilities.preferred_tool_strategy == 'json_fallback'
def test_default_provider_capabilities_follows_global_settings(monkeypatch):
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
capabilities = default_provider_capabilities()
assert capabilities.provider == 'ollama'
assert capabilities.preferred_tool_strategy == 'json_fallback'
def test_normalize_provider_name_without_provider_uses_global_default(monkeypatch):
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
assert normalize_provider_name({'model': 'qwen2.5'}) == 'ollama'

View File

@@ -0,0 +1,281 @@
import sys
from datetime import UTC, date, datetime
from unittest.mock import Mock
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault('psutil', Mock())
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.goal import Goal
from app.models.reminder import Reminder
from app.models.task import Task, TaskPriority, TaskStatus
from app.models.todo import DailyTodo, TodoSource
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.goal import router as goal_router
from app.routers.reminder import router as reminder_router
from app.routers.schedule_center import router as schedule_center_router
from app.routers.task import router as task_router
from app.routers.todo import router as todo_router
from app.services.auth_service import get_password_hash
@pytest.fixture
async def schedule_env(tmp_path):
db_path = tmp_path / 'test_schedule_center.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
username='schedule_user',
email='schedule@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Schedule Tester',
)
other_user = User(
username='other_schedule_user',
email='other-schedule@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Other Schedule Tester',
)
session.add_all([user, other_user])
await session.flush()
session.add_all([
DailyTodo(
user_id=user.id,
title='Legacy todo',
source=TodoSource.MANUAL,
todo_date='2026-04-10',
is_completed=False,
),
DailyTodo(
user_id=user.id,
title='Done todo',
source=TodoSource.MANUAL,
todo_date='2026-04-10',
is_completed=True,
completed_at=datetime(2026, 4, 10, 9, 30, tzinfo=UTC),
),
DailyTodo(
user_id=other_user.id,
title='Other user todo',
source=TodoSource.MANUAL,
todo_date='2026-04-10',
is_completed=False,
),
Task(
user_id=user.id,
title='High priority task',
priority=TaskPriority.HIGH,
status=TaskStatus.TODO,
due_date=datetime(2026, 4, 10, 14, 0, tzinfo=UTC),
),
Task(
user_id=user.id,
title='Urgent task next day',
priority=TaskPriority.URGENT,
status=TaskStatus.IN_PROGRESS,
due_date=datetime(2026, 4, 11, 10, 0, tzinfo=UTC),
),
Task(
user_id=other_user.id,
title='Other user task',
priority=TaskPriority.HIGH,
status=TaskStatus.TODO,
due_date=datetime(2026, 4, 10, 15, 0, tzinfo=UTC),
),
Reminder(
user_id=user.id,
title='Doctor reminder',
note='Bring reports',
reminder_at=datetime(2026, 4, 10, 8, 0, tzinfo=UTC),
),
Goal(
user_id=user.id,
title='Launch calendar beta',
note='Ship MVP',
goal_date='2026-04-10',
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(todo_router)
test_app.include_router(task_router)
test_app.include_router(reminder_router)
test_app.include_router(goal_router)
test_app.include_router(schedule_center_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_create_todo_persists_explicit_todo_date(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.post('/api/todos', json={'title': 'Plan sprint', 'todo_date': '2026-04-12'})
assert response.status_code == 201
payload = response.json()
assert payload['title'] == 'Plan sprint'
assert payload['todo_date'] == '2026-04-12'
@pytest.mark.asyncio
async def test_update_todo_allows_editing_non_today_todo(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
todo_id = todos_response.json()['items'][0]['id']
response = await client.patch(f'/api/todos/{todo_id}', json={'title': 'Updated title', 'todo_date': '2026-04-11'})
assert response.status_code == 200
payload = response.json()
assert payload['title'] == 'Updated title'
assert payload['todo_date'] == '2026-04-11'
@pytest.mark.asyncio
async def test_delete_todo_allows_deleting_non_today_todo(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
todo_id = todos_response.json()['items'][0]['id']
response = await client.delete(f'/api/todos/{todo_id}')
after_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
assert response.status_code == 204
assert all(item['id'] != todo_id for item in after_response.json()['items'])
@pytest.mark.asyncio
async def test_list_tasks_filters_by_due_date(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/tasks', params={'due_date': '2026-04-10'})
assert response.status_code == 200
payload = response.json()
assert [item['title'] for item in payload] == ['High priority task']
@pytest.mark.asyncio
async def test_list_tasks_filters_by_due_date_range(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/tasks', params={'date_from': '2026-04-10', 'date_to': '2026-04-11'})
assert response.status_code == 200
payload = response.json()
assert {item['title'] for item in payload} == {'High priority task', 'Urgent task next day'}
@pytest.mark.asyncio
async def test_get_schedule_center_date_returns_aggregated_resources(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/schedule-center/date', params={'date_str': '2026-04-10'})
assert response.status_code == 200
payload = response.json()
assert payload['date'] == '2026-04-10'
assert payload['summary'] == {
'date': '2026-04-10',
'todo_total': 2,
'todo_completed': 1,
'task_due_total': 1,
'high_priority_total': 1,
'reminder_total': 1,
'goal_total': 1,
}
assert [item['title'] for item in payload['reminders']] == ['Doctor reminder']
assert [item['title'] for item in payload['goals']] == ['Launch calendar beta']
@pytest.mark.asyncio
async def test_get_schedule_center_month_returns_day_summaries(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/schedule-center/month', params={'year': 2026, 'month': 4})
assert response.status_code == 200
payload = response.json()
assert payload['month'] == '2026-04'
day_10 = next(item for item in payload['days'] if item['date'] == '2026-04-10')
day_11 = next(item for item in payload['days'] if item['date'] == '2026-04-11')
assert day_10 == {
'date': '2026-04-10',
'todo_total': 2,
'todo_completed': 1,
'task_due_total': 1,
'high_priority_total': 1,
'reminder_total': 1,
'goal_total': 1,
}
assert day_11['task_due_total'] == 1
assert day_11['high_priority_total'] == 1
@pytest.mark.asyncio
async def test_create_reminder_with_naive_datetime_and_time_zone_appears_in_schedule_center(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
create_response = await client.post(
'/api/reminders',
json={'title': '收被子', 'note': '提醒收被子', 'reminder_at': '2026-03-29T09:00:00'},
)
detail_response = await client.get('/api/schedule-center/date', params={'date_str': '2026-03-29'})
assert create_response.status_code == 201
assert detail_response.status_code == 200
payload = detail_response.json()
assert [item['title'] for item in payload['reminders']] == ['收被子']
assert payload['summary']['reminder_total'] == 1
@pytest.mark.asyncio
async def test_reminder_and_goal_crud(schedule_env):
transport = ASGITransport(app=schedule_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
reminder_response = await client.post('/api/reminders', json={'title': 'Standup', 'note': 'Daily sync', 'reminder_at': '2026-04-12T09:00:00Z'})
goal_response = await client.post('/api/goals', json={'title': 'Finish polish', 'note': 'UI cleanup', 'goal_date': '2026-04-12', 'status': 'active'})
reminder_id = reminder_response.json()['id']
goal_id = goal_response.json()['id']
patch_reminder = await client.patch(f'/api/reminders/{reminder_id}', json={'status': 'done'})
patch_goal = await client.patch(f'/api/goals/{goal_id}', json={'status': 'done'})
reminders_list = await client.get('/api/reminders', params={'date_str': '2026-04-12'})
goals_list = await client.get('/api/goals', params={'date_str': '2026-04-12'})
delete_reminder = await client.delete(f'/api/reminders/{reminder_id}')
delete_goal = await client.delete(f'/api/goals/{goal_id}')
assert reminder_response.status_code == 201
assert goal_response.status_code == 201
assert patch_reminder.json()['status'] == 'done'
assert patch_goal.json()['status'] == 'done'
assert [item['title'] for item in reminders_list.json()['items']] == ['Standup']
assert [item['title'] for item in goals_list.json()['items']] == ['Finish polish']
assert delete_reminder.status_code == 204
assert delete_goal.status_code == 204

View File

@@ -0,0 +1,190 @@
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.skill import Skill
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.skill import router as skill_router
from app.routers.auth import router as auth_router
from app.services.auth_service import get_password_hash
@pytest.fixture
async def skill_env(tmp_path, monkeypatch):
db_path = tmp_path / 'test_skill_router.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
# Ensure app.database.init_db() runs against the test database.
import app.database as database_module
monkeypatch.setattr(database_module, "engine", engine)
monkeypatch.setattr(database_module, "async_session", session_factory)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
username='skill_user',
email='skill@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Skill Tester',
)
other_user = User(
username='other_skill_user',
email='other-skill@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Other Skill Tester',
)
session.add_all([user, other_user])
await session.flush()
session.add_all([
Skill(
name='Planner skill',
description='planner',
instructions='plan',
agent_type='schedule_planner',
tools=['calendar'],
required_context=[],
visibility='private',
is_active=True,
owner_id=user.id,
),
Skill(
name='Executor skill',
description='executor',
instructions='execute',
agent_type='executor',
tools=['shell'],
required_context=[],
visibility='private',
is_active=True,
owner_id=user.id,
),
Skill(
name='Other user planner skill',
description='other',
instructions='other',
agent_type='schedule_planner',
tools=['calendar'],
required_context=[],
visibility='private',
is_active=True,
owner_id=other_user.id,
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(auth_router)
test_app.include_router(skill_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app, session_factory
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_list_skills_filters_by_agent_type(skill_env):
test_app, _session_factory = skill_env
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/skills', params={'agent_type': 'schedule_planner'})
assert response.status_code == 200
payload = response.json()
names = {item['name'] for item in payload}
assert names == {'Planner skill'}
assert 'Other user planner skill' not in names
@pytest.mark.asyncio
async def test_init_db_migrates_planner_skills_to_schedule_planner(skill_env):
app, session_factory = skill_env
async with session_factory() as session:
await session.execute(text("UPDATE skills SET agent_type = 'planner' WHERE name = 'Planner skill'"))
await session.commit()
from app.database import init_db
await init_db()
async with session_factory() as session:
migrated_response = await session.execute(text("SELECT agent_type FROM skills WHERE name = 'Planner skill'"))
assert migrated_response.scalar_one() == 'schedule_planner'
@pytest.mark.asyncio
async def test_list_skills_visibility_filter_still_respects_access_scope(skill_env):
test_app, _session_factory = skill_env
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/skills', params={'visibility': 'private'})
assert response.status_code == 200
payload = response.json()
names = {item['name'] for item in payload}
assert names == {'Planner skill', 'Executor skill'}
assert 'Other user planner skill' not in names
@pytest.mark.asyncio
async def test_list_skills_bootstraps_builtin_market_skills_for_current_user(skill_env):
test_app, session_factory = skill_env
async with session_factory() as session:
await session.execute(text("DELETE FROM skills"))
await session.commit()
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
login_response = await client.post(
'/api/auth/login',
data={'username': 'skill_user', 'password': 'secret123'},
headers={'content-type': 'application/x-www-form-urlencoded'},
)
assert login_response.status_code == 200
async with session_factory() as session:
result = await session.execute(select(Skill.name, Skill.is_builtin).order_by(Skill.name))
skills = result.all()
names = {name for name, _is_builtin in skills}
assert '今日重点拆解' in names
assert '任务执行 SOP' in names
assert any(is_builtin is True for _name, is_builtin in skills)
@pytest.mark.asyncio
async def test_list_skills_without_agent_type_returns_current_user_skills(skill_env):
test_app, _session_factory = skill_env
transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/skills')
assert response.status_code == 200
payload = response.json()
names = {item['name'] for item in payload}
assert names == {'Planner skill', 'Executor skill'}
assert 'Other user planner skill' not in names
assert all(isinstance(item['created_at'], str) for item in payload)
assert all(isinstance(item['updated_at'], str) for item in payload)
assert all('is_builtin' in item for item in payload)
assert all(item['is_builtin'] is False for item in payload)