refactor(backend): update orchestrator endpoint and services

- endpoints/orchestrator.py: update orchestrator API endpoint
- services/agent_conversations.py: update agent conversations service
- services/orchestrator.py: update orchestrator service
- services/user_agent.py: update user agent service
This commit is contained in:
caoxiaozhu
2026-05-13 13:06:52 +00:00
parent 0f7bd43ce3
commit 70cff69b7f
4 changed files with 359 additions and 20 deletions

View File

@@ -11,6 +11,7 @@ from app.models.agent_conversation import AgentConversation, AgentConversationMe
from app.services.settings import SettingsService
STATEFUL_CONTEXT_KEYS = (
"session_type",
"entry_source",
"request_context",
"attachment_names",
@@ -37,10 +38,16 @@ class AgentConversationService:
normalized_id = str(conversation_id or "").strip()
normalized_user_id = str(user_id or "").strip() or None
incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense"
conversation = self.get_conversation(normalized_id) if normalized_id else None
if conversation is not None and conversation.user_id != normalized_user_id:
normalized_id = ""
conversation = None
if conversation is not None:
existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if existing_session_type != incoming_session_type:
normalized_id = ""
conversation = None
if conversation is None:
conversation = AgentConversation(
@@ -117,6 +124,7 @@ class AgentConversationService:
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
) -> AgentConversation | None:
self.prune_expired_conversations()
@@ -128,7 +136,16 @@ class AgentConversationService:
if source:
stmt = stmt.where(AgentConversation.source == source)
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
return self.db.scalar(stmt.limit(1))
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if not normalized_session_type:
return conversations[0] if conversations else None
for conversation in conversations:
current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if current_session_type == normalized_session_type:
return conversation
return None
def hydrate_context_json(
self,
@@ -285,6 +302,7 @@ class AgentConversationService:
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
) -> int:
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
@@ -294,6 +312,14 @@ class AgentConversationService:
if source:
stmt = stmt.where(AgentConversation.source == source)
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if normalized_session_type:
conversations = [
conversation
for conversation in conversations
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
== normalized_session_type
]
if not conversations:
return 0
@@ -303,6 +329,33 @@ class AgentConversationService:
self.db.commit()
return len(conversations)
def delete_conversation(
self,
*,
conversation_id: str | None,
user_id: str | None = None,
source: str | None = "user_message",
) -> int:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return 0
conversation = self.get_conversation(normalized_id)
if conversation is None:
return 0
normalized_user_id = str(user_id or "").strip()
if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id:
return 0
normalized_source = str(source or "").strip()
if normalized_source and str(conversation.source or "").strip() != normalized_source:
return 0
self.db.delete(conversation)
self.db.commit()
return 1
def serialize_conversation(
self,
conversation: AgentConversation,