from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.brain import BrainCandidate, BrainEvent, BrainMemory, BrainTag from app.services.graph_service import GraphService class BrainService: def __init__(self, db: AsyncSession): self.db = db async def create_event( self, user_id: str, *, source_type: str, source_id: str, event_type: str, title: str | None = None, content_summary: str | None = None, raw_excerpt: str | None = None, metadata_: dict | None = None, importance_signal: float = 0.0, ) -> BrainEvent: event = BrainEvent( user_id=user_id, source_type=source_type, source_id=source_id, event_type=event_type, title=title, content_summary=content_summary, raw_excerpt=raw_excerpt, metadata_=metadata_, importance_signal=importance_signal, status="pending", ) self.db.add(event) await self.db.flush() return event async def recall_memories(self, user_id: str, current_query: str, top_k: int = 3) -> list[BrainMemory]: query_tokens = [token.strip().lower() for token in current_query.split() if token.strip()] statement = select(BrainMemory).where( BrainMemory.user_id == user_id, BrainMemory.status == "active", ) if query_tokens: statement = statement.where( or_( *[ or_( BrainMemory.title.ilike(f"%{token}%"), BrainMemory.content.ilike(f"%{token}%"), ) for token in query_tokens ] ) ) result = await self.db.execute( statement.order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()).limit(top_k) ) memories = list(result.scalars().all()) if memories or query_tokens: return memories fallback_result = await self.db.execute( select(BrainMemory) .where(BrainMemory.user_id == user_id, BrainMemory.status == "active") .order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()) .limit(top_k) ) return list(fallback_result.scalars().all()) async def get_overview(self, user_id: str) -> dict: active_memory_count = ( await self.db.execute( select(func.count()).select_from(BrainMemory).where( BrainMemory.user_id == user_id, BrainMemory.status == "active", ) ) ).scalar() or 0 important_tag_count = ( await self.db.execute( select(func.count()).select_from(BrainTag).where( BrainTag.user_id == user_id, BrainTag.priority == "important", ) ) ).scalar() or 0 secondary_tag_count = ( await self.db.execute( select(func.count()).select_from(BrainTag).where( BrainTag.user_id == user_id, BrainTag.priority == "secondary", ) ) ).scalar() or 0 recent_memory_result = await self.db.execute( select(BrainMemory.title) .where(BrainMemory.user_id == user_id, BrainMemory.status == "active") .order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()) .limit(5) ) recent_memory_titles = list(recent_memory_result.scalars().all()) return { "active_memory_count": active_memory_count, "important_tag_count": important_tag_count, "secondary_tag_count": secondary_tag_count, "recent_memory_titles": recent_memory_titles, } async def list_memories(self, user_id: str) -> list[BrainMemory]: result = await self.db.execute( select(BrainMemory) .where(BrainMemory.user_id == user_id, BrainMemory.status == "active") .order_by(BrainMemory.importance.desc(), BrainMemory.created_at.desc()) ) return list(result.scalars().all()) async def list_tags(self, user_id: str) -> dict: important_result = await self.db.execute( select(BrainTag) .where(BrainTag.user_id == user_id, BrainTag.priority == "important") .order_by(BrainTag.score.desc(), BrainTag.created_at.desc()) ) secondary_result = await self.db.execute( select(BrainTag) .where(BrainTag.user_id == user_id, BrainTag.priority == "secondary") .order_by(BrainTag.score.desc(), BrainTag.created_at.desc()) ) return { "important": list(important_result.scalars().all()), "secondary": list(secondary_result.scalars().all()), } async def list_events(self, user_id: str) -> list[BrainEvent]: result = await self.db.execute( select(BrainEvent) .where(BrainEvent.user_id == user_id) .order_by(BrainEvent.created_at.desc()) ) return list(result.scalars().all()) async def run_learning(self, user_id: str) -> dict: pending_events_result = await self.db.execute( select(BrainEvent) .where(BrainEvent.user_id == user_id, BrainEvent.status == "pending") .order_by(BrainEvent.created_at.asc()) ) pending_events = list(pending_events_result.scalars().all()) pending_count = len(pending_events) candidates_created = 0 memories_promoted = 0 if pending_events: candidate = BrainCandidate( user_id=user_id, candidate_type="daily_learning", title="Daily learning synthesis", summary=f"Processed {pending_count} pending brain events.", importance_score=float(pending_count), confidence_score=1.0, status="promoted", source_event_ids=[event.id for event in pending_events], ) self.db.add(candidate) await self.db.flush() candidates_created = 1 memory = BrainMemory( user_id=user_id, memory_type="daily_learning", title="Daily learning synthesis", content=f"Processed {pending_count} pending brain events.", importance=max(pending_count, 1), confidence=1.0, status="active", origin_candidate_id=candidate.id, origin_source_types=sorted({event.source_type for event in pending_events}), ) self.db.add(memory) memories_promoted = 1 for event in pending_events: event.status = "processed" event.processed_at = memory.created_at await self.db.commit() else: await self.db.commit() await GraphService(self.db).build_graph(user_id) return { "events_considered": pending_count, "candidates_created": candidates_created, "memories_promoted": memories_promoted, }