refactor(backend): update orchestrator endpoint and services
- endpoints/orchestrator.py: update orchestrator API endpoint - services/agent_conversations.py: update agent conversations service - services/orchestrator.py: update orchestrator service - services/user_agent.py: update user agent service
This commit is contained in:
@@ -11,6 +11,7 @@ from app.models.agent_conversation import AgentConversation, AgentConversationMe
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
STATEFUL_CONTEXT_KEYS = (
|
||||
"session_type",
|
||||
"entry_source",
|
||||
"request_context",
|
||||
"attachment_names",
|
||||
@@ -37,10 +38,16 @@ class AgentConversationService:
|
||||
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
normalized_user_id = str(user_id or "").strip() or None
|
||||
incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense"
|
||||
conversation = self.get_conversation(normalized_id) if normalized_id else None
|
||||
if conversation is not None and conversation.user_id != normalized_user_id:
|
||||
normalized_id = ""
|
||||
conversation = None
|
||||
if conversation is not None:
|
||||
existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
|
||||
if existing_session_type != incoming_session_type:
|
||||
normalized_id = ""
|
||||
conversation = None
|
||||
|
||||
if conversation is None:
|
||||
conversation = AgentConversation(
|
||||
@@ -117,6 +124,7 @@ class AgentConversationService:
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
session_type: str | None = None,
|
||||
) -> AgentConversation | None:
|
||||
self.prune_expired_conversations()
|
||||
|
||||
@@ -128,7 +136,16 @@ class AgentConversationService:
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
|
||||
return self.db.scalar(stmt.limit(1))
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
normalized_session_type = str(session_type or "").strip()
|
||||
if not normalized_session_type:
|
||||
return conversations[0] if conversations else None
|
||||
|
||||
for conversation in conversations:
|
||||
current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
|
||||
if current_session_type == normalized_session_type:
|
||||
return conversation
|
||||
return None
|
||||
|
||||
def hydrate_context_json(
|
||||
self,
|
||||
@@ -285,6 +302,7 @@ class AgentConversationService:
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
session_type: str | None = None,
|
||||
) -> int:
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if not normalized_user_id:
|
||||
@@ -294,6 +312,14 @@ class AgentConversationService:
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
normalized_session_type = str(session_type or "").strip()
|
||||
if normalized_session_type:
|
||||
conversations = [
|
||||
conversation
|
||||
for conversation in conversations
|
||||
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
|
||||
== normalized_session_type
|
||||
]
|
||||
if not conversations:
|
||||
return 0
|
||||
|
||||
@@ -303,6 +329,33 @@ class AgentConversationService:
|
||||
self.db.commit()
|
||||
return len(conversations)
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
user_id: str | None = None,
|
||||
source: str | None = "user_message",
|
||||
) -> int:
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
if not normalized_id:
|
||||
return 0
|
||||
|
||||
conversation = self.get_conversation(normalized_id)
|
||||
if conversation is None:
|
||||
return 0
|
||||
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id:
|
||||
return 0
|
||||
|
||||
normalized_source = str(source or "").strip()
|
||||
if normalized_source and str(conversation.source or "").strip() != normalized_source:
|
||||
return 0
|
||||
|
||||
self.db.delete(conversation)
|
||||
self.db.commit()
|
||||
return 1
|
||||
|
||||
def serialize_conversation(
|
||||
self,
|
||||
conversation: AgentConversation,
|
||||
|
||||
Reference in New Issue
Block a user