refactor(backend): update and add service layers
- services/ontology.py: update ontology service - services/orchestrator.py: update orchestrator service - services/user_agent.py: update user agent service - services/settings.py: update settings service - services/expense_claims.py: update expense claims service - services/agent_conversations.py: add new agent conversations service
This commit is contained in:
398
server/src/app/services/agent_conversations.py
Normal file
398
server/src/app/services/agent_conversations.py
Normal file
@@ -0,0 +1,398 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
STATEFUL_CONTEXT_KEYS = (
|
||||
"entry_source",
|
||||
"request_context",
|
||||
"attachment_names",
|
||||
"attachment_count",
|
||||
"ocr_summary",
|
||||
"ocr_documents",
|
||||
)
|
||||
DEFAULT_CONVERSATION_RETENTION_DAYS = 3
|
||||
|
||||
|
||||
class AgentConversationService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_or_create_conversation(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
user_id: str | None,
|
||||
source: str,
|
||||
context_json: dict[str, Any],
|
||||
) -> AgentConversation:
|
||||
self.prune_expired_conversations()
|
||||
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
normalized_user_id = str(user_id or "").strip() or None
|
||||
conversation = self.get_conversation(normalized_id) if normalized_id else None
|
||||
if conversation is not None and conversation.user_id != normalized_user_id:
|
||||
normalized_id = ""
|
||||
conversation = None
|
||||
|
||||
if conversation is None:
|
||||
conversation = AgentConversation(
|
||||
conversation_id=normalized_id or f"conv_{uuid.uuid4().hex[:16]}",
|
||||
user_id=normalized_user_id,
|
||||
source=source,
|
||||
entry_source=str(context_json.get("entry_source") or "").strip() or None,
|
||||
title=self._resolve_title(context_json),
|
||||
state_json=self._extract_state_json(context_json),
|
||||
)
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
if not conversation.user_id and normalized_user_id:
|
||||
conversation.user_id = normalized_user_id
|
||||
if not conversation.entry_source:
|
||||
conversation.entry_source = str(context_json.get("entry_source") or "").strip() or None
|
||||
if not conversation.title:
|
||||
conversation.title = self._resolve_title(context_json)
|
||||
conversation.state_json = self._merge_state_json(
|
||||
conversation.state_json,
|
||||
self._extract_state_json(context_json),
|
||||
)
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
def prune_expired_conversations(
|
||||
self,
|
||||
*,
|
||||
retention_days: int | None = None,
|
||||
) -> int:
|
||||
resolved_retention_days = retention_days or self._resolve_retention_days()
|
||||
cutoff = datetime.now(UTC) - timedelta(days=max(1, resolved_retention_days))
|
||||
stmt = select(AgentConversation).where(AgentConversation.updated_at < cutoff)
|
||||
expired_conversations = list(self.db.scalars(stmt).all())
|
||||
if not expired_conversations:
|
||||
return 0
|
||||
|
||||
for conversation in expired_conversations:
|
||||
self.db.delete(conversation)
|
||||
|
||||
self.db.commit()
|
||||
return len(expired_conversations)
|
||||
|
||||
def _resolve_retention_days(self) -> int:
|
||||
try:
|
||||
settings_row, _ = SettingsService(self.db).ensure_settings_ready()
|
||||
configured_days = int(
|
||||
getattr(
|
||||
settings_row,
|
||||
"conversation_retention_days",
|
||||
DEFAULT_CONVERSATION_RETENTION_DAYS,
|
||||
)
|
||||
or DEFAULT_CONVERSATION_RETENTION_DAYS
|
||||
)
|
||||
return max(1, min(configured_days, 10))
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
return DEFAULT_CONVERSATION_RETENTION_DAYS
|
||||
|
||||
def get_conversation(self, conversation_id: str) -> AgentConversation | None:
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
if not normalized_id:
|
||||
return None
|
||||
stmt = select(AgentConversation).where(AgentConversation.conversation_id == normalized_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_latest_conversation_for_user(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
) -> AgentConversation | None:
|
||||
self.prune_expired_conversations()
|
||||
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if not normalized_user_id:
|
||||
return None
|
||||
|
||||
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
|
||||
return self.db.scalar(stmt.limit(1))
|
||||
|
||||
def hydrate_context_json(
|
||||
self,
|
||||
*,
|
||||
conversation: AgentConversation,
|
||||
context_json: dict[str, Any],
|
||||
history_limit: int = 8,
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(context_json or {})
|
||||
state_json = dict(conversation.state_json or {})
|
||||
|
||||
merged["conversation_id"] = conversation.conversation_id
|
||||
merged["conversation_history"] = self.list_message_history(
|
||||
conversation.conversation_id,
|
||||
limit=history_limit,
|
||||
)
|
||||
if conversation.last_scenario:
|
||||
merged.setdefault("conversation_scenario", conversation.last_scenario)
|
||||
if conversation.last_intent:
|
||||
merged.setdefault("conversation_intent", conversation.last_intent)
|
||||
if conversation.draft_claim_id and not str(merged.get("draft_claim_id") or "").strip():
|
||||
merged["draft_claim_id"] = conversation.draft_claim_id
|
||||
merged["conversation_state"] = state_json
|
||||
|
||||
for key in STATEFUL_CONTEXT_KEYS:
|
||||
if self._is_empty_value(merged.get(key)) and not self._is_empty_value(state_json.get(key)):
|
||||
merged[key] = state_json.get(key)
|
||||
|
||||
return merged
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
run_id: str | None = None,
|
||||
message_json: dict[str, Any] | None = None,
|
||||
) -> AgentConversationMessage | None:
|
||||
normalized_content = str(content or "").strip()
|
||||
if not normalized_content:
|
||||
return None
|
||||
|
||||
conversation = self.get_conversation(conversation_id)
|
||||
if conversation is None:
|
||||
return None
|
||||
|
||||
message = AgentConversationMessage(
|
||||
conversation_id=conversation_id,
|
||||
run_id=run_id,
|
||||
role=str(role or "user").strip() or "user",
|
||||
content=normalized_content,
|
||||
message_json=message_json or {},
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
conversation.message_count = int(conversation.message_count or 0) + 1
|
||||
if role == "user" and not conversation.title:
|
||||
conversation.title = normalized_content[:48]
|
||||
conversation.updated_at = datetime.now(UTC)
|
||||
self.db.add(message)
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def list_message_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
limit: int = 8,
|
||||
) -> list[dict[str, Any]]:
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
if not normalized_id or limit <= 0:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == normalized_id)
|
||||
.order_by(AgentConversationMessage.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
messages.reverse()
|
||||
return [
|
||||
{
|
||||
"role": item.role,
|
||||
"content": item.content,
|
||||
"run_id": item.run_id,
|
||||
"created_at": item.created_at.isoformat() if item.created_at else None,
|
||||
}
|
||||
for item in messages
|
||||
]
|
||||
|
||||
def list_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
) -> list[AgentConversationMessage]:
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
if not normalized_id:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(AgentConversationMessage)
|
||||
.where(AgentConversationMessage.conversation_id == normalized_id)
|
||||
.order_by(AgentConversationMessage.created_at.asc(), AgentConversationMessage.id.asc())
|
||||
)
|
||||
if limit and limit > 0:
|
||||
stmt = stmt.limit(limit)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str,
|
||||
run_id: str | None,
|
||||
scenario: str | None,
|
||||
intent: str | None,
|
||||
context_json: dict[str, Any],
|
||||
draft_payload: dict[str, Any] | None = None,
|
||||
) -> AgentConversation | None:
|
||||
conversation = self.get_conversation(conversation_id)
|
||||
if conversation is None:
|
||||
return None
|
||||
|
||||
conversation.last_run_id = str(run_id or "").strip() or conversation.last_run_id
|
||||
conversation.last_scenario = str(scenario or "").strip() or conversation.last_scenario
|
||||
conversation.last_intent = str(intent or "").strip() or conversation.last_intent
|
||||
if draft_payload and str(draft_payload.get("claim_id") or "").strip():
|
||||
conversation.draft_claim_id = str(draft_payload["claim_id"]).strip()
|
||||
|
||||
next_state = self._merge_state_json(
|
||||
conversation.state_json,
|
||||
self._extract_state_json(context_json),
|
||||
)
|
||||
if draft_payload:
|
||||
if str(draft_payload.get("claim_id") or "").strip():
|
||||
next_state["draft_claim_id"] = str(draft_payload["claim_id"]).strip()
|
||||
if str(draft_payload.get("claim_no") or "").strip():
|
||||
next_state["draft_claim_no"] = str(draft_payload["claim_no"]).strip()
|
||||
if str(draft_payload.get("status") or "").strip():
|
||||
next_state["draft_status"] = str(draft_payload["status"]).strip()
|
||||
conversation.state_json = next_state
|
||||
conversation.updated_at = datetime.now(UTC)
|
||||
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
def delete_user_conversations(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
) -> int:
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if not normalized_user_id:
|
||||
return 0
|
||||
|
||||
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
if not conversations:
|
||||
return 0
|
||||
|
||||
for conversation in conversations:
|
||||
self.db.delete(conversation)
|
||||
|
||||
self.db.commit()
|
||||
return len(conversations)
|
||||
|
||||
def serialize_conversation(
|
||||
self,
|
||||
conversation: AgentConversation,
|
||||
*,
|
||||
include_messages: bool = True,
|
||||
message_limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = {
|
||||
"conversation_id": conversation.conversation_id,
|
||||
"user_id": conversation.user_id,
|
||||
"source": conversation.source,
|
||||
"entry_source": conversation.entry_source,
|
||||
"title": conversation.title,
|
||||
"last_run_id": conversation.last_run_id,
|
||||
"last_scenario": conversation.last_scenario,
|
||||
"last_intent": conversation.last_intent,
|
||||
"draft_claim_id": conversation.draft_claim_id,
|
||||
"state_json": dict(conversation.state_json or {}),
|
||||
"message_count": int(conversation.message_count or 0),
|
||||
"updated_at": conversation.updated_at,
|
||||
"messages": [],
|
||||
}
|
||||
if include_messages:
|
||||
payload["messages"] = [
|
||||
self.serialize_message(item)
|
||||
for item in self.list_messages(conversation.conversation_id, limit=message_limit)
|
||||
]
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def serialize_message(message: AgentConversationMessage) -> dict[str, Any]:
|
||||
return {
|
||||
"id": message.id,
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"run_id": message.run_id,
|
||||
"message_json": dict(message.message_json or {}),
|
||||
"created_at": message.created_at,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_empty_value(value: Any) -> bool:
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
return not value.strip()
|
||||
if isinstance(value, (list, tuple, set, dict)):
|
||||
return len(value) == 0
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _resolve_title(context_json: dict[str, Any]) -> str | None:
|
||||
request_context = context_json.get("request_context")
|
||||
if isinstance(request_context, dict):
|
||||
for key in ("reason", "title", "id"):
|
||||
value = str(request_context.get(key) or "").strip()
|
||||
if value:
|
||||
return value[:200]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_state_json(context_json: dict[str, Any]) -> dict[str, Any]:
|
||||
state_json: dict[str, Any] = {}
|
||||
for key in STATEFUL_CONTEXT_KEYS:
|
||||
value = context_json.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and not value.strip():
|
||||
continue
|
||||
if isinstance(value, (list, dict)) and not value:
|
||||
continue
|
||||
state_json[key] = value
|
||||
|
||||
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
|
||||
if draft_claim_id:
|
||||
state_json["draft_claim_id"] = draft_claim_id
|
||||
return state_json
|
||||
|
||||
@staticmethod
|
||||
def _merge_state_json(
|
||||
current_state: dict[str, Any] | None,
|
||||
incoming_state: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(current_state or {})
|
||||
for key, value in (incoming_state or {}).items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and not value.strip():
|
||||
continue
|
||||
if isinstance(value, (list, dict)) and not value:
|
||||
continue
|
||||
merged[key] = value
|
||||
return merged
|
||||
Reference in New Issue
Block a user