""" Tests for MemoryExtractor (M.4) Tests: extract_from_conversation, _deduplicate, _is_similar, save_memories. """ import pytest from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, AsyncMock, patch from app.services.memory.memory_extractor import ( MemoryExtractor, ExtractedMemory, MEMORY_TYPES, ) def create_mock_message(role: str = "user", content: str = "test"): """Create a mock Message.""" msg = MagicMock() msg.role = role msg.content = content msg.created_at = datetime.now(UTC) return msg def create_mock_user_memory( id: int = 1, content: str = "test memory", memory_type: str = "fact", importance_score: float = 0.5, is_archived: bool = False, ): """Create a mock UserMemory.""" mem = MagicMock() mem.id = id mem.content = content mem.memory_type = memory_type mem.importance_score = importance_score mem.is_archived = is_archived return mem class TestExtractedMemory: """Test ExtractedMemory dataclass.""" def test_extracted_memory_fields(self): """ExtractedMemory has correct fields.""" mem = ExtractedMemory( memory_type="fact", content="用户喜欢喝咖啡", confidence=0.9, source_conversation_id="conv-123", ) assert mem.memory_type == "fact" assert mem.content == "用户喜欢喝咖啡" assert mem.confidence == 0.9 assert mem.source_conversation_id == "conv-123" class TestMemoryExtractorIsSimilar: """Test _is_similar() method.""" def test_is_similar_high_overlap(self): """High keyword overlap returns True.""" extractor = MemoryExtractor() # Use English with clear word overlap result = extractor._is_similar("I like coffee and tea", "I like coffee and tea with milk") assert result is True def test_is_similar_low_overlap(self): """Low keyword overlap returns False.""" extractor = MemoryExtractor() result = extractor._is_similar("用户喜欢喝咖啡", "今天天气很好") assert result is False def test_is_similar_empty_content(self): """Empty content returns False.""" extractor = MemoryExtractor() result = extractor._is_similar("", "some text") assert result is False def test_is_similar_substring_match(self): """Same first 20 chars returns True.""" extractor = MemoryExtractor() # First 20 chars: "这是一个测试字符串ABCDEFGHIJKLMN" (20 chars) text1 = "这是一个测试字符串ABCDEFGHIJKLMNOPQRSTUVWXYZ" text2 = "这是一个测试字符串ABCDEFGHIJKLMNQRSTUVWXYZ" result = extractor._is_similar(text1, text2) assert result is True def test_is_similar_case_insensitive(self): """Comparison is case insensitive.""" extractor = MemoryExtractor() result = extractor._is_similar("USER LIKES COFFEE", "user likes coffee and tea") assert result is True class TestMemoryExtractorDeduplicate: """Test _deduplicate() method. Note: Full async integration tests would require proper AsyncSession mocking. These tests verify the deduplication logic with simplified synchronous mocks. """ @pytest.mark.asyncio async def test_deduplicate_empty_list(self): """Empty list returns empty list.""" extractor = MemoryExtractor() mock_db = AsyncMock() result = await extractor._deduplicate(mock_db, "user-123", []) assert result == [] class TestMemoryExtractorExtractFromConversation: """Test extract_from_conversation() method.""" @pytest.mark.asyncio async def test_extract_skips_short_conversation(self): """Less than 2 messages returns empty list.""" extractor = MemoryExtractor() mock_db = AsyncMock() messages = [create_mock_message()] result = await extractor.extract_from_conversation( mock_db, "user-123", "conv-123", messages ) assert result == [] @pytest.mark.asyncio async def test_extract_calls_llm(self): """Calls LLM to extract memories.""" extractor = MemoryExtractor() mock_db = AsyncMock() messages = [create_mock_message(), create_mock_message()] with patch.object( extractor, "_call_llm_extract", return_value=[{"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}], ) as mock_call: with patch.object(extractor, "_deduplicate", return_value=[]): result = await extractor.extract_from_conversation( mock_db, "user-123", "conv-123", messages ) mock_call.assert_called_once() @pytest.mark.asyncio async def test_extract_filters_invalid_types(self): """Filters out memories with invalid type.""" extractor = MemoryExtractor() mock_db = AsyncMock() messages = [create_mock_message(), create_mock_message()] with patch.object( extractor, "_call_llm_extract", return_value=[ {"type": "invalid_type", "content": "test", "confidence": 0.5}, {"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}, ], ): valid_mem = ExtractedMemory( memory_type="fact", content="用户喜欢喝咖啡", confidence=0.9, source_conversation_id="conv-123", ) with patch.object(extractor, "_deduplicate", return_value=[valid_mem]): result = await extractor.extract_from_conversation( mock_db, "user-123", "conv-123", messages ) assert len(result) == 1 assert result[0].memory_type == "fact" @pytest.mark.asyncio async def test_extract_filters_empty_content(self): """Filters out memories with empty content.""" extractor = MemoryExtractor() mock_db = AsyncMock() messages = [create_mock_message(), create_mock_message()] with patch.object( extractor, "_call_llm_extract", return_value=[ {"type": "fact", "content": "", "confidence": 0.5}, {"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}, ], ): valid_mem = ExtractedMemory( memory_type="fact", content="用户喜欢喝咖啡", confidence=0.9, source_conversation_id="conv-123", ) with patch.object(extractor, "_deduplicate", return_value=[valid_mem]): result = await extractor.extract_from_conversation( mock_db, "user-123", "conv-123", messages ) assert len(result) == 1 @pytest.mark.asyncio async def test_extract_sets_source_conversation_id(self): """Sets source_conversation_id on extracted memories.""" extractor = MemoryExtractor() mock_db = AsyncMock() messages = [create_mock_message(), create_mock_message()] with patch.object( extractor, "_call_llm_extract", return_value=[ {"type": "fact", "content": "用户喜欢喝咖啡", "confidence": 0.9}, ], ): valid_mem = ExtractedMemory( memory_type="fact", content="用户喜欢喝咖啡", confidence=0.9, source_conversation_id="conv-abc", ) with patch.object(extractor, "_deduplicate", return_value=[valid_mem]): result = await extractor.extract_from_conversation( mock_db, "user-123", "conv-abc", messages ) assert len(result) == 1 assert result[0].source_conversation_id == "conv-abc" class TestMemoryExtractorSaveMemories: """Test save_memories() method.""" @pytest.mark.asyncio async def test_save_memories_adds_to_db(self): """Adds memories to db and commits.""" extractor = MemoryExtractor() mock_db = AsyncMock() mock_db.commit = AsyncMock() mock_db.refresh = AsyncMock() memories = [ ExtractedMemory( memory_type="fact", content="用户喜欢喝咖啡", confidence=0.9, source_conversation_id="conv-123", ) ] with patch.object(extractor, "_deduplicate", return_value=memories): result = await extractor.save_memories(mock_db, "user-123", "conv-123", memories) assert len(result) == 1 mock_db.add.assert_called() mock_db.commit.assert_called() class TestMemoryTypes: """Test MEMORY_TYPES constant.""" def test_memory_types_has_all_types(self): """MEMORY_TYPES includes all expected types.""" assert "fact" in MEMORY_TYPES assert "preference" in MEMORY_TYPES assert "goal" in MEMORY_TYPES assert "pain_point" in MEMORY_TYPES assert "event" in MEMORY_TYPES if __name__ == "__main__": pytest.main([__file__, "-v"])