697 lines
25 KiB
Python
697 lines
25 KiB
Python
from io import BytesIO
|
||
from types import SimpleNamespace
|
||
|
||
import pytest
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||
from starlette.datastructures import UploadFile
|
||
|
||
import app.models # noqa: F401
|
||
from app.database import Base
|
||
from app.models.brain import BrainEvent, BrainMemory
|
||
from app.models.conversation import Conversation, Message
|
||
from app.models.memory import MemorySummary, UserMemory
|
||
from app.models.user import User
|
||
from app.services import agent_service, memory_service
|
||
from app.services.agent_service import AgentService
|
||
from app.services.auth_service import get_password_hash
|
||
from app.services.document_service import DocumentService
|
||
|
||
|
||
class FakeGraph:
|
||
async def ainvoke(self, state):
|
||
return {"final_response": "已记录你的请求。"}
|
||
|
||
|
||
class FakeStreamingGraph:
|
||
async def astream_events(self, state, version="v2"):
|
||
yield {
|
||
"event": "on_chat_model_stream",
|
||
"name": "master",
|
||
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
|
||
}
|
||
|
||
|
||
class FakeStreamingFinalResponseGraph:
|
||
async def astream_events(self, state, version="v2"):
|
||
yield {
|
||
"event": "on_chain_end",
|
||
"name": "master",
|
||
"data": {"output": {"final_response": "这是最终回答。"}},
|
||
}
|
||
|
||
|
||
class FakeStreamingBadRequestError(Exception):
|
||
pass
|
||
|
||
|
||
class FakeStreamingBadRequestError2(Exception):
|
||
pass
|
||
|
||
|
||
class FakeOpenAIBadRequestError(Exception):
|
||
pass
|
||
|
||
|
||
class FakeStreamingOpenAIBadRequestGraph:
|
||
def __init__(self):
|
||
self.astream_calls = 0
|
||
self.ainvoke_calls = 0
|
||
|
||
async def astream_events(self, state, version="v2"):
|
||
self.astream_calls += 1
|
||
raise FakeOpenAIBadRequestError('invalid_request_error: tool arguments failed validation')
|
||
yield
|
||
|
||
async def ainvoke(self, state):
|
||
self.ainvoke_calls += 1
|
||
return {"final_response": "不应触发同步回退。"}
|
||
|
||
|
||
class FakeStreamingFallbackGraph:
|
||
def __init__(self):
|
||
self.astream_calls = 0
|
||
self.ainvoke_calls = 0
|
||
|
||
async def astream_events(self, state, version="v2"):
|
||
self.astream_calls += 1
|
||
raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)')
|
||
yield
|
||
|
||
async def ainvoke(self, state):
|
||
self.ainvoke_calls += 1
|
||
return {"final_response": "这是回退后的同步回答。"}
|
||
|
||
|
||
class FakeStreamingFallbackGraphGenericError:
|
||
def __init__(self):
|
||
self.astream_calls = 0
|
||
self.ainvoke_calls = 0
|
||
|
||
async def astream_events(self, state, version="v2"):
|
||
self.astream_calls += 1
|
||
raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}")
|
||
yield
|
||
|
||
async def ainvoke(self, state):
|
||
self.ainvoke_calls += 1
|
||
return {"final_response": "这是通用异常回退后的同步回答。"}
|
||
|
||
|
||
class FakeStreamingDelegationThenFinalResponseGraph:
|
||
async def astream_events(self, state, version="v2"):
|
||
yield {
|
||
"event": "on_chat_model_stream",
|
||
"name": "master",
|
||
"data": {"chunk": SimpleNamespace(content="现在显示收到,3月28日的任务是完成对话系统。\n\n我将这部分转给schedule_planner,他会根据这个目标结合你当前的进度和资源,给出具体的安排建议。")},
|
||
}
|
||
yield {
|
||
"event": "on_chain_end",
|
||
"name": "schedule_planner",
|
||
"data": {"output": {"final_response": "今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。"}},
|
||
}
|
||
|
||
|
||
class FakeStreamingDelegationThenModelEndGraph:
|
||
async def astream_events(self, state, version="v2"):
|
||
yield {
|
||
"event": "on_chat_model_stream",
|
||
"name": "master",
|
||
"data": {"chunk": SimpleNamespace(content="我将这部分转给schedule_planner。")},
|
||
}
|
||
yield {
|
||
"event": "on_chat_model_end",
|
||
"name": "schedule_planner",
|
||
"data": {"output": SimpleNamespace(content="最终建议:先完成对话系统,再回归验证。")},
|
||
}
|
||
|
||
|
||
class CapturingStateGraph:
|
||
def __init__(self, final_response: str = '已记录你的请求。'):
|
||
self.final_response = final_response
|
||
self.captured_state = None
|
||
|
||
async def ainvoke(self, state):
|
||
self.captured_state = state
|
||
return {"final_response": self.final_response}
|
||
|
||
|
||
@pytest.fixture
|
||
async def brain_ingestion_env(tmp_path, monkeypatch):
|
||
db_path = tmp_path / 'test_brain_ingestion.db'
|
||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with session_factory() as session:
|
||
user = User(
|
||
username='brain-ingestion-tester',
|
||
email='brain-ingestion@example.com',
|
||
hashed_password=get_password_hash('secret123'),
|
||
full_name='Brain Ingestion Tester',
|
||
)
|
||
session.add(user)
|
||
await session.commit()
|
||
await session.refresh(user)
|
||
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
|
||
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
|
||
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
|
||
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
|
||
|
||
async with session_factory() as session:
|
||
yield session, user
|
||
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
|
||
user.id,
|
||
'请记住我这周要完成知识大脑第一阶段。',
|
||
)
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(result.scalars().all())
|
||
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
|
||
|
||
assert len(user_events) == 1
|
||
assert user_events[0].source_id == conversation_id
|
||
assert user_events[0].event_type == 'message_created'
|
||
assert user_events[0].title == 'User message'
|
||
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
|
||
assert user_events[0].status == 'pending'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
service = DocumentService(session)
|
||
upload = UploadFile(
|
||
filename='brain-notes.md',
|
||
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
|
||
)
|
||
|
||
document = await service.upload_document(user.id, upload)
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(
|
||
BrainEvent.user_id == user.id,
|
||
BrainEvent.source_type == 'document',
|
||
BrainEvent.source_id == document.id,
|
||
)
|
||
)
|
||
event = result.scalar_one_or_none()
|
||
|
||
assert event is not None
|
||
assert event.event_type == 'document_uploaded'
|
||
assert event.title == 'brain-notes.md'
|
||
assert 'Capture important product knowledge.' in (event.content_summary or '')
|
||
assert event.metadata_ == {
|
||
'document_id': document.id,
|
||
'file_type': 'md',
|
||
'ingestion_status': 'uploaded',
|
||
}
|
||
assert event.status == 'pending'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, response, _model_name = await service.chat_simple(
|
||
user.id,
|
||
'帮我总结今天知识大脑的进展。',
|
||
)
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(result.scalars().all())
|
||
|
||
assert len(events) == 2
|
||
assert events[1].source_id == conversation_id
|
||
assert events[1].event_type == 'message_created'
|
||
assert events[1].title == 'Assistant message'
|
||
assert events[1].content_summary == response
|
||
assert events[1].metadata_ == {'role': 'assistant'}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'用流式回复告诉我今天知识大脑学到了什么。',
|
||
)
|
||
|
||
chunks = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(result.scalars().all())
|
||
|
||
assert ''.join(chunks) == '这是流式回复。'
|
||
assert len(events) == 2
|
||
assert events[1].source_id == conversation_id
|
||
assert events[1].event_type == 'message_created'
|
||
assert events[1].title == 'Assistant message'
|
||
assert events[1].content_summary == '这是流式回复。'
|
||
assert events[1].metadata_ == {'role': 'assistant'}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingFinalResponseGraph())
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'直接给我最终回答。',
|
||
)
|
||
|
||
chunks = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(result.scalars().all())
|
||
|
||
assert ''.join(chunks) == '这是最终回答。'
|
||
assert len(events) == 2
|
||
assert events[1].source_id == conversation_id
|
||
assert events[1].event_type == 'message_created'
|
||
assert events[1].title == 'Assistant message'
|
||
assert events[1].content_summary == '这是最终回答。'
|
||
assert events[1].metadata_ == {'role': 'assistant'}
|
||
|
||
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_prefers_chain_end_final_response_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenFinalResponseGraph())
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'帮我安排今天先做什么。',
|
||
)
|
||
|
||
chunks = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
|
||
message_result = await session.execute(
|
||
select(Message)
|
||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||
.order_by(Message.created_at.desc())
|
||
)
|
||
assistant_message = message_result.scalars().first()
|
||
|
||
assert '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。' in chunks
|
||
assert chunks[-1] == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||
assert assistant_message is not None
|
||
assert assistant_message.content == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_prefers_model_end_final_content_over_delegation_chunk(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenModelEndGraph())
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'帮我安排今天先做什么。',
|
||
)
|
||
|
||
chunks = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
|
||
message_result = await session.execute(
|
||
select(Message)
|
||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||
.order_by(Message.created_at.desc())
|
||
)
|
||
assistant_message = message_result.scalars().first()
|
||
|
||
assert '最终建议:先完成对话系统,再回归验证。' in chunks
|
||
assert chunks[-1] == '最终建议:先完成对话系统,再回归验证。'
|
||
assert assistant_message is not None
|
||
assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
graph = FakeStreamingOpenAIBadRequestGraph()
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeOpenAIBadRequestError)
|
||
|
||
original_get_user_llm_config = AgentService._get_user_llm_config
|
||
|
||
async def fake_get_user_llm_config(self, user_id, model_name=None):
|
||
return {
|
||
'name': 'Official OpenAI',
|
||
'provider': 'openai',
|
||
'model': 'gpt-4o',
|
||
'base_url': 'https://api.openai.com/v1',
|
||
'enabled': True,
|
||
}
|
||
|
||
monkeypatch.setattr(AgentService, '_get_user_llm_config', fake_get_user_llm_config)
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'测试官方 OpenAI bad request 不应回退。',
|
||
)
|
||
|
||
chunks = []
|
||
errors = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
if event.get('type') == 'error':
|
||
errors.append(event['error'])
|
||
|
||
message_result = await session.execute(
|
||
select(Message)
|
||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||
.order_by(Message.created_at.desc())
|
||
)
|
||
assistant_message = message_result.scalars().first()
|
||
|
||
assert graph.astream_calls == 1
|
||
assert graph.ainvoke_calls == 0
|
||
assert errors == ['模型服务暂不可用,请稍后再试。']
|
||
assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。']
|
||
assert assistant_message is not None
|
||
assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
fallback_graph = FakeStreamingFallbackGraphGenericError()
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: fallback_graph)
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'帮我制定一下明天的计划。',
|
||
)
|
||
|
||
chunks = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
|
||
result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(result.scalars().all())
|
||
|
||
assert fallback_graph.astream_calls == 1
|
||
assert fallback_graph.ainvoke_calls == 1
|
||
assert ''.join(chunks) == '这是通用异常回退后的同步回答。'
|
||
assert len(events) == 2
|
||
assert events[1].source_id == conversation_id
|
||
assert events[1].content_summary == '这是通用异常回退后的同步回答。'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
|
||
class PartialThenFailGraph:
|
||
def __init__(self):
|
||
self.ainvoke_calls = 0
|
||
|
||
async def astream_events(self, state, version='v2'):
|
||
yield {
|
||
'event': 'on_chat_model_stream',
|
||
'name': 'master',
|
||
'data': {'chunk': SimpleNamespace(content='前半段')},
|
||
}
|
||
raise FakeStreamingBadRequestError('stream interrupted')
|
||
|
||
async def ainvoke(self, state):
|
||
self.ainvoke_calls += 1
|
||
return {'final_response': '不应触发'}
|
||
|
||
graph = PartialThenFailGraph()
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||
monkeypatch.setattr(agent_service, 'BadRequestError', FakeStreamingBadRequestError)
|
||
service = AgentService(session)
|
||
|
||
conversation_id, _message_id, stream = await service.chat(
|
||
user.id,
|
||
'测试部分流式输出失败。',
|
||
)
|
||
|
||
chunks = []
|
||
errors = []
|
||
async for event in stream:
|
||
if event.get('type') == 'chunk':
|
||
chunks.append(event['content'])
|
||
if event.get('type') == 'error':
|
||
errors.append(event['error'])
|
||
|
||
message_result = await session.execute(
|
||
select(Message)
|
||
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
|
||
.order_by(Message.created_at.desc())
|
||
)
|
||
assistant_message = message_result.scalars().first()
|
||
|
||
brain_event_result = await session.execute(
|
||
select(BrainEvent)
|
||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||
.order_by(BrainEvent.created_at.asc())
|
||
)
|
||
events = list(brain_event_result.scalars().all())
|
||
|
||
assert chunks == ['前半段']
|
||
assert graph.ainvoke_calls == 0
|
||
assert errors == ['stream interrupted']
|
||
assert assistant_message is not None
|
||
assert assistant_message.content == '前半段'
|
||
assert events[1].content_summary == '前半段'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch):
|
||
session, user = brain_ingestion_env
|
||
graph = CapturingStateGraph()
|
||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
|
||
service = AgentService(session)
|
||
|
||
await service.chat_simple(
|
||
user.id,
|
||
'3月29日,对话系统交付节点',
|
||
)
|
||
|
||
assert graph.captured_state is not None
|
||
current_context = graph.captured_state.get('current_datetime_context')
|
||
assert isinstance(current_context, str)
|
||
assert current_context
|
||
assert '当前时间' in current_context
|
||
assert '2026' in current_context
|
||
|
||
current_reference = graph.captured_state.get('current_datetime_reference')
|
||
assert isinstance(current_reference, dict)
|
||
assert 'current_time_iso' in current_reference
|
||
assert 'current_date_iso' in current_reference
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_user_llm_config_defaults_to_enabled_chat_model_not_vlm(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
user.llm_config = {
|
||
'chat': [
|
||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||
],
|
||
'vlm': [
|
||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||
],
|
||
}
|
||
await session.commit()
|
||
|
||
service = AgentService(session)
|
||
|
||
config = await service._get_user_llm_config(user.id)
|
||
|
||
assert config is not None
|
||
assert config['name'] == 'Enabled Chat'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_user_llm_config_returns_none_when_only_vlm_is_enabled(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
user.llm_config = {
|
||
'chat': [
|
||
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
|
||
],
|
||
'vlm': [
|
||
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
|
||
],
|
||
}
|
||
await session.commit()
|
||
|
||
service = AgentService(session)
|
||
|
||
config = await service._get_user_llm_config(user.id)
|
||
|
||
assert config is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_simple_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
user.llm_config = {
|
||
'chat': [
|
||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||
],
|
||
'vlm': [
|
||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||
],
|
||
}
|
||
await session.commit()
|
||
|
||
service = AgentService(session)
|
||
|
||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||
await service.chat_simple(user.id, '测试聊天模型选择', model_name='Vision Only')
|
||
|
||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||
message_result = await session.execute(select(Message))
|
||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||
|
||
assert conversation_result.scalars().all() == []
|
||
assert message_result.scalars().all() == []
|
||
assert brain_event_result.scalars().all() == []
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_streaming_chat_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
user.llm_config = {
|
||
'chat': [
|
||
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
|
||
],
|
||
'vlm': [
|
||
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
|
||
],
|
||
}
|
||
await session.commit()
|
||
|
||
service = AgentService(session)
|
||
|
||
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
|
||
await service.chat(user.id, '测试流式聊天模型选择', model_name='Vision Only')
|
||
|
||
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
|
||
message_result = await session.execute(select(Message))
|
||
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
|
||
|
||
assert conversation_result.scalars().all() == []
|
||
assert message_result.scalars().all() == []
|
||
assert brain_event_result.scalars().all() == []
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
|
||
session, user = brain_ingestion_env
|
||
conversation = Conversation(user_id=user.id, title='Brain context test')
|
||
session.add(conversation)
|
||
await session.flush()
|
||
|
||
session.add(UserMemory(
|
||
user_id=user.id,
|
||
memory_type='preference',
|
||
content='用户偏好结构化输出。',
|
||
importance=6,
|
||
source_conversation_id=conversation.id,
|
||
))
|
||
session.add(MemorySummary(
|
||
user_id=user.id,
|
||
conversation_id=conversation.id,
|
||
summary_text='之前讨论了知识大脑的整体设计。',
|
||
turn_count=8,
|
||
))
|
||
session.add(BrainMemory(
|
||
user_id=user.id,
|
||
memory_type='project_fact',
|
||
title='Knowledge brain phase 1',
|
||
content='Jarvis should learn from conversation and document events first.',
|
||
importance=9,
|
||
confidence=0.93,
|
||
status='active',
|
||
origin_source_types=['conversation', 'document'],
|
||
metadata_={'source_count': 2},
|
||
))
|
||
session.add(BrainMemory(
|
||
user_id=user.id,
|
||
memory_type='project_fact',
|
||
title='Forum moderation policy',
|
||
content='Forum moderation escalation stays separate from the current task.',
|
||
importance=10,
|
||
confidence=0.95,
|
||
status='active',
|
||
origin_source_types=['forum'],
|
||
metadata_={'source_count': 1},
|
||
))
|
||
await session.commit()
|
||
|
||
context = await memory_service.build_memory_context(
|
||
session,
|
||
user.id,
|
||
conversation.id,
|
||
'Jarvis 接下来应该优先做什么?',
|
||
)
|
||
|
||
assert '【用户记忆】' in context
|
||
assert '【之前对话摘要】' in context
|
||
assert '【知识大脑】' in context
|
||
assert 'Knowledge brain phase 1' in context
|
||
assert 'Jarvis should learn from conversation and document events first.' in context
|
||
assert 'Forum moderation policy' not in context
|