from __future__ import annotations import uuid from datetime import UTC, datetime, timedelta from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.models.agent_conversation import AgentConversation, AgentConversationMessage from app.services.settings import SettingsService STATEFUL_CONTEXT_KEYS = ( "session_type", "entry_source", "request_context", "attachment_names", "attachment_count", "ocr_summary", "ocr_documents", "review_form_values", "business_time_context", ) REVIEW_FLOW_CONTEXT_KEYS = { "draft_claim_id", "draft_claim_no", "draft_status", "request_context", "attachment_names", "attachment_count", "ocr_summary", "ocr_documents", "review_form_values", "business_time_context", } REVIEW_FLOW_CONTINUATION_KEYWORDS = ( "补充", "继续", "继续上传", "当前", "这张", "这个", "该单据", "现有", "已有", "关联", "合并", "修改", "更正", "改成", "调整", "下一步", "保存草稿", ) NEW_EXPENSE_PROMPT_KEYWORDS = ( "申请报销", "我要报销", "我想报销", "帮我报销", "发起报销", "提交报销", "生成报销", "创建报销", "新建报销", ) DEFAULT_CONVERSATION_RETENTION_DAYS = 3 class AgentConversationService: def __init__(self, db: Session) -> None: self.db = db def get_or_create_conversation( self, *, conversation_id: str | None, user_id: str | None, source: str, context_json: dict[str, Any], ) -> AgentConversation: self.prune_expired_conversations() normalized_id = str(conversation_id or "").strip() normalized_user_id = str(user_id or "").strip() or None incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense" incoming_draft_claim_id = self._resolve_draft_claim_id(context_json) conversation = self.get_conversation(normalized_id) if normalized_id else None if conversation is not None and conversation.user_id != normalized_user_id: normalized_id = "" conversation = None if conversation is not None: existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense" if existing_session_type != incoming_session_type: normalized_id = "" conversation = None if conversation is None: conversation = AgentConversation( conversation_id=normalized_id or f"conv_{uuid.uuid4().hex[:16]}", user_id=normalized_user_id, source=source, entry_source=str(context_json.get("entry_source") or "").strip() or None, title=self._resolve_title(context_json), draft_claim_id=incoming_draft_claim_id or None, state_json=self._extract_state_json(context_json), ) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) return conversation if not conversation.user_id and normalized_user_id: conversation.user_id = normalized_user_id if not conversation.entry_source: conversation.entry_source = str(context_json.get("entry_source") or "").strip() or None if not conversation.title: conversation.title = self._resolve_title(context_json) if incoming_draft_claim_id: conversation.draft_claim_id = incoming_draft_claim_id conversation.state_json = self._merge_state_json( conversation.state_json, self._extract_state_json(context_json), ) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) return conversation def prune_expired_conversations( self, *, retention_days: int | None = None, ) -> int: resolved_retention_days = retention_days or self._resolve_retention_days() cutoff = datetime.now(UTC) - timedelta(days=max(1, resolved_retention_days)) stmt = select(AgentConversation).where(AgentConversation.updated_at < cutoff) expired_conversations = [ conversation for conversation in self.db.scalars(stmt).all() if not self._is_saved_conversation(conversation) ] if not expired_conversations: return 0 for conversation in expired_conversations: self.db.delete(conversation) self.db.commit() return len(expired_conversations) @staticmethod def _is_saved_conversation(conversation: AgentConversation) -> bool: if str(conversation.draft_claim_id or "").strip(): return True state_json = dict(conversation.state_json or {}) return bool(str(state_json.get("draft_claim_id") or "").strip()) def _resolve_retention_days(self) -> int: try: settings_row, _ = SettingsService(self.db).ensure_settings_ready() configured_days = int( getattr( settings_row, "conversation_retention_days", DEFAULT_CONVERSATION_RETENTION_DAYS, ) or DEFAULT_CONVERSATION_RETENTION_DAYS ) return max(1, min(configured_days, 10)) except Exception: self.db.rollback() return DEFAULT_CONVERSATION_RETENTION_DAYS def get_conversation(self, conversation_id: str) -> AgentConversation | None: normalized_id = str(conversation_id or "").strip() if not normalized_id: return None stmt = select(AgentConversation).where(AgentConversation.conversation_id == normalized_id) return self.db.scalar(stmt) def get_latest_conversation_for_user( self, *, user_id: str | None, source: str | None = "user_message", session_type: str | None = None, prefer_recoverable: bool = False, ) -> AgentConversation | None: self.prune_expired_conversations() normalized_user_id = str(user_id or "").strip() if not normalized_user_id: return None stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id) if source: stmt = stmt.where(AgentConversation.source == source) stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc()) conversations = list(self.db.scalars(stmt).all()) normalized_session_type = str(session_type or "").strip() if not normalized_session_type: return conversations[0] if conversations else None fallback_conversation = None for conversation in conversations: current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense" if current_session_type != normalized_session_type: continue if fallback_conversation is None: fallback_conversation = conversation if not prefer_recoverable or self._is_recoverable_conversation(conversation): return conversation return fallback_conversation @staticmethod def _is_recoverable_conversation(conversation: AgentConversation) -> bool: if str(conversation.draft_claim_id or "").strip(): return True state_json = dict(conversation.state_json or {}) documents = state_json.get("ocr_documents") if not isinstance(documents, list): return False for item in documents: if not isinstance(item, dict): continue preview_url = str(item.get("preview_url") or "").strip() preview_data_url = str(item.get("preview_data_url") or "").strip() preview_kind = str(item.get("preview_kind") or "").strip() if (preview_url or preview_data_url) and preview_kind in {"image", "pdf"}: return True return False def hydrate_context_json( self, *, conversation: AgentConversation, context_json: dict[str, Any], message: str | None = None, history_limit: int = 8, ) -> dict[str, Any]: merged = dict(context_json or {}) state_json = dict(conversation.state_json or {}) should_hydrate_review_flow = self._should_hydrate_review_flow_context( context_json=merged, message=message, ) if not should_hydrate_review_flow: for key in REVIEW_FLOW_CONTEXT_KEYS: merged.pop(key, None) merged["conversation_id"] = conversation.conversation_id merged["conversation_history"] = self.list_message_history( conversation.conversation_id, limit=history_limit, ) if conversation.last_scenario: merged.setdefault("conversation_scenario", conversation.last_scenario) if conversation.last_intent: merged.setdefault("conversation_intent", conversation.last_intent) if ( should_hydrate_review_flow and conversation.draft_claim_id and not str(merged.get("draft_claim_id") or "").strip() ): merged["draft_claim_id"] = conversation.draft_claim_id merged["conversation_state"] = state_json for key in STATEFUL_CONTEXT_KEYS: if key in REVIEW_FLOW_CONTEXT_KEYS and not should_hydrate_review_flow: continue if self._is_empty_value(merged.get(key)) and not self._is_empty_value(state_json.get(key)): merged[key] = state_json.get(key) return merged @staticmethod def _should_hydrate_review_flow_context( *, context_json: dict[str, Any], message: str | None, ) -> bool: if isinstance(context_json.get("expense_scene_selection"), dict): return True if AgentConversationService._resolve_draft_claim_id(context_json): compact_message = str(message or "").replace(" ", "") if compact_message and any(keyword in compact_message for keyword in NEW_EXPENSE_PROMPT_KEYWORDS): return False return True if str(context_json.get("review_action") or "").strip(): return True if str(context_json.get("entry_source") or "").strip() == "detail": return True if not AgentConversationService._is_empty_value(context_json.get("attachment_names")): return True if not AgentConversationService._is_empty_value(context_json.get("ocr_documents")): return True if str(context_json.get("ocr_summary") or "").strip(): return True try: if int(context_json.get("attachment_count") or 0) > 0: return True except (TypeError, ValueError): pass compact_message = str(message or "").replace(" ", "") if not compact_message: return False if any(keyword in compact_message for keyword in NEW_EXPENSE_PROMPT_KEYWORDS): return False return any(keyword in compact_message for keyword in REVIEW_FLOW_CONTINUATION_KEYWORDS) def append_message( self, *, conversation_id: str, role: str, content: str, run_id: str | None = None, message_json: dict[str, Any] | None = None, ) -> AgentConversationMessage | None: normalized_content = str(content or "").strip() if not normalized_content: return None conversation = self.get_conversation(conversation_id) if conversation is None: return None next_sequence = int(conversation.message_count or 0) + 1 normalized_message_json = dict(message_json or {}) normalized_message_json.setdefault("sequence", next_sequence) message = AgentConversationMessage( conversation_id=conversation_id, run_id=run_id, role=str(role or "user").strip() or "user", content=normalized_content, message_json=normalized_message_json, created_at=datetime.now(UTC), ) conversation.message_count = next_sequence if role == "user" and not conversation.title: conversation.title = normalized_content[:48] conversation.updated_at = datetime.now(UTC) self.db.add(message) self.db.add(conversation) self.db.commit() self.db.refresh(message) return message def list_message_history( self, conversation_id: str, *, limit: int = 8, ) -> list[dict[str, Any]]: normalized_id = str(conversation_id or "").strip() if not normalized_id or limit <= 0: return [] messages = self.list_messages(normalized_id) if limit > 0: messages = messages[-limit:] return [ { "role": item.role, "content": item.content, "run_id": item.run_id, "created_at": item.created_at.isoformat() if item.created_at else None, } for item in messages ] def list_messages( self, conversation_id: str, *, limit: int | None = None, ) -> list[AgentConversationMessage]: normalized_id = str(conversation_id or "").strip() if not normalized_id: return [] stmt = ( select(AgentConversationMessage) .where(AgentConversationMessage.conversation_id == normalized_id) .order_by(AgentConversationMessage.created_at.asc()) ) messages = list(self.db.scalars(stmt).all()) messages.sort(key=self._message_sort_key) if limit and limit > 0: return messages[:limit] return messages def update_state( self, *, conversation_id: str, run_id: str | None, scenario: str | None, intent: str | None, context_json: dict[str, Any], draft_payload: dict[str, Any] | None = None, ) -> AgentConversation | None: conversation = self.get_conversation(conversation_id) if conversation is None: return None conversation.last_run_id = str(run_id or "").strip() or conversation.last_run_id conversation.last_scenario = str(scenario or "").strip() or conversation.last_scenario conversation.last_intent = str(intent or "").strip() or conversation.last_intent if draft_payload and str(draft_payload.get("claim_id") or "").strip(): conversation.draft_claim_id = str(draft_payload["claim_id"]).strip() next_state = self._merge_state_json( conversation.state_json, self._extract_state_json(context_json), ) if draft_payload: if str(draft_payload.get("claim_id") or "").strip(): next_state["draft_claim_id"] = str(draft_payload["claim_id"]).strip() if str(draft_payload.get("claim_no") or "").strip(): next_state["draft_claim_no"] = str(draft_payload["claim_no"]).strip() if str(draft_payload.get("status") or "").strip(): next_state["draft_status"] = str(draft_payload["status"]).strip() conversation.state_json = next_state conversation.updated_at = datetime.now(UTC) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) return conversation def delete_user_conversations( self, *, user_id: str | None, source: str | None = "user_message", session_type: str | None = None, ) -> int: normalized_user_id = str(user_id or "").strip() if not normalized_user_id: return 0 stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id) if source: stmt = stmt.where(AgentConversation.source == source) conversations = list(self.db.scalars(stmt).all()) normalized_session_type = str(session_type or "").strip() if normalized_session_type: conversations = [ conversation for conversation in conversations if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense") == normalized_session_type ] if not conversations: return 0 for conversation in conversations: self.db.delete(conversation) self.db.commit() return len(conversations) def delete_conversations_for_draft_claim( self, *, claim_id: str | None, source: str | None = "user_message", session_type: str | None = "expense", ) -> int: normalized_claim_id = str(claim_id or "").strip() if not normalized_claim_id: return 0 stmt = select(AgentConversation).where(AgentConversation.draft_claim_id == normalized_claim_id) if source: stmt = stmt.where(AgentConversation.source == source) conversations = list(self.db.scalars(stmt).all()) normalized_session_type = str(session_type or "").strip() if normalized_session_type: conversations = [ conversation for conversation in conversations if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense") == normalized_session_type ] if not conversations: return 0 for conversation in conversations: self.db.delete(conversation) self.db.commit() return len(conversations) def delete_conversation( self, *, conversation_id: str | None, user_id: str | None = None, source: str | None = "user_message", ) -> int: normalized_id = str(conversation_id or "").strip() if not normalized_id: return 0 conversation = self.get_conversation(normalized_id) if conversation is None: return 0 normalized_user_id = str(user_id or "").strip() if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id: return 0 normalized_source = str(source or "").strip() if normalized_source and str(conversation.source or "").strip() != normalized_source: return 0 self.db.delete(conversation) self.db.commit() return 1 def serialize_conversation( self, conversation: AgentConversation, *, include_messages: bool = True, message_limit: int | None = None, ) -> dict[str, Any]: payload = { "conversation_id": conversation.conversation_id, "user_id": conversation.user_id, "source": conversation.source, "entry_source": conversation.entry_source, "title": conversation.title, "last_run_id": conversation.last_run_id, "last_scenario": conversation.last_scenario, "last_intent": conversation.last_intent, "draft_claim_id": conversation.draft_claim_id, "state_json": dict(conversation.state_json or {}), "message_count": int(conversation.message_count or 0), "updated_at": conversation.updated_at, "messages": [], } if include_messages: payload["messages"] = [ self.serialize_message(item) for item in self.list_messages(conversation.conversation_id, limit=message_limit) ] return payload @staticmethod def serialize_message(message: AgentConversationMessage) -> dict[str, Any]: return { "id": message.id, "role": message.role, "content": message.content, "run_id": message.run_id, "message_json": dict(message.message_json or {}), "created_at": message.created_at, } @staticmethod def _message_sort_key(message: AgentConversationMessage) -> tuple[int, datetime, str, int, str]: message_json = dict(message.message_json or {}) sequence = AgentConversationService._coerce_message_sequence(message_json.get("sequence")) created_at = message.created_at or datetime.min.replace(tzinfo=UTC) run_id = str(message.run_id or "") role_priority = 0 if str(message.role or "").strip() == "user" else 1 fallback_sequence = sequence if sequence is not None else 10**9 return ( fallback_sequence, created_at, run_id, role_priority, str(message.id or ""), ) @staticmethod def _coerce_message_sequence(value: Any) -> int | None: try: normalized = int(value) except (TypeError, ValueError): return None return normalized if normalized > 0 else None @staticmethod def _is_empty_value(value: Any) -> bool: if value is None: return True if isinstance(value, str): return not value.strip() if isinstance(value, (list, tuple, set, dict)): return len(value) == 0 return False @staticmethod def _resolve_title(context_json: dict[str, Any]) -> str | None: request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("reason", "title", "id"): value = str(request_context.get(key) or "").strip() if value: return value[:200] return None @staticmethod def _extract_state_json(context_json: dict[str, Any]) -> dict[str, Any]: state_json: dict[str, Any] = {} for key in STATEFUL_CONTEXT_KEYS: value = context_json.get(key) if value is None: continue if isinstance(value, str) and not value.strip(): continue if isinstance(value, (list, dict)) and not value: continue state_json[key] = value draft_claim_id = AgentConversationService._resolve_draft_claim_id(context_json) if draft_claim_id: state_json["draft_claim_id"] = draft_claim_id return state_json @staticmethod def _resolve_draft_claim_id(context_json: dict[str, Any]) -> str: draft_claim_id = str((context_json or {}).get("draft_claim_id") or "").strip() if draft_claim_id: return draft_claim_id request_context = (context_json or {}).get("request_context") if isinstance(request_context, dict): return str( request_context.get("claim_id") or request_context.get("claimId") or request_context.get("draft_claim_id") or request_context.get("draftClaimId") or "" ).strip() return "" @staticmethod def _merge_state_json( current_state: dict[str, Any] | None, incoming_state: dict[str, Any] | None, ) -> dict[str, Any]: merged = dict(current_state or {}) for key, value in (incoming_state or {}).items(): if value is None: continue if isinstance(value, str) and not value.strip(): continue if isinstance(value, (list, dict)) and not value: continue merged[key] = value return merged