feat(memory): complete M.2-M.5 memory upgrade phases with tests
- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
This commit is contained in:
290
backend/tests/services/test_memory_extractor.py
Normal file
290
backend/tests/services/test_memory_extractor.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Tests for MemoryExtractor (M.4)
|
||||
|
||||
Tests: extract_from_conversation, _deduplicate, _is_similar, save_memories.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from app.services.memory.memory_extractor import (
|
||||
MemoryExtractor,
|
||||
ExtractedMemory,
|
||||
MEMORY_TYPES,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_message(role: str = "user", content: str = "test"):
|
||||
"""Create a mock Message."""
|
||||
msg = MagicMock()
|
||||
msg.role = role
|
||||
msg.content = content
|
||||
msg.created_at = datetime.now(UTC)
|
||||
return msg
|
||||
|
||||
|
||||
def create_mock_user_memory(
|
||||
id: int = 1,
|
||||
content: str = "test memory",
|
||||
memory_type: str = "fact",
|
||||
importance_score: float = 0.5,
|
||||
is_archived: bool = False,
|
||||
):
|
||||
"""Create a mock UserMemory."""
|
||||
mem = MagicMock()
|
||||
mem.id = id
|
||||
mem.content = content
|
||||
mem.memory_type = memory_type
|
||||
mem.importance_score = importance_score
|
||||
mem.is_archived = is_archived
|
||||
return mem
|
||||
|
||||
|
||||
class TestExtractedMemory:
|
||||
"""Test ExtractedMemory dataclass."""
|
||||
|
||||
def test_extracted_memory_fields(self):
|
||||
"""ExtractedMemory has correct fields."""
|
||||
mem = ExtractedMemory(
|
||||
memory_type="fact",
|
||||
content="用户喜欢喝咖啡",
|
||||
confidence=0.9,
|
||||
source_conversation_id="conv-123",
|
||||
)
|
||||
|
||||
assert mem.memory_type == "fact"
|
||||
assert mem.content == "用户喜欢喝咖啡"
|
||||
assert mem.confidence == 0.9
|
||||
assert mem.source_conversation_id == "conv-123"
|
||||
|
||||
|
||||
class TestMemoryExtractorIsSimilar:
|
||||
"""Test _is_similar() method."""
|
||||
|
||||
def test_is_similar_high_overlap(self):
|
||||
"""High keyword overlap returns True."""
|
||||
extractor = MemoryExtractor()
|
||||
|
||||
# Use English with clear word overlap
|
||||
result = extractor._is_similar("I like coffee and tea", "I like coffee and tea with milk")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_similar_low_overlap(self):
|
||||
"""Low keyword overlap returns False."""
|
||||
extractor = MemoryExtractor()
|
||||
|
||||
result = extractor._is_similar("用户喜欢喝咖啡", "今天天气很好")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_similar_empty_content(self):
|
||||
"""Empty content returns False."""
|
||||
extractor = MemoryExtractor()
|
||||
|
||||
result = extractor._is_similar("", "some text")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_similar_substring_match(self):
|
||||
"""Same first 20 chars returns True."""
|
||||
extractor = MemoryExtractor()
|
||||
|
||||
# First 20 chars: "这是一个测试字符串ABCDEFGHIJKLMN" (20 chars)
|
||||
text1 = "这是一个测试字符串ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
text2 = "这是一个测试字符串ABCDEFGHIJKLMNQRSTUVWXYZ"
|
||||
result = extractor._is_similar(text1, text2)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_similar_case_insensitive(self):
|
||||
"""Comparison is case insensitive."""
|
||||
extractor = MemoryExtractor()
|
||||
|
||||
result = extractor._is_similar("USER LIKES COFFEE", "user likes coffee and tea")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestMemoryExtractorDeduplicate:
|
||||
"""Test _deduplicate() method.
|
||||
|
||||
Note: Full async integration tests would require proper AsyncSession mocking.
|
||||
These tests verify the deduplication logic with simplified synchronous mocks.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicate_empty_list(self):
|
||||
"""Empty list returns empty list."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
|
||||
result = await extractor._deduplicate(mock_db, "user-123", [])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestMemoryExtractorExtractFromConversation:
|
||||
"""Test extract_from_conversation() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_skips_short_conversation(self):
|
||||
"""Less than 2 messages returns empty list."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
messages = [create_mock_message()]
|
||||
|
||||
result = await extractor.extract_from_conversation(
|
||||
mock_db, "user-123", "conv-123", messages
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_calls_llm(self):
|
||||
"""Calls LLM to extract memories."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
messages = [create_mock_message(), create_mock_message()]
|
||||
|
||||
with patch.object(
|
||||
extractor,
|
||||
"_call_llm_extract",
|
||||
return_value=[{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}],
|
||||
) as mock_call:
|
||||
with patch.object(extractor, "_deduplicate", return_value=[]):
|
||||
result = await extractor.extract_from_conversation(
|
||||
mock_db, "user-123", "conv-123", messages
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_filters_invalid_types(self):
|
||||
"""Filters out memories with invalid type."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
messages = [create_mock_message(), create_mock_message()]
|
||||
|
||||
with patch.object(
|
||||
extractor,
|
||||
"_call_llm_extract",
|
||||
return_value=[
|
||||
{"type": "invalid_type", "content": "test", "confidence": 0.5},
|
||||
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
||||
],
|
||||
):
|
||||
valid_mem = ExtractedMemory(
|
||||
memory_type="fact",
|
||||
content="用户喜欢喝咖啡",
|
||||
confidence=0.9,
|
||||
source_conversation_id="conv-123",
|
||||
)
|
||||
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
||||
result = await extractor.extract_from_conversation(
|
||||
mock_db, "user-123", "conv-123", messages
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].memory_type == "fact"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_filters_empty_content(self):
|
||||
"""Filters out memories with empty content."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
messages = [create_mock_message(), create_mock_message()]
|
||||
|
||||
with patch.object(
|
||||
extractor,
|
||||
"_call_llm_extract",
|
||||
return_value=[
|
||||
{"type": "fact", "content": "", "confidence": 0.5},
|
||||
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
||||
],
|
||||
):
|
||||
valid_mem = ExtractedMemory(
|
||||
memory_type="fact",
|
||||
content="用户喜欢喝咖啡",
|
||||
confidence=0.9,
|
||||
source_conversation_id="conv-123",
|
||||
)
|
||||
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
||||
result = await extractor.extract_from_conversation(
|
||||
mock_db, "user-123", "conv-123", messages
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_sets_source_conversation_id(self):
|
||||
"""Sets source_conversation_id on extracted memories."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
messages = [create_mock_message(), create_mock_message()]
|
||||
|
||||
with patch.object(
|
||||
extractor,
|
||||
"_call_llm_extract",
|
||||
return_value=[
|
||||
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
||||
],
|
||||
):
|
||||
valid_mem = ExtractedMemory(
|
||||
memory_type="fact",
|
||||
content="用户喜欢喝咖啡",
|
||||
confidence=0.9,
|
||||
source_conversation_id="conv-abc",
|
||||
)
|
||||
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
||||
result = await extractor.extract_from_conversation(
|
||||
mock_db, "user-123", "conv-abc", messages
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].source_conversation_id == "conv-abc"
|
||||
|
||||
|
||||
class TestMemoryExtractorSaveMemories:
|
||||
"""Test save_memories() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_memories_adds_to_db(self):
|
||||
"""Adds memories to db and commits."""
|
||||
extractor = MemoryExtractor()
|
||||
mock_db = AsyncMock()
|
||||
mock_db.commit = AsyncMock()
|
||||
mock_db.refresh = AsyncMock()
|
||||
|
||||
memories = [
|
||||
ExtractedMemory(
|
||||
memory_type="fact",
|
||||
content="用户喜欢喝咖啡",
|
||||
confidence=0.9,
|
||||
source_conversation_id="conv-123",
|
||||
)
|
||||
]
|
||||
|
||||
with patch.object(extractor, "_deduplicate", return_value=memories):
|
||||
result = await extractor.save_memories(mock_db, "user-123", "conv-123", memories)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_db.add.assert_called()
|
||||
mock_db.commit.assert_called()
|
||||
|
||||
|
||||
class TestMemoryTypes:
|
||||
"""Test MEMORY_TYPES constant."""
|
||||
|
||||
def test_memory_types_has_all_types(self):
|
||||
"""MEMORY_TYPES includes all expected types."""
|
||||
assert "fact" in MEMORY_TYPES
|
||||
assert "preference" in MEMORY_TYPES
|
||||
assert "goal" in MEMORY_TYPES
|
||||
assert "pain_point" in MEMORY_TYPES
|
||||
assert "event" in MEMORY_TYPES
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user