Files
JARVIS/backend/tests/backend/app/services/test_brain_ingestion.py

697 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from io import BytesIO
from types import SimpleNamespace
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.brain import BrainEvent, BrainMemory
from app.models.conversation import Conversation, Message
from app.models.memory import MemorySummary, UserMemory
from app.models.user import User
from app.services import agent_service, memory_service
from app.services.agent_service import AgentService
from app.services.auth_service import get_password_hash
from app.services.document_service import DocumentService
class FakeGraph:
async def ainvoke(self, state):
return {"final_response": "已记录你的请求。"}
class FakeStreamingGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
}
class FakeStreamingFinalResponseGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chain_end",
"name": "master",
"data": {"output": {"final_response": "这是最终回答。"}},
}
class FakeStreamingBadRequestError(Exception):
pass
class FakeStreamingBadRequestError2(Exception):
pass
class FakeOpenAIBadRequestError(Exception):
pass
class FakeStreamingOpenAIBadRequestGraph:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeOpenAIBadRequestError('invalid_request_error: tool arguments failed validation')
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "不应触发同步回退。"}
class FakeStreamingFallbackGraph:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeStreamingBadRequestError('invalid params, invalid chat setting (2013)')
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "这是回退后的同步回答。"}
class FakeStreamingFallbackGraphGenericError:
def __init__(self):
self.astream_calls = 0
self.ainvoke_calls = 0
async def astream_events(self, state, version="v2"):
self.astream_calls += 1
raise FakeStreamingBadRequestError2("Error code: 400 - {'type': 'error', 'error': {'type': 'bad_request_error', 'message': 'invalid params, invalid chat setting (2013)', 'http_code': '400'}}")
yield
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {"final_response": "这是通用异常回退后的同步回答。"}
class FakeStreamingDelegationThenFinalResponseGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="现在显示收到3月28日的任务是完成对话系统。\n\n我将这部分转给schedule_planner他会根据这个目标结合你当前的进度和资源给出具体的安排建议。")},
}
yield {
"event": "on_chain_end",
"name": "schedule_planner",
"data": {"output": {"final_response": "今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。"}},
}
class FakeStreamingDelegationThenModelEndGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="我将这部分转给schedule_planner。")},
}
yield {
"event": "on_chat_model_end",
"name": "schedule_planner",
"data": {"output": SimpleNamespace(content="最终建议:先完成对话系统,再回归验证。")},
}
class CapturingStateGraph:
def __init__(self, final_response: str = '已记录你的请求。'):
self.final_response = final_response
self.captured_state = None
async def ainvoke(self, state):
self.captured_state = state
return {"final_response": self.final_response}
@pytest.fixture
async def brain_ingestion_env(tmp_path, monkeypatch):
db_path = tmp_path / 'test_brain_ingestion.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
username='brain-ingestion-tester',
email='brain-ingestion@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Ingestion Tester',
)
session.add(user)
await session.commit()
await session.refresh(user)
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
async with session_factory() as session:
yield session, user
await engine.dispose()
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
user.id,
'请记住我这周要完成知识大脑第一阶段。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
assert len(user_events) == 1
assert user_events[0].source_id == conversation_id
assert user_events[0].event_type == 'message_created'
assert user_events[0].title == 'User message'
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
assert user_events[0].status == 'pending'
@pytest.mark.asyncio
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
session, user = brain_ingestion_env
service = DocumentService(session)
upload = UploadFile(
filename='brain-notes.md',
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
)
document = await service.upload_document(user.id, upload)
result = await session.execute(
select(BrainEvent)
.where(
BrainEvent.user_id == user.id,
BrainEvent.source_type == 'document',
BrainEvent.source_id == document.id,
)
)
event = result.scalar_one_or_none()
assert event is not None
assert event.event_type == 'document_uploaded'
assert event.title == 'brain-notes.md'
assert 'Capture important product knowledge.' in (event.content_summary or '')
assert event.metadata_ == {
'document_id': document.id,
'file_type': 'md',
'ingestion_status': 'uploaded',
}
assert event.status == 'pending'
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, response, _model_name = await service.chat_simple(
user.id,
'帮我总结今天知识大脑的进展。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == response
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'用流式回复告诉我今天知识大脑学到了什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert ''.join(chunks) == '这是流式回复。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == '这是流式回复。'
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_emits_final_response_from_chain_end_when_no_model_chunks_exist(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingFinalResponseGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'直接给我最终回答。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert ''.join(chunks) == '这是最终回答。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == '这是最终回答。'
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_prefers_chain_end_final_response_over_delegation_chunk(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenFinalResponseGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我安排今天先做什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。' in chunks
assert chunks[-1] == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
assert assistant_message is not None
assert assistant_message.content == '今天先把完成对话系统拆成三步:先回顾问题,再补测试,最后验证交互链路。'
@pytest.mark.asyncio
async def test_streaming_chat_prefers_model_end_final_content_over_delegation_chunk(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingDelegationThenModelEndGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我安排今天先做什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert '最终建议:先完成对话系统,再回归验证。' in chunks
assert chunks[-1] == '最终建议:先完成对话系统,再回归验证。'
assert assistant_message is not None
assert assistant_message.content == '最终建议:先完成对话系统,再回归验证。'
@pytest.mark.asyncio
async def test_streaming_chat_does_not_fall_back_for_official_openai_bad_request_without_output(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
graph = FakeStreamingOpenAIBadRequestGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
monkeypatch.setattr(agent_service, 'BadRequestError', FakeOpenAIBadRequestError)
original_get_user_llm_config = AgentService._get_user_llm_config
async def fake_get_user_llm_config(self, user_id, model_name=None):
return {
'name': 'Official OpenAI',
'provider': 'openai',
'model': 'gpt-4o',
'base_url': 'https://api.openai.com/v1',
'enabled': True,
}
monkeypatch.setattr(AgentService, '_get_user_llm_config', fake_get_user_llm_config)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'测试官方 OpenAI bad request 不应回退。',
)
chunks = []
errors = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
if event.get('type') == 'error':
errors.append(event['error'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
assert graph.astream_calls == 1
assert graph.ainvoke_calls == 0
assert errors == ['模型服务暂不可用,请稍后再试。']
assert chunks == ['抱歉,发生错误: 模型服务暂不可用,请稍后再试。']
assert assistant_message is not None
assert assistant_message.content == '抱歉,发生错误: 模型服务暂不可用,请稍后再试。'
@pytest.mark.asyncio
async def test_streaming_chat_falls_back_for_generic_400_streaming_error(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
fallback_graph = FakeStreamingFallbackGraphGenericError()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: fallback_graph)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'帮我制定一下明天的计划。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert fallback_graph.astream_calls == 1
assert fallback_graph.ainvoke_calls == 1
assert ''.join(chunks) == '这是通用异常回退后的同步回答。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].content_summary == '这是通用异常回退后的同步回答。'
@pytest.mark.asyncio
async def test_streaming_chat_does_not_fall_back_after_partial_stream_output(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
class PartialThenFailGraph:
def __init__(self):
self.ainvoke_calls = 0
async def astream_events(self, state, version='v2'):
yield {
'event': 'on_chat_model_stream',
'name': 'master',
'data': {'chunk': SimpleNamespace(content='前半段')},
}
raise FakeStreamingBadRequestError('stream interrupted')
async def ainvoke(self, state):
self.ainvoke_calls += 1
return {'final_response': '不应触发'}
graph = PartialThenFailGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
monkeypatch.setattr(agent_service, 'BadRequestError', FakeStreamingBadRequestError)
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'测试部分流式输出失败。',
)
chunks = []
errors = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
if event.get('type') == 'error':
errors.append(event['error'])
message_result = await session.execute(
select(Message)
.where(Message.conversation_id == conversation_id, Message.role == 'assistant')
.order_by(Message.created_at.desc())
)
assistant_message = message_result.scalars().first()
brain_event_result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(brain_event_result.scalars().all())
assert chunks == ['前半段']
assert graph.ainvoke_calls == 0
assert errors == ['stream interrupted']
assert assistant_message is not None
assert assistant_message.content == '前半段'
assert events[1].content_summary == '前半段'
@pytest.mark.asyncio
async def test_chat_simple_passes_current_datetime_context_into_langgraph_state(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
graph = CapturingStateGraph()
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: graph)
service = AgentService(session)
await service.chat_simple(
user.id,
'3月29日对话系统交付节点',
)
assert graph.captured_state is not None
current_context = graph.captured_state.get('current_datetime_context')
assert isinstance(current_context, str)
assert current_context
assert '当前时间' in current_context
assert '2026' in current_context
current_reference = graph.captured_state.get('current_datetime_reference')
assert isinstance(current_reference, dict)
assert 'current_time_iso' in current_reference
assert 'current_date_iso' in current_reference
@pytest.mark.asyncio
async def test_get_user_llm_config_defaults_to_enabled_chat_model_not_vlm(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
config = await service._get_user_llm_config(user.id)
assert config is not None
assert config['name'] == 'Enabled Chat'
@pytest.mark.asyncio
async def test_get_user_llm_config_returns_none_when_only_vlm_is_enabled(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Disabled Chat', 'provider': 'openai', 'model': 'disabled-chat', 'enabled': False},
],
'vlm': [
{'name': 'Enabled Vision', 'provider': 'openai', 'model': 'enabled-vision', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
config = await service._get_user_llm_config(user.id)
assert config is None
@pytest.mark.asyncio
async def test_chat_simple_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
await service.chat_simple(user.id, '测试聊天模型选择', model_name='Vision Only')
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
message_result = await session.execute(select(Message))
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
assert conversation_result.scalars().all() == []
assert message_result.scalars().all() == []
assert brain_event_result.scalars().all() == []
@pytest.mark.asyncio
async def test_streaming_chat_rejects_vlm_model_without_persisting_conversation_state(brain_ingestion_env):
session, user = brain_ingestion_env
user.llm_config = {
'chat': [
{'name': 'Enabled Chat', 'provider': 'openai', 'model': 'enabled-chat', 'enabled': True},
],
'vlm': [
{'name': 'Vision Only', 'provider': 'openai', 'model': 'vision-only', 'enabled': True},
],
}
await session.commit()
service = AgentService(session)
with pytest.raises(ValueError, match='所选模型不可用于聊天,请切换到聊天模型'):
await service.chat(user.id, '测试流式聊天模型选择', model_name='Vision Only')
conversation_result = await session.execute(select(Conversation).where(Conversation.user_id == user.id))
message_result = await session.execute(select(Message))
brain_event_result = await session.execute(select(BrainEvent).where(BrainEvent.user_id == user.id))
assert conversation_result.scalars().all() == []
assert message_result.scalars().all() == []
assert brain_event_result.scalars().all() == []
@pytest.mark.asyncio
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
session, user = brain_ingestion_env
conversation = Conversation(user_id=user.id, title='Brain context test')
session.add(conversation)
await session.flush()
session.add(UserMemory(
user_id=user.id,
memory_type='preference',
content='用户偏好结构化输出。',
importance=6,
source_conversation_id=conversation.id,
))
session.add(MemorySummary(
user_id=user.id,
conversation_id=conversation.id,
summary_text='之前讨论了知识大脑的整体设计。',
turn_count=8,
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Knowledge brain phase 1',
content='Jarvis should learn from conversation and document events first.',
importance=9,
confidence=0.93,
status='active',
origin_source_types=['conversation', 'document'],
metadata_={'source_count': 2},
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Forum moderation policy',
content='Forum moderation escalation stays separate from the current task.',
importance=10,
confidence=0.95,
status='active',
origin_source_types=['forum'],
metadata_={'source_count': 1},
))
await session.commit()
context = await memory_service.build_memory_context(
session,
user.id,
conversation.id,
'Jarvis 接下来应该优先做什么?',
)
assert '【用户记忆】' in context
assert '【之前对话摘要】' in context
assert '【知识大脑】' in context
assert 'Knowledge brain phase 1' in context
assert 'Jarvis should learn from conversation and document events first.' in context
assert 'Forum moderation policy' not in context