238 lines
7.5 KiB
Python
238 lines
7.5 KiB
Python
|
|
"""
|
||
|
|
Tests for MemoryRecallInjector (M.5)
|
||
|
|
|
||
|
|
Tests: build_context, _rank, _budget_select, _format, recall_user_memories_for_injection.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from datetime import UTC, datetime, timedelta
|
||
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||
|
|
|
||
|
|
from app.services.memory.recall_injector import (
|
||
|
|
MemoryRecallInjector,
|
||
|
|
recall_user_memories_for_injection,
|
||
|
|
MEMORY_TYPE_PRIORITY,
|
||
|
|
DEFAULT_TOKEN_BUDGET,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def create_mock_memory(
|
||
|
|
id: int = 1,
|
||
|
|
content: str = "test memory",
|
||
|
|
memory_type: str = "fact",
|
||
|
|
importance_score: float = 0.5,
|
||
|
|
is_archived: bool = False,
|
||
|
|
):
|
||
|
|
"""Create a mock UserMemory."""
|
||
|
|
mem = MagicMock()
|
||
|
|
mem.id = id
|
||
|
|
mem.content = content
|
||
|
|
mem.memory_type = memory_type
|
||
|
|
mem.importance_score = importance_score
|
||
|
|
mem.is_archived = is_archived
|
||
|
|
return mem
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryRecallInjectorFormat:
|
||
|
|
"""Test _format() method."""
|
||
|
|
|
||
|
|
def test_format_empty_list(self):
|
||
|
|
"""Empty list returns empty string."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
|
||
|
|
result = injector._format([])
|
||
|
|
|
||
|
|
assert result == ""
|
||
|
|
|
||
|
|
def test_format_single_memory(self):
|
||
|
|
"""Single memory formatted correctly."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
memory = create_mock_memory(content="用户喜欢喝咖啡", memory_type="preference")
|
||
|
|
|
||
|
|
result = injector._format([memory])
|
||
|
|
|
||
|
|
assert "用户喜欢喝咖啡" in result
|
||
|
|
assert "[preference]" in result
|
||
|
|
assert "[关于你的记忆]" in result
|
||
|
|
|
||
|
|
def test_format_multiple_memories(self):
|
||
|
|
"""Multiple memories formatted with bullets."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
mem1 = create_mock_memory(content="用户住在上海", memory_type="fact")
|
||
|
|
mem2 = create_mock_memory(content="用户喜欢喝咖啡", memory_type="preference")
|
||
|
|
|
||
|
|
result = injector._format([mem1, mem2])
|
||
|
|
|
||
|
|
assert "[关于你的记忆]" in result
|
||
|
|
assert "- [fact] 用户住在上海" in result
|
||
|
|
assert "- [preference] 用户喜欢喝咖啡" in result
|
||
|
|
|
||
|
|
def test_format_handles_missing_type(self):
|
||
|
|
"""Memory without type falls back gracefully."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
memory = create_mock_memory(memory_type=None, content="some content")
|
||
|
|
|
||
|
|
result = injector._format([memory])
|
||
|
|
|
||
|
|
assert "some content" in result
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryRecallInjectorBudgetSelect:
|
||
|
|
"""Test _budget_select() method."""
|
||
|
|
|
||
|
|
def test_budget_select_respects_limit(self):
|
||
|
|
"""Stops when token budget exhausted."""
|
||
|
|
injector = MemoryRecallInjector(token_budget=50) # Small budget
|
||
|
|
|
||
|
|
memories = [
|
||
|
|
create_mock_memory(content="短内容"), # ~6 chars → ~3 tokens
|
||
|
|
create_mock_memory(content="这是一个比较长的内容记忆"), # ~12 chars → ~6 tokens
|
||
|
|
create_mock_memory(content="这是非常非常长的内容记忆"), # ~14 chars → ~7 tokens
|
||
|
|
]
|
||
|
|
|
||
|
|
selected = injector._budget_select(memories, 50)
|
||
|
|
|
||
|
|
# Should select as many as fit in budget
|
||
|
|
assert len(selected) <= len(memories)
|
||
|
|
|
||
|
|
def test_budget_select_empty_list(self):
|
||
|
|
"""Empty list returns empty."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
|
||
|
|
selected = injector._budget_select([], 800)
|
||
|
|
|
||
|
|
assert selected == []
|
||
|
|
|
||
|
|
def test_budget_select_all_fit(self):
|
||
|
|
"""When all fit in budget, returns all."""
|
||
|
|
injector = MemoryRecallInjector(token_budget=10000) # Large budget
|
||
|
|
|
||
|
|
memories = [
|
||
|
|
create_mock_memory(content="short"),
|
||
|
|
create_mock_memory(content="medium content"),
|
||
|
|
]
|
||
|
|
|
||
|
|
selected = injector._budget_select(memories, 10000)
|
||
|
|
|
||
|
|
assert len(selected) == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryRecallInjectorRank:
|
||
|
|
"""Test _rank() method."""
|
||
|
|
|
||
|
|
def test_rank_orders_by_score(self):
|
||
|
|
"""Memories sorted by relevance * 0.6 + importance * 0.4 * type_boost."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
|
||
|
|
# pain_point gets 1.0 type boost, fact gets 0.8
|
||
|
|
mem_pain = create_mock_memory(
|
||
|
|
id=1, memory_type="pain_point", importance_score=0.9, content="pain"
|
||
|
|
)
|
||
|
|
mem_pain.similarity_score = 0.5
|
||
|
|
mem_fact = create_mock_memory(
|
||
|
|
id=2, memory_type="fact", importance_score=0.5, content="fact"
|
||
|
|
)
|
||
|
|
mem_fact.similarity_score = 0.5
|
||
|
|
|
||
|
|
# pain_point: 0.5*0.6 + 0.9*0.4*1.0 = 0.30 + 0.36 = 0.66
|
||
|
|
# fact: 0.5*0.6 + 0.5*0.4*0.8 = 0.30 + 0.16 = 0.46
|
||
|
|
ranked = injector._rank([mem_pain, mem_fact], "test query")
|
||
|
|
|
||
|
|
# pain_point should come first due to type boost and higher importance
|
||
|
|
assert ranked[0].memory_type == "pain_point"
|
||
|
|
|
||
|
|
def test_rank_empty_list(self):
|
||
|
|
"""Empty list returns empty."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
|
||
|
|
ranked = injector._rank([], "test query")
|
||
|
|
|
||
|
|
assert ranked == []
|
||
|
|
|
||
|
|
def test_rank_single_memory(self):
|
||
|
|
"""Single memory returns single item."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
memory = create_mock_memory(content="only one")
|
||
|
|
|
||
|
|
ranked = injector._rank([memory], "query")
|
||
|
|
|
||
|
|
assert len(ranked) == 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryRecallInjectorBuildContext:
|
||
|
|
"""Test build_context() method."""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_build_context_returns_string(self):
|
||
|
|
"""Returns string (possibly empty)."""
|
||
|
|
injector = MemoryRecallInjector()
|
||
|
|
mock_db = AsyncMock()
|
||
|
|
|
||
|
|
with patch(
|
||
|
|
"app.services.memory.recall_injector.recall_user_memories_for_injection",
|
||
|
|
return_value=[],
|
||
|
|
) as mock_recall:
|
||
|
|
result = await injector.build_context(mock_db, "user-123", "test message")
|
||
|
|
|
||
|
|
assert isinstance(result, str)
|
||
|
|
mock_recall.assert_called_once()
|
||
|
|
|
||
|
|
|
||
|
|
class TestRecallUserMemoriesForInjection:
|
||
|
|
"""Test recall_user_memories_for_injection() function."""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_returns_user_memories(self):
|
||
|
|
"""Returns UserMemory objects."""
|
||
|
|
mock_db = AsyncMock()
|
||
|
|
mock_mem = create_mock_memory(content="test")
|
||
|
|
|
||
|
|
mock_result = MagicMock()
|
||
|
|
mock_result.scalars.return_value.all.return_value = [mock_mem]
|
||
|
|
mock_db.execute = AsyncMock(return_value=mock_result)
|
||
|
|
|
||
|
|
result = await recall_user_memories_for_injection(
|
||
|
|
mock_db, "user-123", "test query", top_k=5
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(result) >= 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_token_matching(self):
|
||
|
|
"""Query tokens are matched against memory content."""
|
||
|
|
mock_db = AsyncMock()
|
||
|
|
mock_mem = create_mock_memory(content="用户喜欢喝咖啡")
|
||
|
|
|
||
|
|
mock_result = MagicMock()
|
||
|
|
mock_result.scalars.return_value.all.return_value = [mock_mem]
|
||
|
|
mock_db.execute = AsyncMock(return_value=mock_result)
|
||
|
|
|
||
|
|
result = await recall_user_memories_for_injection(mock_db, "user-123", "咖啡", top_k=5)
|
||
|
|
|
||
|
|
# Should match because "咖啡" is in content
|
||
|
|
assert len(result) >= 1
|
||
|
|
|
||
|
|
|
||
|
|
class TestMemoryTypePriority:
|
||
|
|
"""Test MEMORY_TYPE_PRIORITY constant."""
|
||
|
|
|
||
|
|
def test_priority_values(self):
|
||
|
|
"""pain_point=1 (highest), goal=2, preference=3, fact=4, event=5."""
|
||
|
|
assert MEMORY_TYPE_PRIORITY["pain_point"] == 1
|
||
|
|
assert MEMORY_TYPE_PRIORITY["goal"] == 2
|
||
|
|
assert MEMORY_TYPE_PRIORITY["preference"] == 3
|
||
|
|
assert MEMORY_TYPE_PRIORITY["fact"] == 4
|
||
|
|
assert MEMORY_TYPE_PRIORITY["event"] == 5
|
||
|
|
|
||
|
|
|
||
|
|
class TestDefaultTokenBudget:
|
||
|
|
"""Test DEFAULT_TOKEN_BUDGET constant."""
|
||
|
|
|
||
|
|
def test_default_budget_value(self):
|
||
|
|
"""Default token budget is 800."""
|
||
|
|
assert DEFAULT_TOKEN_BUDGET == 800
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__, "-v"])
|