- M.2: ForgettingCurve, MemoryDecay, MemoryReinforcement (selective forgetting) - M.3: DailyDigestGenerator, ReminderScheduler, ProactiveInformer (proactive reminders) - M.4: MemoryExtractor with LLM-based memory extraction from conversations - M.5: MemoryRecallInjector with token budget control for prompt injection - All phases include comprehensive unit tests (109 tests passing) - Updated checklist.md to mark all tasks complete
137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
"""
|
|
ProactiveInformer
|
|
|
|
Checks conversation context and proactively informs user of relevant memories.
|
|
Only fires probabilistically based on trigger type.
|
|
"""
|
|
|
|
import random
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
# Trigger types and their firing probabilities
|
|
TRIGGERS = {
|
|
"high_importance_topic": {"probability": 0.8},
|
|
"repeat_question": {"probability": 1.0},
|
|
"forgotten_context": {"probability": 0.5},
|
|
"pending_goal": {"probability": 0.3},
|
|
}
|
|
|
|
# Message style templates for proactive informing
|
|
INFORM_STYLE = {
|
|
"casual": "对了,你之前提到...",
|
|
"gentle": "不知道你有没有注意到...",
|
|
"helpful": "我记起你关心这个,要不看看...",
|
|
}
|
|
|
|
# Keywords that indicate different trigger types
|
|
TRIGGER_KEYWORDS = {
|
|
"high_importance_topic": ["关于", "提到", "说过", "记得"],
|
|
"repeat_question": ["之前", "上次", "以前", "好像问过"],
|
|
"forgotten_context": ["忘了", "不记得", "记不清", "之前聊过"],
|
|
"pending_goal": ["目标", "计划", "想做", "打算", "想学"],
|
|
}
|
|
|
|
|
|
class ProactiveInformer:
|
|
"""Proactively inform users of relevant memories based on conversation context."""
|
|
|
|
def should_inform(self, trigger_type: str) -> bool:
|
|
"""Check if should inform based on probability."""
|
|
if trigger_type not in TRIGGERS:
|
|
return False
|
|
probability = TRIGGERS[trigger_type]["probability"]
|
|
return random.random() < probability
|
|
|
|
def detect_trigger(self, message: str) -> str | None:
|
|
"""Detect which trigger type (if any) this message corresponds to."""
|
|
msg_lower = message.lower()
|
|
for trigger_type, keywords in TRIGGER_KEYWORDS.items():
|
|
for keyword in keywords:
|
|
if keyword in msg_lower:
|
|
return trigger_type
|
|
return None
|
|
|
|
def get_inform_message(self, trigger_type: str, context: dict) -> str:
|
|
"""Generate natural proactive message based on trigger type."""
|
|
style = INFORM_STYLE.get(context.get("style", "casual"), INFORM_STYLE["casual"])
|
|
|
|
if trigger_type == "high_importance_topic":
|
|
memory_content = context.get("memory_content", "")
|
|
return f"{style}「{memory_content[:30]}」这个话题你很关心,要深入聊聊吗?"
|
|
|
|
elif trigger_type == "repeat_question":
|
|
return "你之前问过类似的问题,我之前的回答可能还有参考价值。"
|
|
|
|
elif trigger_type == "forgotten_context":
|
|
memory_content = context.get("memory_content", "")
|
|
return f"这个话题你一个月前聊过:「{memory_content[:30]}」。要恢复一下吗?"
|
|
|
|
elif trigger_type == "pending_goal":
|
|
goal_content = context.get("goal_content", "")
|
|
return f"你之前说要「{goal_content[:30]}」,有进展了吗?"
|
|
|
|
return ""
|
|
|
|
async def check_and_inform(
|
|
self,
|
|
db: "AsyncSession",
|
|
user_id: str,
|
|
current_message: str,
|
|
) -> str | None:
|
|
"""Main entry point.
|
|
|
|
Returns proactive message if should inform, else None.
|
|
"""
|
|
# 1. Detect trigger type
|
|
trigger_type = self.detect_trigger(current_message)
|
|
if not trigger_type:
|
|
return None
|
|
|
|
# 2. Check probability
|
|
if not self.should_inform(trigger_type):
|
|
return None
|
|
|
|
# 3. Fetch relevant context
|
|
context = await self._fetch_trigger_context(db, user_id, trigger_type)
|
|
if not context:
|
|
return None
|
|
|
|
# 4. Generate message
|
|
return self.get_inform_message(trigger_type, context)
|
|
|
|
async def _fetch_trigger_context(
|
|
self,
|
|
db: "AsyncSession",
|
|
user_id: str,
|
|
trigger_type: str,
|
|
) -> dict | None:
|
|
"""Fetch relevant context for the trigger type."""
|
|
from sqlalchemy import select
|
|
from app.models.memory import UserMemory
|
|
|
|
# Get most recent high-importance memory
|
|
result = await db.execute(
|
|
select(UserMemory)
|
|
.where(
|
|
UserMemory.user_id == user_id,
|
|
UserMemory.importance_level == "high",
|
|
UserMemory.is_archived == False,
|
|
)
|
|
.order_by(UserMemory.last_accessed_at.desc())
|
|
.limit(1)
|
|
)
|
|
memory = result.scalar_one_or_none()
|
|
|
|
if not memory:
|
|
return None
|
|
|
|
return {
|
|
"memory_content": memory.content,
|
|
"style": "casual",
|
|
"memory_id": str(memory.id),
|
|
}
|