feat: enhance agent orchestration, knowledge flow and UI refinements
This commit is contained in:
@@ -1,9 +1,122 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.agents.graph import master_node
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.graph import (
|
||||
_choose_sub_commander,
|
||||
_parse_json_action,
|
||||
_route_agent_from_user_query,
|
||||
_run_sub_commander,
|
||||
master_node,
|
||||
)
|
||||
from app.agents.tools.time_reasoning import resolve_time_expression
|
||||
from app.agents.state import AgentRole
|
||||
|
||||
|
||||
|
||||
|
||||
def _base_state(message: str, user_llm_config: dict | None = None) -> dict:
|
||||
return {
|
||||
'messages': [HumanMessage(content=message)],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'current_sub_commander': None,
|
||||
'active_sub_commanders': [],
|
||||
'sub_commander_trace': [],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'action_results': [],
|
||||
'created_entities': [],
|
||||
'tool_strategy_used': None,
|
||||
'provider_capabilities': None,
|
||||
'fallback_parse_error': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'schedule_context_summary': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'current_datetime_context': 'CURRENT_TIME: 2026-03-28T12:00:00+08:00',
|
||||
'current_datetime_reference': {'current_time_iso': '2026-03-28T12:00:00+08:00', 'current_date_iso': '2026-03-28', 'timezone': 'UTC'},
|
||||
'user_llm_config': user_llm_config,
|
||||
}
|
||||
|
||||
|
||||
class FakeFallbackLLM:
|
||||
def __init__(self, first_content: str, followup_content: str = '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'):
|
||||
self.first_content = first_content
|
||||
self.followup_content = followup_content
|
||||
self.calls = 0
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return AIMessage(content=self.first_content)
|
||||
return AIMessage(content=self.followup_content)
|
||||
|
||||
def bind_tools(self, tools):
|
||||
raise AssertionError('bind_tools should not be called in JSON fallback mode')
|
||||
|
||||
|
||||
class FakeNativeBoundLLM:
|
||||
async def ainvoke(self, messages):
|
||||
return AIMessage(
|
||||
content='',
|
||||
tool_calls=[
|
||||
{
|
||||
'id': 'call_1',
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FakeNativeLLM:
|
||||
def __init__(self):
|
||||
self.bound = FakeNativeBoundLLM()
|
||||
self.tool_binding_count = 0
|
||||
self.calls = 0
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(provider='openai', supports_native_tools=True, preferred_tool_strategy='native')
|
||||
|
||||
def bind_tools(self, tools):
|
||||
self.tool_binding_count += 1
|
||||
return self.bound
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.calls += 1
|
||||
return AIMessage(content='已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。')
|
||||
|
||||
|
||||
class FakeTool:
|
||||
def __init__(self, name: str, result: str):
|
||||
self.name = name
|
||||
self.result = result
|
||||
self.invocations: list[dict] = []
|
||||
|
||||
def invoke(self, args: dict):
|
||||
self.invocations.append(args)
|
||||
return self.result
|
||||
|
||||
|
||||
class CapturingLLM:
|
||||
def __init__(self, content: str = '{"mode":"final","final_response":"好的。"}'):
|
||||
self.content = content
|
||||
self.messages = None
|
||||
self._jarvis_provider_capabilities = SimpleNamespace(provider='ollama', supports_native_tools=False, preferred_tool_strategy='json_fallback')
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.messages = messages
|
||||
return AIMessage(content=self.content)
|
||||
|
||||
|
||||
class FailIfCalledLLM:
|
||||
async def ainvoke(self, messages):
|
||||
raise AssertionError('LLM should not be called for simple greetings')
|
||||
@@ -71,6 +184,68 @@ async def test_master_node_returns_stable_reply_for_identity_question(monkeypatc
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_identity_question_with_punctuation(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
state = {
|
||||
'messages': [HumanMessage(content='你是谁?')],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'user_llm_config': None,
|
||||
}
|
||||
|
||||
result = await master_node(state)
|
||||
|
||||
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_identity_question_with_particle(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
state = {
|
||||
'messages': [HumanMessage(content='你是谁啊')],
|
||||
'user_id': 'u1',
|
||||
'conversation_id': 'c1',
|
||||
'current_agent': AgentRole.MASTER,
|
||||
'active_agents': [AgentRole.MASTER],
|
||||
'pending_tasks': [],
|
||||
'completed_tasks': [],
|
||||
'tool_calls': [],
|
||||
'last_tool_result': None,
|
||||
'knowledge_context': None,
|
||||
'graph_context': None,
|
||||
'plan': None,
|
||||
'plan_steps': [],
|
||||
'analysis_report': None,
|
||||
'final_response': None,
|
||||
'should_respond': True,
|
||||
'memory_context': None,
|
||||
'user_llm_config': None,
|
||||
}
|
||||
|
||||
result = await master_node(state)
|
||||
|
||||
assert result['final_response'] == '我是 Jarvis。\n\n比起做一个泛泛的助手,我更像您的判断型协作伙伴:帮您看清问题、压缩路径、把事情往前推进。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
async def test_master_node_returns_stable_reply_for_capability_question(monkeypatch):
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: FailIfCalledLLM())
|
||||
|
||||
@@ -100,3 +275,196 @@ async def test_master_node_returns_stable_reply_for_capability_question(monkeypa
|
||||
assert result['final_response'] == '主要做三件事。\n- 帮您判断:看问题本质、梳理取舍、给出方向\n- 帮您收束:把复杂内容理顺,把重点拎出来\n- 帮您推进:拆任务、定步骤、把下一步变清楚\n\n如果您现在有具体目标,我可以直接进入处理。'
|
||||
assert result['current_agent'] == AgentRole.MASTER
|
||||
assert result['active_agents'] == [AgentRole.MASTER]
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_schedule_requests_to_schedule_planning():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '帮我安排一下这周计划') == 'schedule_planning'
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_focus_requests_to_schedule_analysis():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '基于最近对话帮我判断该聚焦什么') == 'schedule_analysis'
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_knowledge_requests_to_librarian():
|
||||
assert _route_agent_from_user_query('帮我搜索知识库里的项目资料') == AgentRole.LIBRARIAN
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_schedule_requests_to_schedule_planner():
|
||||
assert _route_agent_from_user_query('明天提醒我开会') == AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
|
||||
def test_route_agent_from_user_query_routes_explicit_month_day_milestone_to_schedule_planner():
|
||||
assert _route_agent_from_user_query('3月29日,对话系统交付节点') == AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
|
||||
def test_choose_sub_commander_routes_explicit_month_day_milestone_to_schedule_planning():
|
||||
assert _choose_sub_commander(AgentRole.SCHEDULE_PLANNER, '3月29日,对话系统交付节点') == 'schedule_planning'
|
||||
|
||||
|
||||
|
||||
|
||||
def test_parse_json_action_extracts_tool_calls_from_fenced_json():
|
||||
parsed = _parse_json_action(
|
||||
'```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```',
|
||||
['create_reminder'],
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
'mode': 'tool_call',
|
||||
'tool_calls': [
|
||||
{
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
'reason': None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_parse_json_action_returns_none_for_invalid_or_unknown_payload():
|
||||
assert _parse_json_action('not json', ['create_reminder']) is None
|
||||
assert _parse_json_action('{"mode":"tool_call","tool_calls":[{"name":"unknown","arguments":{}}]}', ['create_reminder']) is None
|
||||
|
||||
|
||||
def test_parse_json_action_tolerates_prefix_and_suffix_text():
|
||||
parsed = _parse_json_action(
|
||||
'好的,下面是 JSON:\n```json\n{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}\n```\n谢谢',
|
||||
['create_reminder'],
|
||||
)
|
||||
assert parsed is not None
|
||||
assert parsed['mode'] == 'tool_call'
|
||||
assert parsed['tool_calls'][0]['name'] == 'create_reminder'
|
||||
|
||||
|
||||
def test_parse_json_action_accepts_parameters_alias_for_tool_calls():
|
||||
parsed = _parse_json_action(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","parameters":{"title":"收被子","reminder_at":"2026-03-29T09:00:00+08:00"}}]}',
|
||||
['create_reminder'],
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
'mode': 'tool_call',
|
||||
'tool_calls': [
|
||||
{
|
||||
'name': 'create_reminder',
|
||||
'args': {'title': '收被子', 'reminder_at': '2026-03-29T09:00:00+08:00'},
|
||||
'reason': None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_run_sub_commander_uses_json_fallback_for_non_native_provider(monkeypatch):
|
||||
fake_llm = FakeFallbackLLM(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"create_reminder","arguments":{"title":"开会","reminder_at":"明天 09:00"}}]}'
|
||||
)
|
||||
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'schedule_planning',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'json_fallback'
|
||||
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
|
||||
assert result['tool_calls'][0]['name'] == 'create_reminder'
|
||||
assert result['created_entities'][0]['type'] == 'reminder'
|
||||
assert result['fallback_parse_error'] is None
|
||||
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'
|
||||
|
||||
|
||||
async def test_run_sub_commander_includes_current_datetime_context_in_system_messages(monkeypatch):
|
||||
fake_llm = CapturingLLM('{"mode":"final","final_response":"好的。"}')
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
state['current_datetime_context'] = 'CURRENT_TIME: 2026-03-28T12:00:00+08:00'
|
||||
|
||||
await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert fake_llm.messages is not None
|
||||
assert any(
|
||||
getattr(m, 'type', None) == 'system' and 'CURRENT_TIME:' in str(getattr(m, 'content', ''))
|
||||
for m in fake_llm.messages
|
||||
)
|
||||
|
||||
|
||||
async def test_run_sub_commander_uses_web_search_in_json_fallback(monkeypatch):
|
||||
fake_llm = FakeFallbackLLM(
|
||||
'{"mode":"tool_call","tool_calls":[{"name":"web_search","arguments":{"query":"Jarvis 最新模型更新","top_k":2}}]}',
|
||||
'我查了外部网页,下面是最新结果摘要。',
|
||||
)
|
||||
fake_tool = FakeTool('web_search', '成功搜索到 2 条网页结果')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'librarian_retrieval',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('帮我上网查一下 Jarvis 最新模型更新', {'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
state['current_agent'] = AgentRole.LIBRARIAN
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.LIBRARIAN,
|
||||
'manager prompt',
|
||||
'帮我上网查一下 Jarvis 最新模型更新',
|
||||
use_tools=True,
|
||||
summary_target='knowledge_context',
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'json_fallback'
|
||||
assert fake_tool.invocations == [{'query': 'Jarvis 最新模型更新', 'top_k': 2}]
|
||||
assert result['tool_calls'][0]['name'] == 'web_search'
|
||||
assert result['last_tool_result'] == '[web_search] 成功搜索到 2 条网页结果'
|
||||
assert result['final_response'] == '我查了外部网页,下面是最新结果摘要。'
|
||||
|
||||
|
||||
fake_llm = FakeNativeLLM()
|
||||
fake_tool = FakeTool('create_reminder', '成功创建 reminder: 开会 @ 明天 09:00')
|
||||
|
||||
monkeypatch.setattr('app.agents.graph._get_llm_for_state', lambda state: fake_llm)
|
||||
monkeypatch.setitem(
|
||||
__import__('app.agents.graph', fromlist=['SUB_COMMANDER_TOOLSETS']).SUB_COMMANDER_TOOLSETS,
|
||||
'schedule_planning',
|
||||
[fake_tool],
|
||||
)
|
||||
|
||||
state = _base_state('明天 9 点提醒我开会', {'provider': 'openai', 'model': 'gpt-4o'})
|
||||
state['current_agent'] = AgentRole.SCHEDULE_PLANNER
|
||||
|
||||
result = await _run_sub_commander(
|
||||
state,
|
||||
AgentRole.SCHEDULE_PLANNER,
|
||||
'manager prompt',
|
||||
'明天 9 点提醒我开会',
|
||||
use_tools=True,
|
||||
)
|
||||
|
||||
assert result['tool_strategy_used'] == 'native'
|
||||
assert fake_llm.tool_binding_count == 1
|
||||
assert fake_tool.invocations == [{'title': '开会', 'reminder_at': '2026-03-29T09:00:00'}]
|
||||
assert result['created_entities'][0]['type'] == 'reminder'
|
||||
assert result['final_response'] == '已创建提醒:开会,时间为 2026-03-29 09:00(按当前时间理解为“明天早上9点”)。'
|
||||
|
||||
49
backend/tests/backend/app/agents/test_search_tools.py
Normal file
49
backend/tests/backend/app/agents/test_search_tools.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.tools.search import web_search
|
||||
|
||||
|
||||
class FakeResult(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
def test_web_search_tool_formats_results(monkeypatch):
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
assert query == 'Jarvis 最新更新'
|
||||
assert limit == 2
|
||||
return [
|
||||
FakeResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis 最新更新', top_k=2)
|
||||
|
||||
assert '[1] Jarvis release notes' in result
|
||||
assert '链接: https://example.com/jarvis-release' in result
|
||||
assert '来源: duckduckgo' in result
|
||||
assert '时间: 2026-03-29' in result
|
||||
assert '摘要: Latest Jarvis changes.' in result
|
||||
|
||||
|
||||
def test_web_search_tool_returns_stable_message_when_unavailable(monkeypatch):
|
||||
from app.services.web_search_service import WebSearchConfigurationError
|
||||
|
||||
class FakeService:
|
||||
async def search(self, query: str, limit: int | None = None):
|
||||
raise WebSearchConfigurationError('网页搜索未启用或未配置')
|
||||
|
||||
monkeypatch.setattr('app.services.web_search_service.WebSearchService', FakeService)
|
||||
|
||||
result = web_search.func('Jarvis')
|
||||
|
||||
assert result == '网页搜索不可用: 网页搜索未启用或未配置'
|
||||
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
277
backend/tests/backend/app/agents/test_task_tools.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tool_env(tmp_path):
|
||||
db_path = tmp_path / "test_task_tools.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# 只创建本测试需要的表,避免全量 metadata 引入未注册的外键表。
|
||||
await conn.run_sync(User.metadata.create_all, tables=[
|
||||
User.__table__,
|
||||
Task.__table__,
|
||||
DailyTodo.__table__,
|
||||
Reminder.__table__,
|
||||
Goal.__table__,
|
||||
])
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username="tool_user",
|
||||
email="tool@example.com",
|
||||
hashed_password=get_password_hash("secret123"),
|
||||
full_name="Tool Tester",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
try:
|
||||
yield {"session_factory": session_factory, "user_id": user.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_content_and_date_aliases_and_persists_task(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 0, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_task_accepts_content_and_date_aliases_and_sets_morning_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_schedule_task.func(content="完成对话系统", date="2026-03-28")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.title == "完成对话系统"
|
||||
assert saved.description == "完成对话系统"
|
||||
assert saved.priority == TaskPriority.MEDIUM
|
||||
assert saved.status == TaskStatus.TODO
|
||||
assert saved.due_date == datetime(2026, 3, 28, 9, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("priority_input", "expected"),
|
||||
[
|
||||
(1, TaskPriority.LOW),
|
||||
(2, TaskPriority.MEDIUM),
|
||||
(3, TaskPriority.HIGH),
|
||||
(4, TaskPriority.URGENT),
|
||||
("urgent", TaskPriority.URGENT),
|
||||
],
|
||||
)
|
||||
async def test_create_task_normalizes_legacy_and_string_priorities(tool_env, monkeypatch, priority_input, expected):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title=f"priority-{priority_input}", priority=priority_input)
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].priority == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_accepts_iso_datetime_due_date(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="timed task", due_date="2026-03-28T15:30:00Z")
|
||||
|
||||
assert "任务创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
assert saved.due_date == datetime(2026, 3, 28, 15, 30, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_missing_title_and_content(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func()
|
||||
|
||||
assert result == "创建任务失败: title 不能为空"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_failure_for_invalid_priority(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = task_tools.create_task.func(title="bad priority", priority="top")
|
||||
|
||||
assert "创建任务失败:" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_status_rejects_invalid_status(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
create_result = task_tools.create_task.func(title="status test")
|
||||
assert "任务创建成功" in create_result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Task))).scalar_one()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_filters_by_normalized_status_and_formats_values(tool_env, monkeypatch):
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
monkeypatch.setattr(task_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(task_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
task_tools.create_task.func(title="todo task", priority="high")
|
||||
task_tools.create_task.func(title="done task", priority="low")
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Task).order_by(Task.created_at.asc()))).scalars().all()
|
||||
rows[1].status = TaskStatus.DONE
|
||||
await session.commit()
|
||||
|
||||
result = task_tools.get_tasks.func(status="done")
|
||||
|
||||
assert "done task" in result
|
||||
assert "todo task" not in result
|
||||
assert "状态:done" in result
|
||||
assert "优先级:low" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_accepts_datetime_description_and_at_aliases(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
description="提醒收被子",
|
||||
datetime="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
saved = (await session.execute(select(Reminder))).scalar_one()
|
||||
|
||||
assert saved.title == "收被子"
|
||||
assert saved.note == "提醒收被子"
|
||||
assert saved.reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
datetime="2026-03-29T09:00:00+08:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
content="收被子",
|
||||
time="2026-03-29T09:00:00",
|
||||
time_zone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 9, 0)
|
||||
|
||||
result = schedule_tools.create_reminder.func(
|
||||
title="收被子",
|
||||
remind_at="2026-03-29T18:00:00",
|
||||
)
|
||||
|
||||
assert "提醒创建成功" in result
|
||||
|
||||
async with tool_env["session_factory"]() as session:
|
||||
rows = (await session.execute(select(Reminder).order_by(Reminder.created_at.asc()))).scalars().all()
|
||||
|
||||
assert rows[-1].title == "收被子"
|
||||
assert rows[-1].note is None
|
||||
assert rows[-1].reminder_at == datetime(2026, 3, 29, 18, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule_reminder_returns_failure_when_time_aliases_missing(tool_env, monkeypatch):
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
|
||||
monkeypatch.setattr(schedule_tools, "async_session", tool_env["session_factory"])
|
||||
monkeypatch.setattr(schedule_tools, "get_current_user", lambda: tool_env["user_id"])
|
||||
|
||||
result = schedule_tools.create_reminder.func(title="收被子")
|
||||
|
||||
assert result == "创建提醒失败: reminder_at 不能为空"
|
||||
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
94
backend/tests/backend/app/agents/test_time_reasoning_tool.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.agents.tools.time_reasoning import (
|
||||
extract_reference_datetime,
|
||||
normalize_tool_time_arguments,
|
||||
resolve_time_expression_data,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_reference_datetime_from_current_time_context():
|
||||
context = '【当前时间】\n- current_time_utc: 2026-03-28T12:00:00+00:00\n- current_date_utc: 2026-03-28\n说明:解析相对时间时请以 current_time_utc 为准。'
|
||||
|
||||
result = extract_reference_datetime(context)
|
||||
|
||||
assert result == datetime(2026, 3, 28, 12, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_datetime():
|
||||
payload = resolve_time_expression_data(
|
||||
'明天早上9点',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['grain'] == 'datetime'
|
||||
assert payload['resolved_date'] == '2026-03-29'
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00'
|
||||
assert payload['assumed_time'] is False
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_normalizes_relative_date_window():
|
||||
payload = resolve_time_expression_data(
|
||||
'下周一下午',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_date'] == '2026-03-30'
|
||||
assert payload['resolved_datetime'] == '2026-03-30T15:00:00'
|
||||
assert payload['assumed_time'] is True
|
||||
assert 'assumed_time' in payload['reason']
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_reminder_time_aliases():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天 09:00'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['reminder_at'] == '2026-03-29T09:00:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_converts_date_only_tools():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_goal',
|
||||
{'title': '交付节点', 'goal_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['goal_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_resolve_time_expression_data_preserves_explicit_datetime_offset():
|
||||
payload = resolve_time_expression_data(
|
||||
'2026-03-29T09:00:00+08:00',
|
||||
current_datetime_context='CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
prefer='datetime',
|
||||
)
|
||||
|
||||
assert payload['resolved_datetime'] == '2026-03-29T09:00:00+08:00'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_keeps_create_task_date_without_explicit_time():
|
||||
normalized = normalize_tool_time_arguments(
|
||||
'create_task',
|
||||
{'title': '写周报', 'due_date': '明天'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
|
||||
assert normalized['due_date'] == '2026-03-29'
|
||||
|
||||
|
||||
def test_normalize_tool_time_arguments_raises_for_invalid_time_text():
|
||||
try:
|
||||
normalize_tool_time_arguments(
|
||||
'create_reminder',
|
||||
{'title': '开会', 'reminder_at': '明天25点'},
|
||||
'CURRENT_TIME: 2026-03-28T12:00:00+00:00',
|
||||
)
|
||||
except ValueError as exc:
|
||||
assert 'hour must be in 0..23' in str(exc)
|
||||
else:
|
||||
raise AssertionError('expected ValueError for invalid time text')
|
||||
23
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
23
backend/tests/backend/app/agents/test_tool_async_bridge.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from app.agents.tools import forum as forum_tools
|
||||
from app.agents.tools import schedule as schedule_tools
|
||||
from app.agents.tools import task as task_tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("module", "label"),
|
||||
[
|
||||
(task_tools, "task"),
|
||||
(schedule_tools, "schedule"),
|
||||
(forum_tools, "forum"),
|
||||
],
|
||||
)
|
||||
async def test_run_async_bridge_works_inside_running_event_loop(module, label):
|
||||
async def sample():
|
||||
return f"ok:{label}"
|
||||
|
||||
result = module._run_async(sample())
|
||||
|
||||
assert result == f"ok:{label}"
|
||||
@@ -9,7 +9,7 @@ from starlette.datastructures import UploadFile
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.brain import BrainEvent, BrainMemory
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.conversation import Conversation, Message
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.user import User
|
||||
from app.services import agent_service, memory_service
|
||||
@@ -32,6 +32,110 @@ class FakeStreamingGraph:
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingFinalResponseGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chain_end",
|
||||
"name": "master",
|
||||
"data": {"output": {"final_response": "这是最终回答。"}},
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingBadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStreamingBadRequestError2(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeOpenAIBadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStreamingOpenAIBadRequestGraph:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeOpenAIBadRequestError('invalid_request_error: tool arguments failed validation')
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "不应触发同步回退。"}
|
||||
|
||||
|
||||
class FakeStreamingFallbackGraph:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)')
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "这是回退后的同步回答。"}
|
||||
|
||||
|
||||
class FakeStreamingFallbackGraphGenericError:
|
||||
def __init__(self):
|
||||
self.astream_calls = 0
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version="v2"):
|
||||
self.astream_calls += 1
|
||||
raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}")
|
||||
yield
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {"final_response": "这是通用异常回退后的同步回答。"}
|
||||
|
||||
|
||||
class FakeStreamingDelegationThenFinalResponseGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="现在显示收到,3月28日的任务是完成对话系统。\n\n我将这部分转给schedule_planner,他会根据这个目标结合你当前的进度和资源,给出具体的安排建议。")},
|
||||
}
|
||||
yield {
|
||||
"event": "on_chain_end",
|
||||
"name": "schedule_planner",
|
||||
"data": {"output": {"final_response": "今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。"}},
|
||||
}
|
||||
|
||||
|
||||
class FakeStreamingDelegationThenModelEndGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="我将这部分转给schedule_planner。")},
|
||||
}
|
||||
yield {
|
||||
"event": "on_chat_model_end",
|
||||
"name": "schedule_planner",
|
||||
"data": {"output": SimpleNamespace(content="最终建议:先完成对话系统,再回归验证。")},
|
||||
}
|
||||
|
||||
|
||||
class CapturingStateGraph:
|
||||
def __init__(self, final_response: str = '已记录你的请求。'):
|
||||
self.final_response = final_response
|
||||
self.captured_state = None
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.captured_state = state
|
||||
return {"final_response": self.final_response}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_brain_ingestion.db'
|
||||
@@ -43,6 +147,7 @@ async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='brain-ingestion-tester',
|
||||
email='brain-ingestion@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Ingestion Tester',
|
||||
@@ -178,6 +283,360 @@ async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_in
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingFinalResponseGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'直接给我最终回答。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert ''.join(chunks) == '这是最终回答。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == '这是最终回答。'
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_prefers_chain_end_final_response_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenFinalResponseGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我安排今天先做什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。' in chunks
|
||||
assert chunks[-1] == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_prefers_model_end_final_content_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenModelEndGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我安排今天先做什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert '最终建议:先完成对话系统,再回归验证。' in chunks
|
||||
assert chunks[-1] == '最终建议:先完成对话系统,再回归验证。'
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
graph = FakeStreamingOpenAIBadRequestGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeOpenAIBadRequestError)
|
||||
|
||||
original_get_user_llm_config = AgentService._get_user_llm_config
|
||||
|
||||
async def fake_get_user_llm_config(self, user_id, model_name=None):
|
||||
return {
|
||||
'name': 'Official OpenAI',
|
||||
'provider': 'openai',
|
||||
'model': 'gpt-4o',
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'enabled': True,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(AgentService, '_get_user_llm_config', fake_get_user_llm_config)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'测试官方 OpenAI bad request 不应回退。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
errors = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
if event.get('type') == 'error':
|
||||
errors.append(event['error'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
assert graph.astream_calls == 1
|
||||
assert graph.ainvoke_calls == 0
|
||||
assert errors == ['模型服务暂不可用,请稍后再试。']
|
||||
assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。']
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
fallback_graph = FakeStreamingFallbackGraphGenericError()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: fallback_graph)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'帮我制定一下明天的计划。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert fallback_graph.astream_calls == 1
|
||||
assert fallback_graph.ainvoke_calls == 1
|
||||
assert ''.join(chunks) == '这是通用异常回退后的同步回答。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].content_summary == '这是通用异常回退后的同步回答。'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
|
||||
class PartialThenFailGraph:
|
||||
def __init__(self):
|
||||
self.ainvoke_calls = 0
|
||||
|
||||
async def astream_events(self, state, version='v2'):
|
||||
yield {
|
||||
'event': 'on_chat_model_stream',
|
||||
'name': 'master',
|
||||
'data': {'chunk': SimpleNamespace(content='前半段')},
|
||||
}
|
||||
raise FakeStreamingBadRequestError('stream interrupted')
|
||||
|
||||
async def ainvoke(self, state):
|
||||
self.ainvoke_calls += 1
|
||||
return {'final_response': '不应触发'}
|
||||
|
||||
graph = PartialThenFailGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeStreamingBadRequestError)
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'测试部分流式输出失败。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
errors = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
if event.get('type') == 'error':
|
||||
errors.append(event['error'])
|
||||
|
||||
message_result = await session.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
assistant_message = message_result.scalars().first()
|
||||
|
||||
brain_event_result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(brain_event_result.scalars().all())
|
||||
|
||||
assert chunks == ['前半段']
|
||||
assert graph.ainvoke_calls == 0
|
||||
assert errors == ['stream interrupted']
|
||||
assert assistant_message is not None
|
||||
assert assistant_message.content == '前半段'
|
||||
assert events[1].content_summary == '前半段'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
graph = CapturingStateGraph()
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||||
service = AgentService(session)
|
||||
|
||||
await service.chat_simple(
|
||||
user.id,
|
||||
'3月29日,对话系统交付节点',
|
||||
)
|
||||
|
||||
assert graph.captured_state is not None
|
||||
current_context = graph.captured_state.get('current_datetime_context')
|
||||
assert isinstance(current_context, str)
|
||||
assert current_context
|
||||
assert '当前时间' in current_context
|
||||
assert '2026' in current_context
|
||||
|
||||
current_reference = graph.captured_state.get('current_datetime_reference')
|
||||
assert isinstance(current_reference, dict)
|
||||
assert 'current_time_iso' in current_reference
|
||||
assert 'current_date_iso' in current_reference
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_llm_config_defaults_to_enabled_chat_model_not_vlm(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
config = await service._get_user_llm_config(user.id)
|
||||
|
||||
assert config is not None
|
||||
assert config['name'] == 'Enabled Chat'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_llm_config_returns_none_when_only_vlm_is_enabled(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
config = await service._get_user_llm_config(user.id)
|
||||
|
||||
assert config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||||
await service.chat_simple(user.id, '测试聊天模型选择', model_name='Vision Only')
|
||||
|
||||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||||
message_result = await session.execute(select(Message))
|
||||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||||
|
||||
assert conversation_result.scalars().all() == []
|
||||
assert message_result.scalars().all() == []
|
||||
assert brain_event_result.scalars().all() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
user.llm_config = {
|
||||
'chat': [
|
||||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||||
],
|
||||
'vlm': [
|
||||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||||
],
|
||||
}
|
||||
await session.commit()
|
||||
|
||||
service = AgentService(session)
|
||||
|
||||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||||
await service.chat(user.id, '测试流式聊天模型选择', model_name='Vision Only')
|
||||
|
||||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||||
message_result = await session.execute(select(Message))
|
||||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||||
|
||||
assert conversation_result.scalars().all() == []
|
||||
assert message_result.scalars().all() == []
|
||||
assert brain_event_result.scalars().all() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.services.admin_bootstrap_service import ensure_builtin_skills
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_builtin_skills_creates_default_ability_skills(tmp_path):
|
||||
db_path = tmp_path / 'test_builtin_skills.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='bootstrap_user',
|
||||
email='bootstrap@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Bootstrap User',
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
await ensure_builtin_skills(session)
|
||||
await ensure_builtin_skills(session)
|
||||
result = await session.execute(select(Skill).order_by(Skill.agent_type, Skill.name))
|
||||
skills = result.scalars().all()
|
||||
|
||||
assert len(skills) >= 9
|
||||
assert any(skill.agent_type == 'schedule_planner' for skill in skills)
|
||||
assert any(skill.agent_type == 'executor' for skill in skills)
|
||||
assert any(skill.agent_type == 'librarian' for skill in skills)
|
||||
librarian_skill = next(skill for skill in skills if skill.name == '知识检索摘要')
|
||||
assert 'web_search' in (librarian_skill.tools or [])
|
||||
assert any(skill.agent_type == 'analyst' for skill in skills)
|
||||
assert len({skill.name for skill in skills}) == len(skills)
|
||||
|
||||
await engine.dispose()
|
||||
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
144
backend/tests/backend/app/services/test_web_search_service.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.web_search_service import (
|
||||
WebSearchConfigurationError,
|
||||
WebSearchRequestError,
|
||||
WebSearchResult,
|
||||
WebSearchService,
|
||||
)
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
raise httpx.HTTPStatusError(
|
||||
'request failed',
|
||||
request=httpx.Request('GET', 'http://searx.example/search'),
|
||||
response=httpx.Response(self.status_code, request=httpx.Request('GET', 'http://searx.example/search')),
|
||||
)
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *, response=None, error=None, recorder=None, **kwargs):
|
||||
self._response = response
|
||||
self._error = error
|
||||
self._recorder = recorder if recorder is not None else []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, *, params=None, headers=None):
|
||||
self._recorder.append({'url': url, 'params': params, 'headers': headers})
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
return self._response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_normalized_results_from_searxng(monkeypatch):
|
||||
requests = []
|
||||
payload = {
|
||||
'results': [
|
||||
{
|
||||
'title': 'Jarvis release notes',
|
||||
'url': 'https://example.com/jarvis-release',
|
||||
'content': 'Latest Jarvis changes and release notes.',
|
||||
'engine': 'duckduckgo',
|
||||
'publishedDate': '2026-03-29',
|
||||
}
|
||||
]
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse(payload), recorder=requests, **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
default_limit=5,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
results = await service.search('Jarvis 最新版本', limit=3)
|
||||
|
||||
assert results == [
|
||||
WebSearchResult(
|
||||
title='Jarvis release notes',
|
||||
url='https://example.com/jarvis-release',
|
||||
snippet='Latest Jarvis changes and release notes.',
|
||||
source='duckduckgo',
|
||||
published_at='2026-03-29',
|
||||
)
|
||||
]
|
||||
assert requests == [
|
||||
{
|
||||
'url': 'http://searx.example/search',
|
||||
'params': {
|
||||
'q': 'Jarvis 最新版本',
|
||||
'format': 'json',
|
||||
'language': 'zh-CN',
|
||||
'safesearch': 1,
|
||||
},
|
||||
'headers': {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_returns_empty_list_when_searxng_has_no_results(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(response=FakeResponse({'results': []}), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
results = await service.search('不存在的话题')
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_on_searxng_http_failure(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
'app.services.web_search_service.httpx.AsyncClient',
|
||||
lambda **kwargs: FakeAsyncClient(error=httpx.TimeoutException('timed out'), **kwargs),
|
||||
)
|
||||
|
||||
service = WebSearchService(
|
||||
enabled=True,
|
||||
provider='searxng',
|
||||
base_url='http://searx.example',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchRequestError, match='SearxNG 请求失败'):
|
||||
await service.search('Jarvis')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_service_raises_clear_error_when_not_configured():
|
||||
service = WebSearchService(
|
||||
enabled=False,
|
||||
provider='searxng',
|
||||
base_url='',
|
||||
)
|
||||
|
||||
with pytest.raises(WebSearchConfigurationError, match='网页搜索未启用或未配置'):
|
||||
await service.search('Jarvis')
|
||||
150
backend/tests/backend/app/test_agent_router.py
Normal file
150
backend/tests/backend/app/test_agent_router.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.agent import router as agent_router
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_env(tmp_path):
|
||||
db_path = tmp_path / 'test_agent_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='agent_user',
|
||||
email='agent@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Agent Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
skill_a = Skill(
|
||||
name='Planner skill A',
|
||||
description='planner',
|
||||
instructions='plan a',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
skill_b = Skill(
|
||||
name='Planner skill B',
|
||||
description='planner',
|
||||
instructions='plan b',
|
||||
agent_type='schedule_planner',
|
||||
tools=['tasks'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
)
|
||||
session.add_all([
|
||||
Agent(
|
||||
name='SCHEDULE PLANNER',
|
||||
role='schedule_planner',
|
||||
description='日程规划师',
|
||||
system_prompt='prompt',
|
||||
is_active=True,
|
||||
),
|
||||
skill_a,
|
||||
skill_b,
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(skill_a)
|
||||
await session.refresh(skill_b)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(agent_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, {'skill_a_id': skill_a.id, 'skill_b_id': skill_b.id}
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_config_returns_default_empty_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['selected_skill_ids'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_persists_selected_skill_ids(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id'], ids['skill_b_id']]},
|
||||
)
|
||||
get_response = await client.get('/api/agents/config/schedule_planner')
|
||||
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()['selected_skill_ids'] == [ids['skill_a_id'], ids['skill_b_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_preserves_selected_skill_ids_when_omitted(agent_env):
|
||||
app, ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
first_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': [ids['skill_a_id']]},
|
||||
)
|
||||
update_response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'description': 'updated description'},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()['description'] == 'updated description'
|
||||
assert update_response.json()['selected_skill_ids'] == [ids['skill_a_id']]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_config_rejects_invalid_selected_skill_ids(agent_env):
|
||||
app, _ids = agent_env
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.put(
|
||||
'/api/agents/config/schedule_planner',
|
||||
json={'selected_skill_ids': ['missing-skill']},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()['detail'] == '存在无效的技能绑定'
|
||||
67
backend/tests/backend/app/test_config.py
Normal file
67
backend/tests/backend/app/test_config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
from app import config as config_module
|
||||
from app.services.llm_service import default_provider_capabilities, normalize_provider_name, resolve_provider_capabilities
|
||||
|
||||
|
||||
def test_env_file_points_to_repo_root_env_file():
|
||||
assert config_module.ENV_FILE == Path(__file__).resolve().parents[4] / '.env'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_prefers_native_for_openai():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'gpt-4o', 'base_url': 'https://api.openai.com/v1'})
|
||||
|
||||
assert capabilities.provider == 'openai'
|
||||
assert capabilities.supports_native_tools is True
|
||||
assert capabilities.preferred_tool_strategy == 'native'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_openai_compatible_non_official_endpoint():
|
||||
capabilities = resolve_provider_capabilities(
|
||||
{'provider': 'openai', 'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}
|
||||
)
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_uses_global_openai_base_url_when_user_config_omits_it(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'OPENAI_BASE_URL', 'https://api.minimax.chat/v1')
|
||||
|
||||
capabilities = resolve_provider_capabilities({'provider': 'openai', 'model': 'abab7.5-chat-preview'})
|
||||
|
||||
assert capabilities.provider == 'minimax'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_from_custom_config():
|
||||
assert normalize_provider_name({'provider': 'custom', 'model': 'MiniMax-M2.7-highspeed'}) == 'minimax'
|
||||
|
||||
|
||||
def test_normalize_provider_name_recognizes_minimax_without_provider_when_base_url_matches():
|
||||
assert normalize_provider_name({'model': 'abab7.5-chat-preview', 'base_url': 'https://api.minimax.chat/v1'}) == 'minimax'
|
||||
|
||||
|
||||
def test_resolve_provider_capabilities_falls_back_for_ollama():
|
||||
capabilities = resolve_provider_capabilities({'provider': 'ollama', 'model': 'qwen2.5'})
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.supports_native_tools is False
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_default_provider_capabilities_follows_global_settings(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
capabilities = default_provider_capabilities()
|
||||
|
||||
assert capabilities.provider == 'ollama'
|
||||
assert capabilities.preferred_tool_strategy == 'json_fallback'
|
||||
|
||||
|
||||
def test_normalize_provider_name_without_provider_uses_global_default(monkeypatch):
|
||||
monkeypatch.setattr(config_module.settings, 'LLM_PROVIDER', 'ollama')
|
||||
|
||||
assert normalize_provider_name({'model': 'qwen2.5'}) == 'ollama'
|
||||
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
281
backend/tests/backend/app/test_schedule_center_router.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import sys
|
||||
from datetime import UTC, date, datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.goal import Goal
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.task import Task, TaskPriority, TaskStatus
|
||||
from app.models.todo import DailyTodo, TodoSource
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.goal import router as goal_router
|
||||
from app.routers.reminder import router as reminder_router
|
||||
from app.routers.schedule_center import router as schedule_center_router
|
||||
from app.routers.task import router as task_router
|
||||
from app.routers.todo import router as todo_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def schedule_env(tmp_path):
|
||||
db_path = tmp_path / 'test_schedule_center.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='schedule_user',
|
||||
email='schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Schedule Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_schedule_user',
|
||||
email='other-schedule@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Schedule Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Legacy todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=user.id,
|
||||
title='Done todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=True,
|
||||
completed_at=datetime(2026, 4, 10, 9, 30, tzinfo=UTC),
|
||||
),
|
||||
DailyTodo(
|
||||
user_id=other_user.id,
|
||||
title='Other user todo',
|
||||
source=TodoSource.MANUAL,
|
||||
todo_date='2026-04-10',
|
||||
is_completed=False,
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='High priority task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 14, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=user.id,
|
||||
title='Urgent task next day',
|
||||
priority=TaskPriority.URGENT,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
due_date=datetime(2026, 4, 11, 10, 0, tzinfo=UTC),
|
||||
),
|
||||
Task(
|
||||
user_id=other_user.id,
|
||||
title='Other user task',
|
||||
priority=TaskPriority.HIGH,
|
||||
status=TaskStatus.TODO,
|
||||
due_date=datetime(2026, 4, 10, 15, 0, tzinfo=UTC),
|
||||
),
|
||||
Reminder(
|
||||
user_id=user.id,
|
||||
title='Doctor reminder',
|
||||
note='Bring reports',
|
||||
reminder_at=datetime(2026, 4, 10, 8, 0, tzinfo=UTC),
|
||||
),
|
||||
Goal(
|
||||
user_id=user.id,
|
||||
title='Launch calendar beta',
|
||||
note='Ship MVP',
|
||||
goal_date='2026-04-10',
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(todo_router)
|
||||
test_app.include_router(task_router)
|
||||
test_app.include_router(reminder_router)
|
||||
test_app.include_router(goal_router)
|
||||
test_app.include_router(schedule_center_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_todo_persists_explicit_todo_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post('/api/todos', json={'title': 'Plan sprint', 'todo_date': '2026-04-12'})
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Plan sprint'
|
||||
assert payload['todo_date'] == '2026-04-12'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_todo_allows_editing_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.patch(f'/api/todos/{todo_id}', json={'title': 'Updated title', 'todo_date': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['title'] == 'Updated title'
|
||||
assert payload['todo_date'] == '2026-04-11'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_todo_allows_deleting_non_today_todo(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
todos_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
todo_id = todos_response.json()['items'][0]['id']
|
||||
response = await client.delete(f'/api/todos/{todo_id}')
|
||||
after_response = await client.get('/api/todos', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 204
|
||||
assert all(item['id'] != todo_id for item in after_response.json()['items'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'due_date': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['title'] for item in payload] == ['High priority task']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filters_by_due_date_range(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/tasks', params={'date_from': '2026-04-10', 'date_to': '2026-04-11'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert {item['title'] for item in payload} == {'High priority task', 'Urgent task next day'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_date_returns_aggregated_resources(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/date', params={'date_str': '2026-04-10'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['date'] == '2026-04-10'
|
||||
assert payload['summary'] == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert [item['title'] for item in payload['reminders']] == ['Doctor reminder']
|
||||
assert [item['title'] for item in payload['goals']] == ['Launch calendar beta']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schedule_center_month_returns_day_summaries(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/schedule-center/month', params={'year': 2026, 'month': 4})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['month'] == '2026-04'
|
||||
day_10 = next(item for item in payload['days'] if item['date'] == '2026-04-10')
|
||||
day_11 = next(item for item in payload['days'] if item['date'] == '2026-04-11')
|
||||
assert day_10 == {
|
||||
'date': '2026-04-10',
|
||||
'todo_total': 2,
|
||||
'todo_completed': 1,
|
||||
'task_due_total': 1,
|
||||
'high_priority_total': 1,
|
||||
'reminder_total': 1,
|
||||
'goal_total': 1,
|
||||
}
|
||||
assert day_11['task_due_total'] == 1
|
||||
assert day_11['high_priority_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_reminder_with_naive_datetime_and_time_zone_appears_in_schedule_center(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
create_response = await client.post(
|
||||
'/api/reminders',
|
||||
json={'title': '收被子', 'note': '提醒收被子', 'reminder_at': '2026-03-29T09:00:00'},
|
||||
)
|
||||
detail_response = await client.get('/api/schedule-center/date', params={'date_str': '2026-03-29'})
|
||||
|
||||
assert create_response.status_code == 201
|
||||
assert detail_response.status_code == 200
|
||||
payload = detail_response.json()
|
||||
assert [item['title'] for item in payload['reminders']] == ['收被子']
|
||||
assert payload['summary']['reminder_total'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reminder_and_goal_crud(schedule_env):
|
||||
transport = ASGITransport(app=schedule_env)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
reminder_response = await client.post('/api/reminders', json={'title': 'Standup', 'note': 'Daily sync', 'reminder_at': '2026-04-12T09:00:00Z'})
|
||||
goal_response = await client.post('/api/goals', json={'title': 'Finish polish', 'note': 'UI cleanup', 'goal_date': '2026-04-12', 'status': 'active'})
|
||||
reminder_id = reminder_response.json()['id']
|
||||
goal_id = goal_response.json()['id']
|
||||
patch_reminder = await client.patch(f'/api/reminders/{reminder_id}', json={'status': 'done'})
|
||||
patch_goal = await client.patch(f'/api/goals/{goal_id}', json={'status': 'done'})
|
||||
reminders_list = await client.get('/api/reminders', params={'date_str': '2026-04-12'})
|
||||
goals_list = await client.get('/api/goals', params={'date_str': '2026-04-12'})
|
||||
delete_reminder = await client.delete(f'/api/reminders/{reminder_id}')
|
||||
delete_goal = await client.delete(f'/api/goals/{goal_id}')
|
||||
|
||||
assert reminder_response.status_code == 201
|
||||
assert goal_response.status_code == 201
|
||||
assert patch_reminder.json()['status'] == 'done'
|
||||
assert patch_goal.json()['status'] == 'done'
|
||||
assert [item['title'] for item in reminders_list.json()['items']] == ['Standup']
|
||||
assert [item['title'] for item in goals_list.json()['items']] == ['Finish polish']
|
||||
assert delete_reminder.status_code == 204
|
||||
assert delete_goal.status_code == 204
|
||||
190
backend/tests/backend/app/test_skill_router.py
Normal file
190
backend/tests/backend/app/test_skill_router.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.skill import Skill
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.skill import router as skill_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def skill_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_skill_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
# Ensure app.database.init_db() runs against the test database.
|
||||
import app.database as database_module
|
||||
|
||||
monkeypatch.setattr(database_module, "engine", engine)
|
||||
monkeypatch.setattr(database_module, "async_session", session_factory)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
username='skill_user',
|
||||
email='skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Skill Tester',
|
||||
)
|
||||
other_user = User(
|
||||
username='other_skill_user',
|
||||
email='other-skill@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Other Skill Tester',
|
||||
)
|
||||
session.add_all([user, other_user])
|
||||
await session.flush()
|
||||
session.add_all([
|
||||
Skill(
|
||||
name='Planner skill',
|
||||
description='planner',
|
||||
instructions='plan',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Executor skill',
|
||||
description='executor',
|
||||
instructions='execute',
|
||||
agent_type='executor',
|
||||
tools=['shell'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=user.id,
|
||||
),
|
||||
Skill(
|
||||
name='Other user planner skill',
|
||||
description='other',
|
||||
instructions='other',
|
||||
agent_type='schedule_planner',
|
||||
tools=['calendar'],
|
||||
required_context=[],
|
||||
visibility='private',
|
||||
is_active=True,
|
||||
owner_id=other_user.id,
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(auth_router)
|
||||
test_app.include_router(skill_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app, session_factory
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_filters_by_agent_type(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'agent_type': 'schedule_planner'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_migrates_planner_skills_to_schedule_planner(skill_env):
|
||||
app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("UPDATE skills SET agent_type = 'planner' WHERE name = 'Planner skill'"))
|
||||
await session.commit()
|
||||
|
||||
from app.database import init_db
|
||||
|
||||
await init_db()
|
||||
|
||||
async with session_factory() as session:
|
||||
migrated_response = await session.execute(text("SELECT agent_type FROM skills WHERE name = 'Planner skill'"))
|
||||
assert migrated_response.scalar_one() == 'schedule_planner'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_visibility_filter_still_respects_access_scope(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills', params={'visibility': 'private'})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_bootstraps_builtin_market_skills_for_current_user(skill_env):
|
||||
test_app, session_factory = skill_env
|
||||
async with session_factory() as session:
|
||||
await session.execute(text("DELETE FROM skills"))
|
||||
await session.commit()
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
login_response = await client.post(
|
||||
'/api/auth/login',
|
||||
data={'username': 'skill_user', 'password': 'secret123'},
|
||||
headers={'content-type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
|
||||
assert login_response.status_code == 200
|
||||
|
||||
async with session_factory() as session:
|
||||
result = await session.execute(select(Skill.name, Skill.is_builtin).order_by(Skill.name))
|
||||
skills = result.all()
|
||||
|
||||
names = {name for name, _is_builtin in skills}
|
||||
assert '今日重点拆解' in names
|
||||
assert '任务执行 SOP' in names
|
||||
assert any(is_builtin is True for _name, is_builtin in skills)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_skills_without_agent_type_returns_current_user_skills(skill_env):
|
||||
test_app, _session_factory = skill_env
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/skills')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
names = {item['name'] for item in payload}
|
||||
assert names == {'Planner skill', 'Executor skill'}
|
||||
assert 'Other user planner skill' not in names
|
||||
assert all(isinstance(item['created_at'], str) for item in payload)
|
||||
assert all(isinstance(item['updated_at'], str) for item in payload)
|
||||
assert all('is_builtin' in item for item in payload)
|
||||
assert all(item['is_builtin'] is False for item in payload)
|
||||
Reference in New Issue
Block a user