Files
JARVIS/backend/tests/services/test_reinforcement.py

214 lines
7.5 KiB
Python
Raw Normal View History

"""
Tests for MemoryReinforcement (M.2)
Tests: trigger(), auto_reinforce().
"""
import pytest
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock
from app.services.memory.reinforcement import MemoryReinforcement
def create_mock_memory(
frequency_count: int = 0,
last_accessed_at=None,
last_recalled_at=None,
decay_score: float = 1.0,
importance_level: str = "medium",
):
"""Create a mock UserMemory for testing."""
memory = MagicMock()
memory.frequency_count = frequency_count
memory.last_accessed_at = last_accessed_at
memory.last_recalled_at = last_recalled_at
memory.decay_score = decay_score
memory.importance_level = importance_level
return memory
class TestMemoryReinforcementTrigger:
"""Test trigger() method - called on memory recall."""
def test_trigger_increments_frequency(self):
"""trigger() increments frequency_count by 1."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(frequency_count=5)
result = reinforcement.trigger(memory)
assert result.frequency_count == 6
def test_trigger_frequency_capped_at_max(self):
"""trigger() caps frequency_count at MAX_FREQUENCY (10)."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(frequency_count=10)
result = reinforcement.trigger(memory)
assert result.frequency_count == 10
def test_trigger_updates_last_accessed_at(self):
"""trigger() updates last_accessed_at to now."""
reinforcement = MemoryReinforcement()
old_time = datetime.now(UTC) - timedelta(days=10)
memory = create_mock_memory(last_accessed_at=old_time)
before = datetime.now(UTC)
result = reinforcement.trigger(memory)
after = datetime.now(UTC)
assert before <= result.last_accessed_at <= after
def test_trigger_updates_last_recalled_at(self):
"""trigger() updates last_recalled_at to now."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory()
before = datetime.now(UTC)
result = reinforcement.trigger(memory)
after = datetime.now(UTC)
assert before <= result.last_recalled_at <= after
def test_trigger_boosts_decay_score(self):
"""trigger() boosts decay_score by 0.1 (capped at 0.95)."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(decay_score=0.5)
result = reinforcement.trigger(memory)
assert result.decay_score > 0.5
assert result.decay_score <= 0.95
def test_trigger_decay_score_capped_at_095(self):
"""trigger() decay_score boost is capped at 0.95."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(decay_score=0.95)
result = reinforcement.trigger(memory)
assert result.decay_score == 0.95
def test_trigger_from_zero_frequency(self):
"""trigger() works from frequency_count = 0."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(frequency_count=0)
result = reinforcement.trigger(memory)
assert result.frequency_count == 1
def test_trigger_returns_same_memory_object(self):
"""trigger() returns the same memory object (modified in place)."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory()
result = reinforcement.trigger(memory)
assert result is memory
class TestMemoryReinforcementAutoReinforce:
"""Test auto_reinforce() method - weekly maintenance for high-importance memories."""
def test_auto_reinforce_skips_non_high_importance(self):
"""auto_reinforce() skips memories that are not high importance."""
reinforcement = MemoryReinforcement()
memory_low = create_mock_memory(importance_level="low", frequency_count=5)
memory_medium = create_mock_memory(importance_level="medium", frequency_count=5)
result = reinforcement.auto_reinforce([memory_low, memory_medium])
assert len(result) == 0
def test_auto_reinforce_includes_high_importance(self):
"""auto_reinforce() includes high-importance memories."""
reinforcement = MemoryReinforcement()
memory_high = create_mock_memory(importance_level="high", frequency_count=5)
result = reinforcement.auto_reinforce([memory_high])
assert len(result) == 1
assert result[0] is memory_high
def test_auto_reinforce_skips_max_frequency(self):
"""auto_reinforce() skips high-importance memories already at MAX_FREQUENCY."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(importance_level="high", frequency_count=10)
result = reinforcement.auto_reinforce([memory])
assert len(result) == 0
def test_auto_reinforce_boosts_frequency(self):
"""auto_reinforce() applies 10% boost to frequency_count."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(importance_level="high", frequency_count=5)
result = reinforcement.auto_reinforce([memory])
# 5 * 1.1 + 1 = 6.5 → int = 6
assert result[0].frequency_count == 6
def test_auto_reinforce_frequency_capped(self):
"""auto_reinforce() caps frequency at MAX_FREQUENCY."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(importance_level="high", frequency_count=9)
result = reinforcement.auto_reinforce([memory])
assert result[0].frequency_count == 10
def test_auto_reinforce_improves_decay_score(self):
"""auto_reinforce() improves decay_score by 5%."""
reinforcement = MemoryReinforcement()
memory = create_mock_memory(importance_level="high", frequency_count=5, decay_score=0.5)
result = reinforcement.auto_reinforce([memory])
assert result[0].decay_score > 0.5
assert result[0].decay_score == pytest.approx(0.525, abs=0.001)
def test_auto_reinforce_updates_last_accessed(self):
"""auto_reinforce() updates last_accessed_at to now."""
reinforcement = MemoryReinforcement()
old_time = datetime.now(UTC) - timedelta(days=30)
memory = create_mock_memory(
importance_level="high", frequency_count=5, last_accessed_at=old_time
)
before = datetime.now(UTC)
result = reinforcement.auto_reinforce([memory])
after = datetime.now(UTC)
assert before <= result[0].last_accessed_at <= after
def test_auto_reinforce_empty_list(self):
"""auto_reinforce() handles empty list gracefully."""
reinforcement = MemoryReinforcement()
result = reinforcement.auto_reinforce([])
assert result == []
def test_auto_reinforce_mixed_memories(self):
"""auto_reinforce() processes only high-importance, leaves others untouched."""
reinforcement = MemoryReinforcement()
memory_high = create_mock_memory(importance_level="high", frequency_count=5)
memory_low = create_mock_memory(importance_level="low", frequency_count=5)
memory_medium = create_mock_memory(importance_level="medium", frequency_count=5)
result = reinforcement.auto_reinforce([memory_high, memory_low, memory_medium])
assert len(result) == 1
assert result[0] is memory_high
# Others should not be modified
assert memory_low.frequency_count == 5
assert memory_medium.frequency_count == 5
if __name__ == "__main__":
pytest.main([__file__, "-v"])