Add brain memory services and APIs

Introduce the backend pieces for brain memory ingestion, routing, and
system telemetry so the new knowledge workflows can project data into a
brain view. The supporting tests lock in the new behavior and keep the
expanded backend surface stable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-22 13:47:34 +08:00
parent e3691b01bb
commit d2447ee635
28 changed files with 2278 additions and 197 deletions

View File

@@ -0,0 +1,155 @@
import sys
from unittest.mock import Mock
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault('psutil', Mock())
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.brain import BrainMemory, BrainTag
from app.models.knowledge_graph import KGEdge, KGNode
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.graph import router as graph_router
from app.services.auth_service import get_password_hash
from app.services.brain_service import BrainService
from app.services.graph_service import GraphService
@pytest.fixture
async def brain_graph_env(tmp_path):
db_path = tmp_path / 'test_brain_graph.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain-graph@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Graph Tester',
)
session.add(user)
await session.flush()
session.add_all([
BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Knowledge brain phase 1',
content='Jarvis should learn from conversations and documents first.',
importance=9,
confidence=0.95,
status='active',
origin_source_types=['conversation', 'document'],
),
BrainMemory(
user_id=user.id,
memory_type='user_preference',
title='Structured delivery preference',
content='The user prefers concise structured summaries.',
importance=7,
confidence=0.88,
status='active',
origin_source_types=['conversation'],
),
BrainTag(
user_id=user.id,
name='knowledge-brain',
category='topic',
priority='important',
score=9.5,
),
BrainTag(
user_id=user.id,
name='conversation',
category='source',
priority='secondary',
score=7.0,
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
app = FastAPI()
app.include_router(graph_router)
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield session_factory, user, app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_build_graph_projects_kg_nodes_and_edges_from_brain_data(brain_graph_env):
session_factory, user, _app = brain_graph_env
async with session_factory() as session:
service = GraphService(session)
await service.build_graph(user.id)
node_result = await session.execute(
select(KGNode).where(KGNode.user_id == user.id).order_by(KGNode.name.asc())
)
nodes = list(node_result.scalars().all())
edge_result = await session.execute(select(KGEdge))
edges = list(edge_result.scalars().all())
node_names = [node.name for node in nodes]
assert 'Knowledge brain phase 1' in node_names
assert 'Structured delivery preference' in node_names
assert 'knowledge-brain' in node_names
assert len(edges) >= 2
@pytest.mark.asyncio
async def test_run_learning_triggers_graph_rebuild(brain_graph_env, monkeypatch):
session_factory, user, _app = brain_graph_env
calls: list[str] = []
async def fake_build_graph(self, user_id, document_ids=None):
calls.append(user_id)
monkeypatch.setattr(GraphService, 'build_graph', fake_build_graph)
async with session_factory() as session:
service = BrainService(session)
await service.run_learning(user.id)
assert calls == [user.id]
@pytest.mark.asyncio
async def test_graph_api_returns_brain_projected_graph_after_build(brain_graph_env):
session_factory, user, app = brain_graph_env
async with session_factory() as session:
service = GraphService(session)
await service.build_graph(user.id)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/graph')
assert response.status_code == 200
payload = response.json()
assert payload['stats']['node_count'] >= 3
assert payload['stats']['edge_count'] >= 2
assert any(node['name'] == 'Knowledge brain phase 1' for node in payload['nodes'])
assert any(node['name'] == 'knowledge-brain' for node in payload['nodes'])

View File

@@ -0,0 +1,237 @@
from io import BytesIO
from types import SimpleNamespace
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.brain import BrainEvent, BrainMemory
from app.models.conversation import Conversation
from app.models.memory import MemorySummary, UserMemory
from app.models.user import User
from app.services import agent_service, memory_service
from app.services.agent_service import AgentService
from app.services.auth_service import get_password_hash
from app.services.document_service import DocumentService
class FakeGraph:
async def ainvoke(self, state):
return {"final_response": "已记录你的请求。"}
class FakeStreamingGraph:
async def astream_events(self, state, version="v2"):
yield {
"event": "on_chat_model_stream",
"name": "master",
"data": {"chunk": SimpleNamespace(content="这是流式回复。")},
}
@pytest.fixture
async def brain_ingestion_env(tmp_path, monkeypatch):
db_path = tmp_path / 'test_brain_ingestion.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain-ingestion@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Ingestion Tester',
)
session.add(user)
await session.commit()
await session.refresh(user)
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeGraph())
monkeypatch.setattr(agent_service, 'set_current_user', lambda user_id: None)
monkeypatch.setattr(agent_service, 'clear_current_user', lambda: None)
monkeypatch.setattr('app.services.document_service.settings.UPLOAD_DIR', str(tmp_path / 'uploads'))
async with session_factory() as session:
yield session, user
await engine.dispose()
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_user_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, _response, _model_name = await service.chat_simple(
user.id,
'请记住我这周要完成知识大脑第一阶段。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
user_events = [event for event in events if event.metadata_ == {'role': 'user'}]
assert len(user_events) == 1
assert user_events[0].source_id == conversation_id
assert user_events[0].event_type == 'message_created'
assert user_events[0].title == 'User message'
assert '知识大脑第一阶段' in (user_events[0].content_summary or '')
assert user_events[0].status == 'pending'
@pytest.mark.asyncio
async def test_upload_document_creates_brain_event_for_document_flow(brain_ingestion_env):
session, user = brain_ingestion_env
service = DocumentService(session)
upload = UploadFile(
filename='brain-notes.md',
file=BytesIO('# Brain\n\nCapture important product knowledge.'.encode('utf-8')),
)
document = await service.upload_document(user.id, upload)
result = await session.execute(
select(BrainEvent)
.where(
BrainEvent.user_id == user.id,
BrainEvent.source_type == 'document',
BrainEvent.source_id == document.id,
)
)
event = result.scalar_one_or_none()
assert event is not None
assert event.event_type == 'document_uploaded'
assert event.title == 'brain-notes.md'
assert 'Capture important product knowledge.' in (event.content_summary or '')
assert event.metadata_ == {
'document_id': document.id,
'file_type': 'md',
'ingestion_status': 'uploaded',
}
assert event.status == 'pending'
@pytest.mark.asyncio
async def test_chat_simple_creates_brain_event_for_assistant_message(brain_ingestion_env):
session, user = brain_ingestion_env
service = AgentService(session)
conversation_id, _message_id, response, _model_name = await service.chat_simple(
user.id,
'帮我总结今天知识大脑的进展。',
)
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == response
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_streaming_chat_creates_brain_event_for_assistant_message(brain_ingestion_env, monkeypatch):
session, user = brain_ingestion_env
monkeypatch.setattr(agent_service, 'get_agent_graph', lambda: FakeStreamingGraph())
service = AgentService(session)
conversation_id, _message_id, stream = await service.chat(
user.id,
'用流式回复告诉我今天知识大脑学到了什么。',
)
chunks = []
async for event in stream:
if event.get('type') == 'chunk':
chunks.append(event['content'])
result = await session.execute(
select(BrainEvent)
.where(BrainEvent.user_id == user.id, BrainEvent.source_type == 'conversation')
.order_by(BrainEvent.created_at.asc())
)
events = list(result.scalars().all())
assert ''.join(chunks) == '这是流式回复。'
assert len(events) == 2
assert events[1].source_id == conversation_id
assert events[1].event_type == 'message_created'
assert events[1].title == 'Assistant message'
assert events[1].content_summary == '这是流式回复。'
assert events[1].metadata_ == {'role': 'assistant'}
@pytest.mark.asyncio
async def test_build_memory_context_includes_brain_memory_section(brain_ingestion_env):
session, user = brain_ingestion_env
conversation = Conversation(user_id=user.id, title='Brain context test')
session.add(conversation)
await session.flush()
session.add(UserMemory(
user_id=user.id,
memory_type='preference',
content='用户偏好结构化输出。',
importance=6,
source_conversation_id=conversation.id,
))
session.add(MemorySummary(
user_id=user.id,
conversation_id=conversation.id,
summary_text='之前讨论了知识大脑的整体设计。',
turn_count=8,
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Knowledge brain phase 1',
content='Jarvis should learn from conversation and document events first.',
importance=9,
confidence=0.93,
status='active',
origin_source_types=['conversation', 'document'],
metadata_={'source_count': 2},
))
session.add(BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Forum moderation policy',
content='Forum moderation escalation stays separate from the current task.',
importance=10,
confidence=0.95,
status='active',
origin_source_types=['forum'],
metadata_={'source_count': 1},
))
await session.commit()
context = await memory_service.build_memory_context(
session,
user.id,
conversation.id,
'Jarvis 接下来应该优先做什么?',
)
assert '【用户记忆】' in context
assert '【之前对话摘要】' in context
assert '【知识大脑】' in context
assert 'Knowledge brain phase 1' in context
assert 'Jarvis should learn from conversation and document events first.' in context
assert 'Forum moderation policy' not in context

View File

@@ -0,0 +1,194 @@
import sys
from unittest.mock import Mock
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
sys.modules.setdefault('psutil', Mock())
import app.models # noqa: F401
from app.database import Base, get_db
from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag
from app.models.user import User
from app.routers.auth import get_current_user
from app.routers.brain import router as brain_router
from app.services.auth_service import get_password_hash
@pytest.fixture
async def brain_router_env(tmp_path):
db_path = tmp_path / 'test_brain_router.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='brain@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Brain Tester',
)
session.add(user)
await session.flush()
session.add_all([
BrainMemory(
user_id=user.id,
memory_type='project_fact',
title='Current project direction',
content='Jarvis knowledge brain should learn from all major product surfaces.',
importance=8,
confidence=0.92,
status='active',
),
BrainMemory(
user_id=user.id,
memory_type='preference',
title='User prefers brain-first UX',
content='The knowledge brain should be broader than the graph page.',
importance=7,
confidence=0.88,
status='active',
),
BrainTag(
user_id=user.id,
name='knowledge-brain',
category='topic',
priority='important',
score=9.5,
),
BrainTag(
user_id=user.id,
name='graph',
category='topic',
priority='secondary',
score=4.0,
),
BrainEvent(
user_id=user.id,
source_type='conversation',
source_id='conv-1',
event_type='created',
title='Conversation created',
content_summary='User described the desired knowledge brain behavior.',
status='pending',
),
BrainEvent(
user_id=user.id,
source_type='document',
source_id='doc-1',
event_type='indexed',
title='Document indexed',
content_summary='A strategic document was indexed into the system.',
status='processed',
),
BrainCandidate(
user_id=user.id,
candidate_type='project_fact',
title='Brain spans all product surfaces',
summary='The knowledge brain should learn from conversation, docs, tasks, todos, and forum.',
importance_score=9.2,
confidence_score=0.95,
status='new',
),
])
await session.commit()
await session.refresh(user)
async def override_get_db():
async with session_factory() as session:
yield session
async def override_get_current_user():
return user
test_app = FastAPI()
test_app.include_router(brain_router)
test_app.dependency_overrides[get_db] = override_get_db
test_app.dependency_overrides[get_current_user] = override_get_current_user
try:
yield test_app
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_brain_overview_returns_memory_and_tag_summary(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/overview')
assert response.status_code == 200
payload = response.json()
assert payload['active_memory_count'] == 2
assert payload['important_tag_count'] == 1
assert payload['secondary_tag_count'] == 1
assert payload['recent_memory_titles'] == [
'Current project direction',
'User prefers brain-first UX',
]
@pytest.mark.asyncio
async def test_list_brain_memories_returns_active_memories_sorted_by_importance(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/memories')
assert response.status_code == 200
payload = response.json()
assert [item['title'] for item in payload] == [
'Current project direction',
'User prefers brain-first UX',
]
assert all(item['status'] == 'active' for item in payload)
@pytest.mark.asyncio
async def test_list_brain_tags_groups_important_and_secondary_tags(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/tags')
assert response.status_code == 200
payload = response.json()
assert [item['name'] for item in payload['important']] == ['knowledge-brain']
assert [item['name'] for item in payload['secondary']] == ['graph']
@pytest.mark.asyncio
async def test_list_brain_events_returns_latest_events_first(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.get('/api/brain/events')
assert response.status_code == 200
payload = response.json()
assert len(payload) == 2
assert payload[0]['title'] == 'Document indexed'
assert payload[1]['title'] == 'Conversation created'
@pytest.mark.asyncio
async def test_manual_brain_learning_run_returns_processed_counts(brain_router_env):
transport = ASGITransport(app=brain_router_env)
async with AsyncClient(transport=transport, base_url='http://testserver') as client:
response = await client.post('/api/brain/learn/run')
assert response.status_code == 200
payload = response.json()
assert payload == {
'events_considered': 1,
'candidates_created': 1,
'memories_promoted': 1,
}

View File

@@ -0,0 +1,371 @@
import json
from io import BytesIO
from types import SimpleNamespace
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.datastructures import UploadFile
import app.models # noqa: F401
from app.database import Base
from app.models.document import Document, DocumentChunk
from app.models.folder import Folder
from app.models.user import User
from app.services.auth_service import get_password_hash
from app.services.knowledge_service import KnowledgeService, SearchResult
from app.services.graph_service import GraphService
class FakeCollection:
def __init__(self):
self.add_calls = []
self.delete_calls = []
def add(self, *, ids, documents, metadatas):
self.add_calls.append({
'ids': ids,
'documents': documents,
'metadatas': metadatas,
})
def delete(self, *, where):
self.delete_calls.append(where)
def query(self, **kwargs):
self.last_query = kwargs
return {
'ids': [['chunk-schema', 'chunk-rows']],
'documents': [['schema chunk', 'row chunk']],
'metadatas': [[
{
'document_id': 'doc-1',
'document_title': 'Revenue',
'chunk_index': 0,
'content_type': 'table_schema',
'sheet_name': 'Revenue',
'row_start': 0,
'row_end': 0,
},
{
'document_id': 'doc-1',
'document_title': 'Revenue',
'chunk_index': 1,
'content_type': 'table_rows',
'sheet_name': 'Revenue',
'row_start': 1,
'row_end': 10,
},
]],
'distances': [[0.3, 0.35]],
}
@pytest.fixture
async def knowledge_test_env(tmp_path):
db_path = tmp_path / 'test_knowledge.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with session_factory() as session:
user = User(
email='knowledge@example.com',
hashed_password=get_password_hash('secret123'),
full_name='Knowledge Tester',
)
session.add(user)
await session.flush()
root = Folder(user_id=user.id, name='Finance', parent_id=None)
session.add(root)
await session.flush()
child = Folder(user_id=user.id, name='Reports', parent_id=root.id)
session.add(child)
await session.flush()
document = Document(
id='doc-1',
user_id=user.id,
title='Revenue Workbook',
filename='revenue.xlsx',
file_type='xlsx',
file_size=128,
file_path=str(tmp_path / 'revenue.xlsx'),
folder_id=child.id,
summary='Revenue summary',
chunk_count=2,
is_indexed=False,
)
session.add(document)
session.add_all([
DocumentChunk(
id='chunk-schema',
document_id=document.id,
chunk_index=0,
content='schema chunk',
metadata_=json.dumps({
'content_type': 'table_schema',
'sheet_name': 'Revenue',
'headers': ['region', 'amount'],
'source_order': 0,
'section_path': ['Revenue'],
'page_number': 1,
}),
),
DocumentChunk(
id='chunk-rows',
document_id=document.id,
chunk_index=1,
content='row chunk',
metadata_=json.dumps({
'content_type': 'table_rows',
'sheet_name': 'Revenue',
'row_start': 1,
'row_end': 10,
'source_order': 1,
'section_path': ['Revenue'],
'page_number': 1,
}),
),
])
await session.commit()
await session.refresh(user)
await session.refresh(document)
await session.refresh(child)
yield session, user, document, child
await engine.dispose()
@pytest.mark.asyncio
async def test_index_document_writes_folder_and_structure_metadata(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
await service.index_document(document.id, user.id)
assert fake_collection.add_calls
metadatas = fake_collection.add_calls[0]['metadatas']
assert metadatas[0]['folder_path'] == '/Finance/Reports'
assert metadatas[0]['content_type'] == 'table_schema'
assert metadatas[0]['sheet_name'] == 'Revenue'
assert metadatas[1]['content_type'] == 'table_rows'
await session.refresh(document)
assert document.is_indexed is True
assert document.ingestion_status == 'ready'
assert document.indexed_at is not None
@pytest.mark.asyncio
async def test_retrieve_prefers_table_schema_for_tabular_queries(knowledge_test_env):
session, user, _document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
results = await service.retrieve('excel表 Revenue 的列有哪些', user.id, top_k=2, use_rerank=True)
assert [item.chunk_id for item in results] == ['chunk-schema', 'chunk-rows']
metadata = json.loads(results[0].metadata_)
assert metadata['content_type'] == 'table_schema'
@pytest.mark.asyncio
async def test_context_expansion_uses_same_sheet_for_table_rows(knowledge_test_env):
session, user, _document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
prev_chunk, next_chunk = await service._get_related_chunks('chunk-rows', 1, 'doc-1')
assert prev_chunk == 'schema chunk'
assert next_chunk is None
def test_rerank_boosts_table_chunks_when_query_mentions_sheet():
service = KnowledgeService(db=None, user_id='user-1')
schema = SearchResult(
chunk_id='schema',
document_id='doc-1',
document_title='Revenue Workbook',
content='Columns: region amount',
score=0.6,
metadata_=json.dumps({'content_type': 'table_schema', 'sheet_name': 'Revenue'}),
)
paragraph = SearchResult(
chunk_id='paragraph',
document_id='doc-1',
document_title='Revenue Workbook',
content='General revenue narrative',
score=0.65,
metadata_=json.dumps({'content_type': 'paragraph', 'section_title': 'Overview'}),
)
ranked = service._rerank('sheet Revenue 有哪些列', [paragraph, schema], top_k=2)
assert [item.chunk_id for item in ranked] == ['schema', 'paragraph']
@pytest.mark.asyncio
async def test_context_expansion_prefers_same_section_before_linear_neighbors(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
session.add_all([
DocumentChunk(
id='chunk-overview',
document_id=document.id,
chunk_index=2,
content='overview chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Overview'],
'section_title': 'Overview',
'source_order': 2,
'page_number': 2,
}),
),
DocumentChunk(
id='chunk-overview-2',
document_id=document.id,
chunk_index=3,
content='overview details chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Overview'],
'section_title': 'Overview',
'source_order': 3,
'page_number': 2,
}),
),
DocumentChunk(
id='chunk-appendix',
document_id=document.id,
chunk_index=4,
content='appendix chunk',
metadata_=json.dumps({
'content_type': 'paragraph',
'section_path': ['Appendix'],
'section_title': 'Appendix',
'source_order': 4,
'page_number': 3,
}),
),
])
await session.commit()
service = KnowledgeService(session, user_id=user.id)
prev_chunk, next_chunk = await service._get_related_chunks('chunk-overview-2', 3, 'doc-1')
assert prev_chunk == 'overview chunk'
assert next_chunk is None
@pytest.mark.asyncio
async def test_reindex_document_rebuilds_chunks_and_versions(knowledge_test_env):
session, user, _document, folder = knowledge_test_env
from app.services.document_service import DocumentService
from app.services.knowledge_service import KnowledgeService
upload = UploadFile(
filename='reindex.md',
file=BytesIO(b'# Intro\n\nOriginal content\n\n## Details\n\nUpdated content'),
)
doc_service = DocumentService(session)
document = await doc_service.upload_document(user.id, upload, folder_id=folder.id)
chunk_result = await session.execute(select(DocumentChunk).where(DocumentChunk.document_id == document.id))
original_chunks = list(chunk_result.scalars().all())
assert original_chunks
service = KnowledgeService(session, user_id=user.id)
rebuilt = await service.reindex_document(document.id, user.id)
assert rebuilt is True
await session.refresh(document)
assert document.parser_version == 'v2'
assert document.index_version == 'v2'
assert document.ingestion_status == 'ready'
new_chunk_result = await session.execute(
select(DocumentChunk)
.where(DocumentChunk.document_id == document.id)
.order_by(DocumentChunk.chunk_index)
)
rebuilt_chunks = list(new_chunk_result.scalars().all())
assert rebuilt_chunks
assert all(chunk.metadata_ for chunk in rebuilt_chunks)
@pytest.mark.asyncio
async def test_reindex_document_chunks_reuses_existing_db_chunks(knowledge_test_env):
session, user, document, _folder = knowledge_test_env
service = KnowledgeService(session, user_id=user.id)
fake_collection = FakeCollection()
service.get_collection = lambda user_id: fake_collection
chunk_result = await session.execute(
select(DocumentChunk)
.where(DocumentChunk.id == 'chunk-schema')
)
chunk = chunk_result.scalar_one()
chunk.content = 'edited schema chunk'
document.ingestion_status = 'indexing'
await session.commit()
rebuilt = await service.reindex_document_chunks(document.id, user.id)
assert rebuilt is True
assert fake_collection.delete_calls == [{'document_id': document.id}]
assert fake_collection.add_calls
assert fake_collection.add_calls[0]['documents'][0] == 'edited schema chunk'
await session.refresh(document)
assert document.ingestion_status == 'ready'
assert document.indexed_at is not None
@pytest.mark.anyio
async def test_graph_service_uses_user_llm_configured_model(knowledge_test_env, monkeypatch):
session, user, document, _folder = knowledge_test_env
document.is_indexed = True
await session.commit()
used_providers = []
class FakeLLM:
async def invoke(self, _messages):
return SimpleNamespace(content=json.dumps({
'entities': [{'name': 'Revenue', 'type': 'topic', 'description': 'Revenue topic'}],
'relations': [],
}))
async def fake_get_user_llm_config(self, user_id, model_name=None):
assert user_id == user.id
assert model_name is None
return {
'provider': 'openai',
'model': 'user-model',
'api_key': 'secret',
'base_url': 'https://example.com/v1',
'enabled': True,
}
def fake_create_llm_from_config(config):
used_providers.append(config['provider'])
return FakeLLM()
def fail_if_global_llm_used():
raise AssertionError('global get_llm should not be used when user config exists')
monkeypatch.setattr('app.services.graph_service.get_llm', fail_if_global_llm_used)
monkeypatch.setattr('app.services.graph_service.resolve_user_llm', fake_get_user_llm_config, raising=False)
monkeypatch.setattr('app.services.graph_service._create_llm_from_config', fake_create_llm_from_config, raising=False)
service = GraphService(session)
await service.build_graph(user.id)
assert used_providers == ['openai']

View File

@@ -0,0 +1,28 @@
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
import app.models # noqa: F401
from app.database import Base
@pytest.mark.anyio
async def test_brain_tables_are_registered_in_metadata(tmp_path):
db_path = tmp_path / 'test_brain_models.db'
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", future=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'"))
table_names = {row[0] for row in result.fetchall()}
await engine.dispose()
assert 'brain_events' in table_names
assert 'brain_candidates' in table_names
assert 'brain_memories' in table_names
assert 'brain_tags' in table_names
assert 'brain_event_tags' in table_names
assert 'brain_memory_tags' in table_names
assert 'brain_memory_sources' in table_names