Add brain memory services and APIs
Introduce the backend pieces for brain memory ingestion, routing, and system telemetry so the new knowledge workflows can project data into a brain view. The supporting tests lock in the new behavior and keep the expanded backend surface stable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.brain import BrainMemory, BrainTag
|
||||
from app.models.knowledge_graph import KGEdge, KGNode
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.graph import router as graph_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.brain_service import BrainService
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_graph_env(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_graph.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain-graph@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Graph Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Knowledge brain phase 1',
|
||||
content='Jarvis should learn from conversations and documents first.',
|
||||
importance=9,
|
||||
confidence=0.95,
|
||||
status='active',
|
||||
origin_source_types=['conversation', 'document'],
|
||||
),
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='user_preference',
|
||||
title='Structured delivery preference',
|
||||
content='The user prefers concise structured summaries.',
|
||||
importance=7,
|
||||
confidence=0.88,
|
||||
status='active',
|
||||
origin_source_types=['conversation'],
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='knowledge-brain',
|
||||
category='topic',
|
||||
priority='important',
|
||||
score=9.5,
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='conversation',
|
||||
category='source',
|
||||
priority='secondary',
|
||||
score=7.0,
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(graph_router)
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield session_factory, user, app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_graph_projects_kg_nodes_and_edges_from_brain_data(brain_graph_env):
|
||||
session_factory, user, _app = brain_graph_env
|
||||
|
||||
async with session_factory() as session:
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
node_result = await session.execute(
|
||||
select(KGNode).where(KGNode.user_id == user.id).order_by(KGNode.name.asc())
|
||||
)
|
||||
nodes = list(node_result.scalars().all())
|
||||
edge_result = await session.execute(select(KGEdge))
|
||||
edges = list(edge_result.scalars().all())
|
||||
|
||||
node_names = [node.name for node in nodes]
|
||||
assert 'Knowledge brain phase 1' in node_names
|
||||
assert 'Structured delivery preference' in node_names
|
||||
assert 'knowledge-brain' in node_names
|
||||
assert len(edges) >= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_learning_triggers_graph_rebuild(brain_graph_env, monkeypatch):
|
||||
session_factory, user, _app = brain_graph_env
|
||||
calls: list[str] = []
|
||||
|
||||
async def fake_build_graph(self, user_id, document_ids=None):
|
||||
calls.append(user_id)
|
||||
|
||||
monkeypatch.setattr(GraphService, 'build_graph', fake_build_graph)
|
||||
|
||||
async with session_factory() as session:
|
||||
service = BrainService(session)
|
||||
await service.run_learning(user.id)
|
||||
|
||||
assert calls == [user.id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_api_returns_brain_projected_graph_after_build(brain_graph_env):
|
||||
session_factory, user, app = brain_graph_env
|
||||
|
||||
async with session_factory() as session:
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/graph')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['stats']['node_count'] >= 3
|
||||
assert payload['stats']['edge_count'] >= 2
|
||||
assert any(node['name'] == 'Knowledge brain phase 1' for node in payload['nodes'])
|
||||
assert any(node['name'] == 'knowledge-brain' for node in payload['nodes'])
|
||||
237
backend/tests/backend/app/services/test_brain_ingestion.py
Normal file
237
backend/tests/backend/app/services/test_brain_ingestion.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from starlette.datastructures import UploadFile
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.brain import BrainEvent, BrainMemory
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.memory import MemorySummary, UserMemory
|
||||
from app.models.user import User
|
||||
from app.services import agent_service, memory_service
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.document_service import DocumentService
|
||||
|
||||
|
||||
class FakeGraph:
|
||||
async def ainvoke(self, state):
|
||||
return {"final_response": "已记录你的请求。"}
|
||||
|
||||
|
||||
class FakeStreamingGraph:
|
||||
async def astream_events(self, state, version="v2"):
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"name": "master",
|
||||
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_ingestion_env(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / 'test_brain_ingestion.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain-ingestion@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Ingestion Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
|
||||
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
|
||||
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
|
||||
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
|
||||
|
||||
async with session_factory() as session:
|
||||
yield session, user
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
|
||||
user.id,
|
||||
'请记住我这周要完成知识大脑第一阶段。',
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
|
||||
|
||||
assert len(user_events) == 1
|
||||
assert user_events[0].source_id == conversation_id
|
||||
assert user_events[0].event_type == 'message_created'
|
||||
assert user_events[0].title == 'User message'
|
||||
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
|
||||
assert user_events[0].status == 'pending'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = DocumentService(session)
|
||||
upload = UploadFile(
|
||||
filename='brain-notes.md',
|
||||
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
|
||||
)
|
||||
|
||||
document = await service.upload_document(user.id, upload)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(
|
||||
BrainEvent.user_id == user.id,
|
||||
BrainEvent.source_type == 'document',
|
||||
BrainEvent.source_id == document.id,
|
||||
)
|
||||
)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
assert event is not None
|
||||
assert event.event_type == 'document_uploaded'
|
||||
assert event.title == 'brain-notes.md'
|
||||
assert 'Capture important product knowledge.' in (event.content_summary or '')
|
||||
assert event.metadata_ == {
|
||||
'document_id': document.id,
|
||||
'file_type': 'md',
|
||||
'ingestion_status': 'uploaded',
|
||||
}
|
||||
assert event.status == 'pending'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, response, _model_name = await service.chat_simple(
|
||||
user.id,
|
||||
'帮我总结今天知识大脑的进展。',
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == response
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
|
||||
session, user = brain_ingestion_env
|
||||
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
|
||||
service = AgentService(session)
|
||||
|
||||
conversation_id, _message_id, stream = await service.chat(
|
||||
user.id,
|
||||
'用流式回复告诉我今天知识大脑学到了什么。',
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for event in stream:
|
||||
if event.get('type') == 'chunk':
|
||||
chunks.append(event['content'])
|
||||
|
||||
result = await session.execute(
|
||||
select(BrainEvent)
|
||||
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
|
||||
.order_by(BrainEvent.created_at.asc())
|
||||
)
|
||||
events = list(result.scalars().all())
|
||||
|
||||
assert ''.join(chunks) == '这是流式回复。'
|
||||
assert len(events) == 2
|
||||
assert events[1].source_id == conversation_id
|
||||
assert events[1].event_type == 'message_created'
|
||||
assert events[1].title == 'Assistant message'
|
||||
assert events[1].content_summary == '这是流式回复。'
|
||||
assert events[1].metadata_ == {'role': 'assistant'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
|
||||
session, user = brain_ingestion_env
|
||||
conversation = Conversation(user_id=user.id, title='Brain context test')
|
||||
session.add(conversation)
|
||||
await session.flush()
|
||||
|
||||
session.add(UserMemory(
|
||||
user_id=user.id,
|
||||
memory_type='preference',
|
||||
content='用户偏好结构化输出。',
|
||||
importance=6,
|
||||
source_conversation_id=conversation.id,
|
||||
))
|
||||
session.add(MemorySummary(
|
||||
user_id=user.id,
|
||||
conversation_id=conversation.id,
|
||||
summary_text='之前讨论了知识大脑的整体设计。',
|
||||
turn_count=8,
|
||||
))
|
||||
session.add(BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Knowledge brain phase 1',
|
||||
content='Jarvis should learn from conversation and document events first.',
|
||||
importance=9,
|
||||
confidence=0.93,
|
||||
status='active',
|
||||
origin_source_types=['conversation', 'document'],
|
||||
metadata_={'source_count': 2},
|
||||
))
|
||||
session.add(BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Forum moderation policy',
|
||||
content='Forum moderation escalation stays separate from the current task.',
|
||||
importance=10,
|
||||
confidence=0.95,
|
||||
status='active',
|
||||
origin_source_types=['forum'],
|
||||
metadata_={'source_count': 1},
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
context = await memory_service.build_memory_context(
|
||||
session,
|
||||
user.id,
|
||||
conversation.id,
|
||||
'Jarvis 接下来应该优先做什么?',
|
||||
)
|
||||
|
||||
assert '【用户记忆】' in context
|
||||
assert '【之前对话摘要】' in context
|
||||
assert '【知识大脑】' in context
|
||||
assert 'Knowledge brain phase 1' in context
|
||||
assert 'Jarvis should learn from conversation and document events first.' in context
|
||||
assert 'Forum moderation policy' not in context
|
||||
194
backend/tests/backend/app/services/test_brain_router.py
Normal file
194
backend/tests/backend/app/services/test_brain_router.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import sys
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
sys.modules.setdefault('psutil', Mock())
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base, get_db
|
||||
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.brain import router as brain_router
|
||||
from app.services.auth_service import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def brain_router_env(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_router.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='brain@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Brain Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
session.add_all([
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='project_fact',
|
||||
title='Current project direction',
|
||||
content='Jarvis knowledge brain should learn from all major product surfaces.',
|
||||
importance=8,
|
||||
confidence=0.92,
|
||||
status='active',
|
||||
),
|
||||
BrainMemory(
|
||||
user_id=user.id,
|
||||
memory_type='preference',
|
||||
title='User prefers brain-first UX',
|
||||
content='The knowledge brain should be broader than the graph page.',
|
||||
importance=7,
|
||||
confidence=0.88,
|
||||
status='active',
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='knowledge-brain',
|
||||
category='topic',
|
||||
priority='important',
|
||||
score=9.5,
|
||||
),
|
||||
BrainTag(
|
||||
user_id=user.id,
|
||||
name='graph',
|
||||
category='topic',
|
||||
priority='secondary',
|
||||
score=4.0,
|
||||
),
|
||||
BrainEvent(
|
||||
user_id=user.id,
|
||||
source_type='conversation',
|
||||
source_id='conv-1',
|
||||
event_type='created',
|
||||
title='Conversation created',
|
||||
content_summary='User described the desired knowledge brain behavior.',
|
||||
status='pending',
|
||||
),
|
||||
BrainEvent(
|
||||
user_id=user.id,
|
||||
source_type='document',
|
||||
source_id='doc-1',
|
||||
event_type='indexed',
|
||||
title='Document indexed',
|
||||
content_summary='A strategic document was indexed into the system.',
|
||||
status='processed',
|
||||
),
|
||||
BrainCandidate(
|
||||
user_id=user.id,
|
||||
candidate_type='project_fact',
|
||||
title='Brain spans all product surfaces',
|
||||
summary='The knowledge brain should learn from conversation, docs, tasks, todos, and forum.',
|
||||
importance_score=9.2,
|
||||
confidence_score=0.95,
|
||||
status='new',
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def override_get_current_user():
|
||||
return user
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(brain_router)
|
||||
test_app.dependency_overrides[get_db] = override_get_db
|
||||
test_app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
try:
|
||||
yield test_app
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_overview_returns_memory_and_tag_summary(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/overview')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload['active_memory_count'] == 2
|
||||
assert payload['important_tag_count'] == 1
|
||||
assert payload['secondary_tag_count'] == 1
|
||||
assert payload['recent_memory_titles'] == [
|
||||
'Current project direction',
|
||||
'User prefers brain-first UX',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_memories_returns_active_memories_sorted_by_importance(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/memories')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['title'] for item in payload] == [
|
||||
'Current project direction',
|
||||
'User prefers brain-first UX',
|
||||
]
|
||||
assert all(item['status'] == 'active' for item in payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_tags_groups_important_and_secondary_tags(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/tags')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert [item['name'] for item in payload['important']] == ['knowledge-brain']
|
||||
assert [item['name'] for item in payload['secondary']] == ['graph']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_brain_events_returns_latest_events_first(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.get('/api/brain/events')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert len(payload) == 2
|
||||
assert payload[0]['title'] == 'Document indexed'
|
||||
assert payload[1]['title'] == 'Conversation created'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_brain_learning_run_returns_processed_counts(brain_router_env):
|
||||
transport = ASGITransport(app=brain_router_env)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
|
||||
response = await client.post('/api/brain/learn/run')
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload == {
|
||||
'events_considered': 1,
|
||||
'candidates_created': 1,
|
||||
'memories_promoted': 1,
|
||||
}
|
||||
371
backend/tests/backend/app/services/test_knowledge_service.py
Normal file
371
backend/tests/backend/app/services/test_knowledge_service.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from starlette.datastructures import UploadFile
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.models.folder import Folder
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import get_password_hash
|
||||
from app.services.knowledge_service import KnowledgeService, SearchResult
|
||||
from app.services.graph_service import GraphService
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self):
|
||||
self.add_calls = []
|
||||
self.delete_calls = []
|
||||
|
||||
def add(self, *, ids, documents, metadatas):
|
||||
self.add_calls.append({
|
||||
'ids': ids,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
})
|
||||
|
||||
def delete(self, *, where):
|
||||
self.delete_calls.append(where)
|
||||
|
||||
def query(self, **kwargs):
|
||||
self.last_query = kwargs
|
||||
return {
|
||||
'ids': [['chunk-schema', 'chunk-rows']],
|
||||
'documents': [['schema chunk', 'row chunk']],
|
||||
'metadatas': [[
|
||||
{
|
||||
'document_id': 'doc-1',
|
||||
'document_title': 'Revenue',
|
||||
'chunk_index': 0,
|
||||
'content_type': 'table_schema',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 0,
|
||||
'row_end': 0,
|
||||
},
|
||||
{
|
||||
'document_id': 'doc-1',
|
||||
'document_title': 'Revenue',
|
||||
'chunk_index': 1,
|
||||
'content_type': 'table_rows',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 1,
|
||||
'row_end': 10,
|
||||
},
|
||||
]],
|
||||
'distances': [[0.3, 0.35]],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def knowledge_test_env(tmp_path):
|
||||
db_path = tmp_path / 'test_knowledge.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email='knowledge@example.com',
|
||||
hashed_password=get_password_hash('secret123'),
|
||||
full_name='Knowledge Tester',
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
root = Folder(user_id=user.id, name='Finance', parent_id=None)
|
||||
session.add(root)
|
||||
await session.flush()
|
||||
child = Folder(user_id=user.id, name='Reports', parent_id=root.id)
|
||||
session.add(child)
|
||||
await session.flush()
|
||||
|
||||
document = Document(
|
||||
id='doc-1',
|
||||
user_id=user.id,
|
||||
title='Revenue Workbook',
|
||||
filename='revenue.xlsx',
|
||||
file_type='xlsx',
|
||||
file_size=128,
|
||||
file_path=str(tmp_path / 'revenue.xlsx'),
|
||||
folder_id=child.id,
|
||||
summary='Revenue summary',
|
||||
chunk_count=2,
|
||||
is_indexed=False,
|
||||
)
|
||||
session.add(document)
|
||||
session.add_all([
|
||||
DocumentChunk(
|
||||
id='chunk-schema',
|
||||
document_id=document.id,
|
||||
chunk_index=0,
|
||||
content='schema chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'table_schema',
|
||||
'sheet_name': 'Revenue',
|
||||
'headers': ['region', 'amount'],
|
||||
'source_order': 0,
|
||||
'section_path': ['Revenue'],
|
||||
'page_number': 1,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-rows',
|
||||
document_id=document.id,
|
||||
chunk_index=1,
|
||||
content='row chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'table_rows',
|
||||
'sheet_name': 'Revenue',
|
||||
'row_start': 1,
|
||||
'row_end': 10,
|
||||
'source_order': 1,
|
||||
'section_path': ['Revenue'],
|
||||
'page_number': 1,
|
||||
}),
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(document)
|
||||
await session.refresh(child)
|
||||
yield session, user, document, child
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_document_writes_folder_and_structure_metadata(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
await service.index_document(document.id, user.id)
|
||||
|
||||
assert fake_collection.add_calls
|
||||
metadatas = fake_collection.add_calls[0]['metadatas']
|
||||
assert metadatas[0]['folder_path'] == '/Finance/Reports'
|
||||
assert metadatas[0]['content_type'] == 'table_schema'
|
||||
assert metadatas[0]['sheet_name'] == 'Revenue'
|
||||
assert metadatas[1]['content_type'] == 'table_rows'
|
||||
|
||||
await session.refresh(document)
|
||||
assert document.is_indexed is True
|
||||
assert document.ingestion_status == 'ready'
|
||||
assert document.indexed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_prefers_table_schema_for_tabular_queries(knowledge_test_env):
|
||||
session, user, _document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
results = await service.retrieve('excel表 Revenue 的列有哪些', user.id, top_k=2, use_rerank=True)
|
||||
|
||||
assert [item.chunk_id for item in results] == ['chunk-schema', 'chunk-rows']
|
||||
metadata = json.loads(results[0].metadata_)
|
||||
assert metadata['content_type'] == 'table_schema'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_expansion_uses_same_sheet_for_table_rows(knowledge_test_env):
|
||||
session, user, _document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
|
||||
prev_chunk, next_chunk = await service._get_related_chunks('chunk-rows', 1, 'doc-1')
|
||||
|
||||
assert prev_chunk == 'schema chunk'
|
||||
assert next_chunk is None
|
||||
|
||||
|
||||
def test_rerank_boosts_table_chunks_when_query_mentions_sheet():
|
||||
service = KnowledgeService(db=None, user_id='user-1')
|
||||
schema = SearchResult(
|
||||
chunk_id='schema',
|
||||
document_id='doc-1',
|
||||
document_title='Revenue Workbook',
|
||||
content='Columns: region amount',
|
||||
score=0.6,
|
||||
metadata_=json.dumps({'content_type': 'table_schema', 'sheet_name': 'Revenue'}),
|
||||
)
|
||||
paragraph = SearchResult(
|
||||
chunk_id='paragraph',
|
||||
document_id='doc-1',
|
||||
document_title='Revenue Workbook',
|
||||
content='General revenue narrative',
|
||||
score=0.65,
|
||||
metadata_=json.dumps({'content_type': 'paragraph', 'section_title': 'Overview'}),
|
||||
)
|
||||
|
||||
ranked = service._rerank('sheet Revenue 有哪些列', [paragraph, schema], top_k=2)
|
||||
|
||||
assert [item.chunk_id for item in ranked] == ['schema', 'paragraph']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_expansion_prefers_same_section_before_linear_neighbors(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
session.add_all([
|
||||
DocumentChunk(
|
||||
id='chunk-overview',
|
||||
document_id=document.id,
|
||||
chunk_index=2,
|
||||
content='overview chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Overview'],
|
||||
'section_title': 'Overview',
|
||||
'source_order': 2,
|
||||
'page_number': 2,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-overview-2',
|
||||
document_id=document.id,
|
||||
chunk_index=3,
|
||||
content='overview details chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Overview'],
|
||||
'section_title': 'Overview',
|
||||
'source_order': 3,
|
||||
'page_number': 2,
|
||||
}),
|
||||
),
|
||||
DocumentChunk(
|
||||
id='chunk-appendix',
|
||||
document_id=document.id,
|
||||
chunk_index=4,
|
||||
content='appendix chunk',
|
||||
metadata_=json.dumps({
|
||||
'content_type': 'paragraph',
|
||||
'section_path': ['Appendix'],
|
||||
'section_title': 'Appendix',
|
||||
'source_order': 4,
|
||||
'page_number': 3,
|
||||
}),
|
||||
),
|
||||
])
|
||||
await session.commit()
|
||||
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
prev_chunk, next_chunk = await service._get_related_chunks('chunk-overview-2', 3, 'doc-1')
|
||||
|
||||
assert prev_chunk == 'overview chunk'
|
||||
assert next_chunk is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reindex_document_rebuilds_chunks_and_versions(knowledge_test_env):
|
||||
session, user, _document, folder = knowledge_test_env
|
||||
from app.services.document_service import DocumentService
|
||||
from app.services.knowledge_service import KnowledgeService
|
||||
|
||||
upload = UploadFile(
|
||||
filename='reindex.md',
|
||||
file=BytesIO(b'# Intro\n\nOriginal content\n\n## Details\n\nUpdated content'),
|
||||
)
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.upload_document(user.id, upload, folder_id=folder.id)
|
||||
|
||||
chunk_result = await session.execute(select(DocumentChunk).where(DocumentChunk.document_id == document.id))
|
||||
original_chunks = list(chunk_result.scalars().all())
|
||||
assert original_chunks
|
||||
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
rebuilt = await service.reindex_document(document.id, user.id)
|
||||
|
||||
assert rebuilt is True
|
||||
await session.refresh(document)
|
||||
assert document.parser_version == 'v2'
|
||||
assert document.index_version == 'v2'
|
||||
assert document.ingestion_status == 'ready'
|
||||
|
||||
new_chunk_result = await session.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.document_id == document.id)
|
||||
.order_by(DocumentChunk.chunk_index)
|
||||
)
|
||||
rebuilt_chunks = list(new_chunk_result.scalars().all())
|
||||
assert rebuilt_chunks
|
||||
assert all(chunk.metadata_ for chunk in rebuilt_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reindex_document_chunks_reuses_existing_db_chunks(knowledge_test_env):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
service = KnowledgeService(session, user_id=user.id)
|
||||
fake_collection = FakeCollection()
|
||||
service.get_collection = lambda user_id: fake_collection
|
||||
|
||||
chunk_result = await session.execute(
|
||||
select(DocumentChunk)
|
||||
.where(DocumentChunk.id == 'chunk-schema')
|
||||
)
|
||||
chunk = chunk_result.scalar_one()
|
||||
chunk.content = 'edited schema chunk'
|
||||
document.ingestion_status = 'indexing'
|
||||
await session.commit()
|
||||
|
||||
rebuilt = await service.reindex_document_chunks(document.id, user.id)
|
||||
|
||||
assert rebuilt is True
|
||||
assert fake_collection.delete_calls == [{'document_id': document.id}]
|
||||
assert fake_collection.add_calls
|
||||
assert fake_collection.add_calls[0]['documents'][0] == 'edited schema chunk'
|
||||
|
||||
await session.refresh(document)
|
||||
assert document.ingestion_status == 'ready'
|
||||
assert document.indexed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_graph_service_uses_user_llm_configured_model(knowledge_test_env, monkeypatch):
|
||||
session, user, document, _folder = knowledge_test_env
|
||||
document.is_indexed = True
|
||||
await session.commit()
|
||||
|
||||
used_providers = []
|
||||
|
||||
class FakeLLM:
|
||||
async def invoke(self, _messages):
|
||||
return SimpleNamespace(content=json.dumps({
|
||||
'entities': [{'name': 'Revenue', 'type': 'topic', 'description': 'Revenue topic'}],
|
||||
'relations': [],
|
||||
}))
|
||||
|
||||
async def fake_get_user_llm_config(self, user_id, model_name=None):
|
||||
assert user_id == user.id
|
||||
assert model_name is None
|
||||
return {
|
||||
'provider': 'openai',
|
||||
'model': 'user-model',
|
||||
'api_key': 'secret',
|
||||
'base_url': 'https://example.com/v1',
|
||||
'enabled': True,
|
||||
}
|
||||
|
||||
def fake_create_llm_from_config(config):
|
||||
used_providers.append(config['provider'])
|
||||
return FakeLLM()
|
||||
|
||||
def fail_if_global_llm_used():
|
||||
raise AssertionError('global get_llm should not be used when user config exists')
|
||||
|
||||
monkeypatch.setattr('app.services.graph_service.get_llm', fail_if_global_llm_used)
|
||||
monkeypatch.setattr('app.services.graph_service.resolve_user_llm', fake_get_user_llm_config, raising=False)
|
||||
monkeypatch.setattr('app.services.graph_service._create_llm_from_config', fake_create_llm_from_config, raising=False)
|
||||
|
||||
service = GraphService(session)
|
||||
await service.build_graph(user.id)
|
||||
|
||||
assert used_providers == ['openai']
|
||||
28
backend/tests/backend/app/test_brain_models.py
Normal file
28
backend/tests/backend/app/test_brain_models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
import app.models # noqa: F401
|
||||
from app.database import Base
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_brain_tables_are_registered_in_metadata(tmp_path):
|
||||
db_path = tmp_path / 'test_brain_models.db'
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
|
||||
table_names = {row[0] for row in result.fetchall()}
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
assert 'brain_events' in table_names
|
||||
assert 'brain_candidates' in table_names
|
||||
assert 'brain_memories' in table_names
|
||||
assert 'brain_tags' in table_names
|
||||
assert 'brain_event_tags' in table_names
|
||||
assert 'brain_memory_tags' in table_names
|
||||
assert 'brain_memory_sources' in table_names
|
||||
Reference in New Issue
Block a user