Files
JARVIS/backend/tests/backend/app/test_conversation_router.py

98 lines
3.3 KiB
Python

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
import app.models # noqa: F401
from app.database import Base, get_db, ensure_conversation_columns
from app.models.conversation import Conversation
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.conversation import router as conversation_router
from app.services.auth_service import get_password_hash
@pytest.fixture
async def conversation_env(tmp_path):
db_path = tmp_path / 'test_conversation_router.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await conn.execute(text('ALTER TABLE conversations DROP COLUMN agent_state'))
await ensure_conversation_columns(conn)
async with session_factory() as session:
user = User(
username='conversation_user',
email='conversation@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Conversation Tester',
is_active=True,
)
session.add(user)
await session.flush()
session.add(
Conversation(
user_id=user.id,
title='Existing conversation',
message_count=3,
)
)
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(conversation_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_list_conversations_succeeds_when_agent_state_column_was_missing(conversation_env):
transport = ASGITransport(app=conversation_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/conversations')
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]['title'] == 'Existing conversation'
assert payload[0]['message_count'] == 3
@pytest.mark.asyncio
async def test_chat_stream_emits_error_event_when_agent_service_fails_before_stream_starts(
conversation_env,
monkeypatch,
):
async def fail_chat(*args, **kwargs):
raise RuntimeError('stream boot failed')
monkeypatch.setattr('app.routers.conversation.AgentService.chat', fail_chat)
transport = ASGITransport(app=conversation_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.post(
'/api/conversations/chat/stream',
json={'message': 'hello'},
)
assert response.status_code == 200
assert 'event: error' in response.text
assert 'stream boot failed' in response.text