""" Tests for MemoryRecallInjector (M.5) Tests: build_context, _rank, _budget_select, _format, recall_user_memories_for_injection. """ import pytest from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, AsyncMock, patch from app.services.memory.recall_injector import ( MemoryRecallInjector, recall_user_memories_for_injection, MEMORY_TYPE_PRIORITY, DEFAULT_TOKEN_BUDGET, ) def create_mock_memory( id: int = 1, content: str = "test memory", memory_type: str = "fact", importance_score: float = 0.5, is_archived: bool = False, ): """Create a mock UserMemory.""" mem = MagicMock() mem.id = id mem.content = content mem.memory_type = memory_type mem.importance_score = importance_score mem.is_archived = is_archived return mem class TestMemoryRecallInjectorFormat: """Test _format() method.""" def test_format_empty_list(self): """Empty list returns empty string.""" injector = MemoryRecallInjector() result = injector._format([]) assert result == "" def test_format_single_memory(self): """Single memory formatted correctly.""" injector = MemoryRecallInjector() memory = create_mock_memory(content="用户喜欢喝咖啡", memory_type="preference") result = injector._format([memory]) assert "用户喜欢喝咖啡" in result assert "[preference]" in result assert "[关于你的记忆]" in result def test_format_multiple_memories(self): """Multiple memories formatted with bullets.""" injector = MemoryRecallInjector() mem1 = create_mock_memory(content="用户住在上海", memory_type="fact") mem2 = create_mock_memory(content="用户喜欢喝咖啡", memory_type="preference") result = injector._format([mem1, mem2]) assert "[关于你的记忆]" in result assert "- [fact] 用户住在上海" in result assert "- [preference] 用户喜欢喝咖啡" in result def test_format_handles_missing_type(self): """Memory without type falls back gracefully.""" injector = MemoryRecallInjector() memory = create_mock_memory(memory_type=None, content="some content") result = injector._format([memory]) assert "some content" in result class TestMemoryRecallInjectorBudgetSelect: """Test _budget_select() method.""" def test_budget_select_respects_limit(self): """Stops when token budget exhausted.""" injector = MemoryRecallInjector(token_budget=50) # Small budget memories = [ create_mock_memory(content="短内容"), # ~6 chars → ~3 tokens create_mock_memory(content="这是一个比较长的内容记忆"), # ~12 chars → ~6 tokens create_mock_memory(content="这是非常非常长的内容记忆"), # ~14 chars → ~7 tokens ] selected = injector._budget_select(memories, 50) # Should select as many as fit in budget assert len(selected) <= len(memories) def test_budget_select_empty_list(self): """Empty list returns empty.""" injector = MemoryRecallInjector() selected = injector._budget_select([], 800) assert selected == [] def test_budget_select_all_fit(self): """When all fit in budget, returns all.""" injector = MemoryRecallInjector(token_budget=10000) # Large budget memories = [ create_mock_memory(content="short"), create_mock_memory(content="medium content"), ] selected = injector._budget_select(memories, 10000) assert len(selected) == 2 class TestMemoryRecallInjectorRank: """Test _rank() method.""" def test_rank_orders_by_score(self): """Memories sorted by relevance * 0.6 + importance * 0.4 * type_boost.""" injector = MemoryRecallInjector() # pain_point gets 1.0 type boost, fact gets 0.8 mem_pain = create_mock_memory( id=1, memory_type="pain_point", importance_score=0.9, content="pain" ) mem_pain.similarity_score = 0.5 mem_fact = create_mock_memory( id=2, memory_type="fact", importance_score=0.5, content="fact" ) mem_fact.similarity_score = 0.5 # pain_point: 0.5*0.6 + 0.9*0.4*1.0 = 0.30 + 0.36 = 0.66 # fact: 0.5*0.6 + 0.5*0.4*0.8 = 0.30 + 0.16 = 0.46 ranked = injector._rank([mem_pain, mem_fact], "test query") # pain_point should come first due to type boost and higher importance assert ranked[0].memory_type == "pain_point" def test_rank_empty_list(self): """Empty list returns empty.""" injector = MemoryRecallInjector() ranked = injector._rank([], "test query") assert ranked == [] def test_rank_single_memory(self): """Single memory returns single item.""" injector = MemoryRecallInjector() memory = create_mock_memory(content="only one") ranked = injector._rank([memory], "query") assert len(ranked) == 1 class TestMemoryRecallInjectorBuildContext: """Test build_context() method.""" @pytest.mark.asyncio async def test_build_context_returns_string(self): """Returns string (possibly empty).""" injector = MemoryRecallInjector() mock_db = AsyncMock() with patch( "app.services.memory.recall_injector.recall_user_memories_for_injection", return_value=[], ) as mock_recall: result = await injector.build_context(mock_db, "user-123", "test message") assert isinstance(result, str) mock_recall.assert_called_once() class TestRecallUserMemoriesForInjection: """Test recall_user_memories_for_injection() function.""" @pytest.mark.asyncio async def test_returns_user_memories(self): """Returns UserMemory objects.""" mock_db = AsyncMock() mock_mem = create_mock_memory(content="test") mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_mem] mock_db.execute = AsyncMock(return_value=mock_result) result = await recall_user_memories_for_injection( mock_db, "user-123", "test query", top_k=5 ) assert len(result) >= 1 @pytest.mark.asyncio async def test_token_matching(self): """Query tokens are matched against memory content.""" mock_db = AsyncMock() mock_mem = create_mock_memory(content="用户喜欢喝咖啡") mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [mock_mem] mock_db.execute = AsyncMock(return_value=mock_result) result = await recall_user_memories_for_injection(mock_db, "user-123", "咖啡", top_k=5) # Should match because "咖啡" is in content assert len(result) >= 1 class TestMemoryTypePriority: """Test MEMORY_TYPE_PRIORITY constant.""" def test_priority_values(self): """pain_point=1 (highest), goal=2, preference=3, fact=4, event=5.""" assert MEMORY_TYPE_PRIORITY["pain_point"] == 1 assert MEMORY_TYPE_PRIORITY["goal"] == 2 assert MEMORY_TYPE_PRIORITY["preference"] == 3 assert MEMORY_TYPE_PRIORITY["fact"] == 4 assert MEMORY_TYPE_PRIORITY["event"] == 5 class TestDefaultTokenBudget: """Test DEFAULT_TOKEN_BUDGET constant.""" def test_default_budget_value(self): """Default token budget is 800.""" assert DEFAULT_TOKEN_BUDGET == 800 if __name__ == "__main__": pytest.main([__file__, "-v"])