Files
JARVIS/backend/tests/services/test_memory_extractor.py

291 lines
9.2 KiB
Python
Raw Permalink Normal View History

"""
Tests for MemoryExtractor (M.4)
Tests: extract_from_conversation, _deduplicate, _is_similar, save_memories.
"""
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.memory.memory_extractor import (
MemoryExtractor,
ExtractedMemory,
MEMORY_TYPES,
)
def create_mock_message(role: str = "user", content: str = "test"):
"""Create a mock Message."""
msg = MagicMock()
msg.role = role
msg.content = content
msg.created_at = datetime.now(UTC)
return msg
def create_mock_user_memory(
id: int = 1,
content: str = "test memory",
memory_type: str = "fact",
importance_score: float = 0.5,
is_archived: bool = False,
):
"""Create a mock UserMemory."""
mem = MagicMock()
mem.id = id
mem.content = content
mem.memory_type = memory_type
mem.importance_score = importance_score
mem.is_archived = is_archived
return mem
class TestExtractedMemory:
"""Test ExtractedMemory dataclass."""
def test_extracted_memory_fields(self):
"""ExtractedMemory has correct fields."""
mem = ExtractedMemory(
memory_type="fact",
content="用户喜欢喝咖啡",
confidence=0.9,
source_conversation_id="conv-123",
)
assert mem.memory_type == "fact"
assert mem.content == "用户喜欢喝咖啡"
assert mem.confidence == 0.9
assert mem.source_conversation_id == "conv-123"
class TestMemoryExtractorIsSimilar:
"""Test _is_similar() method."""
def test_is_similar_high_overlap(self):
"""High keyword overlap returns True."""
extractor = MemoryExtractor()
# Use English with clear word overlap
result = extractor._is_similar("I like coffee and tea", "I like coffee and tea with milk")
assert result is True
def test_is_similar_low_overlap(self):
"""Low keyword overlap returns False."""
extractor = MemoryExtractor()
result = extractor._is_similar("用户喜欢喝咖啡", "今天天气很好")
assert result is False
def test_is_similar_empty_content(self):
"""Empty content returns False."""
extractor = MemoryExtractor()
result = extractor._is_similar("", "some text")
assert result is False
def test_is_similar_substring_match(self):
"""Same first 20 chars returns True."""
extractor = MemoryExtractor()
# First 20 chars: "这是一个测试字符串ABCDEFGHIJKLMN" (20 chars)
text1 = "这是一个测试字符串ABCDEFGHIJKLMNOPQRSTUVWXYZ"
text2 = "这是一个测试字符串ABCDEFGHIJKLMNQRSTUVWXYZ"
result = extractor._is_similar(text1, text2)
assert result is True
def test_is_similar_case_insensitive(self):
"""Comparison is case insensitive."""
extractor = MemoryExtractor()
result = extractor._is_similar("USER LIKES COFFEE", "user likes coffee and tea")
assert result is True
class TestMemoryExtractorDeduplicate:
"""Test _deduplicate() method.
Note: Full async integration tests would require proper AsyncSession mocking.
These tests verify the deduplication logic with simplified synchronous mocks.
"""
@pytest.mark.asyncio
async def test_deduplicate_empty_list(self):
"""Empty list returns empty list."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
result = await extractor._deduplicate(mock_db, "user-123", [])
assert result == []
class TestMemoryExtractorExtractFromConversation:
"""Test extract_from_conversation() method."""
@pytest.mark.asyncio
async def test_extract_skips_short_conversation(self):
"""Less than 2 messages returns empty list."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
messages = [create_mock_message()]
result = await extractor.extract_from_conversation(
mock_db, "user-123", "conv-123", messages
)
assert result == []
@pytest.mark.asyncio
async def test_extract_calls_llm(self):
"""Calls LLM to extract memories."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
messages = [create_mock_message(), create_mock_message()]
with patch.object(
extractor,
"_call_llm_extract",
return_value=[{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}],
) as mock_call:
with patch.object(extractor, "_deduplicate", return_value=[]):
result = await extractor.extract_from_conversation(
mock_db, "user-123", "conv-123", messages
)
mock_call.assert_called_once()
@pytest.mark.asyncio
async def test_extract_filters_invalid_types(self):
"""Filters out memories with invalid type."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
messages = [create_mock_message(), create_mock_message()]
with patch.object(
extractor,
"_call_llm_extract",
return_value=[
{"type": "invalid_type", "content": "test", "confidence": 0.5},
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
],
):
valid_mem = ExtractedMemory(
memory_type="fact",
content="用户喜欢喝咖啡",
confidence=0.9,
source_conversation_id="conv-123",
)
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
result = await extractor.extract_from_conversation(
mock_db, "user-123", "conv-123", messages
)
assert len(result) == 1
assert result[0].memory_type == "fact"
@pytest.mark.asyncio
async def test_extract_filters_empty_content(self):
"""Filters out memories with empty content."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
messages = [create_mock_message(), create_mock_message()]
with patch.object(
extractor,
"_call_llm_extract",
return_value=[
{"type": "fact", "content": "", "confidence": 0.5},
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
],
):
valid_mem = ExtractedMemory(
memory_type="fact",
content="用户喜欢喝咖啡",
confidence=0.9,
source_conversation_id="conv-123",
)
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
result = await extractor.extract_from_conversation(
mock_db, "user-123", "conv-123", messages
)
assert len(result) == 1
@pytest.mark.asyncio
async def test_extract_sets_source_conversation_id(self):
"""Sets source_conversation_id on extracted memories."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
messages = [create_mock_message(), create_mock_message()]
with patch.object(
extractor,
"_call_llm_extract",
return_value=[
{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9},
],
):
valid_mem = ExtractedMemory(
memory_type="fact",
content="用户喜欢喝咖啡",
confidence=0.9,
source_conversation_id="conv-abc",
)
with patch.object(extractor, "_deduplicate", return_value=[valid_mem]):
result = await extractor.extract_from_conversation(
mock_db, "user-123", "conv-abc", messages
)
assert len(result) == 1
assert result[0].source_conversation_id == "conv-abc"
class TestMemoryExtractorSaveMemories:
"""Test save_memories() method."""
@pytest.mark.asyncio
async def test_save_memories_adds_to_db(self):
"""Adds memories to db and commits."""
extractor = MemoryExtractor()
mock_db = AsyncMock()
mock_db.commit = AsyncMock()
mock_db.refresh = AsyncMock()
memories = [
ExtractedMemory(
memory_type="fact",
content="用户喜欢喝咖啡",
confidence=0.9,
source_conversation_id="conv-123",
)
]
with patch.object(extractor, "_deduplicate", return_value=memories):
result = await extractor.save_memories(mock_db, "user-123", "conv-123", memories)
assert len(result) == 1
mock_db.add.assert_called()
mock_db.commit.assert_called()
class TestMemoryTypes:
"""Test MEMORY_TYPES constant."""
def test_memory_types_has_all_types(self):
"""MEMORY_TYPES includes all expected types."""
assert "fact" in MEMORY_TYPES
assert "preference" in MEMORY_TYPES
assert "goal" in MEMORY_TYPES
assert "pain_point" in MEMORY_TYPES
assert "event" in MEMORY_TYPES
if __name__ == "__main__":
pytest.main([__file__, "-v"])