- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
291 lines
9.2 KiB
Python
291 lines
9.2 KiB
Python
"""
|
|
Tests for MemoryExtractor (M.4)
|
|
|
|
Tests: extract_from_conversation, _deduplicate, _is_similar, save_memories.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
|
|
from app.services.memory.memory_extractor import (
|
|
MemoryExtractor,
|
|
ExtractedMemory,
|
|
MEMORY_TYPES,
|
|
)
|
|
|
|
|
|
def create_mock_message(role: str = "user", content: str = "test"):
|
|
"""Create a mock Message."""
|
|
msg = MagicMock()
|
|
msg.role = role
|
|
msg.content = content
|
|
msg.created_at = datetime.now(UTC)
|
|
return msg
|
|
|
|
|
|
def create_mock_user_memory(
|
|
id: int = 1,
|
|
content: str = "test memory",
|
|
memory_type: str = "fact",
|
|
importance_score: float = 0.5,
|
|
is_archived: bool = False,
|
|
):
|
|
"""Create a mock UserMemory."""
|
|
mem = MagicMock()
|
|
mem.id = id
|
|
mem.content = content
|
|
mem.memory_type = memory_type
|
|
mem.importance_score = importance_score
|
|
mem.is_archived = is_archived
|
|
return mem
|
|
|
|
|
|
class TestExtractedMemory:
|
|
"""Test ExtractedMemory dataclass."""
|
|
|
|
def test_extracted_memory_fields(self):
|
|
"""ExtractedMemory has correct fields."""
|
|
mem = ExtractedMemory(
|
|
memory_type="fact",
|
|
content="用户喜欢喝咖啡",
|
|
confidence=0.9,
|
|
source_conversation_id="conv-123",
|
|
)
|
|
|
|
assert mem.memory_type == "fact"
|
|
assert mem.content == "用户喜欢喝咖啡"
|
|
assert mem.confidence == 0.9
|
|
assert mem.source_conversation_id == "conv-123"
|
|
|
|
|
|
class TestMemoryExtractorIsSimilar:
|
|
"""Test _is_similar() method."""
|
|
|
|
def test_is_similar_high_overlap(self):
|
|
"""High keyword overlap returns True."""
|
|
extractor = MemoryExtractor()
|
|
|
|
# Use English with clear word overlap
|
|
result = extractor._is_similar("I like coffee and tea", "I like coffee and tea with milk")
|
|
|
|
assert result is True
|
|
|
|
def test_is_similar_low_overlap(self):
|
|
"""Low keyword overlap returns False."""
|
|
extractor = MemoryExtractor()
|
|
|
|
result = extractor._is_similar("用户喜欢喝咖啡", "今天天气很好")
|
|
|
|
assert result is False
|
|
|
|
def test_is_similar_empty_content(self):
|
|
"""Empty content returns False."""
|
|
extractor = MemoryExtractor()
|
|
|
|
result = extractor._is_similar("", "some text")
|
|
|
|
assert result is False
|
|
|
|
def test_is_similar_substring_match(self):
|
|
"""Same first 20 chars returns True."""
|
|
extractor = MemoryExtractor()
|
|
|
|
# First 20 chars: "这是一个测试字符串ABCDEFGHIJKLMN" (20 chars)
|
|
text1 = "这是一个测试字符串ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
text2 = "这是一个测试字符串ABCDEFGHIJKLMNQRSTUVWXYZ"
|
|
result = extractor._is_similar(text1, text2)
|
|
|
|
assert result is True
|
|
|
|
def test_is_similar_case_insensitive(self):
|
|
"""Comparison is case insensitive."""
|
|
extractor = MemoryExtractor()
|
|
|
|
result = extractor._is_similar("USER LIKES COFFEE", "user likes coffee and tea")
|
|
|
|
assert result is True
|
|
|
|
|
|
class TestMemoryExtractorDeduplicate:
|
|
"""Test _deduplicate() method.
|
|
|
|
Note: Full async integration tests would require proper AsyncSession mocking.
|
|
These tests verify the deduplication logic with simplified synchronous mocks.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deduplicate_empty_list(self):
|
|
"""Empty list returns empty list."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
|
|
result = await extractor._deduplicate(mock_db, "user-123", [])
|
|
|
|
assert result == []
|
|
|
|
|
|
class TestMemoryExtractorExtractFromConversation:
|
|
"""Test extract_from_conversation() method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_skips_short_conversation(self):
|
|
"""Less than 2 messages returns empty list."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
messages = [create_mock_message()]
|
|
|
|
result = await extractor.extract_from_conversation(
|
|
mock_db, "user-123", "conv-123", messages
|
|
)
|
|
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_calls_llm(self):
|
|
"""Calls LLM to extract memories."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
messages = [create_mock_message(), create_mock_message()]
|
|
|
|
with patch.object(
|
|
extractor,
|
|
"_call_llm_extract",
|
|
return_value=[{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}],
|
|
) as mock_call:
|
|
with patch.object(extractor, "_deduplicate", return_value=[]):
|
|
result = await extractor.extract_from_conversation(
|
|
mock_db, "user-123", "conv-123", messages
|
|
)
|
|
|
|
mock_call.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_filters_invalid_types(self):
|
|
"""Filters out memories with invalid type."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
messages = [create_mock_message(), create_mock_message()]
|
|
|
|
with patch.object(
|
|
extractor,
|
|
"_call_llm_extract",
|
|
return_value=[
|
|
{"type": "invalid_type", "content": "test", "confidence": 0.5},
|
|
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
|
],
|
|
):
|
|
valid_mem = ExtractedMemory(
|
|
memory_type="fact",
|
|
content="用户喜欢喝咖啡",
|
|
confidence=0.9,
|
|
source_conversation_id="conv-123",
|
|
)
|
|
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
|
result = await extractor.extract_from_conversation(
|
|
mock_db, "user-123", "conv-123", messages
|
|
)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].memory_type == "fact"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_filters_empty_content(self):
|
|
"""Filters out memories with empty content."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
messages = [create_mock_message(), create_mock_message()]
|
|
|
|
with patch.object(
|
|
extractor,
|
|
"_call_llm_extract",
|
|
return_value=[
|
|
{"type": "fact", "content": "", "confidence": 0.5},
|
|
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
|
],
|
|
):
|
|
valid_mem = ExtractedMemory(
|
|
memory_type="fact",
|
|
content="用户喜欢喝咖啡",
|
|
confidence=0.9,
|
|
source_conversation_id="conv-123",
|
|
)
|
|
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
|
result = await extractor.extract_from_conversation(
|
|
mock_db, "user-123", "conv-123", messages
|
|
)
|
|
|
|
assert len(result) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_sets_source_conversation_id(self):
|
|
"""Sets source_conversation_id on extracted memories."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
messages = [create_mock_message(), create_mock_message()]
|
|
|
|
with patch.object(
|
|
extractor,
|
|
"_call_llm_extract",
|
|
return_value=[
|
|
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
|
|
],
|
|
):
|
|
valid_mem = ExtractedMemory(
|
|
memory_type="fact",
|
|
content="用户喜欢喝咖啡",
|
|
confidence=0.9,
|
|
source_conversation_id="conv-abc",
|
|
)
|
|
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
|
|
result = await extractor.extract_from_conversation(
|
|
mock_db, "user-123", "conv-abc", messages
|
|
)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].source_conversation_id == "conv-abc"
|
|
|
|
|
|
class TestMemoryExtractorSaveMemories:
|
|
"""Test save_memories() method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_save_memories_adds_to_db(self):
|
|
"""Adds memories to db and commits."""
|
|
extractor = MemoryExtractor()
|
|
mock_db = AsyncMock()
|
|
mock_db.commit = AsyncMock()
|
|
mock_db.refresh = AsyncMock()
|
|
|
|
memories = [
|
|
ExtractedMemory(
|
|
memory_type="fact",
|
|
content="用户喜欢喝咖啡",
|
|
confidence=0.9,
|
|
source_conversation_id="conv-123",
|
|
)
|
|
]
|
|
|
|
with patch.object(extractor, "_deduplicate", return_value=memories):
|
|
result = await extractor.save_memories(mock_db, "user-123", "conv-123", memories)
|
|
|
|
assert len(result) == 1
|
|
mock_db.add.assert_called()
|
|
mock_db.commit.assert_called()
|
|
|
|
|
|
class TestMemoryTypes:
|
|
"""Test MEMORY_TYPES constant."""
|
|
|
|
def test_memory_types_has_all_types(self):
|
|
"""MEMORY_TYPES includes all expected types."""
|
|
assert "fact" in MEMORY_TYPES
|
|
assert "preference" in MEMORY_TYPES
|
|
assert "goal" in MEMORY_TYPES
|
|
assert "pain_point" in MEMORY_TYPES
|
|
assert "event" in MEMORY_TYPES
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|