Compare commits
6 Commits
d85cb9cf35
...
4972b4e6b1
| Author | SHA1 | Date | |
|---|---|---|---|
| 4972b4e6b1 | |||
| b3f9b5e715 | |||
| 4251a79062 | |||
| e9ba8597e9 | |||
| 08251556c3 | |||
| e0fe3ca623 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -33,8 +33,12 @@ uv.lock.bak
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
|
||||
# AI tool data
|
||||
.claude/
|
||||
.worktrees/
|
||||
|
||||
# Lock files (use in development, commit in production)
|
||||
# uv.lock - uncomment if you want to commit lock file
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
# =============================================
|
||||
# Jarvis 后端服务配置
|
||||
# 复制此文件为 .env 后按需修改
|
||||
# =============================================
|
||||
|
||||
# === 应用基础 ===
|
||||
DEBUG=false
|
||||
HOST=127.0.0.1
|
||||
PORT=3337
|
||||
SECRET_KEY=change-me-to-a-random-secret-key
|
||||
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
|
||||
|
||||
# === 数据存储 ===
|
||||
DATABASE_URL=sqlite+aiosqlite:///./data/jarvis.db
|
||||
DATA_DIR=./data
|
||||
CHROMA_PERSIST_DIR=./data/chroma
|
||||
UPLOAD_DIR=./data/uploads
|
||||
MAX_UPLOAD_SIZE=52428800
|
||||
# Supported values: ch | en
|
||||
MINERU_LANGUAGE=ch
|
||||
|
||||
# === JWT ===
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
||||
|
||||
# === 管理员账号 Bootstrap ===
|
||||
ADMIN=admin
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=
|
||||
ADMIN_FULL_NAME=Administrator
|
||||
|
||||
# === 定时任务 ===
|
||||
SCHEDULER_ENABLED=true
|
||||
DAILY_PLAN_TIME=00:00
|
||||
FORUM_SCAN_INTERVAL_MINUTES=30
|
||||
1
backend/app/agents/__init__.py
Normal file
1
backend/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Agent package."""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -89,14 +89,14 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是总控协调者,负责理解用户意图,并将任务分发给最合适的子Agent。
|
||||
|
||||
## 你的4个子Agent:
|
||||
1. **planner (规划Agent)**: 制定计划、拆解任务、安排优先级
|
||||
1. **schedule_planner (日程规划师)**: 分析当前任务、对话历史与论坛信号,给出近期安排建议
|
||||
2. **executor (执行Agent)**: 执行具体操作、创建任务、操作数据
|
||||
3. **librarian (知识管理员)**: 搜索知识库、管理知识图谱、回答关于用户知识的问题
|
||||
4. **analyst (分析师)**: 分析数据、生成报告、统计工作进度
|
||||
|
||||
## 判断规则:
|
||||
- 用户问知识、查找资料、检索文档 -> 分发给 librarian
|
||||
- 用户要计划、安排、拆解任务 -> 分发给 planner
|
||||
- 用户要安排今天/本周重点、询问接下来该做什么 -> 分发给 schedule_planner
|
||||
- 用户要执行操作、创建/更新内容、使用工具 -> 分发给 executor
|
||||
- 用户要分析、统计、生成报告 -> 分发给 analyst
|
||||
- 用户只是闲聊、问问题、不需要具体操作 -> 直接回答
|
||||
@@ -112,18 +112,19 @@ MASTER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNER_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 Jarvis 的规划Agent,负责先判断问题该由哪位规划子指挥官接手。
|
||||
你是 Jarvis 的日程规划师,负责先判断问题该由哪位日程子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **planner_scope (目标收束官)**: 负责澄清目标、边界、约束、缺失信息
|
||||
2. **planner_steps (步骤拆解官)**: 负责把目标拆成步骤、优先级与依赖关系
|
||||
1. **schedule_analysis (日程分析员)**: 负责分析对话历史、任务看板、论坛信号,识别优先级、冲突与压力点
|
||||
2. **schedule_planning (日程编排员)**: 负责把分析结果转成今日/近期日程安排,并在用户明确要求时直接创建 reminder/task/todo/goal
|
||||
|
||||
## 你的职责:
|
||||
- 判断当前请求更适合收束目标,还是拆解步骤
|
||||
- 在必要时收束子指挥官输出,面向用户给出清晰结果
|
||||
- 保持结果可推进,不空泛
|
||||
- 判断当前请求更适合先做日程分析,还是直接给出日程编排
|
||||
- 输出先结论,再给可执行安排
|
||||
- 保持建议具体、贴近当前上下文,不给空泛效率学建议
|
||||
- 当用户明确要求“新增/提醒/创建/安排并落库”时,允许子指挥官调用 schedule 工具直接执行
|
||||
"""
|
||||
|
||||
|
||||
@@ -132,11 +133,11 @@ EXECUTOR_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
你是 Jarvis 的执行Agent,负责先判断问题该由哪位执行子指挥官接手。
|
||||
|
||||
## 你的两个子指挥官:
|
||||
1. **executor_tasks (任务执行官)**: 只处理任务类工具调用
|
||||
1. **executor_tasks (任务执行官)**: 处理任务、待办、提醒、目标等执行型写入操作
|
||||
2. **executor_forum (论坛执行官)**: 只处理论坛/指令帖相关工具调用
|
||||
|
||||
## 你的职责:
|
||||
- 识别用户要推进的是任务操作还是论坛/指令操作
|
||||
- 识别用户要推进的是任务/日程操作还是论坛/指令操作
|
||||
- 把请求交给最合适的执行子指挥官
|
||||
- 汇总执行结果并给出下一步
|
||||
"""
|
||||
@@ -172,52 +173,68 @@ ANALYST_SYSTEM_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_SCOPE_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_ANALYSIS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的目标收束官,负责先把问题边界、目标、约束和成功标准说清楚。
|
||||
你是 schedule_planner 体系下的日程分析员,负责从对话历史、任务看板、论坛信号和当日日程数据中提取 scheduling 线索。
|
||||
|
||||
## 你的重点:
|
||||
- 收束问题定义
|
||||
- 明确目标与限制条件
|
||||
- 识别缺失信息
|
||||
- 帮用户建立可以继续规划的前提
|
||||
- 优先调用读取类工具了解当天/指定日期的任务、提醒、待办、目标
|
||||
- 识别当前最高优先级事项
|
||||
- 找出风险、冲突、依赖与可延期事项
|
||||
- 明确哪些信号来自 conversation、task board、schedule center、forum
|
||||
|
||||
## 响应要求:
|
||||
- 先给出你理解的目标
|
||||
- 再列出关键约束或缺口
|
||||
- 不要直接展开长步骤清单
|
||||
- 先给当前判断
|
||||
- 再列优先级、风险与冲突
|
||||
- 不直接展开长篇日程表
|
||||
- 只做分析,不创建任何记录
|
||||
- 如果涉及“今天/明天/后天/下周一下午”这类自然语言时间窗口,先调用 `resolve_time_expression` 把查询目标转换成明确日期
|
||||
"""
|
||||
|
||||
|
||||
PLANNER_STEPS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
SCHEDULE_PLANNING_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 planner 体系下的步骤拆解官,负责把目标转成有顺序的执行路径。
|
||||
你是 schedule_planner 体系下的日程编排员,负责把当前重点转成近期可执行安排。
|
||||
|
||||
## 你的重点:
|
||||
- 拆解步骤
|
||||
- 标注优先级与依赖
|
||||
- 输出清晰的行动顺序
|
||||
- 先给结论
|
||||
- 再给今天/近期的时间安排建议
|
||||
- 最后给按顺序执行的 next actions
|
||||
- 当用户明确要求新增/提醒/创建/安排并真正落库时,调用 schedule 工具创建对应 reminder/task/todo/goal
|
||||
- 当用户给出“日期 + 事项/节点/交付/会议”等记录型表达时,也应视为落库意图,直接创建相应记录,不要反问
|
||||
- 解析“今天/明天/后天/本周/下周”或“3月29日”这类日期时,必须以系统提供的当前时间为准,并把工具参数转换成明确的 ISO 日期/时间字符串
|
||||
- 只要用户输入里包含自然语言时间,优先调用 `resolve_time_expression`,先拿到明确日期/时间,再调用 `create_reminder`、`create_schedule_task`、`create_goal`、`create_todo`
|
||||
|
||||
## 响应要求:
|
||||
- 用编号列表
|
||||
- 每步具体,不要空泛
|
||||
- 必要时标注先后关系
|
||||
- 用清晰列表表达
|
||||
- 建议必须具体、可执行、贴近当前工作
|
||||
- 避免空泛的自我管理建议
|
||||
- 如果只是规划,不要创建任何记录
|
||||
- 如果已创建记录,要明确说明创建了什么、时间如何解析
|
||||
"""
|
||||
|
||||
|
||||
EXECUTOR_TASKS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
|
||||
你是 executor 体系下的任务执行官,只负责任务相关工具调用。
|
||||
你是 executor 体系下的任务执行官,负责处理任务、待办、提醒、目标等执行型工具调用。
|
||||
|
||||
## 允许使用的工具:
|
||||
- get_tasks
|
||||
- create_task
|
||||
- update_task_status
|
||||
- create_todo
|
||||
- create_schedule_task
|
||||
- create_reminder
|
||||
- create_goal
|
||||
- resolve_time_expression
|
||||
|
||||
## 要求:
|
||||
- 只处理任务类操作
|
||||
- 只处理任务/日程类操作
|
||||
- 遇到自然语言时间表达时,先调用 `resolve_time_expression`,再把解析后的明确日期/时间传给写入工具
|
||||
- 最终说明执行结果时,优先复用已经解析出的绝对时间,不要只重复“今天/明天”
|
||||
- 明确已执行动作、结果与下一步
|
||||
- 信息不足时直接指出缺口
|
||||
- 如果用户只是要分析建议,不要创建记录
|
||||
"""
|
||||
|
||||
|
||||
@@ -244,10 +261,14 @@ LIBRARIAN_RETRIEVAL_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
## 允许使用的工具:
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
- get_knowledge_graph_context
|
||||
|
||||
## 要求:
|
||||
- 优先检索与综合证据
|
||||
- 私有/项目知识优先使用 `search_knowledge` 或 `hybrid_search`
|
||||
- 当用户明确要求联网、查询外部资料或查询最新信息时,使用 `web_search`
|
||||
- 回答时区分内部知识与外部网页结果
|
||||
- 证据不足时明确说明边界
|
||||
- 以回答问题为主,不主动做图谱构建
|
||||
"""
|
||||
@@ -293,9 +314,52 @@ ANALYST_INSIGHTS_PROMPT = f"""{JARVIS_PERSONA_PROMPT}
|
||||
- get_forum_posts
|
||||
- search_knowledge
|
||||
- hybrid_search
|
||||
- web_search
|
||||
|
||||
## 要求:
|
||||
- 先给结论与判断
|
||||
- 再说明依据与建议
|
||||
- 当需要外部/最新信息时,可使用 `web_search`
|
||||
- 重点输出趋势、风险、机会点
|
||||
"""
|
||||
|
||||
|
||||
JSON_ACTION_FALLBACK_PROMPT = """你当前运行在 JSON action fallback 模式。
|
||||
|
||||
你的输出必须满足以下规则:
|
||||
1. 只能输出一个 JSON 对象,不要输出 markdown、解释、前后缀文字。
|
||||
2. JSON 对象字段仅允许:
|
||||
- `mode`: `final` | `tool_call` | `clarification`
|
||||
- `tool_calls`: 数组;每项包含 `name`、`arguments`,可选 `reason`
|
||||
- `final_response`: 当无需工具时填写
|
||||
- `clarification_question`: 当信息不足时填写
|
||||
3. 如果需要调用工具,返回:
|
||||
- `{ "mode": "tool_call", "tool_calls": [...] }`
|
||||
4. 如果无需工具,直接返回:
|
||||
- `{ "mode": "final", "final_response": "..." }`
|
||||
5. 如果信息不足,不要猜测参数,返回:
|
||||
- `{ "mode": "clarification", "clarification_question": "..." }`
|
||||
6. 只能使用系统消息里明确列出的工具名。
|
||||
7. `arguments` 必须是 JSON 对象。
|
||||
"""
|
||||
|
||||
|
||||
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY = {
|
||||
"master": MASTER_SYSTEM_PROMPT,
|
||||
"schedule_planner": SCHEDULE_PLANNER_SYSTEM_PROMPT,
|
||||
"executor": EXECUTOR_SYSTEM_PROMPT,
|
||||
"librarian": LIBRARIAN_SYSTEM_PROMPT,
|
||||
"analyst": ANALYST_SYSTEM_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
SUB_COMMANDER_PROMPTS_BY_KEY = {
|
||||
"schedule_analysis": SCHEDULE_ANALYSIS_PROMPT,
|
||||
"schedule_planning": SCHEDULE_PLANNING_PROMPT,
|
||||
"executor_tasks": EXECUTOR_TASKS_PROMPT,
|
||||
"executor_forum": EXECUTOR_FORUM_PROMPT,
|
||||
"librarian_retrieval": LIBRARIAN_RETRIEVAL_PROMPT,
|
||||
"librarian_graph": LIBRARIAN_GRAPH_PROMPT,
|
||||
"analyst_progress": ANALYST_PROGRESS_PROMPT,
|
||||
"analyst_insights": ANALYST_INSIGHTS_PROMPT,
|
||||
}
|
||||
|
||||
11
backend/app/agents/registry/__init__.py
Normal file
11
backend/app/agents/registry/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Registry manifest models and validation helpers."""
|
||||
|
||||
from app.agents.registry.indexes import RegistryIndexes, build_registry_indexes
|
||||
from app.agents.registry.loader import RegistryBundle, load_builtin_registry_bundle
|
||||
|
||||
__all__ = [
|
||||
"RegistryBundle",
|
||||
"RegistryIndexes",
|
||||
"build_registry_indexes",
|
||||
"load_builtin_registry_bundle",
|
||||
]
|
||||
114
backend/app/agents/registry/builtins.py
Normal file
114
backend/app/agents/registry/builtins.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from app.agents.prompts import SUB_COMMANDER_PROMPTS_BY_KEY
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
from app.agents.state import AgentRole
|
||||
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
||||
|
||||
|
||||
TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS: dict[str, tuple[str, ...]] = {
|
||||
AgentRole.MASTER.value: (),
|
||||
AgentRole.SCHEDULE_PLANNER.value: (
|
||||
"schedule_analysis",
|
||||
"schedule_planning",
|
||||
),
|
||||
AgentRole.EXECUTOR.value: (
|
||||
"executor_tasks",
|
||||
"executor_forum",
|
||||
),
|
||||
AgentRole.LIBRARIAN.value: (
|
||||
"librarian_retrieval",
|
||||
"librarian_graph",
|
||||
),
|
||||
AgentRole.ANALYST.value: (
|
||||
"analyst_progress",
|
||||
"analyst_insights",
|
||||
),
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_DISPLAY_NAMES: dict[str, str] = {
|
||||
AgentRole.MASTER.value: "Master",
|
||||
AgentRole.SCHEDULE_PLANNER.value: "Schedule Planner",
|
||||
AgentRole.EXECUTOR.value: "Executor",
|
||||
AgentRole.LIBRARIAN.value: "Librarian",
|
||||
AgentRole.ANALYST.value: "Analyst",
|
||||
}
|
||||
|
||||
TOP_LEVEL_AGENT_ROUTING_HINTS: dict[str, tuple[str, ...]] = {
|
||||
AgentRole.MASTER.value: (
|
||||
"Route user requests to the most suitable top-level runtime agent or answer directly.",
|
||||
),
|
||||
AgentRole.SCHEDULE_PLANNER.value: (
|
||||
"Handle planning-oriented requests using schedule analysis and schedule planning sub-commanders.",
|
||||
),
|
||||
AgentRole.EXECUTOR.value: (
|
||||
"Handle execution-oriented requests using task and forum sub-commanders.",
|
||||
),
|
||||
AgentRole.LIBRARIAN.value: (
|
||||
"Handle knowledge retrieval and graph-context requests using librarian sub-commanders.",
|
||||
),
|
||||
AgentRole.ANALYST.value: (
|
||||
"Handle reporting and insight requests using analyst sub-commanders.",
|
||||
),
|
||||
}
|
||||
|
||||
SUB_COMMANDER_PARENT_AGENT_IDS: dict[str, str] = {
|
||||
"schedule_analysis": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"schedule_planning": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"executor_tasks": AgentRole.EXECUTOR.value,
|
||||
"executor_forum": AgentRole.EXECUTOR.value,
|
||||
"librarian_retrieval": AgentRole.LIBRARIAN.value,
|
||||
"librarian_graph": AgentRole.LIBRARIAN.value,
|
||||
"analyst_progress": AgentRole.ANALYST.value,
|
||||
"analyst_insights": AgentRole.ANALYST.value,
|
||||
}
|
||||
|
||||
|
||||
BUILTIN_AGENT_MANIFESTS: tuple[AgentManifest, ...] = tuple(
|
||||
AgentManifest(
|
||||
agent_id=role.value,
|
||||
display_name=TOP_LEVEL_AGENT_DISPLAY_NAMES[role.value],
|
||||
role_value=role.value,
|
||||
system_prompt_key=role.value,
|
||||
routing_hints=list(TOP_LEVEL_AGENT_ROUTING_HINTS[role.value]),
|
||||
default_sub_commanders=list(TOP_LEVEL_AGENT_DEFAULT_SUB_COMMANDERS[role.value]),
|
||||
skill_context_key=role.value.replace("agent_", ""),
|
||||
)
|
||||
for role in AgentRole
|
||||
)
|
||||
|
||||
|
||||
_capability_tool_names = tuple(
|
||||
dict.fromkeys(
|
||||
tool.name
|
||||
for tools in SUB_COMMANDER_TOOLSETS.values()
|
||||
for tool in tools
|
||||
)
|
||||
)
|
||||
|
||||
BUILTIN_CAPABILITY_MANIFESTS: tuple[CapabilityManifest, ...] = tuple(
|
||||
CapabilityManifest(
|
||||
capability_id=tool_name,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
for tool_name in _capability_tool_names
|
||||
)
|
||||
|
||||
|
||||
BUILTIN_SUB_COMMANDER_MANIFESTS: tuple[SubCommanderManifest, ...] = tuple(
|
||||
SubCommanderManifest(
|
||||
sub_commander_id=sub_commander_id,
|
||||
parent_agent_id=SUB_COMMANDER_PARENT_AGENT_IDS[sub_commander_id],
|
||||
prompt_text=SUB_COMMANDER_PROMPTS_BY_KEY[sub_commander_id],
|
||||
capability_ids=list(
|
||||
dict.fromkeys(tool.name for tool in tools)
|
||||
),
|
||||
)
|
||||
for sub_commander_id, tools in SUB_COMMANDER_TOOLSETS.items()
|
||||
)
|
||||
|
||||
|
||||
BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS: tuple[SpecialistTemplateManifest, ...] = ()
|
||||
76
backend/app/agents/registry/indexes.py
Normal file
76
backend/app/agents/registry/indexes.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
|
||||
from app.agents.registry.loader import RegistryBundle
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryIndexes:
|
||||
agent_by_id: Mapping[str, AgentManifest]
|
||||
sub_commander_by_id: Mapping[str, SubCommanderManifest]
|
||||
capability_by_id: Mapping[str, CapabilityManifest]
|
||||
specialist_template_by_id: Mapping[str, SpecialistTemplateManifest]
|
||||
agent_prompt_key_by_id: Mapping[str, str]
|
||||
sub_commander_prompt_key_by_id: Mapping[str, str]
|
||||
skill_context_key_by_agent_id: Mapping[str, str]
|
||||
capability_id_by_tool_name: Mapping[str, str]
|
||||
capability_ids_by_sub_commander_id: Mapping[str, tuple[str, ...]]
|
||||
|
||||
|
||||
def summarize_registry_indexes(indexes: RegistryIndexes) -> dict[str, int]:
|
||||
return {
|
||||
"agent_count": len(indexes.agent_by_id),
|
||||
"sub_commander_count": len(indexes.sub_commander_by_id),
|
||||
"capability_count": len(indexes.capability_by_id),
|
||||
"specialist_template_count": len(indexes.specialist_template_by_id),
|
||||
}
|
||||
|
||||
|
||||
def build_registry_indexes(bundle: RegistryBundle) -> RegistryIndexes:
|
||||
agent_by_id = {agent.agent_id: agent for agent in bundle.agents}
|
||||
sub_commander_by_id = {
|
||||
sub_commander.sub_commander_id: sub_commander
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
capability_by_id = {
|
||||
capability.capability_id: capability for capability in bundle.capabilities
|
||||
}
|
||||
specialist_template_by_id = {
|
||||
template.template_id: template for template in bundle.specialist_templates
|
||||
}
|
||||
|
||||
return RegistryIndexes(
|
||||
agent_by_id=MappingProxyType(agent_by_id),
|
||||
sub_commander_by_id=MappingProxyType(sub_commander_by_id),
|
||||
capability_by_id=MappingProxyType(capability_by_id),
|
||||
specialist_template_by_id=MappingProxyType(specialist_template_by_id),
|
||||
agent_prompt_key_by_id=MappingProxyType({
|
||||
agent.agent_id: agent.system_prompt_key for agent in bundle.agents
|
||||
}),
|
||||
sub_commander_prompt_key_by_id=MappingProxyType({
|
||||
sub_commander.sub_commander_id: sub_commander.sub_commander_id
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}),
|
||||
skill_context_key_by_agent_id=MappingProxyType({
|
||||
agent.agent_id: agent.skill_context_key
|
||||
for agent in bundle.agents
|
||||
if agent.skill_context_key is not None
|
||||
}),
|
||||
capability_id_by_tool_name=MappingProxyType({
|
||||
capability.tool_name: capability.capability_id
|
||||
for capability in bundle.capabilities
|
||||
}),
|
||||
capability_ids_by_sub_commander_id=MappingProxyType({
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}),
|
||||
)
|
||||
33
backend/app/agents/registry/loader.py
Normal file
33
backend/app/agents/registry/loader.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.agents.registry.builtins import (
|
||||
BUILTIN_AGENT_MANIFESTS,
|
||||
BUILTIN_CAPABILITY_MANIFESTS,
|
||||
BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS,
|
||||
BUILTIN_SUB_COMMANDER_MANIFESTS,
|
||||
)
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryBundle:
|
||||
agents: tuple[AgentManifest, ...]
|
||||
sub_commanders: tuple[SubCommanderManifest, ...]
|
||||
capabilities: tuple[CapabilityManifest, ...]
|
||||
specialist_templates: tuple[SpecialistTemplateManifest, ...]
|
||||
|
||||
|
||||
def load_builtin_registry_bundle() -> RegistryBundle:
|
||||
return RegistryBundle(
|
||||
agents=BUILTIN_AGENT_MANIFESTS,
|
||||
sub_commanders=BUILTIN_SUB_COMMANDER_MANIFESTS,
|
||||
capabilities=BUILTIN_CAPABILITY_MANIFESTS,
|
||||
specialist_templates=BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS,
|
||||
)
|
||||
32
backend/app/agents/registry/models.py
Normal file
32
backend/app/agents/registry/models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentManifest(BaseModel):
|
||||
agent_id: str
|
||||
display_name: str
|
||||
role_value: str
|
||||
system_prompt_key: str
|
||||
routing_hints: list[str]
|
||||
default_sub_commanders: list[str]
|
||||
skill_context_key: str | None = None
|
||||
continuity_policy: str | None = None
|
||||
clarification_policy: str | None = None
|
||||
|
||||
|
||||
class SubCommanderManifest(BaseModel):
|
||||
sub_commander_id: str
|
||||
parent_agent_id: str
|
||||
prompt_text: str
|
||||
capability_ids: list[str]
|
||||
|
||||
|
||||
class CapabilityManifest(BaseModel):
|
||||
capability_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class SpecialistTemplateManifest(BaseModel):
|
||||
template_id: str
|
||||
display_name: str
|
||||
description: str
|
||||
allowed_capability_ids: list[str] | None = None
|
||||
55
backend/app/agents/registry/validator.py
Normal file
55
backend/app/agents/registry/validator.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
|
||||
|
||||
def _validate_unique_ids(values: Iterable[str], label: str) -> set[str]:
|
||||
unique_values: set[str] = set()
|
||||
for value in values:
|
||||
if value in unique_values:
|
||||
raise ValueError(f"duplicate {label}: {value}")
|
||||
unique_values.add(value)
|
||||
return unique_values
|
||||
|
||||
|
||||
def validate_registry_bundle(
|
||||
*,
|
||||
agents: list[AgentManifest],
|
||||
sub_commanders: list[SubCommanderManifest],
|
||||
capabilities: list[CapabilityManifest],
|
||||
specialist_templates: list[SpecialistTemplateManifest],
|
||||
) -> None:
|
||||
agent_ids = _validate_unique_ids((agent.agent_id for agent in agents), "agent id")
|
||||
_validate_unique_ids(
|
||||
(sub_commander.sub_commander_id for sub_commander in sub_commanders),
|
||||
"sub commander id",
|
||||
)
|
||||
capability_ids = _validate_unique_ids(
|
||||
(capability.capability_id for capability in capabilities),
|
||||
"capability id",
|
||||
)
|
||||
_validate_unique_ids(
|
||||
(specialist_template.template_id for specialist_template in specialist_templates),
|
||||
"template id",
|
||||
)
|
||||
|
||||
for sub_commander in sub_commanders:
|
||||
if sub_commander.parent_agent_id not in agent_ids:
|
||||
raise ValueError(f"unknown parent agent id: {sub_commander.parent_agent_id}")
|
||||
|
||||
for capability_id in sub_commander.capability_ids:
|
||||
if capability_id not in capability_ids:
|
||||
raise ValueError(f"unknown capability id: {capability_id}")
|
||||
|
||||
for specialist_template in specialist_templates:
|
||||
if specialist_template.allowed_capability_ids is None:
|
||||
continue
|
||||
|
||||
for capability_id in specialist_template.allowed_capability_ids:
|
||||
if capability_id not in capability_ids:
|
||||
raise ValueError(f"unknown capability id: {capability_id}")
|
||||
@@ -1,32 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict, Annotated
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
MASTER = "master"
|
||||
PLANNER = "planner"
|
||||
SCHEDULE_PLANNER = "schedule_planner"
|
||||
EXECUTOR = "executor"
|
||||
LIBRARIAN = "librarian"
|
||||
ANALYST = "analyst"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
name: str
|
||||
role: AgentRole
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
tool: str
|
||||
args: dict
|
||||
result: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
role: str # "user" | "assistant"
|
||||
@@ -35,60 +22,66 @@ class ConversationTurn:
|
||||
model: str | None = None
|
||||
|
||||
|
||||
def turn_to_message(turn: ConversationTurn) -> HumanMessage:
|
||||
return HumanMessage(content=turn.content)
|
||||
|
||||
|
||||
def message_to_turn(msg, agent: AgentRole | None = None) -> ConversationTurn:
|
||||
msg_type = getattr(msg, "type", None) or getattr(msg, "role", "assistant")
|
||||
return ConversationTurn(
|
||||
role="user" if msg_type in ("human", "user") else "assistant",
|
||||
content=msg.content,
|
||||
agent=agent,
|
||||
model=getattr(msg, "model", None),
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, None]
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
|
||||
# Agent routing
|
||||
current_agent: AgentRole
|
||||
current_agent: str | None
|
||||
next_step: str | None
|
||||
active_agents: list[AgentRole]
|
||||
current_sub_commander: str | None
|
||||
active_sub_commanders: list[str]
|
||||
sub_commander_trace: list[dict]
|
||||
sub_commander_trace: list[dict[str, Any]]
|
||||
agent_trace: list[str]
|
||||
|
||||
# Task tracking
|
||||
pending_tasks: list[dict]
|
||||
completed_tasks: list[dict]
|
||||
|
||||
# Tool usage
|
||||
tool_calls: list[ToolCall]
|
||||
pending_tasks: list[dict[str, Any]]
|
||||
completed_tasks: list[dict[str, Any]]
|
||||
tool_calls: list[dict[str, Any]]
|
||||
last_tool_result: str | None
|
||||
action_results: list[dict[str, Any]]
|
||||
created_entities: list[dict[str, Any]]
|
||||
tool_outcomes: list[dict[str, Any]]
|
||||
|
||||
# Knowledge context
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
tool_strategy_used: str | None
|
||||
tool_round_count: int
|
||||
max_tool_rounds: int
|
||||
retry_count: int
|
||||
max_retries: int
|
||||
iteration_count: int
|
||||
max_iterations: int
|
||||
routing_hops: int
|
||||
max_routing_hops: int
|
||||
terminated_due_to_loop_guard: bool
|
||||
retrieval_trace: list[dict[str, Any]]
|
||||
stop_reason: str | None
|
||||
|
||||
# Planning
|
||||
plan: str | None
|
||||
plan_steps: list[dict]
|
||||
|
||||
# Analysis
|
||||
analysis_report: str | None
|
||||
|
||||
# Output control
|
||||
final_response: str | None
|
||||
clarification_needed: bool
|
||||
clarification_question: str | None
|
||||
fallback_parse_error: str | None
|
||||
should_respond: bool
|
||||
|
||||
# Memory context (injected at start of each conversation)
|
||||
memory_context: str | None
|
||||
knowledge_context: str | None
|
||||
graph_context: str | None
|
||||
schedule_context_summary: str | None
|
||||
plan: str | None
|
||||
plan_steps: list[dict[str, Any]]
|
||||
analysis_report: str | None
|
||||
final_response: str | None
|
||||
|
||||
# User LLM config (for using user-configured models)
|
||||
user_llm_config: dict | None
|
||||
memory_context: str | None
|
||||
current_datetime_context: str | None
|
||||
current_datetime_reference: dict[str, str] | None
|
||||
|
||||
turn_context: dict[str, Any] | None
|
||||
routing_decision: dict[str, Any] | None
|
||||
continuity_state: dict[str, Any] | None
|
||||
pending_action: dict[str, Any] | None
|
||||
last_completed_action: dict[str, Any] | None
|
||||
clarification_context: dict[str, Any] | None
|
||||
|
||||
user_llm_config: dict[str, Any] | None
|
||||
provider_capabilities: dict[str, Any] | None
|
||||
|
||||
|
||||
def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
@@ -96,22 +89,52 @@ def initial_state(user_id: str, conversation_id: str) -> AgentState:
|
||||
messages=[],
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
current_agent=AgentRole.MASTER,
|
||||
current_agent=AgentRole.MASTER.value,
|
||||
next_step=None,
|
||||
active_agents=[AgentRole.MASTER],
|
||||
current_sub_commander=None,
|
||||
active_sub_commanders=[],
|
||||
sub_commander_trace=[],
|
||||
agent_trace=[AgentRole.MASTER.value],
|
||||
pending_tasks=[],
|
||||
completed_tasks=[],
|
||||
tool_calls=[],
|
||||
last_tool_result=None,
|
||||
action_results=[],
|
||||
created_entities=[],
|
||||
tool_outcomes=[],
|
||||
tool_strategy_used=None,
|
||||
tool_round_count=0,
|
||||
max_tool_rounds=2,
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
iteration_count=0,
|
||||
max_iterations=3,
|
||||
routing_hops=0,
|
||||
max_routing_hops=2,
|
||||
terminated_due_to_loop_guard=False,
|
||||
retrieval_trace=[],
|
||||
stop_reason=None,
|
||||
clarification_needed=False,
|
||||
clarification_question=None,
|
||||
fallback_parse_error=None,
|
||||
should_respond=True,
|
||||
knowledge_context=None,
|
||||
graph_context=None,
|
||||
schedule_context_summary=None,
|
||||
plan=None,
|
||||
plan_steps=[],
|
||||
analysis_report=None,
|
||||
final_response=None,
|
||||
should_respond=True,
|
||||
memory_context=None,
|
||||
current_datetime_context=None,
|
||||
current_datetime_reference=None,
|
||||
turn_context=None,
|
||||
routing_decision=None,
|
||||
continuity_state=None,
|
||||
pending_action=None,
|
||||
last_completed_action=None,
|
||||
clarification_context=None,
|
||||
user_llm_config=None,
|
||||
provider_capabilities=None,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from app.agents.tools.search import (
|
||||
search_knowledge, get_knowledge_graph_context,
|
||||
build_knowledge_graph, hybrid_search,
|
||||
build_knowledge_graph, hybrid_search, web_search,
|
||||
)
|
||||
from app.agents.tools.task import get_tasks, create_task, update_task_status
|
||||
from app.agents.tools.forum import get_forum_posts, create_forum_post, scan_forum_for_instructions
|
||||
from app.agents.tools.schedule import (
|
||||
get_schedule_day,
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
|
||||
TASK_TOOLS = [
|
||||
get_tasks,
|
||||
@@ -11,6 +19,19 @@ TASK_TOOLS = [
|
||||
update_task_status,
|
||||
]
|
||||
|
||||
SCHEDULE_READ_TOOLS = [
|
||||
get_schedule_day,
|
||||
get_tasks,
|
||||
resolve_time_expression,
|
||||
]
|
||||
|
||||
SCHEDULE_WRITE_TOOLS = [
|
||||
create_todo,
|
||||
create_schedule_task,
|
||||
create_reminder,
|
||||
create_goal,
|
||||
]
|
||||
|
||||
FORUM_TOOLS = [
|
||||
get_forum_posts,
|
||||
create_forum_post,
|
||||
@@ -20,6 +41,7 @@ FORUM_TOOLS = [
|
||||
KNOWLEDGE_RETRIEVAL_TOOLS = [
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
get_knowledge_graph_context,
|
||||
]
|
||||
|
||||
@@ -39,19 +61,22 @@ ANALYST_INSIGHT_TOOLS = [
|
||||
get_forum_posts,
|
||||
search_knowledge,
|
||||
hybrid_search,
|
||||
web_search,
|
||||
]
|
||||
|
||||
ALL_TOOLS = [
|
||||
*KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
build_knowledge_graph,
|
||||
*TASK_TOOLS,
|
||||
*SCHEDULE_READ_TOOLS,
|
||||
*SCHEDULE_WRITE_TOOLS,
|
||||
*FORUM_TOOLS,
|
||||
]
|
||||
|
||||
SUB_COMMANDER_TOOLSETS = {
|
||||
"planner_scope": [],
|
||||
"planner_steps": [],
|
||||
"executor_tasks": TASK_TOOLS,
|
||||
"schedule_analysis": SCHEDULE_READ_TOOLS,
|
||||
"schedule_planning": [*SCHEDULE_READ_TOOLS, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_tasks": [*TASK_TOOLS, resolve_time_expression, *SCHEDULE_WRITE_TOOLS],
|
||||
"executor_forum": FORUM_TOOLS,
|
||||
"librarian_retrieval": KNOWLEDGE_RETRIEVAL_TOOLS,
|
||||
"librarian_graph": KNOWLEDGE_GRAPH_TOOLS,
|
||||
|
||||
18
backend/app/agents/tools/async_bridge.py
Normal file
18
backend/app/agents/tools/async_bridge.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def run_async(coro: Any, timeout: int = 30):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return _executor.submit(asyncio.run, coro).result(timeout=timeout)
|
||||
|
||||
|
||||
__all__ = ["run_async"]
|
||||
@@ -4,17 +4,12 @@ from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(__import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
301
backend/app/agents/tools/schedule.py
Normal file
301
backend/app/agents/tools/schedule.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Agent 工具集 - 日程相关"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
def _parse_date(value: str | None) -> date:
|
||||
if not value:
|
||||
return date.today()
|
||||
return date.fromisoformat(value)
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def _parse_datetime_with_timezone(value: str, time_zone: str | None) -> datetime:
|
||||
"""Parse an ISO datetime and return a tz-naive datetime in the intended local time.
|
||||
|
||||
- If value includes an offset/Z, it will be converted to `time_zone` when provided.
|
||||
- If value is naive and `time_zone` is provided, it is interpreted in that zone.
|
||||
"""
|
||||
parsed = _parse_datetime(value)
|
||||
tz = (time_zone or "").strip()
|
||||
if parsed.tzinfo is None:
|
||||
if tz:
|
||||
parsed = parsed.replace(tzinfo=ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
if tz:
|
||||
parsed = parsed.astimezone(ZoneInfo(tz))
|
||||
return parsed.replace(tzinfo=None)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_schedule_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
if not resolved:
|
||||
return None
|
||||
if "T" in resolved:
|
||||
return resolved
|
||||
return f"{resolved}T09:00:00"
|
||||
|
||||
|
||||
def _format_summary(target_date: date, todos: list[DailyTodo], tasks: list[Task], reminders: list[Reminder], goals: list[Goal]) -> str:
|
||||
lines = [f"日期: {target_date.isoformat()}"]
|
||||
|
||||
if todos:
|
||||
lines.append("待办:")
|
||||
lines.extend(f"- {item.title} | 完成:{'是' if item.is_completed else '否'}" for item in todos)
|
||||
else:
|
||||
lines.append("待办: 无")
|
||||
|
||||
if tasks:
|
||||
lines.append("任务:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status} | 优先级:{item.priority.value if hasattr(item.priority, 'value') else item.priority} | 截止:{item.due_date.isoformat() if item.due_date else '无'}"
|
||||
for item in tasks
|
||||
)
|
||||
else:
|
||||
lines.append("任务: 无")
|
||||
|
||||
if reminders:
|
||||
lines.append("提醒:")
|
||||
lines.extend(f"- {item.title} | 时间:{item.reminder_at.isoformat()}" for item in reminders)
|
||||
else:
|
||||
lines.append("提醒: 无")
|
||||
|
||||
if goals:
|
||||
lines.append("目标:")
|
||||
lines.extend(
|
||||
f"- {item.title} | 状态:{item.status.value if hasattr(item.status, 'value') else item.status}"
|
||||
for item in goals
|
||||
)
|
||||
else:
|
||||
lines.append("目标: 无")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
def get_schedule_day(target_date: str | None = None) -> str:
|
||||
"""获取指定日期的 todo/task/reminder/goal 聚合信息。target_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(target_date)
|
||||
date_key = parsed_date.isoformat()
|
||||
start_dt = datetime.combine(parsed_date, datetime.min.time())
|
||||
end_dt = datetime.combine(parsed_date, datetime.max.time())
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
todos = (
|
||||
await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == uid, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
tasks = (
|
||||
await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == uid,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
reminders = (
|
||||
await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == uid,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
goals = (
|
||||
await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == uid, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
).scalars().all()
|
||||
return _format_summary(parsed_date, todos, tasks, reminders, goals)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as exc:
|
||||
return f"获取日程失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_todo(title: str, todo_date: str | None = None) -> str:
|
||||
"""创建指定日期的待办。todo_date 格式 YYYY-MM-DD,默认今天。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(todo_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
todo = DailyTodo(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
source=TodoSource.AI_CHAT,
|
||||
todo_date=parsed_date.isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
await db.refresh(todo)
|
||||
return f"TODO创建成功: [{todo.id[:8]}] {todo.title} @ {todo.todo_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建TODO失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_schedule_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: str = "medium",
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""创建任务。priority 支持 low/medium/high/urgent;due_date 使用 ISO datetime。兼容 content/date 别名。"""
|
||||
uid = get_current_user()
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_schedule_due_date(due_date, date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=TaskPriority(priority),
|
||||
due_date=_parse_datetime(resolved_due_date) if resolved_due_date else None,
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
due_label = task.due_date.isoformat() if task.due_date else "无截止时间"
|
||||
return f"任务创建成功: [{task.id[:8]}] {task.title} | 优先级:{task.priority.value} | 截止:{due_label}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建任务失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_reminder(
|
||||
title: str = "",
|
||||
reminder_at: str | None = None,
|
||||
note: str = "",
|
||||
description: str = "",
|
||||
datetime: str = "",
|
||||
at: str = "",
|
||||
remind_at: str = "",
|
||||
content: str = "",
|
||||
time_zone: str = "",
|
||||
timezone: str = "",
|
||||
time: str = "",
|
||||
) -> str:
|
||||
"""创建提醒。reminder_at 使用 ISO datetime。兼容 description/datetime/at/remind_at/time_zone 别名。"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_title = (title or content or "").strip()
|
||||
if not resolved_title:
|
||||
raise ValueError("title 不能为空")
|
||||
|
||||
resolved_at = ((reminder_at or datetime or at or remind_at or time or "").strip())
|
||||
if not resolved_at:
|
||||
raise ValueError("reminder_at 不能为空")
|
||||
|
||||
resolved_note = (note or description or "").strip()
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
tz = (time_zone or timezone or "").strip()
|
||||
reminder = Reminder(
|
||||
user_id=uid,
|
||||
title=resolved_title,
|
||||
note=resolved_note or None,
|
||||
reminder_at=_parse_datetime_with_timezone(resolved_at, tz),
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return f"提醒创建成功: [{reminder.id[:8]}] {reminder.title} @ {reminder.reminder_at.isoformat()}"
|
||||
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建提醒失败: {exc}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_goal(title: str, goal_date: str | None = None, note: str = "", status: str = "active") -> str:
|
||||
"""创建指定日期目标。goal_date 格式 YYYY-MM-DD,默认今天;status 支持 active/done/archived。"""
|
||||
uid = get_current_user()
|
||||
parsed_date = _parse_date(goal_date)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
goal = Goal(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
note=note or None,
|
||||
goal_date=parsed_date.isoformat(),
|
||||
status=GoalStatus(status),
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return f"目标创建成功: [{goal.id[:8]}] {goal.title} @ {goal.goal_date}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as exc:
|
||||
return f"创建目标失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_schedule_day",
|
||||
"create_todo",
|
||||
"create_schedule_task",
|
||||
"create_reminder",
|
||||
"create_goal",
|
||||
]
|
||||
@@ -6,22 +6,15 @@ Agent 工具集 - 知识库 & 图谱相关
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from app.database import async_session
|
||||
from app.agents.context import get_current_user
|
||||
import asyncio
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
"""在同步上下文中运行 async 代码"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor, lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -151,9 +144,56 @@ def hybrid_search(query: str, top_k: int = 5) -> str:
|
||||
return f"混合搜索失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str, top_k: int = 5) -> str:
|
||||
"""
|
||||
通过 SearxNG 搜索外部网页信息,返回标题、链接和摘要。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
top_k: 返回结果数量,默认 5 条
|
||||
|
||||
Returns:
|
||||
适合模型综合的网页结果文本
|
||||
"""
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
async def _search():
|
||||
service = WebSearchService()
|
||||
results = await service.search(query, limit=top_k)
|
||||
if not results:
|
||||
return "未找到相关网页结果。"
|
||||
|
||||
texts = []
|
||||
for index, result in enumerate(results, 1):
|
||||
source = f"\n来源: {result.source}" if result.source else ""
|
||||
published_at = f"\n时间: {result.published_at}" if result.published_at else ""
|
||||
snippet = result.snippet or "(无摘要)"
|
||||
texts.append(
|
||||
f"[{index}] {result.title}\n"
|
||||
f"链接: {result.url}{source}{published_at}\n"
|
||||
f"摘要: {snippet}"
|
||||
)
|
||||
return "\n\n---\n\n".join(texts)
|
||||
|
||||
try:
|
||||
return _run_async(_search(), timeout=30)
|
||||
except WebSearchConfigurationError as exc:
|
||||
return f"网页搜索不可用: {exc}"
|
||||
except WebSearchRequestError as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
except Exception as exc:
|
||||
return f"网页搜索失败: {exc}"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"search_knowledge",
|
||||
"get_knowledge_graph_context",
|
||||
"build_knowledge_graph",
|
||||
"hybrid_search",
|
||||
"web_search",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,77 @@
|
||||
"""Agent 工具集 - 任务相关"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from app.database import async_session
|
||||
from app.models.task import Task
|
||||
from app.agents.context import get_current_user
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
|
||||
_executor = None
|
||||
from app.models.base import utc_now
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.context import get_current_user
|
||||
from app.agents.tools.async_bridge import run_async
|
||||
from app.database import async_session
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
|
||||
|
||||
def _run_async(coro, timeout: int = 30):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.run_in_executor(_executor or __import__("concurrent.futures").ThreadPoolExecutor(), lambda: asyncio.run(coro))
|
||||
return future.result(timeout=timeout)
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
return run_async(coro, timeout=timeout)
|
||||
|
||||
|
||||
def _normalize_title(title: str | None, content: str | None) -> str:
|
||||
resolved = (title or content or "").strip()
|
||||
if not resolved:
|
||||
raise ValueError("title 不能为空")
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_due_date(due_date: str | None, date_value: str | None) -> str | None:
|
||||
resolved = (due_date or date_value or "").strip()
|
||||
return resolved or None
|
||||
|
||||
|
||||
def _parse_due_date(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" not in normalized:
|
||||
normalized = f"{normalized}T00:00:00"
|
||||
parsed = datetime.fromisoformat(normalized.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is not None:
|
||||
return parsed.astimezone(UTC).replace(tzinfo=None)
|
||||
return parsed
|
||||
|
||||
|
||||
def _normalize_priority(priority: int | str | None) -> TaskPriority:
|
||||
if priority is None or priority == "":
|
||||
return TaskPriority.MEDIUM
|
||||
if isinstance(priority, TaskPriority):
|
||||
return priority
|
||||
if isinstance(priority, int):
|
||||
return {
|
||||
1: TaskPriority.LOW,
|
||||
2: TaskPriority.MEDIUM,
|
||||
3: TaskPriority.HIGH,
|
||||
4: TaskPriority.URGENT,
|
||||
}.get(priority, TaskPriority.MEDIUM)
|
||||
normalized = str(priority).strip().lower()
|
||||
if not normalized:
|
||||
return TaskPriority.MEDIUM
|
||||
return TaskPriority(normalized)
|
||||
|
||||
|
||||
def _normalize_status(status: str) -> TaskStatus:
|
||||
normalized = status.strip().lower()
|
||||
return TaskStatus(normalized)
|
||||
|
||||
|
||||
def _format_status(value: TaskStatus | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
def _format_priority(value: TaskPriority | str) -> str:
|
||||
return value.value if hasattr(value, "value") else str(value)
|
||||
|
||||
|
||||
@tool
|
||||
@@ -25,7 +80,7 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
获取用户当前的任务列表。
|
||||
|
||||
Args:
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/blocked)
|
||||
status: 可选,筛选任务状态 (todo/in_progress/done/cancelled)
|
||||
limit: 返回数量,默认20
|
||||
|
||||
Returns:
|
||||
@@ -33,6 +88,9 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status) if status else None
|
||||
|
||||
async def _get():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
@@ -41,8 +99,8 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
.join(User, User.id == Task.user_id)
|
||||
.where(User.id == uid)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
if resolved_status:
|
||||
query = query.where(Task.status == resolved_status)
|
||||
query = query.order_by(Task.priority.desc(), Task.updated_at.desc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
tasks = result.scalars().all()
|
||||
@@ -52,48 +110,60 @@ def get_tasks(status: str | None = None, limit: int = 20) -> str:
|
||||
for t in tasks:
|
||||
lines.append(
|
||||
f"- [{t.id[:8]}] {t.title} | "
|
||||
f"状态:{t.status} | 优先级:{t.priority} | 截止:{t.due_date or '无'}"
|
||||
f"状态:{_format_status(t.status)} | 优先级:{_format_priority(t.priority)} | 截止:{t.due_date or '无'}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
try:
|
||||
return _run_async(_get())
|
||||
except Exception as e:
|
||||
return f"获取任务失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def create_task(title: str, description: str = "", priority: int = 2, due_date: str | None = None) -> str:
|
||||
def create_task(
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
priority: int | str = 2,
|
||||
due_date: str | None = None,
|
||||
content: str = "",
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建新任务。
|
||||
|
||||
Args:
|
||||
title: 任务标题(必填)
|
||||
title: 任务标题(必填,兼容 content 作为别名)
|
||||
description: 任务描述
|
||||
priority: 优先级 1-4,数字越大优先级越高,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD
|
||||
priority: 优先级,支持 1-4 或 low/medium/high/urgent,默认2
|
||||
due_date: 截止日期,格式 YYYY-MM-DD 或 ISO datetime
|
||||
content: title 的兼容别名
|
||||
date: due_date 的兼容别名
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_title = _normalize_title(title, content)
|
||||
resolved_due_date = _normalize_due_date(due_date, date)
|
||||
resolved_priority = _normalize_priority(priority)
|
||||
|
||||
async def _create():
|
||||
async with async_session() as db:
|
||||
task = Task(
|
||||
user_id=uid,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
due_date=due_date,
|
||||
status="todo",
|
||||
title=resolved_title,
|
||||
description=description or content or None,
|
||||
priority=resolved_priority,
|
||||
due_date=_parse_due_date(resolved_due_date),
|
||||
status=TaskStatus.TODO,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return f"任务创建成功: [{task.id[:8]}] {title}"
|
||||
return f"任务创建成功: [{task.id[:8]}] {resolved_title}"
|
||||
|
||||
try:
|
||||
return _run_async(_create())
|
||||
except Exception as e:
|
||||
return f"创建任务失败: {str(e)}"
|
||||
@@ -106,13 +176,16 @@ def update_task_status(task_id: str, status: str) -> str:
|
||||
|
||||
Args:
|
||||
task_id: 任务ID(完整ID或前8位)
|
||||
status: 新状态 (todo/in_progress/done/blocked)
|
||||
status: 新状态 (todo/in_progress/done/cancelled)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
uid = get_current_user()
|
||||
|
||||
try:
|
||||
resolved_status = _normalize_status(status)
|
||||
|
||||
async def _update():
|
||||
async with async_session() as db:
|
||||
from app.models.user import User
|
||||
@@ -129,11 +202,11 @@ def update_task_status(task_id: str, status: str) -> str:
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return f"任务不存在: {task_id}"
|
||||
task.status = status
|
||||
task.status = resolved_status
|
||||
task.completed_at = utc_now() if resolved_status == TaskStatus.DONE else None
|
||||
await db.commit()
|
||||
return f"任务状态已更新: {task.title} -> {status}"
|
||||
return f"任务状态已更新: {task.title} -> {resolved_status.value}"
|
||||
|
||||
try:
|
||||
return _run_async(_update())
|
||||
except Exception as e:
|
||||
return f"更新任务失败: {str(e)}"
|
||||
|
||||
273
backend/app/agents/tools/time_reasoning.py
Normal file
273
backend/app/agents/tools/time_reasoning.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
_WEEKDAY_MAP = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6, "天": 6}
|
||||
_DEFAULT_HOUR_BY_PERIOD = {
|
||||
"morning": 9,
|
||||
"noon": 12,
|
||||
"afternoon": 15,
|
||||
"evening": 20,
|
||||
}
|
||||
_TIME_KEYWORDS = ("今天", "明天", "后天", "本周", "这周", "下周", "周", "星期", "月", "日", "早上", "上午", "中午", "下午", "晚上", "今晚", "点", ":", ":")
|
||||
|
||||
|
||||
def _parse_datetime(value: str) -> datetime:
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized)
|
||||
|
||||
|
||||
def extract_reference_datetime(current_datetime_context: str | None) -> datetime:
|
||||
context = (current_datetime_context or "").strip()
|
||||
if context:
|
||||
for pattern in (r"current_time_utc:\s*(\S+)", r"CURRENT_TIME:\s*(\S+)", r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2}))"):
|
||||
match = re.search(pattern, context)
|
||||
if match:
|
||||
return _parse_datetime(match.group(1))
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def _normalize_local_iso(value: datetime) -> str:
|
||||
return value.replace(tzinfo=None).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _normalize_datetime_iso(value: datetime) -> str:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat(timespec="seconds")
|
||||
return _normalize_local_iso(value)
|
||||
|
||||
|
||||
def _normalize_date_iso(value: date) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _is_iso_datetime(value: str) -> bool:
|
||||
try:
|
||||
parsed = _parse_datetime(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return isinstance(parsed, datetime)
|
||||
|
||||
|
||||
def _is_iso_date(value: str) -> bool:
|
||||
try:
|
||||
date.fromisoformat(value.strip())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _has_explicit_time(text: str) -> bool:
|
||||
return bool(
|
||||
re.search(r"\d{1,2}[::]\d{2}", text)
|
||||
or re.search(r"\d{1,2}点(?:半|(?:\d{1,2})分?)?", text)
|
||||
or any(keyword in text for keyword in ("早上", "上午", "中午", "下午", "晚上", "今晚"))
|
||||
)
|
||||
|
||||
|
||||
def _detect_period(text: str) -> str | None:
|
||||
if any(keyword in text for keyword in ("晚上", "今晚")):
|
||||
return "evening"
|
||||
if "下午" in text:
|
||||
return "afternoon"
|
||||
if "中午" in text:
|
||||
return "noon"
|
||||
if any(keyword in text for keyword in ("早上", "上午", "早晨", "清晨")):
|
||||
return "morning"
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_time(text: str) -> tuple[time, bool, str | None]:
|
||||
period = _detect_period(text)
|
||||
colon_match = re.search(r"(\d{1,2})[::](\d{2})", text)
|
||||
if colon_match:
|
||||
hour = int(colon_match.group(1))
|
||||
minute = int(colon_match.group(2))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
half_match = re.search(r"(\d{1,2})点半", text)
|
||||
if half_match:
|
||||
hour = int(half_match.group(1))
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=30), False, period
|
||||
|
||||
dot_match = re.search(r"(\d{1,2})点(?:(\d{1,2})分?)?", text)
|
||||
if dot_match:
|
||||
hour = int(dot_match.group(1))
|
||||
minute = int(dot_match.group(2) or 0)
|
||||
if period in {"afternoon", "evening"} and hour < 12:
|
||||
hour += 12
|
||||
if period == "noon" and hour < 11:
|
||||
hour += 12
|
||||
return time(hour=hour, minute=minute), False, period
|
||||
|
||||
if period:
|
||||
return time(hour=_DEFAULT_HOUR_BY_PERIOD[period], minute=0), True, period
|
||||
return time(hour=9, minute=0), True, None
|
||||
|
||||
|
||||
def _resolve_date(text: str, reference: datetime) -> tuple[date, str]:
|
||||
stripped = text.strip()
|
||||
if _is_iso_date(stripped):
|
||||
return date.fromisoformat(stripped), "explicit_date"
|
||||
|
||||
month_day_match = re.search(r"(\d{1,2})月(\d{1,2})日", stripped)
|
||||
if month_day_match:
|
||||
month = int(month_day_match.group(1))
|
||||
day = int(month_day_match.group(2))
|
||||
candidate = date(reference.year, month, day)
|
||||
if candidate < reference.date() - timedelta(days=1):
|
||||
candidate = date(reference.year + 1, month, day)
|
||||
return candidate, "explicit_month_day"
|
||||
|
||||
if "后天" in stripped:
|
||||
return reference.date() + timedelta(days=2), "relative_day"
|
||||
if "明天" in stripped:
|
||||
return reference.date() + timedelta(days=1), "relative_day"
|
||||
if "今天" in stripped:
|
||||
return reference.date(), "relative_day"
|
||||
|
||||
weekday_match = re.search(r"((?:本周|这周|下周)?)(?:周|星期)([一二三四五六日天])", stripped)
|
||||
if weekday_match:
|
||||
prefix = weekday_match.group(1)
|
||||
weekday = _WEEKDAY_MAP[weekday_match.group(2)]
|
||||
current_weekday = reference.date().weekday()
|
||||
delta = weekday - current_weekday
|
||||
if prefix == "下周":
|
||||
delta += 7 if delta <= 0 else 7
|
||||
elif prefix in {"本周", "这周"}:
|
||||
if delta < 0:
|
||||
delta += 7
|
||||
elif delta < 0:
|
||||
delta += 7
|
||||
return reference.date() + timedelta(days=delta), "relative_weekday"
|
||||
|
||||
return reference.date(), "reference_day"
|
||||
|
||||
|
||||
def resolve_time_expression_data(
|
||||
expression: str,
|
||||
*,
|
||||
current_datetime_context: str | None = None,
|
||||
prefer: str = "datetime",
|
||||
) -> dict:
|
||||
text = (expression or "").strip()
|
||||
if not text:
|
||||
raise ValueError("expression 不能为空")
|
||||
|
||||
reference = extract_reference_datetime(current_datetime_context)
|
||||
|
||||
if _is_iso_datetime(text):
|
||||
parsed = _parse_datetime(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "datetime",
|
||||
"resolved_date": _normalize_date_iso(parsed.date()),
|
||||
"resolved_datetime": _normalize_datetime_iso(parsed),
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_datetime",
|
||||
}
|
||||
|
||||
if _is_iso_date(text):
|
||||
parsed_date = date.fromisoformat(text)
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": "date",
|
||||
"resolved_date": _normalize_date_iso(parsed_date),
|
||||
"resolved_datetime": None,
|
||||
"assumed_time": False,
|
||||
"reason": "explicit_date",
|
||||
}
|
||||
|
||||
resolved_date, date_reason = _resolve_date(text, reference)
|
||||
resolved_time, assumed_time, period = _resolve_time(text)
|
||||
has_explicit_time = _has_explicit_time(text)
|
||||
grain = "date" if prefer == "date" and not has_explicit_time else "datetime"
|
||||
resolved_dt = datetime.combine(resolved_date, resolved_time)
|
||||
note = date_reason
|
||||
if period:
|
||||
note = f"{note}:{period}"
|
||||
if assumed_time:
|
||||
note = f"{note}:assumed_time"
|
||||
return {
|
||||
"expression": text,
|
||||
"reference_time": reference.isoformat(),
|
||||
"grain": grain,
|
||||
"resolved_date": _normalize_date_iso(resolved_date),
|
||||
"resolved_datetime": None if grain == "date" else _normalize_local_iso(resolved_dt),
|
||||
"assumed_time": assumed_time,
|
||||
"reason": note,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def resolve_time_expression(
|
||||
expression: str,
|
||||
current_datetime_context: str = "",
|
||||
prefer: str = "datetime",
|
||||
) -> str:
|
||||
"""解析中文自然语言时间表达,基于当前参考时间返回明确的日期或 datetime。prefer 支持 datetime/date。"""
|
||||
try:
|
||||
payload = resolve_time_expression_data(
|
||||
expression,
|
||||
current_datetime_context=current_datetime_context or None,
|
||||
prefer=prefer,
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
except Exception as exc:
|
||||
return json.dumps(
|
||||
{
|
||||
"expression": expression,
|
||||
"error": str(exc),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def normalize_tool_time_arguments(tool_name: str, args: dict, current_datetime_context: str | None) -> dict:
|
||||
normalized = dict(args)
|
||||
|
||||
if tool_name == "create_reminder":
|
||||
raw_value = next((normalized.get(key) for key in ("reminder_at", "datetime", "at", "remind_at", "time") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="datetime")
|
||||
normalized["reminder_at"] = payload["resolved_datetime"]
|
||||
raw_date = normalized.get("date")
|
||||
if isinstance(raw_date, str) and raw_date.strip() and not _is_iso_date(raw_date):
|
||||
payload = resolve_time_expression_data(raw_date, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized["date"] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_schedule_task", "create_task"}:
|
||||
raw_value = next((normalized.get(key) for key in ("due_date", "date") if isinstance(normalized.get(key), str) and normalized.get(key).strip()), None)
|
||||
if raw_value and not _is_iso_datetime(raw_value) and not _is_iso_date(raw_value):
|
||||
prefer = "datetime" if tool_name == "create_schedule_task" or _has_explicit_time(raw_value) else "date"
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer=prefer)
|
||||
normalized["due_date"] = payload["resolved_datetime"] or payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
if tool_name in {"create_todo", "create_goal", "get_schedule_day"}:
|
||||
field_name = {
|
||||
"create_todo": "todo_date",
|
||||
"create_goal": "goal_date",
|
||||
"get_schedule_day": "target_date",
|
||||
}[tool_name]
|
||||
raw_value = normalized.get(field_name)
|
||||
if isinstance(raw_value, str) and raw_value.strip() and not _is_iso_date(raw_value):
|
||||
payload = resolve_time_expression_data(raw_value, current_datetime_context=current_datetime_context, prefer="date")
|
||||
normalized[field_name] = payload["resolved_date"]
|
||||
return normalized
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["resolve_time_expression", "resolve_time_expression_data", "normalize_tool_time_arguments", "extract_reference_datetime"]
|
||||
@@ -15,7 +15,9 @@ def _resolve_path(value: str) -> str:
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore")
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=str(ENV_FILE), env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
# === 应用基础 ===
|
||||
APP_NAME: str = "Jarvis"
|
||||
@@ -75,6 +77,8 @@ class Settings(BaseSettings):
|
||||
|
||||
# === 向量化 ===
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
EMBEDDING_BASE_URL: str = "https://api.openai.com/v1"
|
||||
EMBEDDING_API_KEY: str = ""
|
||||
CHUNK_SIZE: int = 500
|
||||
CHUNK_OVERLAP: int = 50
|
||||
|
||||
@@ -86,6 +90,17 @@ class Settings(BaseSettings):
|
||||
# === NAS 部署 ===
|
||||
NAS_DATA_ROOT: str = "/data/jarvis"
|
||||
|
||||
# === Web Search / SearxNG ===
|
||||
WEB_SEARCH_ENABLED: bool = False
|
||||
WEB_SEARCH_PROVIDER: str = "searxng"
|
||||
SEARXNG_BASE_URL: str = ""
|
||||
SEARXNG_AUTH_TYPE: Literal["none", "bearer", "basic"] = "none"
|
||||
SEARXNG_AUTH_TOKEN: str = ""
|
||||
SEARXNG_BASIC_USER: str = ""
|
||||
SEARXNG_BASIC_PASSWORD: str = ""
|
||||
WEB_SEARCH_DEFAULT_LIMIT: int = 5
|
||||
WEB_SEARCH_TIMEOUT_SECONDS: int = 10
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.DATABASE_URL = settings.DATABASE_URL.replace("./data", _resolve_path("./data"), 1)
|
||||
|
||||
@@ -40,6 +40,8 @@ async def init_db():
|
||||
await ensure_document_columns(conn)
|
||||
await ensure_user_columns(conn)
|
||||
await ensure_forum_columns(conn)
|
||||
await ensure_agent_columns(conn)
|
||||
await ensure_skill_columns(conn)
|
||||
|
||||
|
||||
async def ensure_log_columns(conn):
|
||||
@@ -139,6 +141,55 @@ async def ensure_forum_columns(conn):
|
||||
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_forum_posts_board ON forum_posts (board)"))
|
||||
|
||||
|
||||
async def ensure_agent_columns(conn):
|
||||
rows = await _get_table_info(conn, 'agents')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'selected_skill_ids': "ALTER TABLE agents ADD COLUMN selected_skill_ids JSON DEFAULT '[]' NOT NULL",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
|
||||
async def ensure_skill_columns(conn):
|
||||
rows = await _get_table_info(conn, 'skills')
|
||||
if not rows:
|
||||
return
|
||||
|
||||
columns = {row[1] for row in rows}
|
||||
required_columns = {
|
||||
'required_context': "ALTER TABLE skills ADD COLUMN required_context JSON DEFAULT '[]' NOT NULL",
|
||||
'output_format': "ALTER TABLE skills ADD COLUMN output_format TEXT",
|
||||
'is_builtin': "ALTER TABLE skills ADD COLUMN is_builtin BOOLEAN DEFAULT 0 NOT NULL",
|
||||
'team_id': "ALTER TABLE skills ADD COLUMN team_id VARCHAR(36)",
|
||||
}
|
||||
for column, ddl in required_columns.items():
|
||||
if column not in columns:
|
||||
await conn.execute(text(ddl))
|
||||
|
||||
await conn.execute(text("UPDATE skills SET agent_type = 'schedule_planner' WHERE agent_type = 'planner'"))
|
||||
builtin_names = [
|
||||
'今日重点拆解',
|
||||
'周计划编排',
|
||||
'时间冲突分析',
|
||||
'任务执行 SOP',
|
||||
'外部交互推进',
|
||||
'知识检索摘要',
|
||||
'图谱沉淀策略',
|
||||
'风险识别模板',
|
||||
'趋势洞察模板',
|
||||
]
|
||||
for name in builtin_names:
|
||||
await conn.execute(
|
||||
text("UPDATE skills SET is_builtin = 1 WHERE name = :name"),
|
||||
{'name': name},
|
||||
)
|
||||
|
||||
|
||||
async def _backfill_usernames(conn):
|
||||
result = await conn.execute(text("SELECT id, email, username FROM users ORDER BY created_at, id"))
|
||||
users = result.fetchall()
|
||||
|
||||
@@ -14,6 +14,9 @@ from app.routers import (
|
||||
graph_router,
|
||||
agent_router,
|
||||
todo_router,
|
||||
reminder_router,
|
||||
goal_router,
|
||||
schedule_center_router,
|
||||
settings_router,
|
||||
folder_router,
|
||||
skill_router,
|
||||
@@ -23,7 +26,7 @@ from app.routers import (
|
||||
)
|
||||
from app.routers.scheduler import router as scheduler_router
|
||||
from app.services.scheduler_service import start_scheduler, stop_scheduler, get_scheduler_status
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user
|
||||
from app.services.admin_bootstrap_service import ensure_admin_user, ensure_builtin_skills
|
||||
from app.config import settings
|
||||
from app.logging_utils import (
|
||||
setup_logging,
|
||||
@@ -53,6 +56,7 @@ async def run_startup() -> None:
|
||||
await init_db()
|
||||
async with async_session() as session:
|
||||
await ensure_admin_user(session, settings)
|
||||
await ensure_builtin_skills(session)
|
||||
await persist_system_log(
|
||||
message="application_started",
|
||||
source="app",
|
||||
@@ -103,6 +107,9 @@ app.include_router(forum_router)
|
||||
app.include_router(graph_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(todo_router)
|
||||
app.include_router(reminder_router)
|
||||
app.include_router(goal_router)
|
||||
app.include_router(schedule_center_router)
|
||||
app.include_router(settings_router)
|
||||
app.include_router(folder_router)
|
||||
app.include_router(skill_router)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from app.models.base import Base
|
||||
from app.models.user import User
|
||||
from app.models.folder import Folder
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.task import Task, TaskHistory
|
||||
from app.models.forum import ForumPost, ForumReply
|
||||
@@ -17,11 +18,14 @@ from app.models.brain import (
|
||||
brain_memory_sources,
|
||||
)
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.reminder import Reminder, ReminderStatus
|
||||
from app.models.goal import Goal, GoalStatus
|
||||
from app.models.log import Log, LogType, LogLevel
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"Folder",
|
||||
"Document",
|
||||
"DocumentChunk",
|
||||
"Task",
|
||||
@@ -45,6 +49,10 @@ __all__ = [
|
||||
"brain_memory_sources",
|
||||
"DailyTodo",
|
||||
"TodoSource",
|
||||
"Reminder",
|
||||
"ReminderStatus",
|
||||
"Goal",
|
||||
"GoalStatus",
|
||||
"Log",
|
||||
"LogType",
|
||||
"LogLevel",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -7,9 +7,10 @@ class Agent(BaseModel):
|
||||
__tablename__ = "agents"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
role = Column(String(100), nullable=False) # master, planner, executor, librarian, analyst
|
||||
role = Column(String(100), nullable=False) # master, schedule_planner, executor, librarian, analyst
|
||||
description = Column(Text, nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
selected_skill_ids = Column(JSON, default=list, nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ class Conversation(BaseModel):
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(500), nullable=True)
|
||||
message_count = Column(Integer, default=0)
|
||||
agent_state = Column(JSON, nullable=True)
|
||||
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
21
backend/app/models/goal.py
Normal file
21
backend/app/models/goal.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class GoalStatus(str, PyEnum):
|
||||
ACTIVE = "active"
|
||||
DONE = "done"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class Goal(BaseModel):
|
||||
__tablename__ = "goals"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
goal_date = Column(String(10), nullable=False, index=True)
|
||||
status = Column(Enum(GoalStatus), default=GoalStatus.ACTIVE, nullable=False)
|
||||
21
backend/app/models/reminder.py
Normal file
21
backend/app/models/reminder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, String, Text
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class ReminderStatus(str, PyEnum):
|
||||
PENDING = "pending"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class Reminder(BaseModel):
|
||||
__tablename__ = "reminders"
|
||||
|
||||
user_id = Column(String(36), ForeignKey("users.id"), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
note = Column(Text, nullable=True)
|
||||
reminder_at = Column(DateTime, nullable=False, index=True)
|
||||
status = Column(Enum(ReminderStatus), default=ReminderStatus.PENDING, nullable=False)
|
||||
is_dismissed = Column(Boolean, default=False, nullable=False)
|
||||
@@ -9,11 +9,12 @@ class Skill(BaseModel):
|
||||
name = Column(String(100), nullable=False, unique=True, index=True)
|
||||
description = Column(Text, nullable=True) # 供 LLM 理解用途
|
||||
instructions = Column(Text, nullable=False) # Agent 执行时的指令模板
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/planner/executor/librarian/analyst
|
||||
agent_type = Column(String(50), nullable=False, index=True) # master/schedule_planner/executor/librarian/analyst
|
||||
tools = Column(JSON, default=list) # 引用的工具名称列表
|
||||
required_context = Column(JSON, default=list) # 需要的前置数据
|
||||
output_format = Column(Text, nullable=True) # 输出格式要求
|
||||
visibility = Column(String(20), default="private") # private/team/market
|
||||
is_builtin = Column(Boolean, default=False, nullable=False)
|
||||
team_id = Column(String(36), ForeignKey("users.id"), nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
owner_id = Column(String(36), ForeignKey("users.id"), nullable=False)
|
||||
|
||||
@@ -6,6 +6,9 @@ from app.routers.forum import router as forum_router
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.settings import router as settings_router
|
||||
from app.routers.folder import router as folder_router
|
||||
from app.routers.skill import router as skill_router
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.agent import AgentCreate, AgentOut, AgentStats, AgentConfigUpdate, AgentConfigOut
|
||||
@@ -13,9 +14,9 @@ _agent_call_counts: dict[str, int] = {}
|
||||
_agent_current_tasks: dict[str, str | None] = {}
|
||||
_agent_statuses: dict[str, str] = {}
|
||||
|
||||
DEFAULT_AGENT_ROLES = ["master", "planner", "executor", "librarian", "analyst"]
|
||||
DEFAULT_AGENT_ROLES = ["master", "schedule_planner", "executor", "librarian", "analyst"]
|
||||
SUB_COMMANDERS_BY_ROLE = {
|
||||
"planner": ["planner_scope", "planner_steps"],
|
||||
"schedule_planner": ["schedule_analysis", "schedule_planning"],
|
||||
"executor": ["executor_tasks", "executor_forum"],
|
||||
"librarian": ["librarian_retrieval", "librarian_graph"],
|
||||
"analyst": ["analyst_progress", "analyst_insights"],
|
||||
@@ -88,10 +89,10 @@ async def get_agent_config(
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if not agent:
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
from app.agents.prompts import MASTER_SYSTEM_PROMPT, SCHEDULE_PLANNER_SYSTEM_PROMPT, EXECUTOR_SYSTEM_PROMPT, LIBRARIAN_SYSTEM_PROMPT, ANALYST_SYSTEM_PROMPT
|
||||
defaults = {
|
||||
"master": ("JARVIS", "主控制核心", MASTER_SYSTEM_PROMPT),
|
||||
"planner": ("PLANNER", "规划专家", PLANNER_SYSTEM_PROMPT),
|
||||
"schedule_planner": ("SCHEDULE PLANNER", "日程规划师", SCHEDULE_PLANNER_SYSTEM_PROMPT),
|
||||
"executor": ("EXECUTOR", "执行专家", EXECUTOR_SYSTEM_PROMPT),
|
||||
"librarian": ("LIBRARIAN", "知识管理员", LIBRARIAN_SYSTEM_PROMPT),
|
||||
"analyst": ("ANALYST", "数据分析师", ANALYST_SYSTEM_PROMPT),
|
||||
@@ -107,6 +108,7 @@ async def get_agent_config(
|
||||
system_prompt=prompt,
|
||||
enabled=True,
|
||||
is_active=True,
|
||||
selected_skill_ids=[],
|
||||
)
|
||||
return AgentConfigOut(
|
||||
id=agent.role,
|
||||
@@ -116,6 +118,7 @@ async def get_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +144,19 @@ async def update_agent_config(
|
||||
if data.enabled is not None:
|
||||
agent.is_active = data.enabled
|
||||
_agent_statuses[agent_id] = "disabled" if not data.enabled else "idle"
|
||||
if data.selected_skill_ids is not None:
|
||||
if data.selected_skill_ids:
|
||||
result = await db.execute(
|
||||
select(Skill.id).where(
|
||||
Skill.id.in_(data.selected_skill_ids),
|
||||
Skill.owner_id == current_user.id,
|
||||
)
|
||||
)
|
||||
allowed_skill_ids = set(result.scalars().all())
|
||||
invalid_skill_ids = [skill_id for skill_id in data.selected_skill_ids if skill_id not in allowed_skill_ids]
|
||||
if invalid_skill_ids:
|
||||
raise HTTPException(status_code=400, detail="存在无效的技能绑定")
|
||||
agent.selected_skill_ids = data.selected_skill_ids
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
@@ -152,6 +168,7 @@ async def update_agent_config(
|
||||
system_prompt=agent.system_prompt,
|
||||
enabled=agent.is_active,
|
||||
is_active=agent.is_active,
|
||||
selected_skill_ids=agent.selected_skill_ids or [],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import UserCreate, UserOut, Token
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import verify_password, get_password_hash, create_access_token, decode_token
|
||||
from app.config import settings
|
||||
|
||||
@@ -58,6 +59,7 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或邮箱已被注册")
|
||||
await db.refresh(user)
|
||||
await ensure_builtin_skills(db)
|
||||
return user
|
||||
|
||||
|
||||
@@ -97,10 +99,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSessi
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名、邮箱或密码错误")
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用")
|
||||
await ensure_builtin_skills(db)
|
||||
access_token = create_access_token(data={"sub": user.id})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
async def get_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)):
|
||||
await ensure_builtin_skills(db)
|
||||
return current_user
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
@@ -92,6 +93,7 @@ async def chat(
|
||||
):
|
||||
"""简单版对话(非流式)"""
|
||||
agent_svc = AgentService(db)
|
||||
try:
|
||||
conv_id, msg_id, content, model_name = await agent_svc.chat_simple(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
@@ -99,6 +101,8 @@ async def chat(
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
# 更新对话消息计数
|
||||
result = await db.execute(select(Conversation).where(Conversation.id == conv_id))
|
||||
@@ -126,6 +130,11 @@ async def chat_stream(
|
||||
agent_svc = AgentService(db)
|
||||
|
||||
async def stream_generator():
|
||||
stream = None
|
||||
msg_id = None
|
||||
should_emit_done = False
|
||||
try:
|
||||
try:
|
||||
conv_id, msg_id, stream = await agent_svc.chat(
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
@@ -133,6 +142,9 @@ async def chat_stream(
|
||||
file_ids=data.file_ids,
|
||||
model_name=data.model_name,
|
||||
)
|
||||
except ValueError as exc:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(exc)}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
yield f"event: metadata\ndata: {json.dumps({'conversation_id': conv_id, 'message_id': msg_id})}\n\n"
|
||||
|
||||
@@ -148,8 +160,12 @@ async def chat_stream(
|
||||
yield f"event: progress\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
should_emit_done = msg_id is not None
|
||||
if should_emit_done:
|
||||
yield f"event: done\ndata: {json.dumps({'message_id': msg_id})}\n\n"
|
||||
finally:
|
||||
if stream is not None:
|
||||
await stream.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
|
||||
92
backend/app/routers/goal.py
Normal file
92
backend/app/routers/goal.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.goal import GoalCreate, GoalListOut, GoalOut, GoalUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/goals", tags=["目标"])
|
||||
|
||||
|
||||
@router.get("", response_model=GoalListOut)
|
||||
async def list_goals(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str).isoformat()
|
||||
query = (
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id)
|
||||
.where(Goal.goal_date == target_date)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return GoalListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=GoalOut, status_code=201)
|
||||
async def create_goal(
|
||||
data: GoalCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
goal = Goal(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
goal_date=data.goal_date.isoformat(),
|
||||
status=data.status,
|
||||
)
|
||||
db.add(goal)
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.patch("/{goal_id}", response_model=GoalOut)
|
||||
async def update_goal(
|
||||
goal_id: str,
|
||||
data: GoalUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
payload = data.model_dump(exclude_none=True)
|
||||
if "goal_date" in payload:
|
||||
payload["goal_date"] = payload["goal_date"].isoformat()
|
||||
|
||||
for field, value in payload.items():
|
||||
setattr(goal, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.delete("/{goal_id}", status_code=204)
|
||||
async def delete_goal(
|
||||
goal_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Goal).where(Goal.id == goal_id, Goal.user_id == current_user.id)
|
||||
)
|
||||
goal = result.scalar_one_or_none()
|
||||
if not goal:
|
||||
raise HTTPException(status_code=404, detail="目标不存在")
|
||||
|
||||
await db.delete(goal)
|
||||
await db.commit()
|
||||
90
backend/app/routers/reminder.py
Normal file
90
backend/app/routers/reminder.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.reminder import ReminderCreate, ReminderListOut, ReminderOut, ReminderUpdate
|
||||
|
||||
router = APIRouter(prefix="/api/reminders", tags=["提醒"])
|
||||
|
||||
|
||||
@router.get("", response_model=ReminderListOut)
|
||||
async def list_reminders(
|
||||
date_str: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date.fromisoformat(date_str)
|
||||
start = datetime.combine(target_date, datetime.min.time())
|
||||
end = datetime.combine(target_date, datetime.max.time())
|
||||
query = (
|
||||
select(Reminder)
|
||||
.where(Reminder.user_id == current_user.id)
|
||||
.where(Reminder.reminder_at >= start)
|
||||
.where(Reminder.reminder_at <= end)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)
|
||||
items = (await db.execute(query)).scalars().all()
|
||||
return ReminderListOut(items=items)
|
||||
|
||||
|
||||
@router.post("", response_model=ReminderOut, status_code=201)
|
||||
async def create_reminder(
|
||||
data: ReminderCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
reminder = Reminder(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
note=data.note,
|
||||
reminder_at=data.reminder_at,
|
||||
)
|
||||
db.add(reminder)
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.patch("/{reminder_id}", response_model=ReminderOut)
|
||||
async def update_reminder(
|
||||
reminder_id: str,
|
||||
data: ReminderUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
for field, value in data.model_dump(exclude_none=True).items():
|
||||
setattr(reminder, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(reminder)
|
||||
return reminder
|
||||
|
||||
|
||||
@router.delete("/{reminder_id}", status_code=204)
|
||||
async def delete_reminder(
|
||||
reminder_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Reminder).where(Reminder.id == reminder_id, Reminder.user_id == current_user.id)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
if not reminder:
|
||||
raise HTTPException(status_code=404, detail="提醒不存在")
|
||||
|
||||
await db.delete(reminder)
|
||||
await db.commit()
|
||||
160
backend/app/routers/schedule_center.py
Normal file
160
backend/app/routers/schedule_center.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from calendar import monthrange
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.schedule_center import (
|
||||
ScheduleCenterDateOut,
|
||||
ScheduleCenterDaySummary,
|
||||
ScheduleCenterMonthOut,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/schedule-center", tags=["调度中心"])
|
||||
|
||||
|
||||
def _build_summary(
|
||||
target_date: str,
|
||||
todos: list[DailyTodo],
|
||||
tasks: list[Task],
|
||||
reminders: list[Reminder],
|
||||
goals: list[Goal],
|
||||
) -> ScheduleCenterDaySummary:
|
||||
return ScheduleCenterDaySummary(
|
||||
date=target_date,
|
||||
todo_total=len(todos),
|
||||
todo_completed=sum(1 for item in todos if item.is_completed),
|
||||
task_due_total=len(tasks),
|
||||
high_priority_total=sum(1 for item in tasks if item.priority in {TaskPriority.HIGH, TaskPriority.URGENT}),
|
||||
reminder_total=len(reminders),
|
||||
goal_total=len(goals),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/month", response_model=ScheduleCenterMonthOut)
|
||||
async def get_month_schedule(
|
||||
year: int = Query(..., ge=2000, le=2100),
|
||||
month: int = Query(..., ge=1, le=12),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
month_start = date(year, month, 1)
|
||||
days_in_month = monthrange(month_start.year, month_start.month)[1]
|
||||
start_key = month_start.isoformat()
|
||||
end_key = month_start.replace(day=days_in_month).isoformat()
|
||||
start_dt = datetime.combine(month_start, datetime.min.time())
|
||||
end_dt = datetime.combine(month_start.replace(day=days_in_month), datetime.max.time())
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo).where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date >= start_key, DailyTodo.todo_date <= end_key)
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task).where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal).where(Goal.user_id == current_user.id, Goal.goal_date >= start_key, Goal.goal_date <= end_key)
|
||||
)).scalars().all()
|
||||
|
||||
todo_map: dict[str, list[DailyTodo]] = {}
|
||||
for item in todos:
|
||||
todo_map.setdefault(item.todo_date, []).append(item)
|
||||
|
||||
task_map: dict[str, list[Task]] = {}
|
||||
for item in tasks:
|
||||
key = item.due_date.date().isoformat()
|
||||
task_map.setdefault(key, []).append(item)
|
||||
|
||||
reminder_map: dict[str, list[Reminder]] = {}
|
||||
for item in reminders:
|
||||
key = item.reminder_at.date().isoformat()
|
||||
reminder_map.setdefault(key, []).append(item)
|
||||
|
||||
goal_map: dict[str, list[Goal]] = {}
|
||||
for item in goals:
|
||||
goal_map.setdefault(item.goal_date, []).append(item)
|
||||
|
||||
days = []
|
||||
for day in range(1, days_in_month + 1):
|
||||
date_key = month_start.replace(day=day).isoformat()
|
||||
days.append(_build_summary(
|
||||
date_key,
|
||||
todo_map.get(date_key, []),
|
||||
task_map.get(date_key, []),
|
||||
reminder_map.get(date_key, []),
|
||||
goal_map.get(date_key, []),
|
||||
))
|
||||
|
||||
return ScheduleCenterMonthOut(month=f"{year:04d}-{month:02d}", days=days)
|
||||
|
||||
|
||||
@router.get("/date", response_model=ScheduleCenterDateOut)
|
||||
async def get_date_schedule(
|
||||
date_str: date = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
target_date = date_str
|
||||
start_dt = datetime.combine(target_date, datetime.min.time())
|
||||
end_dt = datetime.combine(target_date, datetime.max.time())
|
||||
date_key = target_date.isoformat()
|
||||
|
||||
todos = (await db.execute(
|
||||
select(DailyTodo)
|
||||
.where(DailyTodo.user_id == current_user.id, DailyTodo.todo_date == date_key)
|
||||
.order_by(DailyTodo.created_at.desc())
|
||||
)).scalars().all()
|
||||
tasks = (await db.execute(
|
||||
select(Task)
|
||||
.where(
|
||||
Task.user_id == current_user.id,
|
||||
Task.due_date.is_not(None),
|
||||
Task.due_date >= start_dt,
|
||||
Task.due_date <= end_dt,
|
||||
)
|
||||
.order_by(Task.created_at.desc())
|
||||
)).scalars().all()
|
||||
reminders = (await db.execute(
|
||||
select(Reminder)
|
||||
.where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.reminder_at >= start_dt,
|
||||
Reminder.reminder_at <= end_dt,
|
||||
)
|
||||
.order_by(Reminder.reminder_at.asc(), Reminder.created_at.asc())
|
||||
)).scalars().all()
|
||||
goals = (await db.execute(
|
||||
select(Goal)
|
||||
.where(Goal.user_id == current_user.id, Goal.goal_date == date_key)
|
||||
.order_by(Goal.created_at.desc())
|
||||
)).scalars().all()
|
||||
|
||||
summary = _build_summary(date_key, todos, tasks, reminders, goals)
|
||||
return ScheduleCenterDateOut(
|
||||
date=date_key,
|
||||
todos=todos,
|
||||
tasks=tasks,
|
||||
reminders=reminders,
|
||||
goals=goals,
|
||||
summary=summary,
|
||||
generated_at=datetime.now(UTC),
|
||||
)
|
||||
@@ -1,11 +1,13 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.schemas.skill import SkillCreate, SkillOut, SkillUpdate
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.skill_service import SkillService
|
||||
|
||||
router = APIRouter(prefix="/api/skills", tags=["Skill"])
|
||||
|
||||
@@ -37,13 +39,23 @@ async def create_skill(
|
||||
|
||||
@router.get("", response_model=list[SkillOut])
|
||||
async def list_skills(
|
||||
agent_type: str | None = Query(default=None),
|
||||
visibility: str | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Skill).where(Skill.owner_id == current_user.id).order_by(Skill.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id, agent_type=agent_type, visibility=visibility)
|
||||
|
||||
|
||||
@router.post("/bootstrap-builtin", response_model=list[SkillOut])
|
||||
async def bootstrap_builtin_skills(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await ensure_builtin_skills(db, preferred_owner_id=current_user.id)
|
||||
service = SkillService(db)
|
||||
return await service.list_for_user(current_user.id)
|
||||
|
||||
|
||||
@router.get("/{skill_id}", response_model=SkillOut)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from app.database import get_db
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.user import User
|
||||
@@ -13,12 +15,28 @@ router = APIRouter(prefix="/api/tasks", tags=["看板"])
|
||||
@router.get("", response_model=list[TaskOut])
|
||||
async def list_tasks(
|
||||
status: TaskStatus | None = None,
|
||||
due_date: date | None = Query(default=None),
|
||||
date_from: date | None = Query(default=None),
|
||||
date_to: date | None = Query(default=None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Task).where(Task.user_id == current_user.id)
|
||||
if status:
|
||||
query = query.where(Task.status == status)
|
||||
if due_date:
|
||||
start = datetime.combine(due_date, datetime.min.time())
|
||||
end = datetime.combine(due_date, datetime.max.time())
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start, Task.due_date <= end)
|
||||
else:
|
||||
start = datetime.combine(date_from, datetime.min.time()) if date_from else None
|
||||
end = datetime.combine(date_to, datetime.max.time()) if date_to else None
|
||||
if start and end and start > end:
|
||||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||||
if start is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date >= start)
|
||||
if end is not None:
|
||||
query = query.where(Task.due_date.is_not(None), Task.due_date <= end)
|
||||
query = query.order_by(desc(Task.created_at))
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
@@ -64,10 +82,10 @@ async def update_task(
|
||||
if field == "tags":
|
||||
setattr(task, field, json.dumps(value))
|
||||
elif field == "status" and value == TaskStatus.DONE:
|
||||
from datetime import UTC, datetime
|
||||
task.completed_at = datetime.now(UTC)
|
||||
setattr(task, field, value)
|
||||
else:
|
||||
elif field == "status":
|
||||
task.completed_at = None
|
||||
setattr(task, field, value)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from datetime import date
|
||||
from app.database import get_db
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
@@ -52,7 +53,7 @@ async def create_todo(
|
||||
user_id=current_user.id,
|
||||
title=data.title,
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date=date.today().isoformat(),
|
||||
todo_date=(data.todo_date or date.today()).isoformat(),
|
||||
)
|
||||
db.add(todo)
|
||||
await db.commit()
|
||||
@@ -74,14 +75,11 @@ async def update_todo(
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
|
||||
# 历史日期不允许修改
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可修改")
|
||||
|
||||
if data.title is not None:
|
||||
todo.title = data.title
|
||||
if data.todo_date is not None:
|
||||
todo.todo_date = data.todo_date.isoformat()
|
||||
if data.is_completed is not None:
|
||||
from datetime import UTC, datetime
|
||||
todo.is_completed = data.is_completed
|
||||
todo.completed_at = datetime.now(UTC) if data.is_completed else None
|
||||
|
||||
@@ -102,9 +100,6 @@ async def delete_todo(
|
||||
todo = result.scalar_one_or_none()
|
||||
if not todo:
|
||||
raise HTTPException(status_code=404, detail="待办不存在")
|
||||
if todo.todo_date != date.today().isoformat():
|
||||
raise HTTPException(status_code=403, detail="历史待办不可删除")
|
||||
|
||||
await db.delete(todo)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class AgentConfigUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
system_prompt: str | None = None
|
||||
enabled: bool | None = None
|
||||
selected_skill_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AgentConfigOut(BaseModel):
|
||||
@@ -51,5 +52,6 @@ class AgentConfigOut(BaseModel):
|
||||
system_prompt: str
|
||||
enabled: bool
|
||||
is_active: bool
|
||||
selected_skill_ids: list[str]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
35
backend/app/schemas/goal.py
Normal file
35
backend/app/schemas/goal.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.goal import GoalStatus
|
||||
|
||||
|
||||
class GoalCreate(BaseModel):
|
||||
title: str
|
||||
goal_date: date
|
||||
note: str | None = None
|
||||
status: GoalStatus = GoalStatus.ACTIVE
|
||||
|
||||
|
||||
class GoalUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
goal_date: date | None = None
|
||||
note: str | None = None
|
||||
status: GoalStatus | None = None
|
||||
|
||||
|
||||
class GoalOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
goal_date: str
|
||||
status: GoalStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class GoalListOut(BaseModel):
|
||||
items: list[GoalOut]
|
||||
40
backend/app/schemas/reminder.py
Normal file
40
backend/app/schemas/reminder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.reminder import ReminderStatus
|
||||
|
||||
|
||||
class ReminderCreate(BaseModel):
|
||||
title: str
|
||||
reminder_at: datetime
|
||||
note: str | None = None
|
||||
|
||||
|
||||
class ReminderUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
reminder_at: datetime | None = None
|
||||
note: str | None = None
|
||||
status: ReminderStatus | None = None
|
||||
is_dismissed: bool | None = None
|
||||
|
||||
|
||||
class ReminderOut(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
note: str | None
|
||||
reminder_at: datetime
|
||||
status: ReminderStatus
|
||||
is_dismissed: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ReminderListOut(BaseModel):
|
||||
items: list[ReminderOut]
|
||||
|
||||
|
||||
class ReminderDateQuery(BaseModel):
|
||||
date: date
|
||||
33
backend/app/schemas/schedule_center.py
Normal file
33
backend/app/schemas/schedule_center.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas.goal import GoalOut
|
||||
from app.schemas.reminder import ReminderOut
|
||||
from app.schemas.task import TaskOut
|
||||
from app.schemas.todo import TodoOut
|
||||
|
||||
|
||||
class ScheduleCenterDaySummary(BaseModel):
|
||||
date: str
|
||||
todo_total: int
|
||||
todo_completed: int
|
||||
task_due_total: int
|
||||
high_priority_total: int
|
||||
reminder_total: int
|
||||
goal_total: int
|
||||
|
||||
|
||||
class ScheduleCenterMonthOut(BaseModel):
|
||||
month: str
|
||||
days: list[ScheduleCenterDaySummary]
|
||||
|
||||
|
||||
class ScheduleCenterDateOut(BaseModel):
|
||||
date: str
|
||||
todos: list[TodoOut]
|
||||
tasks: list[TaskOut]
|
||||
reminders: list[ReminderOut]
|
||||
goals: list[GoalOut]
|
||||
summary: ScheduleCenterDaySummary
|
||||
generated_at: datetime
|
||||
@@ -10,7 +10,8 @@ LLMType = Literal["chat", "vlm", "embedding", "rerank"]
|
||||
# 单个模型配置
|
||||
class LLMModelConfig(BaseModel):
|
||||
name: str = "" # 模型名称/别名,用于标识
|
||||
provider: LLMProviderType = "openai"
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
@@ -52,7 +53,8 @@ class SettingsOut(BaseModel):
|
||||
# 测试 LLM 连接请求
|
||||
class LLMTestIn(BaseModel):
|
||||
type: LLMType
|
||||
provider: LLMProviderType
|
||||
# provider 已废弃为必填字段:优先通过 base_url + model 推断。
|
||||
provider: Optional[LLMProviderType] = None
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,7 +7,7 @@ class SkillCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
instructions: str
|
||||
agent_type: str # master/planner/executor/librarian/analyst
|
||||
agent_type: str # master/schedule_planner/executor/librarian/analyst
|
||||
tools: list[str] = []
|
||||
required_context: list[str] = []
|
||||
output_format: Optional[str] = None
|
||||
@@ -39,10 +40,11 @@ class SkillOut(BaseModel):
|
||||
required_context: list[str]
|
||||
output_format: Optional[str]
|
||||
visibility: str
|
||||
is_builtin: bool
|
||||
team_id: Optional[str]
|
||||
is_active: bool
|
||||
owner_id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.todo import TodoSource
|
||||
|
||||
|
||||
class TodoCreate(BaseModel):
|
||||
title: str
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
is_completed: bool | None = None
|
||||
todo_date: date | None = None
|
||||
|
||||
|
||||
class TodoOut(BaseModel):
|
||||
|
||||
@@ -2,10 +2,87 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
BUILTIN_SKILLS = [
|
||||
{
|
||||
'name': '今日重点拆解',
|
||||
'description': '帮助日程规划师从上下文中提炼今天最值得推进的事项。',
|
||||
'instructions': '优先识别今天最关键的 1-3 个重点,说明原因,并给出可执行顺序。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '周计划编排',
|
||||
'description': '把本周目标整理成可落地的节奏与时间块。',
|
||||
'instructions': '将目标拆成周内节奏安排,明确先后顺序、时间块与缓冲。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '时间冲突分析',
|
||||
'description': '识别任务、日程与优先级之间的冲突。',
|
||||
'instructions': '分析冲突来源、影响和推荐取舍,必要时给出替代方案。',
|
||||
'agent_type': 'schedule_planner',
|
||||
'tools': ['calendar', 'tasks'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '任务执行 SOP',
|
||||
'description': '为执行角色提供标准执行步骤和结果回报格式。',
|
||||
'instructions': '执行前先确认目标与边界,执行中记录关键动作,执行后输出结果、风险与下一步。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['shell', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '外部交互推进',
|
||||
'description': '支持论坛、外部接口或内容发布类动作。',
|
||||
'instructions': '围绕外部交互任务,优先保证动作完整、结果清晰、反馈及时。',
|
||||
'agent_type': 'executor',
|
||||
'tools': ['api_calls', 'git'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '知识检索摘要',
|
||||
'description': '从知识中枢中提炼与当前问题最相关的信息。',
|
||||
'instructions': '检索后只保留当前决策需要的内容,输出摘要、来源与缺口。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['web_search', 'database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '图谱沉淀策略',
|
||||
'description': '帮助知识管理员把零散信息沉淀为结构化关系。',
|
||||
'instructions': '识别应沉淀的实体、关系与后续可检索维度。',
|
||||
'agent_type': 'librarian',
|
||||
'tools': ['database'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '风险识别模板',
|
||||
'description': '帮助分析师快速识别当前推进中的风险点。',
|
||||
'instructions': '从进度、依赖、资源与外部信号中提炼风险,并按严重度排序。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'api_calls'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
{
|
||||
'name': '趋势洞察模板',
|
||||
'description': '把多源状态汇总为趋势与判断。',
|
||||
'instructions': '对比近期变化,输出趋势、证据、判断与建议动作。',
|
||||
'agent_type': 'analyst',
|
||||
'tools': ['database', 'code_execution'],
|
||||
'visibility': 'market',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _is_bootstrap_enabled(settings) -> bool:
|
||||
return bool(settings.ADMIN.strip() and settings.ADMIN_EMAIL.strip() and settings.ADMIN_PASSWORD.strip())
|
||||
|
||||
@@ -58,3 +135,49 @@ async def ensure_admin_user(db: AsyncSession, settings) -> None:
|
||||
return
|
||||
raise
|
||||
await db.refresh(admin_user)
|
||||
|
||||
|
||||
async def ensure_builtin_skills(db: AsyncSession, preferred_owner_id: str | None = None) -> None:
|
||||
owner = None
|
||||
if preferred_owner_id:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.id == preferred_owner_id, User.is_active == True)
|
||||
)
|
||||
owner = owner_result.scalar_one_or_none()
|
||||
|
||||
if not owner:
|
||||
owner_result = await db.execute(
|
||||
select(User).where(User.is_active == True).order_by(User.is_superuser.desc(), User.created_at.asc())
|
||||
)
|
||||
owner = owner_result.scalars().first()
|
||||
|
||||
if not owner:
|
||||
return
|
||||
|
||||
existing_result = await db.execute(select(Skill.name))
|
||||
existing_names = set(existing_result.scalars().all())
|
||||
|
||||
missing_skills = [
|
||||
Skill(
|
||||
owner_id=owner.id,
|
||||
name=item['name'],
|
||||
description=item['description'],
|
||||
instructions=item['instructions'],
|
||||
agent_type=item['agent_type'],
|
||||
tools=item['tools'],
|
||||
required_context=[],
|
||||
output_format=None,
|
||||
visibility=item['visibility'],
|
||||
is_builtin=True,
|
||||
team_id=None,
|
||||
is_active=True,
|
||||
)
|
||||
for item in BUILTIN_SKILLS
|
||||
if item['name'] not in existing_names
|
||||
]
|
||||
|
||||
if not missing_skills:
|
||||
return
|
||||
|
||||
db.add_all(missing_skills)
|
||||
await db.commit()
|
||||
|
||||
@@ -5,18 +5,17 @@ Jarvis Agent 服务层
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, AsyncGenerator
|
||||
import asyncio
|
||||
from openai import BadRequestError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_ollama import ChatOllama
|
||||
import httpx
|
||||
|
||||
from app.database import async_session
|
||||
from app.logging_utils import summarize_llm_config
|
||||
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.user import User
|
||||
@@ -24,42 +23,262 @@ from app.agents.graph import get_agent_graph
|
||||
from app.agents.context import set_current_user, clear_current_user
|
||||
from app.services import memory_service
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import create_llm_from_config, resolve_provider_capabilities
|
||||
from app.agents.tools.time_reasoning import extract_reference_datetime
|
||||
from app.agents.state import initial_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_llm_from_config(config: dict):
|
||||
"""根据用户模型配置创建 LLM 实例"""
|
||||
provider = config.get("provider", "openai")
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
MEMORY_SECTION_HEADERS = (
|
||||
"【用户记忆】",
|
||||
"【之前对话摘要】",
|
||||
"【知识大脑】",
|
||||
)
|
||||
|
||||
if provider == "openai" or provider == "deepseek" or provider == "custom":
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
|
||||
def _split_memory_context_sections(memory_context: str | None) -> dict[str, str]:
|
||||
text = (memory_context or "").strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
sections: dict[str, str] = {}
|
||||
current_header: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped in MEMORY_SECTION_HEADERS:
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
current_header = stripped
|
||||
current_lines = [stripped]
|
||||
continue
|
||||
if current_header:
|
||||
current_lines.append(line)
|
||||
|
||||
if current_header and current_lines:
|
||||
sections[current_header] = "\n".join(current_lines).strip()
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _derive_role_memory_contexts(memory_context: str | None) -> dict[str, str | None]:
|
||||
sections = _split_memory_context_sections(memory_context)
|
||||
user_memory = sections.get("【用户记忆】")
|
||||
summaries = sections.get("【之前对话摘要】")
|
||||
knowledge = sections.get("【知识大脑】")
|
||||
|
||||
def _join_parts(*parts: str | None) -> str | None:
|
||||
values = [part for part in parts if part]
|
||||
return "\n\n".join(values) if values else None
|
||||
|
||||
return {
|
||||
"schedule_context_summary": _join_parts(user_memory, summaries),
|
||||
"knowledge_context": knowledge,
|
||||
"analysis_report": _join_parts(summaries, knowledge),
|
||||
}
|
||||
|
||||
|
||||
def _is_streaming_rejection_error(error: Exception, user_llm_config: dict | None) -> bool:
|
||||
capabilities = resolve_provider_capabilities(user_llm_config)
|
||||
error_text = str(error).lower()
|
||||
markers = [
|
||||
"invalid chat setting",
|
||||
"invalid params",
|
||||
"stream",
|
||||
"streaming",
|
||||
"unsupported",
|
||||
"bad_request_error",
|
||||
"http 400",
|
||||
"error code: 400",
|
||||
]
|
||||
|
||||
if isinstance(error, BadRequestError):
|
||||
return (
|
||||
getattr(capabilities, "provider", None) not in {"openai", "claude"}
|
||||
and any(marker in error_text for marker in markers)
|
||||
)
|
||||
elif provider == "claude":
|
||||
return ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
|
||||
return any(marker in error_text for marker in markers)
|
||||
|
||||
|
||||
def _coerce_event_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content) if content else ""
|
||||
|
||||
|
||||
_CONTINUITY_STATE_VERSION = 1
|
||||
_CONTINUITY_SNAPSHOT_FIELDS = (
|
||||
"turn_context",
|
||||
"routing_decision",
|
||||
"continuity_state",
|
||||
"pending_action",
|
||||
"last_completed_action",
|
||||
"clarification_context",
|
||||
"tool_outcomes",
|
||||
"pending_tasks",
|
||||
"completed_tasks",
|
||||
"created_entities",
|
||||
"current_agent",
|
||||
"next_step",
|
||||
"agent_trace",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_legacy_turn_context(turn_context: Any, current_agent: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(turn_context, dict):
|
||||
return None
|
||||
normalized = dict(turn_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
active_sub_flow = normalized.pop("active_sub_flow", None)
|
||||
if isinstance(active_agent, str) and active_agent and "active_agent" not in normalized:
|
||||
normalized["active_agent"] = active_agent
|
||||
if isinstance(active_sub_flow, str) and active_sub_flow and "active_sub_commander" not in normalized:
|
||||
normalized["active_sub_commander"] = active_sub_flow
|
||||
if not normalized.get("active_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["active_agent"] = current_agent
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_pending_action(pending_action: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(pending_action, dict):
|
||||
return None
|
||||
normalized = dict(pending_action)
|
||||
legacy_action_type = normalized.pop("action_type", None)
|
||||
if legacy_action_type and "type" not in normalized:
|
||||
normalized["type"] = legacy_action_type
|
||||
legacy_agent = normalized.pop("agent", None)
|
||||
legacy_sub_flow = normalized.pop("sub_flow", None)
|
||||
if legacy_agent and "owner_agent" not in normalized:
|
||||
normalized["owner_agent"] = legacy_agent
|
||||
if legacy_sub_flow and "owner_sub_commander" not in normalized:
|
||||
normalized["owner_sub_commander"] = legacy_sub_flow
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_confirmation":
|
||||
normalized["status"] = "pending"
|
||||
elif legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "blocked_on_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_clarification_context(
|
||||
clarification_context: Any,
|
||||
pending_action: dict[str, Any] | None,
|
||||
current_agent: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(clarification_context, dict):
|
||||
return None
|
||||
normalized = dict(clarification_context)
|
||||
active_agent = normalized.pop("active_agent", None)
|
||||
sub_flow = normalized.pop("sub_flow", None)
|
||||
awaiting_user_input = normalized.pop("awaiting_user_input", None)
|
||||
if isinstance(active_agent, str) and active_agent and "owning_agent" not in normalized:
|
||||
normalized["owning_agent"] = active_agent
|
||||
if isinstance(sub_flow, str) and sub_flow and "owning_sub_commander" not in normalized:
|
||||
normalized["owning_sub_commander"] = sub_flow
|
||||
if "target_action" not in normalized:
|
||||
target_action = None
|
||||
if pending_action:
|
||||
pending_type = pending_action.get("type")
|
||||
if isinstance(pending_type, str) and pending_type and pending_type != "clarification":
|
||||
target_action = pending_type
|
||||
if target_action is None and isinstance(sub_flow, str) and sub_flow.startswith("create_"):
|
||||
target_action = sub_flow
|
||||
if target_action:
|
||||
normalized["target_action"] = target_action
|
||||
if not normalized.get("owning_agent") and isinstance(current_agent, str) and current_agent:
|
||||
normalized["owning_agent"] = current_agent
|
||||
if awaiting_user_input is True and "status" not in normalized:
|
||||
normalized["status"] = "pending"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_legacy_continuity_state(
|
||||
continuity_state: Any,
|
||||
clarification_context: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(continuity_state, dict):
|
||||
return None
|
||||
normalized = dict(continuity_state)
|
||||
normalized.pop("active_agent", None)
|
||||
normalized.pop("active_sub_flow", None)
|
||||
legacy_status = normalized.get("status")
|
||||
if legacy_status == "awaiting_clarification":
|
||||
normalized["status"] = "fresh"
|
||||
if clarification_context and "mode" not in normalized:
|
||||
normalized["mode"] = "resume_after_clarification"
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _normalize_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(state)
|
||||
current_agent = normalized.get("current_agent")
|
||||
pending_action = _normalize_legacy_pending_action(normalized.get("pending_action"))
|
||||
clarification_context = _normalize_legacy_clarification_context(
|
||||
normalized.get("clarification_context"),
|
||||
pending_action,
|
||||
current_agent,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
# 默认使用 OpenAI
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
continuity_state = _normalize_legacy_continuity_state(
|
||||
normalized.get("continuity_state"),
|
||||
clarification_context,
|
||||
)
|
||||
turn_context = _normalize_legacy_turn_context(normalized.get("turn_context"), current_agent)
|
||||
if pending_action is not None:
|
||||
normalized["pending_action"] = pending_action
|
||||
if clarification_context is not None:
|
||||
normalized["clarification_context"] = clarification_context
|
||||
if continuity_state is not None:
|
||||
normalized["continuity_state"] = continuity_state
|
||||
if turn_context is not None:
|
||||
normalized["turn_context"] = turn_context
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_continuity_snapshot(state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
normalized_state = _normalize_continuity_snapshot(state)
|
||||
snapshot = {
|
||||
field: normalized_state.get(field)
|
||||
for field in _CONTINUITY_SNAPSHOT_FIELDS
|
||||
if normalized_state.get(field) is not None
|
||||
}
|
||||
if not snapshot:
|
||||
return None
|
||||
return {
|
||||
"version": _CONTINUITY_STATE_VERSION,
|
||||
"state": snapshot,
|
||||
}
|
||||
|
||||
|
||||
def _extract_continuity_snapshot(payload: Any) -> dict[str, Any] | None:
|
||||
if isinstance(payload, list):
|
||||
for item in payload:
|
||||
snapshot = _extract_continuity_snapshot(item)
|
||||
if snapshot:
|
||||
return snapshot
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
if payload.get("kind") != "agent_continuity_state":
|
||||
return None
|
||||
if payload.get("version") != _CONTINUITY_STATE_VERSION:
|
||||
return None
|
||||
state = payload.get("state")
|
||||
if isinstance(state, dict):
|
||||
return _normalize_continuity_snapshot(state)
|
||||
return None
|
||||
|
||||
|
||||
class AgentService:
|
||||
@@ -92,38 +311,83 @@ class AgentService:
|
||||
"steps": steps or [],
|
||||
}
|
||||
|
||||
def _build_current_datetime_context(self) -> tuple[str, dict[str, str]]:
|
||||
now_utc = datetime.now(UTC)
|
||||
reference = {
|
||||
"current_time_iso": now_utc.isoformat(),
|
||||
"current_date_iso": now_utc.date().isoformat(),
|
||||
}
|
||||
context = (
|
||||
"【当前时间】\n"
|
||||
f"- current_time_utc: {reference['current_time_iso']}\n"
|
||||
f"- current_date_utc: {reference['current_date_iso']}\n"
|
||||
"说明:解析‘今天/明天/后天/本周/下周’等相对时间时,请以 current_time_utc 为准。"
|
||||
)
|
||||
return context, reference
|
||||
|
||||
async def _get_user_llm_config(self, user_id: str, model_name: str | None = None) -> dict | None:
|
||||
"""获取用户的 LLM 模型配置"""
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await self.db.get(User, user_id)
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
llm_config = user.llm_config
|
||||
|
||||
# 如果指定了模型名称,查找对应的配置
|
||||
if model_name:
|
||||
for model_type in ["chat", "vlm"]:
|
||||
models = llm_config.get(model_type, [])
|
||||
models = llm_config.get("chat", [])
|
||||
for m in models:
|
||||
if m.get("name") == model_name:
|
||||
return m
|
||||
# 没找到,返回 None 让调用方知道配置不存在
|
||||
return None
|
||||
|
||||
# 如果没指定模型名,返回默认启用的 chat 模型
|
||||
chat_models = llm_config.get("chat", [])
|
||||
for m in chat_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
vlm_models = llm_config.get("vlm", [])
|
||||
for m in vlm_models:
|
||||
if m.get("enabled"):
|
||||
return m
|
||||
|
||||
return None
|
||||
|
||||
async def _load_continuity_snapshot(self, conversation: Conversation) -> dict[str, Any] | None:
|
||||
snapshot = _extract_continuity_snapshot(getattr(conversation, "agent_state", None))
|
||||
if snapshot:
|
||||
return snapshot
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.role == "assistant")
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
for message in result.scalars():
|
||||
snapshot = _extract_continuity_snapshot(message.attachments)
|
||||
if snapshot:
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
async def _build_agent_state(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
conversation: Conversation,
|
||||
full_message: str,
|
||||
memory_context: str | None,
|
||||
current_datetime_context: str,
|
||||
current_datetime_reference: dict[str, str],
|
||||
user_llm_config: dict | None,
|
||||
) -> dict[str, Any]:
|
||||
state = initial_state(user_id, conversation.id)
|
||||
state.update({
|
||||
"messages": [HumanMessage(content=full_message)],
|
||||
"memory_context": memory_context,
|
||||
"current_datetime_context": current_datetime_context,
|
||||
"current_datetime_reference": current_datetime_reference,
|
||||
"user_llm_config": user_llm_config,
|
||||
})
|
||||
previous_snapshot = await self._load_continuity_snapshot(conversation)
|
||||
if previous_snapshot:
|
||||
state.update(previous_snapshot)
|
||||
state["messages"] = [HumanMessage(content=full_message)]
|
||||
return state
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -134,16 +398,36 @@ class AgentService:
|
||||
) -> tuple[str, str, AsyncGenerator[dict[str, Any], None]]:
|
||||
"""
|
||||
处理对话请求(流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_stream)
|
||||
"""
|
||||
# 获取或创建对话
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if model_name and not user_llm_config:
|
||||
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
logger.info(
|
||||
"agent_chat_started",
|
||||
extra={
|
||||
"details": {
|
||||
"mode": "stream",
|
||||
"requested_model_name": model_name,
|
||||
"resolved_model_name": model_name_used,
|
||||
"message_length": len(message or ""),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv is None:
|
||||
raise ValueError("会话不存在或无权访问")
|
||||
else:
|
||||
conv = None
|
||||
|
||||
@@ -156,7 +440,6 @@ class AgentService:
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
@@ -168,7 +451,6 @@ class AgentService:
|
||||
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -193,156 +475,170 @@ class AgentService:
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 预创建助手消息(后续更新内容)
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model=model_name_used or "jarvis",
|
||||
attachments=None,
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
def _build_assistant_event_payload(content: str) -> dict[str, Any]:
|
||||
return {
|
||||
"source_type": "conversation",
|
||||
"source_id": conversation_id,
|
||||
"event_type": "message_created",
|
||||
"title": "Assistant message",
|
||||
"content_summary": content[:500],
|
||||
"raw_excerpt": content[:2000],
|
||||
"metadata_": {"role": "assistant"},
|
||||
"importance_signal": 0.8,
|
||||
}
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
async def run_agent():
|
||||
collected = ""
|
||||
state: dict[str, Any] | None = None
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"current_sub_commander": None,
|
||||
"active_sub_commanders": [],
|
||||
"sub_commander_trace": [],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
||||
|
||||
state = await self._build_agent_state(
|
||||
user_id=user_id,
|
||||
conversation=conv,
|
||||
full_message=full_message,
|
||||
memory_context=memory_ctx,
|
||||
current_datetime_context=current_datetime_context,
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
|
||||
yield self._build_progress_event("thinking", "Jarvis 正在分析请求", agent="master", step="理解你的问题")
|
||||
|
||||
collected = ""
|
||||
async for event in graph.astream_events(langgraph_state, version="v2"):
|
||||
try:
|
||||
async for event in graph.astream_events(state, version="v2"):
|
||||
kind = event.get("event")
|
||||
event_name = event.get("name", "")
|
||||
metadata = event.get("metadata", {})
|
||||
data = event.get("data", {})
|
||||
|
||||
if kind == "on_chain_start" and event_name in {"master", "planner", "executor", "librarian", "analyst"}:
|
||||
if kind == "on_chain_start" and event_name in {"master", "schedule_planner", "executor", "librarian", "analyst"}:
|
||||
stage_map = {
|
||||
"master": ("thinking", "Jarvis 正在理解请求"),
|
||||
"planner": ("planning", "Jarvis 正在拆解步骤"),
|
||||
"schedule_planner": ("planning", "Jarvis 正在编排日程"),
|
||||
"executor": ("tool", "Jarvis 正在执行操作"),
|
||||
"librarian": ("tool", "Jarvis 正在检索知识"),
|
||||
"analyst": ("thinking", "Jarvis 正在分析信息"),
|
||||
}
|
||||
stage, label = stage_map[event_name]
|
||||
stage, label = stage_map.get(event_name, ("thinking", "Jarvis 正在思考"))
|
||||
yield self._build_progress_event(stage, label, agent=event_name, step=label)
|
||||
|
||||
elif kind == "on_tool_start":
|
||||
tool_input = data.get("input")
|
||||
step = None
|
||||
if isinstance(tool_input, dict) and tool_input:
|
||||
step = f"调用工具 {event_name}"
|
||||
yield self._build_progress_event("tool", f"Jarvis 正在调用工具 {event_name}", agent="executor", tool_name=event_name, step=step)
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"Jarvis 正在调用工具 {event_name}",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=f"正在执行 {event_name}",
|
||||
)
|
||||
|
||||
elif kind == "on_tool_end":
|
||||
yield self._build_progress_event("tool", f"工具 {event_name} 已完成", agent="executor", tool_name=event_name, step=f"已获得 {event_name} 结果")
|
||||
elif kind == "on_chain_end" and event_name == "planner":
|
||||
output = data.get("output") or {}
|
||||
plan_steps = output.get("plan_steps") or []
|
||||
steps = [item.get("description", "") for item in plan_steps if item.get("description")]
|
||||
yield self._build_progress_event("planning", "Jarvis 已生成处理步骤", agent="planner", step=steps[0] if steps else "正在整理计划", steps=steps[:4])
|
||||
tool_result = data.get("output")
|
||||
step = f"已完成 {event_name}"
|
||||
if isinstance(tool_result, str) and len(tool_result) > 0:
|
||||
step = tool_result[:100]
|
||||
yield self._build_progress_event(
|
||||
"tool",
|
||||
f"工具 {event_name} 已完成",
|
||||
agent="executor",
|
||||
tool_name=event_name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
elif kind == "on_chat_model_stream":
|
||||
chunk = data.get("chunk")
|
||||
content = getattr(chunk, "content", "") if chunk else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
content = _coerce_event_text(getattr(chunk, "content", "") if chunk else "")
|
||||
if content:
|
||||
collected += content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chat_model_end" and not collected:
|
||||
|
||||
elif kind == "on_chain_end":
|
||||
output = data.get("output")
|
||||
content = getattr(output, "content", "") if output else ""
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
content = "".join(text_parts)
|
||||
if content:
|
||||
collected = content
|
||||
yield {"type": "chunk", "content": content}
|
||||
elif kind == "on_chain_end" and event_name in {"executor", "librarian", "analyst"}:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在整理最终回答", agent=event_name, step="生成回复")
|
||||
final_resp = None
|
||||
if isinstance(output, dict):
|
||||
state.update(output)
|
||||
final_resp = output.get("final_response")
|
||||
if final_resp:
|
||||
final_text = str(final_resp)
|
||||
if final_text != collected:
|
||||
collected = final_text
|
||||
yield {"type": "chunk", "content": final_text}
|
||||
|
||||
elif kind == "on_chat_model_end":
|
||||
output = data.get("output")
|
||||
final_content = _coerce_event_text(getattr(output, "content", "") if output else "")
|
||||
if final_content:
|
||||
final_text = final_content
|
||||
if final_text != collected:
|
||||
collected = final_text
|
||||
yield {"type": "chunk", "content": final_text}
|
||||
|
||||
except Exception as e:
|
||||
fallback = f"抱歉,发生错误: {str(e)}"
|
||||
collected = fallback
|
||||
if _is_streaming_rejection_error(e, user_llm_config) and not collected:
|
||||
yield self._build_progress_event("responding", "Jarvis 正在生成回复", agent="master", step="fallback")
|
||||
try:
|
||||
result_state = await graph.ainvoke(state)
|
||||
if isinstance(result_state, dict):
|
||||
state.update(result_state)
|
||||
fallback_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
collected = str(fallback_content)
|
||||
yield {"type": "chunk", "content": collected}
|
||||
except Exception:
|
||||
logger.exception("llm_sync_fallback_failed")
|
||||
safe_error = "模型服务暂不可用,请稍后再试。"
|
||||
yield {"type": "error", "error": safe_error}
|
||||
collected = f"抱歉,发生错误: {safe_error}"
|
||||
yield {"type": "chunk", "content": collected}
|
||||
else:
|
||||
logger.exception("agent_streaming_failed")
|
||||
if not collected:
|
||||
safe_error = "模型服务暂不可用,请稍后再试。"
|
||||
yield {"type": "error", "error": safe_error}
|
||||
collected = f"抱歉,发生错误: {safe_error}"
|
||||
yield {"type": "chunk", "content": collected}
|
||||
else:
|
||||
yield {"type": "error", "error": str(e)}
|
||||
yield {"type": "chunk", "content": fallback}
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最终更新数据库中的消息内容
|
||||
if collected:
|
||||
try:
|
||||
result2 = await self.db.execute(
|
||||
select(Message).where(Message.id == assistant_msg.id)
|
||||
)
|
||||
msg = result2.scalar_one_or_none()
|
||||
if msg:
|
||||
msg.content = collected
|
||||
await self.db.commit()
|
||||
await brain_service.create_event(
|
||||
assistant_msg.content = collected
|
||||
continuity_snapshot = _build_continuity_snapshot(state or {})
|
||||
assistant_msg.attachments = ([{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = ({
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
} if continuity_snapshot else None)
|
||||
await BrainService(self.db).create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
source_id=conversation_id,
|
||||
event_type="message_created",
|
||||
title="Assistant message",
|
||||
content_summary=collected[:500],
|
||||
raw_excerpt=collected[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
**_build_assistant_event_payload(collected),
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("save_assistant_message_failed")
|
||||
asyncio.create_task(self._try_auto_summarize_background(user_id, conversation_id))
|
||||
|
||||
return conversation_id, assistant_msg.id, run_agent()
|
||||
|
||||
@@ -355,17 +651,25 @@ class AgentService:
|
||||
model_name: str | None = None,
|
||||
) -> tuple[str, str, str, str | None]:
|
||||
"""
|
||||
简单同步版对话(无流式)
|
||||
|
||||
Returns:
|
||||
(conversation_id, message_id, response_content, model_name_used)
|
||||
简单同步版对话
|
||||
"""
|
||||
# 获取或创建对话
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if model_name and not user_llm_config:
|
||||
raise ValueError("所选模型不可用于聊天,请切换到聊天模型")
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
if conversation_id:
|
||||
result = await self.db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv is None:
|
||||
raise ValueError("会话不存在或无权访问")
|
||||
else:
|
||||
conv = None
|
||||
|
||||
@@ -378,29 +682,17 @@ class AgentService:
|
||||
else:
|
||||
conversation_id = conv.id
|
||||
|
||||
# 如果有文件,读取内容作为上下文
|
||||
file_context = ""
|
||||
if file_ids:
|
||||
from app.services.document_service import DocumentService
|
||||
doc_svc = DocumentService(self.db)
|
||||
for file_id in file_ids:
|
||||
content = await doc_svc.get_document_content(user_id, file_id)
|
||||
if content:
|
||||
file_context += f"\n\n[用户上传文件内容]\n{content}\n[/文件内容]"
|
||||
|
||||
# 将文件上下文添加到消息
|
||||
full_message = f"{message}\n{file_context}" if file_context else message
|
||||
|
||||
# 存储用户消息
|
||||
user_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message,
|
||||
attachments=[{"file_ids": file_ids}] if file_ids else None,
|
||||
)
|
||||
user_msg = Message(conversation_id=conversation_id, role="user", content=message)
|
||||
self.db.add(user_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(user_msg)
|
||||
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content="",
|
||||
model=model_name_used or "jarvis",
|
||||
attachments=None,
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
@@ -414,68 +706,32 @@ class AgentService:
|
||||
metadata_={"role": "user"},
|
||||
importance_signal=1.0,
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 加载记忆上下文
|
||||
memory_ctx = await memory_service.build_memory_context(
|
||||
self.db, user_id, conversation_id, message
|
||||
)
|
||||
memory_ctx = await memory_service.build_memory_context(self.db, user_id, conversation_id, message)
|
||||
|
||||
# 获取用户配置的 LLM
|
||||
user_llm_config = await self._get_user_llm_config(user_id, model_name)
|
||||
model_name_used = model_name
|
||||
if user_llm_config:
|
||||
model_name_used = user_llm_config.get("name", model_name)
|
||||
|
||||
# 调用 LangGraph Agent
|
||||
set_current_user(user_id)
|
||||
graph = get_agent_graph()
|
||||
langgraph_state = {
|
||||
"messages": [HumanMessage(content=full_message)], # type: ignore[arg-type]
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent": "master",
|
||||
"active_agents": ["master"],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": memory_ctx,
|
||||
"user_llm_config": user_llm_config, # 传递用户 LLM 配置
|
||||
}
|
||||
|
||||
try:
|
||||
result_state = await graph.ainvoke(langgraph_state)
|
||||
response_content = result_state.get("final_response", "抱歉,我无法处理这个请求。")
|
||||
graph = get_agent_graph()
|
||||
current_datetime_context, current_datetime_reference = self._build_current_datetime_context()
|
||||
state = await self._build_agent_state(
|
||||
user_id=user_id,
|
||||
conversation=conv,
|
||||
full_message=message,
|
||||
memory_context=memory_ctx,
|
||||
current_datetime_context=current_datetime_context,
|
||||
current_datetime_reference=current_datetime_reference,
|
||||
user_llm_config=user_llm_config,
|
||||
)
|
||||
state.update(_derive_role_memory_contexts(memory_ctx))
|
||||
result_state = await graph.ainvoke(state)
|
||||
response_content = result_state.get("final_response") or str(result_state.get("messages", [AIMessage(content="")])[-1].content)
|
||||
except Exception as e:
|
||||
response_content = f"抱歉,发生错误: {str(e)}"
|
||||
logger.exception("agent_chat_simple_failed")
|
||||
response_content = "抱歉,发生错误。"
|
||||
finally:
|
||||
clear_current_user()
|
||||
try:
|
||||
asyncio.get_running_loop().create_task(
|
||||
self._try_auto_summarize_background(user_id, conversation_id)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 保存助手消息
|
||||
assistant_msg = Message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
model=model_name_used or "jarvis",
|
||||
)
|
||||
self.db.add(assistant_msg)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
brain_service = BrainService(self.db)
|
||||
await brain_service.create_event(
|
||||
user_id,
|
||||
source_type="conversation",
|
||||
@@ -485,8 +741,20 @@ class AgentService:
|
||||
content_summary=response_content[:500],
|
||||
raw_excerpt=response_content[:2000],
|
||||
metadata_={"role": "assistant"},
|
||||
importance_signal=1.0,
|
||||
importance_signal=0.8,
|
||||
)
|
||||
|
||||
assistant_msg.content = response_content
|
||||
continuity_snapshot = _build_continuity_snapshot(result_state) if 'result_state' in locals() else None
|
||||
assistant_msg.attachments = ([{
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
}] if continuity_snapshot else None)
|
||||
conv.agent_state = ({
|
||||
"kind": "agent_continuity_state",
|
||||
**continuity_snapshot,
|
||||
} if continuity_snapshot else None)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(assistant_msg)
|
||||
|
||||
return conversation_id, assistant_msg.id, response_content, model_name_used
|
||||
|
||||
@@ -4,7 +4,8 @@ OpenAI / Claude / Ollama / DeepSeek / 任意 OpenAI 兼容接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncIterator, Literal
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
@@ -16,8 +17,131 @@ from app.models.user import User
|
||||
import httpx
|
||||
import os
|
||||
|
||||
os.makedirs(settings.DATA_DIR, exist_ok=True)
|
||||
os.makedirs(settings.CHROMA_PERSIST_DIR, exist_ok=True)
|
||||
|
||||
|
||||
ToolStrategy = Literal["native", "json_fallback"]
|
||||
|
||||
|
||||
def _resolve_effective_base_url(config: dict | None) -> str:
|
||||
provider = str((config or {}).get("provider") or settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
base_url = str((config or {}).get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
if provider in {"openai", "custom", "deepseek"}:
|
||||
return settings.OPENAI_BASE_URL
|
||||
if provider == "ollama":
|
||||
return settings.OLLAMA_BASE_URL
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderCapabilities:
|
||||
provider: str
|
||||
supports_native_tools: bool
|
||||
preferred_tool_strategy: ToolStrategy
|
||||
|
||||
|
||||
def default_provider_capabilities() -> ProviderCapabilities:
|
||||
return resolve_provider_capabilities({"provider": settings.LLM_PROVIDER})
|
||||
|
||||
|
||||
def normalize_provider_name(config: dict | None) -> str:
|
||||
provider_raw = str((config or {}).get("provider") or "").strip().lower()
|
||||
provider = provider_raw or str(settings.LLM_PROVIDER or "openai").strip().lower()
|
||||
model = str((config or {}).get("model") or "").strip().lower()
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
|
||||
# base_url-first inference (provider may be omitted in user config)
|
||||
if base_url:
|
||||
if any(key in base_url for key in {"localhost:11434", "127.0.0.1:11434"}):
|
||||
return "ollama"
|
||||
if any(key in base_url for key in {"api.anthropic.com", "anthropic"}):
|
||||
return "claude"
|
||||
if "api.deepseek.com" in base_url:
|
||||
return "deepseek"
|
||||
|
||||
# Many "openai-compatible" endpoints are configured as provider=openai.
|
||||
# We treat them as distinct providers so capability routing can stay conservative.
|
||||
if provider in {"openai", "custom"}:
|
||||
if any(key in model or key in base_url for key in {"minimax", "abab"}):
|
||||
return "minimax"
|
||||
if any(key in model or key in base_url for key in {"kimi", "moonshot"}):
|
||||
return "kimi"
|
||||
if any(key in model or key in base_url for key in {"qwen", "dashscope", "aliyuncs"}):
|
||||
return "qwen"
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def resolve_provider_capabilities(config: dict | None) -> ProviderCapabilities:
|
||||
provider = normalize_provider_name(config)
|
||||
|
||||
# Conservative default: only treat official OpenAI + DeepSeek + Claude as reliable native tool providers.
|
||||
# Many OpenAI-compatible endpoints reject tool / response_format / other chat params.
|
||||
native_tool_providers = {"openai", "deepseek", "claude"}
|
||||
|
||||
base_url = _resolve_effective_base_url(config).strip().lower()
|
||||
is_official_openai = (
|
||||
provider != "openai"
|
||||
or not base_url
|
||||
or "api.openai.com" in base_url
|
||||
or "openai.azure.com" in base_url
|
||||
)
|
||||
|
||||
if provider in native_tool_providers and is_official_openai:
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=True,
|
||||
preferred_tool_strategy="native",
|
||||
)
|
||||
|
||||
return ProviderCapabilities(
|
||||
provider=provider,
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_from_config(config: dict | None):
|
||||
"""根据用户模型配置创建底层 LangChain LLM 实例"""
|
||||
if not config:
|
||||
return get_llm()
|
||||
|
||||
provider = normalize_provider_name(config)
|
||||
model = config.get("model", "")
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
|
||||
if provider in {"openai", "deepseek", "custom", "minimax", "kimi", "qwen"}:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "claude":
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=httpx.Timeout(120.0, connect=10.0),
|
||||
)
|
||||
else:
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=httpx.Timeout(60.0, connect=10.0),
|
||||
)
|
||||
|
||||
setattr(llm, "_jarvis_user_llm_config", config)
|
||||
setattr(llm, "_jarvis_provider_capabilities", resolve_provider_capabilities(config))
|
||||
return llm
|
||||
|
||||
|
||||
class LLMService(ABC):
|
||||
@@ -145,4 +269,7 @@ def get_llm() -> LLMService:
|
||||
_llm_instance = OllamaService()
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
setattr(_llm_instance, "_jarvis_provider_capabilities", default_provider_capabilities())
|
||||
return _llm_instance
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,160 @@
|
||||
"""
|
||||
Jarvis 记忆系统
|
||||
Jarvis 记忆系统 (基于 Mem0)
|
||||
三层记忆: 短期(对话历史) → 中期(摘要) → 长期(用户画像)
|
||||
底层使用 Mem0 实现事实提取、时间线、矛盾解决和遗忘机制
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.memory import UserMemory
|
||||
from app.models.user import User
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.llm_service import get_llm
|
||||
from app.agents.context import get_current_user
|
||||
from app.config import settings as _settings
|
||||
|
||||
try:
|
||||
from mem0 import Memory
|
||||
|
||||
MEM0_AVAILABLE = True
|
||||
except ImportError:
|
||||
MEM0_AVAILABLE = False
|
||||
Memory = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_user_embedding_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 embedding 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
embedding_models = user.llm_config.get("embedding", [])
|
||||
for model in embedding_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.EMBEDDING_BASE_URL,
|
||||
"api_key": model.get("api_key")
|
||||
or _settings.EMBEDDING_API_KEY
|
||||
or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def _get_user_chat_config(db: AsyncSession, user_id: str) -> dict | None:
|
||||
"""从用户配置中获取 chat 模型配置"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.llm_config:
|
||||
return None
|
||||
|
||||
chat_models = user.llm_config.get("chat", [])
|
||||
for model in chat_models:
|
||||
if model.get("enabled") and model.get("model"):
|
||||
return {
|
||||
"model": model.get("model"),
|
||||
"base_url": model.get("base_url") or _settings.OPENAI_BASE_URL,
|
||||
"api_key": model.get("api_key") or _settings.OPENAI_API_KEY,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 客户端 - 按用户隔离"""
|
||||
|
||||
_instances: dict[str, Memory] = {}
|
||||
_persist_dir: str = "./data/mem0"
|
||||
|
||||
async def get_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
cache_key = user_id
|
||||
|
||||
if cache_key not in self._instances:
|
||||
self._instances[cache_key] = await self._init_memory(db, user_id)
|
||||
|
||||
return self._instances[cache_key]
|
||||
|
||||
async def _init_memory(self, db: AsyncSession, user_id: str) -> Memory:
|
||||
if not MEM0_AVAILABLE:
|
||||
raise RuntimeError("mem0ai 未安装,请运行: pip install mem0ai")
|
||||
|
||||
os.makedirs(self._persist_dir, exist_ok=True)
|
||||
|
||||
llm_config = {
|
||||
"model": _settings.OPENAI_MODEL,
|
||||
"base_url": _settings.OPENAI_BASE_URL,
|
||||
"api_key": _settings.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
embed_config = _settings.EMBEDDING_MODEL
|
||||
embed_base_url = _settings.EMBEDDING_BASE_URL
|
||||
embed_api_key = _settings.EMBEDDING_API_KEY or _settings.OPENAI_API_KEY
|
||||
|
||||
if db and user_id:
|
||||
try:
|
||||
user_chat = await _get_user_chat_config(db, user_id)
|
||||
if user_chat:
|
||||
llm_config = user_chat
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
user_embed = await _get_user_embedding_config(db, user_id)
|
||||
if user_embed:
|
||||
embed_config = user_embed["model"]
|
||||
embed_base_url = user_embed["base_url"]
|
||||
embed_api_key = user_embed["api_key"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": f"jarvis_memory_{user_id}",
|
||||
"path": self._persist_dir,
|
||||
},
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": llm_config["model"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"base_url": llm_config["base_url"],
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": embed_config,
|
||||
"api_key": embed_api_key,
|
||||
"base_url": embed_base_url,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return Memory.from_config(config)
|
||||
|
||||
|
||||
_mem0_client = Mem0Client()
|
||||
|
||||
|
||||
async def get_mem0(db: AsyncSession, user_id: str) -> Memory:
|
||||
"""获取指定用户的 Mem0 实例"""
|
||||
return await _mem0_client.get_memory(db, user_id)
|
||||
|
||||
|
||||
# ———— 短期记忆: 对话历史 ————
|
||||
|
||||
|
||||
async def load_conversation_history(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
@@ -36,8 +173,7 @@ async def load_conversation_history(
|
||||
async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) -> int:
|
||||
"""获取对话轮数(用户消息数)"""
|
||||
result = await db.execute(
|
||||
select(func.count(Message.id))
|
||||
.where(
|
||||
select(func.count(Message.id)).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
@@ -47,14 +183,15 @@ async def get_conversation_turn_count(db: AsyncSession, conversation_id: str) ->
|
||||
|
||||
# ———— 中期记忆: 对话摘要 ————
|
||||
|
||||
SUMMARIZE_THRESHOLD = 8 # 超过此轮数则摘要
|
||||
MAX_HISTORY_TURNS = 10 # Agent 最多看到的对话历史轮数
|
||||
SUMMARIZE_THRESHOLD = 8
|
||||
MAX_HISTORY_TURNS = 10
|
||||
|
||||
|
||||
async def should_summarize(db: AsyncSession, conversation_id: str) -> bool:
|
||||
"""判断当前对话是否需要摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
# 检查是否已有摘要覆盖到当前轮数
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -72,17 +209,21 @@ async def generate_summary(
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> str:
|
||||
"""调用 LLM 生成对话摘要"""
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages
|
||||
)
|
||||
llm = get_llm()
|
||||
"""生成对话摘要"""
|
||||
from app.services.llm_service import get_llm
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"),
|
||||
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages)
|
||||
llm = get_llm()
|
||||
response = await llm.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="你是一个记忆助手。请用简洁的中文总结以下对话的核心内容,"
|
||||
"提取关键信息、用户偏好、待办事项等。不超过150字。"
|
||||
),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
]
|
||||
)
|
||||
return response.content.strip()
|
||||
|
||||
|
||||
@@ -92,8 +233,10 @@ async def save_summary(
|
||||
conversation_id: str,
|
||||
summary_text: str,
|
||||
turn_count: int,
|
||||
) -> MemorySummary:
|
||||
"""保存对话摘要"""
|
||||
) -> Any:
|
||||
"""保存对话摘要到数据库"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
summary = MemorySummary(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -109,8 +252,10 @@ async def save_summary(
|
||||
async def get_summaries(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
) -> list[MemorySummary]:
|
||||
) -> list[Any]:
|
||||
"""获取某对话的所有历史摘要"""
|
||||
from app.models.memory import MemorySummary
|
||||
|
||||
result = await db.execute(
|
||||
select(MemorySummary)
|
||||
.where(MemorySummary.conversation_id == conversation_id)
|
||||
@@ -119,31 +264,7 @@ async def get_summaries(
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ———— 长期记忆: 用户画像 ————
|
||||
|
||||
EXTRACTION_PROMPT = """从以下对话中提取关于用户的关键信息。
|
||||
只提取事实性的、可能对未来对话有帮助的信息,如:
|
||||
- 用户的身份/职业/背景
|
||||
- 用户的偏好和习惯
|
||||
- 用户的目标和计划
|
||||
- 重要的事件和日期
|
||||
- 用户的观点和态度
|
||||
|
||||
每条记忆格式: [类型] 内容
|
||||
类型: fact(事实) | preference(偏好) | goal(目标) | habit(习惯)
|
||||
|
||||
如果没有提取到任何记忆,回复"无"。
|
||||
"""
|
||||
|
||||
FACT_TYPES = {"fact", "preference", "goal", "habit"}
|
||||
|
||||
|
||||
def _parse_fact_line(line: str) -> tuple[str, str] | None:
|
||||
"""解析一行记忆: [fact] 内容 -> (type, content)"""
|
||||
m = re.match(r"\[(\w+)\]\s*(.+)", line.strip())
|
||||
if m and m.group(1) in FACT_TYPES:
|
||||
return m.group(1), m.group(2).strip()
|
||||
return None
|
||||
# ———— 长期记忆: 基于 Mem0 ————
|
||||
|
||||
|
||||
async def extract_user_memories(
|
||||
@@ -151,55 +272,51 @@ async def extract_user_memories(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
messages: list[Message],
|
||||
) -> list[UserMemory]:
|
||||
"""从对话中提取用户记忆并保存"""
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从对话中提取用户记忆并存储到 Mem0。
|
||||
Mem0 会自动处理:
|
||||
- 事实提取
|
||||
- 时间线追踪
|
||||
- 矛盾解决
|
||||
- 遗忘机制
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return []
|
||||
|
||||
history_text = "\n".join(
|
||||
f"[{m.role}] {m.content}" for m in messages[-10:]
|
||||
history_text = "\n".join(f"[{m.role}] {m.content}" for m in messages[-10:])
|
||||
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.add(
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages[-10:]],
|
||||
user_id=user_id,
|
||||
metadata={
|
||||
"conversation_id": conversation_id,
|
||||
"source": "jarvis_memory",
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
response = await llm.invoke([
|
||||
SystemMessage(content=EXTRACTION_PROMPT),
|
||||
HumanMessage(content=history_text),
|
||||
])
|
||||
|
||||
text = response.content.strip()
|
||||
if text == "无" or not text:
|
||||
return result.get("results", [])
|
||||
except Exception as e:
|
||||
print(f"Mem0 extract error: {e}")
|
||||
return []
|
||||
|
||||
memories = []
|
||||
for line in text.split("\n"):
|
||||
parsed = _parse_fact_line(line)
|
||||
if not parsed:
|
||||
continue
|
||||
mem_type, content = parsed
|
||||
# 检查是否已有完全相同的记忆
|
||||
existing = await db.execute(
|
||||
select(UserMemory).where(
|
||||
UserMemory.user_id == user_id,
|
||||
UserMemory.content == content,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
mem = UserMemory(
|
||||
user_id=user_id,
|
||||
memory_type=mem_type,
|
||||
content=content,
|
||||
importance=5,
|
||||
source_conversation_id=conversation_id,
|
||||
)
|
||||
db.add(mem)
|
||||
memories.append(mem)
|
||||
def _extract_memory_query_tokens(query: str) -> list[str]:
|
||||
normalized_query = (query or "").lower()
|
||||
tokens = [token for token in re.findall(r"[a-z0-9]+", normalized_query) if len(token) >= 3]
|
||||
|
||||
if memories:
|
||||
await db.commit()
|
||||
return memories
|
||||
for chunk in re.findall(r"[\u4e00-\u9fff]+", query or ""):
|
||||
stripped_chunk = chunk.strip()
|
||||
if len(stripped_chunk) >= 4:
|
||||
tokens.append(stripped_chunk)
|
||||
if len(stripped_chunk) > 6:
|
||||
tokens.extend(
|
||||
stripped_chunk[index:index + 4]
|
||||
for index in range(len(stripped_chunk) - 3)
|
||||
)
|
||||
|
||||
return list(dict.fromkeys(tokens))
|
||||
|
||||
|
||||
async def recall_user_memories(
|
||||
@@ -207,41 +324,216 @@ async def recall_user_memories(
|
||||
user_id: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[UserMemory]:
|
||||
"""根据当前输入召回相关的用户记忆(简单关键词匹配)"""
|
||||
# 先尝试语义相似(通过 LLM 判断)
|
||||
# 降级: 直接从数据库取最近的重要记忆
|
||||
result = await db.execute(
|
||||
select(UserMemory)
|
||||
.where(UserMemory.user_id == user_id)
|
||||
.order_by(desc(UserMemory.importance), desc(UserMemory.recall_count))
|
||||
.limit(top_k)
|
||||
) -> list[dict]:
|
||||
"""
|
||||
根据当前输入召回相关的用户记忆。
|
||||
使用 Mem0 的语义搜索;如果 Mem0 不可用或失败,则回退到本地 UserMemory。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
results = mem0.search(
|
||||
query=query,
|
||||
filters={"user_id": user_id},
|
||||
limit=top_k,
|
||||
)
|
||||
memories = list(result.scalars().all())
|
||||
mem0_results = results.get("results", [])
|
||||
if mem0_results:
|
||||
return mem0_results
|
||||
except Exception as e:
|
||||
print(f"Mem0 search error: {e}")
|
||||
|
||||
# 重置召回标记
|
||||
for m in memories:
|
||||
m.is_recalled = False
|
||||
query_tokens = _extract_memory_query_tokens(query)
|
||||
statement = select(UserMemory).where(UserMemory.user_id == user_id)
|
||||
result = await db.execute(statement.order_by(UserMemory.importance.desc(), UserMemory.created_at.desc()))
|
||||
fallback_memories = list(result.scalars().all())
|
||||
|
||||
if _contains_hint(_normalize_query(query), MEMORY_QUERY_HINTS) or _matches_memory_query_pattern(_normalize_query(query)):
|
||||
return fallback_memories[:top_k]
|
||||
|
||||
if query_tokens:
|
||||
matched_memories = [
|
||||
memory for memory in fallback_memories
|
||||
if any(token in (memory.content or '').lower() for token in query_tokens)
|
||||
]
|
||||
return matched_memories[:top_k]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def _mark_memories_recalled(db: AsyncSession, memories: list[UserMemory]) -> None:
|
||||
recalled_at = datetime.now(UTC)
|
||||
updated = False
|
||||
for memory in memories:
|
||||
memory.is_recalled = True
|
||||
memory.recall_count = (memory.recall_count or 0) + 1
|
||||
memory.last_recalled_at = recalled_at
|
||||
updated = True
|
||||
if updated:
|
||||
await db.commit()
|
||||
|
||||
return memories
|
||||
|
||||
|
||||
async def mark_memory_recalled(db: AsyncSession, memory_id: str):
|
||||
"""标记记忆已被召回使用"""
|
||||
result = await db.execute(
|
||||
select(UserMemory).where(UserMemory.id == memory_id)
|
||||
async def _run_tolerated_section(
|
||||
db: AsyncSession,
|
||||
section_name: str,
|
||||
builder,
|
||||
) -> str:
|
||||
try:
|
||||
return await builder()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[MemoryService] %s失败,继续构建剩余上下文",
|
||||
section_name,
|
||||
exc_info=True,
|
||||
)
|
||||
mem = result.scalar_one_or_none()
|
||||
if mem:
|
||||
mem.is_recalled = True
|
||||
mem.recall_count = (mem.recall_count or 0) + 1
|
||||
mem.last_recalled_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
return ""
|
||||
|
||||
|
||||
async def get_user_profile(db: AsyncSession, user_id: str) -> dict:
|
||||
"""
|
||||
获取用户画像。
|
||||
Mem0 的 profile API 会返回 static 和 dynamic facts。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
result = mem0.history(user_id=user_id)
|
||||
return {
|
||||
"memories": result.get("results", []),
|
||||
"static": [],
|
||||
"dynamic": [],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Mem0 profile error: {e}")
|
||||
return {"memories": [], "static": [], "dynamic": []}
|
||||
|
||||
|
||||
# ———— 记忆组装: 供 Agent 使用的上下文 ————
|
||||
|
||||
MEMORY_QUERY_HINTS = (
|
||||
"记住",
|
||||
"记下",
|
||||
"记一下",
|
||||
"记着",
|
||||
"提醒",
|
||||
"偏好",
|
||||
"习惯",
|
||||
)
|
||||
MEMORY_QUERY_PATTERNS = (
|
||||
re.compile(r"\bremember\s+(?:that\s+)?i\b"),
|
||||
)
|
||||
GROUNDING_QUERY_HINTS = (
|
||||
"根据文档",
|
||||
"严格根据",
|
||||
"只根据",
|
||||
"文档内容",
|
||||
"grounded",
|
||||
"strictly based on",
|
||||
"based on the document",
|
||||
"based on the docs",
|
||||
"document only",
|
||||
"docs only",
|
||||
"only use the document",
|
||||
"only use the docs",
|
||||
)
|
||||
AVOID_USER_MEMORY_HINTS = (
|
||||
"不要结合我的个人偏好",
|
||||
"不要结合个人偏好",
|
||||
"不要结合偏好",
|
||||
"不要结合我的记忆",
|
||||
"不要结合记忆",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_query(text: str) -> str:
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def _contains_hint(text: str, hints: tuple[str, ...]) -> bool:
|
||||
return any(hint in text for hint in hints)
|
||||
|
||||
|
||||
def _matches_memory_query_pattern(text: str) -> bool:
|
||||
return any(pattern.search(text) for pattern in MEMORY_QUERY_PATTERNS)
|
||||
|
||||
|
||||
def _should_include_user_memories(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, AVOID_USER_MEMORY_HINTS):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _should_include_summaries(query: str) -> bool:
|
||||
normalized_query = _normalize_query(query)
|
||||
if _contains_hint(normalized_query, GROUNDING_QUERY_HINTS):
|
||||
return False
|
||||
if _contains_hint(normalized_query, MEMORY_QUERY_HINTS):
|
||||
return False
|
||||
if _matches_memory_query_pattern(normalized_query):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _build_user_memory_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
recalled_user_memories: list[UserMemory] = []
|
||||
for memory in memories:
|
||||
if isinstance(memory, UserMemory):
|
||||
memory_text = memory.content
|
||||
memory_type = memory.memory_type
|
||||
recalled_user_memories.append(memory)
|
||||
else:
|
||||
memory_text = memory.get("memory", memory.get("text", ""))
|
||||
memory_type = memory.get("memory_type")
|
||||
|
||||
if not memory_text:
|
||||
continue
|
||||
|
||||
if memory_type:
|
||||
lines.append(f" [{memory_type}] {memory_text}")
|
||||
else:
|
||||
lines.append(f" - {memory_text}")
|
||||
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
if recalled_user_memories:
|
||||
await _mark_memories_recalled(db, recalled_user_memories)
|
||||
return "【用户记忆】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_summary_section(db: AsyncSession, conversation_id: str) -> str:
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i + 1}] {summary.summary_text}" for i, summary in enumerate(recent)]
|
||||
return "【之前对话摘要】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _build_brain_section(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
current_query: str,
|
||||
) -> str:
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if not brain_memories:
|
||||
return ""
|
||||
|
||||
lines = [f"- {memory.title}: {memory.content}" for memory in brain_memories]
|
||||
return "【知识大脑】\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def build_memory_context(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
@@ -252,33 +544,33 @@ async def build_memory_context(
|
||||
构建完整的记忆上下文字符串,
|
||||
供注入到 Agent system prompt 中使用。
|
||||
"""
|
||||
parts = []
|
||||
parts: list[str] = []
|
||||
|
||||
# 1. 用户画像(长期记忆)
|
||||
user_memories = await recall_user_memories(db, user_id, current_query, top_k=5)
|
||||
if user_memories:
|
||||
lines = []
|
||||
for m in user_memories:
|
||||
tag = f"[{m.memory_type}]"
|
||||
lines.append(f" {tag} {m.content}")
|
||||
await mark_memory_recalled(db, m.id)
|
||||
parts.append("【用户记忆】\n" + "\n".join(lines))
|
||||
if _should_include_user_memories(current_query):
|
||||
user_memory_section = await _run_tolerated_section(
|
||||
db,
|
||||
"用户记忆召回",
|
||||
lambda: _build_user_memory_section(db, user_id, current_query),
|
||||
)
|
||||
if user_memory_section:
|
||||
parts.append(user_memory_section)
|
||||
|
||||
# 2. 对话摘要(中期记忆)
|
||||
summaries = await get_summaries(db, conversation_id)
|
||||
if summaries:
|
||||
# 只取最近2条
|
||||
recent = summaries[-2:]
|
||||
lines = [f"[对话摘要{i+1}] {s.summary_text}" for i, s in enumerate(recent)]
|
||||
parts.append("【之前对话摘要】\n" + "\n".join(lines))
|
||||
if _should_include_summaries(current_query):
|
||||
summary_section = await _run_tolerated_section(
|
||||
db,
|
||||
"对话摘要加载",
|
||||
lambda: _build_summary_section(db, conversation_id),
|
||||
)
|
||||
if summary_section:
|
||||
parts.append(summary_section)
|
||||
|
||||
# 3. 知识大脑(长期项目记忆)
|
||||
brain_memories = await BrainService(db).recall_memories(user_id, current_query, top_k=3)
|
||||
if brain_memories:
|
||||
lines = []
|
||||
for memory in brain_memories:
|
||||
lines.append(f"- {memory.title}: {memory.content}")
|
||||
parts.append("【知识大脑】\n" + "\n".join(lines))
|
||||
brain_section = await _run_tolerated_section(
|
||||
db,
|
||||
"知识大脑召回",
|
||||
lambda: _build_brain_section(db, user_id, current_query),
|
||||
)
|
||||
if brain_section:
|
||||
parts.append(brain_section)
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
@@ -292,7 +584,7 @@ async def try_auto_summarize(
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否需要摘要,如果需要则生成并保存。
|
||||
返回是否执行了摘要。
|
||||
同时将对话内容存入 Mem0 进行记忆提取。
|
||||
"""
|
||||
if not await should_summarize(db, conversation_id):
|
||||
return False
|
||||
@@ -306,8 +598,39 @@ async def try_auto_summarize(
|
||||
turn_count = await get_conversation_turn_count(db, conversation_id)
|
||||
await save_summary(db, user_id, conversation_id, summary_text, turn_count)
|
||||
|
||||
# 同时提取用户记忆
|
||||
await extract_user_memories(db, user_id, conversation_id, messages)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Auto summarize error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def forget_memory(db: AsyncSession, user_id: str, memory_id: str) -> bool:
|
||||
"""
|
||||
主动遗忘某条记忆。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.delete(memory_id, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 delete error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def update_memory(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
memory_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""
|
||||
更新某条记忆。Mem0 会自动处理矛盾检测。
|
||||
"""
|
||||
try:
|
||||
mem0 = await get_mem0(db, user_id)
|
||||
mem0.update(memory_id, content, user_id=user_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Mem0 update error: {e}")
|
||||
return False
|
||||
|
||||
@@ -99,46 +99,55 @@ async def update_scheduler_config(user_id: str, config: dict, db: AsyncSession)
|
||||
|
||||
|
||||
async def test_llm_connection(
|
||||
provider: str,
|
||||
provider: str | None,
|
||||
model: str,
|
||||
base_url: str,
|
||||
api_key: str
|
||||
api_key: str,
|
||||
) -> dict:
|
||||
"""测试 LLM 连接"""
|
||||
try:
|
||||
# base_url-first: provider 可省略
|
||||
from app.services.llm_service import normalize_provider_name
|
||||
|
||||
effective_provider = normalize_provider_name({
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"base_url": base_url,
|
||||
})
|
||||
|
||||
# 根据不同 provider 创建临时 LLM 实例并测试
|
||||
if provider == "openai":
|
||||
if effective_provider in {"openai", "custom", "minimax", "kimi", "qwen"}:
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or None,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "claude":
|
||||
elif effective_provider == "claude":
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
llm = ChatAnthropic(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
elif effective_provider == "ollama":
|
||||
from langchain_ollama import ChatOllama
|
||||
llm = ChatOllama(
|
||||
base_url=base_url or "http://localhost:11434",
|
||||
model=model,
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
elif effective_provider == "deepseek":
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=base_url or "https://api.deepseek.com/v1",
|
||||
timeout=30
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"不支持的 provider: {provider}"}
|
||||
return {"success": False, "error": f"不支持的 endpoint/provider: {effective_provider}"}
|
||||
|
||||
# 简单测试调用
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -50,27 +50,21 @@ class SkillService:
|
||||
"""
|
||||
列出用户可访问的技能:自己的 + 市场的 + 团队的
|
||||
"""
|
||||
# 查询条件:自己的 或者 市场公开的 或者 团队的
|
||||
conditions = [
|
||||
access_scope = or_(
|
||||
Skill.owner_id == user_id,
|
||||
Skill.visibility == "market",
|
||||
Skill.team_id == user_id,
|
||||
]
|
||||
)
|
||||
|
||||
filters = [access_scope, Skill.is_active == True]
|
||||
|
||||
# 如果提供了 agent_type 过滤
|
||||
if agent_type:
|
||||
conditions.append(Skill.agent_type == agent_type)
|
||||
filters.append(Skill.agent_type == agent_type)
|
||||
|
||||
# 如果提供了 visibility 过滤
|
||||
if visibility:
|
||||
conditions.append(Skill.visibility == visibility)
|
||||
filters.append(Skill.visibility == visibility)
|
||||
|
||||
query = select(Skill).where(
|
||||
and_(
|
||||
or_(*conditions),
|
||||
Skill.is_active == True
|
||||
)
|
||||
)
|
||||
query = select(Skill).where(and_(*filters))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from datetime import datetime, UTC
|
||||
from time import monotonic
|
||||
import platform
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import psutil
|
||||
@@ -7,21 +11,119 @@ except ModuleNotFoundError: # pragma: no cover - optional runtime dependency fa
|
||||
|
||||
|
||||
class SystemService:
|
||||
_last_net_bytes_sent: int | None = None
|
||||
_last_net_bytes_recv: int | None = None
|
||||
_last_net_sample_at: float | None = None
|
||||
|
||||
def _get_network_rates(self) -> tuple[float, float]:
|
||||
counters = psutil.net_io_counters()
|
||||
now = monotonic()
|
||||
|
||||
if (
|
||||
self.__class__._last_net_sample_at is None
|
||||
or self.__class__._last_net_bytes_sent is None
|
||||
or self.__class__._last_net_bytes_recv is None
|
||||
):
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
return 0.0, 0.0
|
||||
|
||||
elapsed = max(now - self.__class__._last_net_sample_at, 1e-6)
|
||||
upload_bps = max(counters.bytes_sent - self.__class__._last_net_bytes_sent, 0) / elapsed
|
||||
download_bps = max(counters.bytes_recv - self.__class__._last_net_bytes_recv, 0) / elapsed
|
||||
|
||||
self.__class__._last_net_bytes_sent = counters.bytes_sent
|
||||
self.__class__._last_net_bytes_recv = counters.bytes_recv
|
||||
self.__class__._last_net_sample_at = now
|
||||
|
||||
return round(upload_bps, 1), round(download_bps, 1)
|
||||
|
||||
def _get_gpu_status(self) -> dict:
|
||||
empty = {
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
}
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
'nvidia-smi',
|
||||
'--query-gpu=name,memory.total,memory.used,utilization.gpu',
|
||||
'--format=csv,noheader,nounits',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
timeout=2,
|
||||
check=False,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.SubprocessError, OSError):
|
||||
return empty
|
||||
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
return empty
|
||||
|
||||
first_line = result.stdout.strip().splitlines()[0]
|
||||
parts = [part.strip() for part in first_line.split(',')]
|
||||
if len(parts) < 4:
|
||||
return empty
|
||||
|
||||
def parse_number(value: str) -> float | None:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return {
|
||||
'gpu_name': parts[0] or None,
|
||||
'gpu_memory_total_mb': parse_number(parts[1]),
|
||||
'gpu_memory_used_mb': parse_number(parts[2]),
|
||||
'gpu_util_percent': parse_number(parts[3]),
|
||||
}
|
||||
|
||||
def get_status(self) -> dict:
|
||||
if psutil is None:
|
||||
return {
|
||||
'cpu_percent': 0.0,
|
||||
'memory_percent': 0.0,
|
||||
'disk_percent': 0.0,
|
||||
'disk_used_gb': 0.0,
|
||||
'disk_total_gb': 0.0,
|
||||
'network_upload_bps': 0.0,
|
||||
'network_download_bps': 0.0,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': 0.0,
|
||||
'gpu_name': None,
|
||||
'gpu_memory_total_mb': None,
|
||||
'gpu_memory_used_mb': None,
|
||||
'gpu_util_percent': None,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
upload_bps, download_bps = self._get_network_rates()
|
||||
gpu_status = self._get_gpu_status()
|
||||
boot_time = psutil.boot_time()
|
||||
now_ts = datetime.now(UTC).timestamp()
|
||||
return {
|
||||
'cpu_percent': round(cpu_percent, 1),
|
||||
'memory_percent': round(memory.percent, 1),
|
||||
'disk_percent': round(disk.percent, 1),
|
||||
'disk_used_gb': round(disk.used / (1024 ** 3), 1),
|
||||
'disk_total_gb': round(disk.total / (1024 ** 3), 1),
|
||||
'network_upload_bps': upload_bps,
|
||||
'network_download_bps': download_bps,
|
||||
'system_name': platform.system(),
|
||||
'system_version': platform.version(),
|
||||
'hostname': socket.gethostname(),
|
||||
'uptime_seconds': round(max(now_ts - boot_time, 0.0), 1),
|
||||
**gpu_status,
|
||||
'timestamp': datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
124
backend/app/services/web_search_service.py
Normal file
124
backend/app/services/web_search_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
source: str | None = None
|
||||
published_at: str | None = None
|
||||
|
||||
|
||||
class WebSearchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchConfigurationError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchRequestError(WebSearchError):
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool | None = None,
|
||||
provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_limit: int | None = None,
|
||||
timeout_seconds: int | None = None,
|
||||
auth_type: Literal['none', 'bearer', 'basic'] | str | None = None,
|
||||
auth_token: str | None = None,
|
||||
basic_user: str | None = None,
|
||||
basic_password: str | None = None,
|
||||
):
|
||||
self.enabled = settings.WEB_SEARCH_ENABLED if enabled is None else enabled
|
||||
self.provider = (provider or settings.WEB_SEARCH_PROVIDER).strip().lower()
|
||||
self.base_url = (base_url or settings.SEARXNG_BASE_URL).strip().rstrip('/')
|
||||
self.default_limit = max(1, min(default_limit or settings.WEB_SEARCH_DEFAULT_LIMIT, 10))
|
||||
self.timeout_seconds = max(1, timeout_seconds or settings.WEB_SEARCH_TIMEOUT_SECONDS)
|
||||
self.auth_type = str(auth_type or settings.SEARXNG_AUTH_TYPE or 'none').strip().lower()
|
||||
self.auth_token = auth_token if auth_token is not None else settings.SEARXNG_AUTH_TOKEN
|
||||
self.basic_user = basic_user if basic_user is not None else settings.SEARXNG_BASIC_USER
|
||||
self.basic_password = basic_password if basic_password is not None else settings.SEARXNG_BASIC_PASSWORD
|
||||
|
||||
async def search(self, query: str, limit: int | None = None) -> list[WebSearchResult]:
|
||||
normalized_query = (query or '').strip()
|
||||
if not self.enabled or not self.base_url:
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
if self.provider != 'searxng':
|
||||
raise WebSearchConfigurationError(f'不支持的网页搜索 provider: {self.provider}')
|
||||
if not normalized_query:
|
||||
raise WebSearchRequestError('搜索关键词不能为空')
|
||||
|
||||
parsed = urlparse(self.base_url)
|
||||
if parsed.scheme not in {'http', 'https'} or not parsed.netloc:
|
||||
raise WebSearchConfigurationError('SEARXNG_BASE_URL 配置无效')
|
||||
|
||||
params = {
|
||||
'q': normalized_query,
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
}
|
||||
headers = self._build_headers()
|
||||
timeout = httpx.Timeout(float(self.timeout_seconds), connect=min(float(self.timeout_seconds), 5.0))
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f'{self.base_url}/search', params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except httpx.HTTPError as exc:
|
||||
raise WebSearchRequestError('SearxNG 请求失败') from exc
|
||||
except ValueError as exc:
|
||||
raise WebSearchRequestError('SearxNG 返回了无效 JSON') from exc
|
||||
|
||||
raw_results = payload.get('results') if isinstance(payload, dict) else None
|
||||
if not isinstance(raw_results, list):
|
||||
return []
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
target_limit = max(1, min(limit or self.default_limit, 10))
|
||||
for item in raw_results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
title = str(item.get('title') or '').strip()
|
||||
url = str(item.get('url') or '').strip()
|
||||
snippet = str(item.get('content') or item.get('snippet') or '').strip()
|
||||
if not title or not url:
|
||||
continue
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet,
|
||||
source=str(item.get('engine') or item.get('source') or '').strip() or None,
|
||||
published_at=str(item.get('publishedDate') or item.get('published_at') or '').strip() or None,
|
||||
)
|
||||
)
|
||||
if len(results) >= target_limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
if self.auth_type == 'bearer' and self.auth_token:
|
||||
return {'Authorization': f'Bearer {self.auth_token}'}
|
||||
if self.auth_type == 'basic' and self.basic_user and self.basic_password:
|
||||
credentials = httpx.BasicAuth(self.basic_user, self.basic_password)
|
||||
request = httpx.Request('GET', self.base_url)
|
||||
credentials.auth_flow(request)
|
||||
return dict(request.headers)
|
||||
return {}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1 +0,0 @@
|
||||
bad
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -1 +0,0 @@
|
||||
%PDF-1.4 bad
|
||||
@@ -27,6 +27,9 @@ dependencies = [
|
||||
"llama-index-vector-stores-chroma>=0.3.0",
|
||||
"chromadb>=0.5.0",
|
||||
|
||||
# Memory
|
||||
"mem0ai>=1.0.0",
|
||||
|
||||
# 数据库
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
317
backend/tests/backend/app/agents/test_graph_system_messages.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
sys.modules.setdefault("trafilatura", Mock())
|
||||
|
||||
from app.agents.graph import _build_system_messages, _run_sub_commander
|
||||
from app.agents.state import AgentRole
|
||||
|
||||
|
||||
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
||||
return {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"user_id": "u1",
|
||||
"conversation_id": "c1",
|
||||
"current_agent": AgentRole.MASTER,
|
||||
"active_agents": [AgentRole.MASTER],
|
||||
"current_sub_commander": None,
|
||||
"active_sub_commanders": [],
|
||||
"sub_commander_trace": [],
|
||||
"pending_tasks": [],
|
||||
"completed_tasks": [],
|
||||
"tool_calls": [],
|
||||
"last_tool_result": None,
|
||||
"action_results": [],
|
||||
"created_entities": [],
|
||||
"tool_strategy_used": None,
|
||||
"provider_capabilities": None,
|
||||
"fallback_parse_error": None,
|
||||
"knowledge_context": None,
|
||||
"graph_context": None,
|
||||
"schedule_context_summary": None,
|
||||
"plan": None,
|
||||
"plan_steps": [],
|
||||
"analysis_report": None,
|
||||
"final_response": None,
|
||||
"should_respond": True,
|
||||
"memory_context": "memory context",
|
||||
"current_datetime_context": "CURRENT_TIME: 2026-03-28T12:00:00+08:00",
|
||||
"current_datetime_reference": {
|
||||
"current_time_iso": "2026-03-28T12:00:00+08:00",
|
||||
"current_date_iso": "2026-03-28",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
"user_llm_config": user_llm_config,
|
||||
}
|
||||
|
||||
|
||||
class FakeTool:
|
||||
def __init__(self, name: str, result: str):
|
||||
self.name = name
|
||||
self.result = result
|
||||
self.invocations: list[dict] = []
|
||||
|
||||
def invoke(self, args: dict):
|
||||
self.invocations.append(args)
|
||||
return self.result
|
||||
|
||||
|
||||
class SingleSystemMessageLLM:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
self.system_message_counts: list[int] = []
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(
|
||||
provider="minimax",
|
||||
supports_native_tools=False,
|
||||
preferred_tool_strategy="json_fallback",
|
||||
)
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
self.system_message_counts.append(
|
||||
sum(1 for message in messages if getattr(message, "type", None) == "system")
|
||||
)
|
||||
if self.system_message_counts[-1] != 1:
|
||||
raise AssertionError(
|
||||
f"expected exactly one system message, got {self.system_message_counts[-1]}"
|
||||
)
|
||||
if self.calls == 1:
|
||||
return AIMessage(
|
||||
content=(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder",'
|
||||
'"arguments":{"title":"blanket","reminder_at":"\\u660e\\u5929 09:00"}}]}'
|
||||
)
|
||||
)
|
||||
return AIMessage(content="created reminder for blanket")
|
||||
|
||||
|
||||
def test_build_system_messages_includes_structured_continuity_summary():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" in system_text
|
||||
assert "schedule_creation" in system_text
|
||||
assert "continue_pending_action" in system_text
|
||||
assert "为周报安排明天下午提醒" in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_pending_action_is_not_pending():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "completed",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_routing_reason_is_not_continuation():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "initial_schedule_detection",
|
||||
}
|
||||
state["continuity_state"] = {"status": "fresh"}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_structured_continuity_when_routing_decision_missing():
|
||||
state = _base_state("创建")
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
}
|
||||
state["routing_decision"] = None
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" not in system_text
|
||||
assert "schedule_creation" not in system_text
|
||||
assert "为周报安排明天下午提醒" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_skips_stale_structured_continuity_for_unrelated_new_request():
|
||||
state = _base_state(
|
||||
"帮我搜索 Rust 异步 trait 最佳实践",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {
|
||||
"status": "stale",
|
||||
"override_reason": "new_explicit_request",
|
||||
}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "structured_continuity" not in system_text
|
||||
assert "pending_action" not in system_text
|
||||
assert "continue_pending_action" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_uses_role_scoped_context_instead_of_raw_memory_blob():
|
||||
state = _base_state("帮我搜索 Rust 异步 trait 最佳实践")
|
||||
state["memory_context"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
state["schedule_context_summary"] = "【用户记忆】\n- 用户喜欢燕麦拿铁。\n\n【之前对话摘要】\n[对话摘要1] 之前聊过提醒。"
|
||||
state["knowledge_context"] = "【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
state["analysis_report"] = "【之前对话摘要】\n[对话摘要1] 之前聊过提醒。\n\n【知识大脑】\n- Rust Async: trait object 需要 pin。"
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.LIBRARIAN,
|
||||
"librarian_retrieval",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "角色上下文" in system_text
|
||||
assert "【知识大脑】" in system_text
|
||||
assert "Rust Async" in system_text
|
||||
assert "用户喜欢燕麦拿铁" not in system_text
|
||||
assert "记忆上下文" not in system_text
|
||||
|
||||
|
||||
def test_build_system_messages_keeps_fresh_structured_continuity_for_matching_followup():
|
||||
state = _base_state(
|
||||
"创建",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
state["pending_action"] = {
|
||||
"type": "schedule_creation",
|
||||
"summary": "为周报安排明天下午提醒",
|
||||
"status": "pending",
|
||||
}
|
||||
state["routing_decision"] = {
|
||||
"target_agent": AgentRole.SCHEDULE_PLANNER.value,
|
||||
"reason": "continue_pending_action",
|
||||
}
|
||||
state["continuity_state"] = {
|
||||
"status": "fresh",
|
||||
}
|
||||
|
||||
messages = _build_system_messages(
|
||||
state,
|
||||
"manager prompt",
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"schedule_planning",
|
||||
)
|
||||
|
||||
system_text = "\n\n".join(str(getattr(message, "content", "")) for message in messages)
|
||||
assert "pending_action" in system_text
|
||||
assert "continue_pending_action" in system_text
|
||||
|
||||
|
||||
async def test_run_sub_commander_coalesces_system_messages_for_openai_compatible_provider(
|
||||
monkeypatch,
|
||||
):
|
||||
fake_llm = SingleSystemMessageLLM()
|
||||
fake_tool = FakeTool("create_reminder", "created reminder: blanket @ tomorrow 09:00")
|
||||
|
||||
monkeypatch.setattr("app.agents.graph._get_llm_for_state", lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__("app.agents.graph", fromlist=["SUB_COMMANDER_TOOLSETS"]).SUB_COMMANDER_TOOLSETS,
|
||||
"schedule_planning",
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state(
|
||||
"给我设置明天的提醒,提醒我收被子",
|
||||
{
|
||||
"provider": "openai",
|
||||
"model": "MiniMax-M2.7-highspeed",
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
},
|
||||
)
|
||||
state["current_agent"] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
"manager prompt",
|
||||
"给我设置明天的提醒,提醒我收被子",
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert fake_llm.system_message_counts == [1, 1]
|
||||
assert result["tool_strategy_used"] == "json_fallback"
|
||||
assert fake_tool.invocations == [{"title": "blanket", "reminder_at": "2026-03-29T09:00:00"}]
|
||||
assert result["final_response"] == "created reminder for blanket"
|
||||
360
backend/tests/backend/app/agents/test_registry.py
Normal file
360
backend/tests/backend/app/agents/test_registry.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import pytest
|
||||
from collections.abc import Mapping
|
||||
|
||||
from app.agents.prompts import (
|
||||
SUB_COMMANDER_PROMPTS_BY_KEY,
|
||||
TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY,
|
||||
)
|
||||
from app.agents.registry import build_registry_indexes, load_builtin_registry_bundle
|
||||
from app.agents.registry.indexes import summarize_registry_indexes
|
||||
from app.agents.registry.models import (
|
||||
AgentManifest,
|
||||
CapabilityManifest,
|
||||
SpecialistTemplateManifest,
|
||||
SubCommanderManifest,
|
||||
)
|
||||
from app.agents.registry.validator import validate_registry_bundle
|
||||
from app.agents.registry.builtins import (
|
||||
BUILTIN_AGENT_MANIFESTS,
|
||||
BUILTIN_CAPABILITY_MANIFESTS,
|
||||
BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS,
|
||||
BUILTIN_SUB_COMMANDER_MANIFESTS,
|
||||
)
|
||||
from app.agents.state import AgentRole
|
||||
from app.agents.tools import SUB_COMMANDER_TOOLSETS
|
||||
|
||||
|
||||
def make_agent(
|
||||
agent_id: str = "master",
|
||||
*,
|
||||
display_name: str = "Master",
|
||||
role_value: str = "master",
|
||||
system_prompt_key: str = "master",
|
||||
default_sub_commanders: list[str] | None = None,
|
||||
) -> AgentManifest:
|
||||
return AgentManifest(
|
||||
agent_id=agent_id,
|
||||
display_name=display_name,
|
||||
role_value=role_value,
|
||||
system_prompt_key=system_prompt_key,
|
||||
routing_hints=["route"],
|
||||
default_sub_commanders=default_sub_commanders or [],
|
||||
)
|
||||
|
||||
|
||||
def make_sub_commander(
|
||||
sub_commander_id: str = "planner",
|
||||
*,
|
||||
parent_agent_id: str = "master",
|
||||
capability_ids: list[str] | None = None,
|
||||
) -> SubCommanderManifest:
|
||||
return SubCommanderManifest(
|
||||
sub_commander_id=sub_commander_id,
|
||||
parent_agent_id=parent_agent_id,
|
||||
prompt_text="Plan the work.",
|
||||
capability_ids=capability_ids or [],
|
||||
)
|
||||
|
||||
|
||||
def make_capability(capability_id: str = "calendar") -> CapabilityManifest:
|
||||
return CapabilityManifest(capability_id=capability_id, tool_name=f"{capability_id}_tool")
|
||||
|
||||
|
||||
def make_specialist_template(
|
||||
template_id: str = "researcher",
|
||||
*,
|
||||
allowed_capability_ids: list[str] | None = None,
|
||||
) -> SpecialistTemplateManifest:
|
||||
return SpecialistTemplateManifest(
|
||||
template_id=template_id,
|
||||
display_name="Researcher",
|
||||
description="Research specialist",
|
||||
allowed_capability_ids=allowed_capability_ids,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_accepts_valid_bundle() -> None:
|
||||
validate_registry_bundle(
|
||||
agents=[make_agent(default_sub_commanders=["planner"])],
|
||||
sub_commanders=[make_sub_commander(capability_ids=["calendar"])],
|
||||
capabilities=[make_capability()],
|
||||
specialist_templates=[make_specialist_template(allowed_capability_ids=["calendar"])],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_duplicate_agent_ids() -> None:
|
||||
agents = [
|
||||
make_agent(default_sub_commanders=["planner"]),
|
||||
make_agent(
|
||||
display_name="Duplicate Master",
|
||||
role_value="master_duplicate",
|
||||
system_prompt_key="master_duplicate",
|
||||
),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="duplicate agent id: master"):
|
||||
validate_registry_bundle(
|
||||
agents=agents,
|
||||
sub_commanders=[],
|
||||
capabilities=[],
|
||||
specialist_templates=[],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_duplicate_sub_commander_ids() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate sub commander id: planner"):
|
||||
validate_registry_bundle(
|
||||
agents=[make_agent()],
|
||||
sub_commanders=[make_sub_commander(), make_sub_commander()],
|
||||
capabilities=[],
|
||||
specialist_templates=[],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_duplicate_capability_ids() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate capability id: calendar"):
|
||||
validate_registry_bundle(
|
||||
agents=[],
|
||||
sub_commanders=[],
|
||||
capabilities=[make_capability(), make_capability()],
|
||||
specialist_templates=[],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_duplicate_template_ids() -> None:
|
||||
with pytest.raises(ValueError, match="duplicate template id: researcher"):
|
||||
validate_registry_bundle(
|
||||
agents=[],
|
||||
sub_commanders=[],
|
||||
capabilities=[],
|
||||
specialist_templates=[make_specialist_template(), make_specialist_template()],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_unknown_sub_commander_parent_agent_ids() -> None:
|
||||
sub_commanders = [make_sub_commander(parent_agent_id="missing-agent")]
|
||||
|
||||
with pytest.raises(ValueError, match="unknown parent agent id: missing-agent"):
|
||||
validate_registry_bundle(
|
||||
agents=[],
|
||||
sub_commanders=sub_commanders,
|
||||
capabilities=[],
|
||||
specialist_templates=[],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_unknown_sub_commander_capability_references() -> None:
|
||||
with pytest.raises(ValueError, match="unknown capability id: search"):
|
||||
validate_registry_bundle(
|
||||
agents=[make_agent(default_sub_commanders=["planner"])],
|
||||
sub_commanders=[make_sub_commander(capability_ids=["search"])],
|
||||
capabilities=[make_capability()],
|
||||
specialist_templates=[],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_registry_bundle_rejects_unknown_specialist_template_capability_references() -> None:
|
||||
with pytest.raises(ValueError, match="unknown capability id: missing-capability"):
|
||||
validate_registry_bundle(
|
||||
agents=[],
|
||||
sub_commanders=[],
|
||||
capabilities=[make_capability()],
|
||||
specialist_templates=[
|
||||
make_specialist_template(allowed_capability_ids=["missing-capability"])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_registry_bundle_agent_roles_match_runtime_agent_role_enum_values() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert set(indexes.agent_by_id) == {role.value for role in AgentRole}
|
||||
assert {agent.role_value for agent in bundle.agents} == {role.value for role in AgentRole}
|
||||
|
||||
|
||||
def test_registry_bundle_agent_system_prompt_keys_match_runtime_top_level_prompt_surface() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
expected_prompt_keys_by_agent_id = {
|
||||
role.value: role.value for role in AgentRole if role.value in TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY
|
||||
}
|
||||
|
||||
assert set(TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY) == {role.value for role in AgentRole}
|
||||
assert indexes.agent_prompt_key_by_id == expected_prompt_keys_by_agent_id
|
||||
assert {
|
||||
agent.agent_id: TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY[agent.system_prompt_key]
|
||||
for agent in bundle.agents
|
||||
} == {
|
||||
role.value: TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY[role.value]
|
||||
for role in AgentRole
|
||||
}
|
||||
|
||||
|
||||
def test_registry_bundle_skill_context_keys_match_graph_role_derivation_rule() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
expected_skill_context_keys = {
|
||||
role.value: role.value.replace("agent_", "")
|
||||
for role in AgentRole
|
||||
}
|
||||
|
||||
assert indexes.skill_context_key_by_agent_id == expected_skill_context_keys
|
||||
assert {
|
||||
agent.agent_id: agent.skill_context_key for agent in bundle.agents
|
||||
} == expected_skill_context_keys
|
||||
|
||||
|
||||
def test_registry_bundle_sub_commander_prompt_texts_match_runtime_prompt_map() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert set(indexes.sub_commander_by_id) == set(SUB_COMMANDER_PROMPTS_BY_KEY)
|
||||
assert indexes.sub_commander_prompt_key_by_id == {
|
||||
sub_commander_id: sub_commander_id
|
||||
for sub_commander_id in SUB_COMMANDER_PROMPTS_BY_KEY
|
||||
}
|
||||
assert {
|
||||
sub_commander.sub_commander_id: sub_commander.prompt_text
|
||||
for sub_commander in bundle.sub_commanders
|
||||
} == SUB_COMMANDER_PROMPTS_BY_KEY
|
||||
|
||||
|
||||
def test_registry_bundle_sub_commander_tool_membership_and_order_match_runtime_toolsets() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert set(indexes.sub_commander_by_id) == set(SUB_COMMANDER_TOOLSETS)
|
||||
assert indexes.capability_ids_by_sub_commander_id == {
|
||||
sub_commander_id: tuple(tool.name for tool in tools)
|
||||
for sub_commander_id, tools in SUB_COMMANDER_TOOLSETS.items()
|
||||
}
|
||||
assert {
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
} == {
|
||||
sub_commander_id: tuple(tool.name for tool in tools)
|
||||
for sub_commander_id, tools in SUB_COMMANDER_TOOLSETS.items()
|
||||
}
|
||||
|
||||
|
||||
def test_builtin_capabilities_reference_actual_runtime_tool_names() -> None:
|
||||
expected_tool_names = {
|
||||
tool.name
|
||||
for tools in SUB_COMMANDER_TOOLSETS.values()
|
||||
for tool in tools
|
||||
}
|
||||
manifest_tool_names = {manifest.tool_name for manifest in BUILTIN_CAPABILITY_MANIFESTS}
|
||||
|
||||
assert manifest_tool_names == expected_tool_names
|
||||
|
||||
|
||||
def test_builtin_sub_commander_capabilities_match_runtime_toolsets() -> None:
|
||||
capabilities_by_tool_name = {
|
||||
manifest.tool_name: manifest.capability_id for manifest in BUILTIN_CAPABILITY_MANIFESTS
|
||||
}
|
||||
|
||||
for sub_commander in BUILTIN_SUB_COMMANDER_MANIFESTS:
|
||||
expected_capability_ids = {
|
||||
capabilities_by_tool_name[tool.name]
|
||||
for tool in SUB_COMMANDER_TOOLSETS[sub_commander.sub_commander_id]
|
||||
}
|
||||
assert set(sub_commander.capability_ids) == expected_capability_ids
|
||||
|
||||
|
||||
def test_builtin_manifests_form_a_valid_registry_bundle() -> None:
|
||||
validate_registry_bundle(
|
||||
agents=list(BUILTIN_AGENT_MANIFESTS),
|
||||
sub_commanders=list(BUILTIN_SUB_COMMANDER_MANIFESTS),
|
||||
capabilities=list(BUILTIN_CAPABILITY_MANIFESTS),
|
||||
specialist_templates=list(BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS),
|
||||
)
|
||||
|
||||
|
||||
def test_load_builtin_registry_bundle_returns_non_empty_manifest_sets() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
|
||||
assert bundle.agents
|
||||
assert bundle.sub_commanders
|
||||
assert bundle.capabilities
|
||||
assert isinstance(bundle.specialist_templates, tuple)
|
||||
|
||||
|
||||
def test_build_registry_indexes_exposes_manifest_lookups_by_id() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert indexes.agent_by_id
|
||||
assert indexes.sub_commander_by_id
|
||||
assert indexes.capability_by_id
|
||||
assert isinstance(indexes.specialist_template_by_id, Mapping)
|
||||
assert set(indexes.agent_by_id) == {agent.agent_id for agent in bundle.agents}
|
||||
assert set(indexes.sub_commander_by_id) == {
|
||||
sub_commander.sub_commander_id for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
assert set(indexes.capability_by_id) == {
|
||||
capability.capability_id for capability in bundle.capabilities
|
||||
}
|
||||
assert set(indexes.specialist_template_by_id) == {
|
||||
template.template_id for template in bundle.specialist_templates
|
||||
}
|
||||
|
||||
|
||||
def test_summarize_registry_indexes_returns_read_only_debug_counts() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert summarize_registry_indexes(indexes) == {
|
||||
"agent_count": len(bundle.agents),
|
||||
"sub_commander_count": len(bundle.sub_commanders),
|
||||
"capability_count": len(bundle.capabilities),
|
||||
"specialist_template_count": len(bundle.specialist_templates),
|
||||
}
|
||||
|
||||
|
||||
def test_build_registry_indexes_exposes_prompt_keys_skill_context_keys_and_capability_mappings() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
|
||||
indexes = build_registry_indexes(bundle)
|
||||
|
||||
assert indexes.agent_prompt_key_by_id == {
|
||||
agent.agent_id: agent.system_prompt_key for agent in bundle.agents
|
||||
}
|
||||
assert indexes.agent_prompt_key_by_id == {
|
||||
agent.agent_id: agent.system_prompt_key for agent in BUILTIN_AGENT_MANIFESTS
|
||||
}
|
||||
assert set(indexes.agent_prompt_key_by_id.values()) == set(TOP_LEVEL_SYSTEM_PROMPTS_BY_KEY)
|
||||
assert indexes.sub_commander_prompt_key_by_id == {
|
||||
sub_commander.sub_commander_id: sub_commander.sub_commander_id
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
assert set(indexes.sub_commander_prompt_key_by_id.values()) == {
|
||||
sub_commander.sub_commander_id for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
assert indexes.skill_context_key_by_agent_id == {
|
||||
agent.agent_id: agent.skill_context_key
|
||||
for agent in bundle.agents
|
||||
if agent.skill_context_key is not None
|
||||
}
|
||||
assert indexes.capability_ids_by_sub_commander_id == {
|
||||
sub_commander.sub_commander_id: tuple(sub_commander.capability_ids)
|
||||
for sub_commander in bundle.sub_commanders
|
||||
}
|
||||
|
||||
|
||||
def test_validate_registry_bundle_accepts_loaded_builtin_registry_bundle() -> None:
|
||||
bundle = load_builtin_registry_bundle()
|
||||
|
||||
validate_registry_bundle(
|
||||
agents=list(bundle.agents),
|
||||
sub_commanders=list(bundle.sub_commanders),
|
||||
capabilities=list(bundle.capabilities),
|
||||
specialist_templates=list(bundle.specialist_templates),
|
||||
)
|
||||
|
||||
|
||||
def test_phase_one_still_declares_specialist_template_surface_even_if_runtime_is_deferred() -> None:
|
||||
assert isinstance(BUILTIN_SPECIALIST_TEMPLATE_MANIFESTS, tuple)
|
||||
73
backend/tests/backend/app/agents/test_search_tools.py
Normal file
73
backend/tests/backend/app/agents/test_search_tools.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.tools.search import web_search
|
||||
|
||||
|
||||
class FakeResult(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
def test_web_search_tool_formats_results(monkeypatch):
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
assert query == 'Jarvis 最新更新'
|
||||
assert limit == 2
|
||||
return [
|
||||
FakeResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis 最新更新', top_k=2)
|
||||
|
||||
assert '[1] Jarvis release notes' in result
|
||||
assert '链接: https://example.com/jarvis-release' in result
|
||||
assert '来源: duckduckgo' in result
|
||||
assert '时间: 2026-03-29' in result
|
||||
assert '摘要: Latest Jarvis changes.' in result
|
||||
|
||||
|
||||
def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
|
||||
from app.services.web_search_service import WebSearchConfigurationError
|
||||
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis')
|
||||
|
||||
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_tool_runs_from_active_event_loop(monkeypatch):
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
assert query == 'Jarvis 最新更新'
|
||||
assert limit == 1
|
||||
return [
|
||||
FakeResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis 最新更新', top_k=1)
|
||||
|
||||
assert '[1] Jarvis release notes' in result
|
||||
assert '链接: https://example.com/jarvis-release' in result
|
||||
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tool_env(tmp_path):
|
||||
db_path = tmp_path / "test_task_tools.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# 只创建本测试需要的表,避免全量 metadata 引入未注册的外键表。
|
||||
await conn.run_sync(User.metadata.create_all, tables=[
|
||||
User.__table__,
|
||||
Task.__table__,
|
||||
DailyTodo.__table__,
|
||||
Reminder.__table__,
|
||||
Goal.__table__,
|
||||
])
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username="tool_user",
|
||||
email="tool@example.com",
|
||||
hashed_password=get_password_hash("secret123"),
|
||||
full_name="Tool Tester",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
try:
|
||||
yield {"session_factory": session_factory, "user_id": user.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_content_and_date_aliases_and_persists_task(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 0, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_task_accepts_content_and_date_aliases_and_sets_morning_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_schedule_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 9, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("priority_input", "expected"),
|
||||
[
|
||||
(1, TaskPriority.LOW),
|
||||
(2, TaskPriority.MEDIUM),
|
||||
(3, TaskPriority.HIGH),
|
||||
(4, TaskPriority.URGENT),
|
||||
("urgent", TaskPriority.URGENT),
|
||||
],
|
||||
)
|
||||
async def test_create_task_normalizes_legacy_and_string_priorities(tool_env, monkeypatch, priority_input, expected):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title=f"priority-{priority_input}", priority=priority_input)
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].priority == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_iso_datetime_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="timed task", due_date="2026-03-28T15:30:00Z")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.due_date == datetime(2026, 3, 28, 15, 30, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_missing_title_and_content(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func()
|
||||
|
||||
assert result == "创建任务失败: title 不能为空"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_invalid_priority(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="bad priority", priority="top")
|
||||
|
||||
assert "创建任务失败:" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_status_rejects_invalid_status(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
create_result = task_tools.create_task.func(title="status test")
|
||||
assert "任务创建成功" in create_result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_filters_by_normalized_status_and_formats_values(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
task_tools.create_task.func(title="todo task", priority="high")
|
||||
task_tools.create_task.func(title="done task", priority="low")
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
rows[1].status = TaskStatus.DONE
|
||||
await session.commit()
|
||||
|
||||
result = task_tools.get_tasks.func(status="done")
|
||||
|
||||
assert "done task" in result
|
||||
assert "todo task" not in result
|
||||
assert "状态:done" in result
|
||||
assert "优先级:low" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_accepts_datetime_description_and_at_aliases(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
description="提醒收被子",
|
||||
datetime="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Reminder))).scalar_one()
|
||||
|
||||
assert saved.title == "收被子"
|
||||
assert saved.note == "提醒收被子"
|
||||
assert saved.reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
datetime="2026-03-29T09:00:00+08:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
time="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
remind_at="2026-03-29T18:00:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 18, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_returns_failure_when_time_aliases_missing(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(title="收被子")
|
||||
|
||||
assert result == "创建提醒失败: reminder_at 不能为空"
|
||||
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.agents.tools.time_reasoning import (
|
||||
extract_reference_datetime,
|
||||
normalize_tool_time_arguments,
|
||||
resolve_time_expression_data,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_reference_datetime_from_current_time_context():
|
||||
context = '【当前时间】\n- current_time_utc: 2026-03-28T12:00:00+00:00\n- current_date_utc: 2026-03-28\n说明:解析相对时间时请以 current_time_utc 为准。'
|
||||
|
||||
result = extract_reference_datetime(context)
|
||||
|
||||
assert result == datetime(2026, 3, 28, 12, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_datetime():
|
||||
payload = resolve_time_expression_data(
|
||||
'明天早上9点',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['grain'] == 'datetime'
|
||||
assert payload['resolved_date'] == '2026-03-29'
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00'
|
||||
assert payload['assumed_time'] is False
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_date_window():
|
||||
payload = resolve_time_expression_data(
|
||||
'下周一下午',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_date'] == '2026-03-30'
|
||||
assert payload['resolved_datetime'] == '2026-03-30T15:00:00'
|
||||
assert payload['assumed_time'] is True
|
||||
assert 'assumed_time' in payload['reason']
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_reminder_time_aliases():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['reminder_at'] == '2026-03-29T09:00:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_date_only_tools():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_goal',
|
||||
{'title': '交付节点', 'goal_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['goal_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_preserves_explicit_datetime_offset():
|
||||
payload = resolve_time_expression_data(
|
||||
'2026-03-29T09:00:00+08:00',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00+08:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_keeps_create_task_date_without_explicit_time():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_task',
|
||||
{'title': '写周报', 'due_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['due_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_raises_for_invalid_time_text():
|
||||
try:
|
||||
normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天25点'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
except ValueError as exc:
|
||||
assert 'hour must be in 0..23' in str(exc)
|
||||
else:
|
||||
raise AssertionError('expected ValueError for invalid time text')
|
||||
25
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
25
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
|
||||
from app.agents.tools import forum as forum_tools
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
from app.agents.tools import search as search_tools
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("module", "label"),
|
||||
[
|
||||
(task_tools, "task"),
|
||||
(schedule_tools, "schedule"),
|
||||
(forum_tools, "forum"),
|
||||
(search_tools, "search"),
|
||||
],
|
||||
)
|
||||
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
||||
async def sample():
|
||||
return f"ok:{label}"
|
||||
|
||||
result = module._run_async(sample())
|
||||
|
||||
assert result == f"ok:{label}"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_builtin_skills_creates_default_ability_skills(tmp_path):
|
||||
db_path = tmp_path / 'test_builtin_skills.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='bootstrap_user',
|
||||
email='bootstrap@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Bootstrap User',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_builtin_skills(session)
|
||||
await ensure_builtin_skills(session)
|
||||
result = await session.execute(select(Skill).order_by(Skill.agent_type, Skill.name))
|
||||
skills = result.scalars().all()
|
||||
|
||||
assert len(skills) >= 9
|
||||
assert any(skill.agent_type == 'schedule_planner' for skill in skills)
|
||||
assert any(skill.agent_type == 'executor' for skill in skills)
|
||||
assert any(skill.agent_type == 'librarian' for skill in skills)
|
||||
librarian_skill = next(skill for skill in skills if skill.name == '知识检索摘要')
|
||||
assert 'web_search' in (librarian_skill.tools or [])
|
||||
assert any(skill.agent_type == 'analyst' for skill in skills)
|
||||
assert len({skill.name for skill in skills}) == len(skills)
|
||||
|
||||
await engine.dispose()
|
||||
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchResult,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
'request failed',
|
||||
request=httpx.Request('GET', 'http://searx.example/search'),
|
||||
response=httpx.Response(self.status_code, request=httpx.Request('GET', 'http://searx.example/search')),
|
||||
)
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *, response=None, error=None, recorder=None, **kwargs):
|
||||
self._response = response
|
||||
self._error = error
|
||||
self._recorder = recorder if recorder is not None else []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, *, params=None, headers=None):
|
||||
self._recorder.append({'url': url, 'params': params, 'headers': headers})
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
return self._response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_normalized_results_from_searxng(monkeypatch):
|
||||
requests = []
|
||||
payload = {
|
||||
'results': [
|
||||
{
|
||||
'title': 'Jarvis release notes',
|
||||
'url': 'https://example.com/jarvis-release',
|
||||
'content': 'Latest Jarvis changes and release notes.',
|
||||
'engine': 'duckduckgo',
|
||||
'publishedDate': '2026-03-29',
|
||||
}
|
||||
]
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse(payload), recorder=requests, **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
default_limit=5,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
results = await service.search('Jarvis 最新版本', limit=3)
|
||||
|
||||
assert results == [
|
||||
WebSearchResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes and release notes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
assert requests == [
|
||||
{
|
||||
'url': 'http://searx.example/search',
|
||||
'params': {
|
||||
'q': 'Jarvis 最新版本',
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
},
|
||||
'headers': {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_empty_list_when_searxng_has_no_results(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse({'results': []}), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
results = await service.search('不存在的话题')
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_on_searxng_http_failure(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(error=httpx.TimeoutException('timed out'), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchRequestError, match='SearxNG 请求失败'):
|
||||
await service.search('Jarvis')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_when_not_configured():
|
||||
service = WebSearchService(
|
||||
enabled=False,
|
||||
provider='searxng',
|
||||
base_url='',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchConfigurationError, match='网页搜索未启用或未配置'):
|
||||
await service.search('Jarvis')
|
||||
150
backend/tests/backend/app/test_agent_router.py
Normal file
150
backend/tests/backend/app/test_agent_router.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_env(tmp_path):
|
||||
db_path = tmp_path / 'test_agent_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='agent_user',
|
||||
email='agent@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Agent Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
skill_a = Skill(
|
||||
name='Planner skill A',
|
||||
description='planner',
|
||||
instructions='plan a',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
skill_b = Skill(
|
||||
name='Planner skill B',
|
||||
description='planner',
|
||||
instructions='plan b',
|
||||
agent_type='schedule_planner',
|
||||
tools=['tasks'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
session.add_all([
|
||||
Agent(
|
||||
name='SCHEDULE PLANNER',
|
||||
role='schedule_planner',
|
||||
description='日程规划师',
|
||||
system_prompt='prompt',
|
||||
is_active=True,
|
||||
),
|
||||
skill_a,
|
||||
skill_b,
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(skill_a)
|
||||
await session.refresh(skill_b)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(agent_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_config_returns_default_empty_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['selected_skill_ids'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_persists_selected_skill_ids(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id'], ids['skill_b_id']]},
|
||||
)
|
||||
get_response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
first_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id']]},
|
||||
)
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'description': 'updated description'},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['description'] == 'updated description'
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': ['missing-skill']},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()['detail'] == '存在无效的技能绑定'
|
||||
67
backend/tests/backend/app/test_config.py
Normal file
67
backend/tests/backend/app/test_config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
from app import config as config_module
|
||||
from app.services.llm_service import default_provider_capabilities, normalize_provider_name, resolve_provider_capabilities
|
||||
|
||||
|
||||
def test_env_file_points_to_repo_root_env_file():
|
||||
assert config_module.ENV_FILE == Path(__file__).resolve().parents[4] / '.env'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_prefers_native_for_openai():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'gpt-4o', 'base_url': 'https://api.openai.com/v1'})
|
||||
|
||||
assert capabilities.provider == 'openai'
|
||||
assert capabilities.supports_native_tools is True
|
||||
assert capabilities.preferred_tool_strategy == 'native'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_openai_compatible_non_official_endpoint():
|
||||
capabilities = resolve_provider_capabilities(
|
||||
{'provider': 'openai', 'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}
|
||||
)
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_uses_global_openai_base_url_when_user_config_omits_it(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'OPENAI_BASE_URL', 'https://api.minimax.chat/v1')
|
||||
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'abab7.5-chat-preview'})
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_from_custom_config():
|
||||
assert normalize_provider_name({'provider': 'custom', 'model': 'MiniMax-M2.7-highspeed'}) == 'minimax'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_without_provider_when_base_url_matches():
|
||||
assert normalize_provider_name({'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}) == 'minimax'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_ollama():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_default_provider_capabilities_follows_global_settings(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
capabilities = default_provider_capabilities()
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_without_provider_uses_global_default(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
assert normalize_provider_name({'model': 'qwen2.5'}) == 'ollama'
|
||||
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import sys
|
||||
from datetime import UTC, date, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.task import router as task_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def schedule_env(tmp_path):
|
||||
db_path = tmp_path / 'test_schedule_center.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='schedule_user',
|
||||
email='schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Schedule Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_schedule_user',
|
||||
email='other-schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Schedule Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Legacy todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Done todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=True,
|
||||
completed_at=datetime(2026, 4, 10, 9, 30, tzinfo=UTC),
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=other_user.id,
|
||||
title='Other user todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='High priority task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 14, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='Urgent task next day',
|
||||
priority=TaskPriority.URGENT,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
due_date=datetime(2026, 4, 11, 10, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=other_user.id,
|
||||
title='Other user task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 15, 0, tzinfo=UTC),
|
||||
),
|
||||
Reminder(
|
||||
user_id=user.id,
|
||||
title='Doctor reminder',
|
||||
note='Bring reports',
|
||||
reminder_at=datetime(2026, 4, 10, 8, 0, tzinfo=UTC),
|
||||
),
|
||||
Goal(
|
||||
user_id=user.id,
|
||||
title='Launch calendar beta',
|
||||
note='Ship MVP',
|
||||
goal_date='2026-04-10',
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(todo_router)
|
||||
test_app.include_router(task_router)
|
||||
test_app.include_router(reminder_router)
|
||||
test_app.include_router(goal_router)
|
||||
test_app.include_router(schedule_center_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_todo_persists_explicit_todo_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post('/api/todos', json={'title': 'Plan sprint', 'todo_date': '2026-04-12'})
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Plan sprint'
|
||||
assert payload['todo_date'] == '2026-04-12'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_todo_allows_editing_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.patch(f'/api/todos/{todo_id}', json={'title': 'Updated title', 'todo_date': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Updated title'
|
||||
assert payload['todo_date'] == '2026-04-11'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_todo_allows_deleting_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.delete(f'/api/todos/{todo_id}')
|
||||
after_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 204
|
||||
assert all(item['id'] != todo_id for item in after_response.json()['items'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'due_date': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['title'] for item in payload] == ['High priority task']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date_range(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'date_from': '2026-04-10', 'date_to': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert {item['title'] for item in payload} == {'High priority task', 'Urgent task next day'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_date_returns_aggregated_resources(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/date', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['date'] == '2026-04-10'
|
||||
assert payload['summary'] == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert [item['title'] for item in payload['reminders']] == ['Doctor reminder']
|
||||
assert [item['title'] for item in payload['goals']] == ['Launch calendar beta']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_month_returns_day_summaries(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/month', params={'year': 2026, 'month': 4})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['month'] == '2026-04'
|
||||
day_10 = next(item for item in payload['days'] if item['date'] == '2026-04-10')
|
||||
day_11 = next(item for item in payload['days'] if item['date'] == '2026-04-11')
|
||||
assert day_10 == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert day_11['task_due_total'] == 1
|
||||
assert day_11['high_priority_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_reminder_with_naive_datetime_and_time_zone_appears_in_schedule_center(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
create_response = await client.post(
|
||||
'/api/reminders',
|
||||
json={'title': '收被子', 'note': '提醒收被子', 'reminder_at': '2026-03-29T09:00:00'},
|
||||
)
|
||||
detail_response = await client.get('/api/schedule-center/date', params={'date_str': '2026-03-29'})
|
||||
|
||||
assert create_response.status_code == 201
|
||||
assert detail_response.status_code == 200
|
||||
payload = detail_response.json()
|
||||
assert [item['title'] for item in payload['reminders']] == ['收被子']
|
||||
assert payload['summary']['reminder_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reminder_and_goal_crud(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
reminder_response = await client.post('/api/reminders', json={'title': 'Standup', 'note': 'Daily sync', 'reminder_at': '2026-04-12T09:00:00Z'})
|
||||
goal_response = await client.post('/api/goals', json={'title': 'Finish polish', 'note': 'UI cleanup', 'goal_date': '2026-04-12', 'status': 'active'})
|
||||
reminder_id = reminder_response.json()['id']
|
||||
goal_id = goal_response.json()['id']
|
||||
patch_reminder = await client.patch(f'/api/reminders/{reminder_id}', json={'status': 'done'})
|
||||
patch_goal = await client.patch(f'/api/goals/{goal_id}', json={'status': 'done'})
|
||||
reminders_list = await client.get('/api/reminders', params={'date_str': '2026-04-12'})
|
||||
goals_list = await client.get('/api/goals', params={'date_str': '2026-04-12'})
|
||||
delete_reminder = await client.delete(f'/api/reminders/{reminder_id}')
|
||||
delete_goal = await client.delete(f'/api/goals/{goal_id}')
|
||||
|
||||
assert reminder_response.status_code == 201
|
||||
assert goal_response.status_code == 201
|
||||
assert patch_reminder.json()['status'] == 'done'
|
||||
assert patch_goal.json()['status'] == 'done'
|
||||
assert [item['title'] for item in reminders_list.json()['items']] == ['Standup']
|
||||
assert [item['title'] for item in goals_list.json()['items']] == ['Finish polish']
|
||||
assert delete_reminder.status_code == 204
|
||||
assert delete_goal.status_code == 204
|
||||
190
backend/tests/backend/app/test_skill_router.py
Normal file
190
backend/tests/backend/app/test_skill_router.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.skill import router as skill_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def skill_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_skill_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
# Ensure app.database.init_db() runs against the test database.
|
||||
import app.database as database_module
|
||||
|
||||
monkeypatch.setattr(database_module, "engine", engine)
|
||||
monkeypatch.setattr(database_module, "async_session", session_factory)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='skill_user',
|
||||
email='skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Skill Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_skill_user',
|
||||
email='other-skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Skill Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
session.add_all([
|
||||
Skill(
|
||||
name='Planner skill',
|
||||
description='planner',
|
||||
instructions='plan',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Executor skill',
|
||||
description='executor',
|
||||
instructions='execute',
|
||||
agent_type='executor',
|
||||
tools=['shell'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Other user planner skill',
|
||||
description='other',
|
||||
instructions='other',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=other_user.id,
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(auth_router)
|
||||
test_app.include_router(skill_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, session_factory
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_filters_by_agent_type(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'agent_type': 'schedule_planner'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_migrates_planner_skills_to_schedule_planner(skill_env):
|
||||
app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("UPDATE skills SET agent_type = 'planner' WHERE name = 'Planner skill'"))
|
||||
await session.commit()
|
||||
|
||||
from app.database import init_db
|
||||
|
||||
await init_db()
|
||||
|
||||
async with session_factory() as session:
|
||||
migrated_response = await session.execute(text("SELECT agent_type FROM skills WHERE name = 'Planner skill'"))
|
||||
assert migrated_response.scalar_one() == 'schedule_planner'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_visibility_filter_still_respects_access_scope(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'visibility': 'private'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_bootstraps_builtin_market_skills_for_current_user(skill_env):
|
||||
test_app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("DELETE FROM skills"))
|
||||
await session.commit()
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
login_response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'skill_user', 'password': 'secret123'},
|
||||
headers={'content-type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
|
||||
assert login_response.status_code == 200
|
||||
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(select(Skill.name, Skill.is_builtin).order_by(Skill.name))
|
||||
skills = result.all()
|
||||
|
||||
names = {name for name, _is_builtin in skills}
|
||||
assert '今日重点拆解' in names
|
||||
assert '任务执行 SOP' in names
|
||||
assert any(is_builtin is True for _name, is_builtin in skills)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_without_agent_type_returns_current_user_skills(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
assert all(isinstance(item['created_at'], str) for item in payload)
|
||||
assert all(isinstance(item['updated_at'], str) for item in payload)
|
||||
assert all('is_builtin' in item for item in payload)
|
||||
assert all(item['is_builtin'] is False for item in payload)
|
||||
143
backend/uv.lock
generated
143
backend/uv.lock
generated
@@ -4,7 +4,8 @@ requires-python = ">=3.11"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
"python_full_version == '3.13.*'",
|
||||
"python_full_version < '3.13'",
|
||||
"python_full_version == '3.12.*'",
|
||||
"python_full_version < '3.12'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -258,6 +259,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/41/0a/0896b829a39b5669a2d811e1a79598de661693685cd62b31f11d0c18e65b/av-17.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dba98603fc4665b4f750de86fbaf6c0cfaece970671a9b529e0e3d1711e8367e", size = 22071058, upload-time = "2026-03-14T14:38:43.663Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "2.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001, upload-time = "2022-10-05T19:19:32.061Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "banks"
|
||||
version = "2.4.1"
|
||||
@@ -1297,6 +1307,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "4.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "hpack" },
|
||||
{ name = "hyperframe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
version = "1.4.2"
|
||||
@@ -1329,6 +1352,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hpack"
|
||||
version = "4.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
@@ -1393,6 +1425,11 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
http2 = [
|
||||
{ name = "h2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-retries"
|
||||
version = "0.4.6"
|
||||
@@ -1425,6 +1462,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "identify"
|
||||
version = "2.6.18"
|
||||
@@ -1508,6 +1554,7 @@ dependencies = [
|
||||
{ name = "langsmith" },
|
||||
{ name = "llama-index" },
|
||||
{ name = "llama-index-vector-stores-chroma" },
|
||||
{ name = "mem0ai" },
|
||||
{ name = "mineru" },
|
||||
{ name = "openpyxl" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
@@ -1552,6 +1599,7 @@ requires-dist = [
|
||||
{ name = "langsmith", specifier = ">=0.1.0" },
|
||||
{ name = "llama-index", specifier = ">=0.12.0" },
|
||||
{ name = "llama-index-vector-stores-chroma", specifier = ">=0.3.0" },
|
||||
{ name = "mem0ai", specifier = ">=1.0.0" },
|
||||
{ name = "mineru", specifier = ">=2.0.3" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.10.0" },
|
||||
{ name = "openpyxl", specifier = ">=3.1.0" },
|
||||
@@ -2431,6 +2479,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mem0ai"
|
||||
version = "1.0.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "openai" },
|
||||
{ name = "posthog" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pytz" },
|
||||
{ name = "qdrant-client" },
|
||||
{ name = "sqlalchemy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/d9/1fbf24b055f9f14ec61d593bc95b75b033fd1a69bfddee33d6453b21e5d0/mem0ai-1.0.10.tar.gz", hash = "sha256:f3e22c9aff695ca6c66631c4e79ceef92457c8d3355c56359ab4257fa031c046", size = 200416, upload-time = "2026-04-01T18:23:27.018Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/e6/2c6ea68c404757e683da23b942dfff6987fe283ccbf2fa1fb0c128ddbdc6/mem0ai-1.0.10-py3-none-any.whl", hash = "sha256:9ff586c3a39a834042ce6755fc9da2315e284fb622ee773cd344ecca756ccad5", size = 295374, upload-time = "2026-04-01T18:23:25.022Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mineru"
|
||||
version = "2.7.6"
|
||||
@@ -3372,6 +3438,35 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portalocker"
|
||||
version = "3.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5e/77/65b857a69ed876e1951e88aaba60f5ce6120c33703f7cb61a3c894b8c1b6/portalocker-3.2.0.tar.gz", hash = "sha256:1f3002956a54a8c3730586c5c77bf18fae4149e07eaf1c29fc3faf4d5a3f89ac", size = 95644, upload-time = "2025-06-14T13:20:40.03Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl", hash = "sha256:3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968", size = 22424, upload-time = "2025-06-14T13:20:38.083Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "posthog"
|
||||
version = "7.9.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backoff" },
|
||||
{ name = "distro" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
{ name = "six" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1c/a7/2865487853061fbd62383492237b546d2d8f7c1846272350d2b9e14138cd/posthog-7.9.12.tar.gz", hash = "sha256:ebabf2eb2e1c1fbf22b0759df4644623fa43cc6c9dcbe9fd429b7937d14251ec", size = 176828, upload-time = "2026-03-12T09:01:15.184Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/a9/7a803aed5a5649cf78ea7b31e90d0080181ba21f739243e1741a1e607f1f/posthog-7.9.12-py3-none-any.whl", hash = "sha256:7175bd1698a566bfea98a016c64e3456399f8046aeeca8f1d04ae5bf6c5a38d0", size = 202469, upload-time = "2026-03-12T09:01:13.38Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "4.5.1"
|
||||
@@ -3996,6 +4091,34 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2026.1.post1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/56/db/b8721d71d945e6a8ac63c0fc900b2067181dbb50805958d4d4661cf7d277/pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1", size = 321088, upload-time = "2026-03-03T07:47:50.683Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a", size = 510489, upload-time = "2026-03-03T07:47:49.167Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pywin32"
|
||||
version = "311"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.3"
|
||||
@@ -4051,6 +4174,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qdrant-client"
|
||||
version = "1.17.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "grpcio" },
|
||||
{ name = "httpx", extra = ["http2"] },
|
||||
{ name = "numpy" },
|
||||
{ name = "portalocker" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/30/dd/f8a8261b83946af3cd65943c93c4f83e044f01184e8525404989d22a81a5/qdrant_client-1.17.1.tar.gz", hash = "sha256:22f990bbd63485ed97ba551a4c498181fcb723f71dcab5d6e4e43fe1050a2bc0", size = 344979, upload-time = "2026-03-13T17:13:44.678Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/69/77d1a971c4b933e8c79403e99bcbb790463da5e48333cc4fd5d412c63c98/qdrant_client-1.17.1-py3-none-any.whl", hash = "sha256:6cda4064adfeaf211c751f3fbc00edbbdb499850918c7aff4855a9a759d56cbd", size = 389947, upload-time = "2026-03-13T17:13:43.156Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qwen-vl-utils"
|
||||
version = "0.0.14"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user