Files
X-Financial/server/src/app/services/agent_conversations.py
caoxiaozhu d0e946cf47 feat: 完善文档中心与报销申请交互及侧边栏重构
后端优化编排器报销查询和本体检测精度,增强报销单草稿保
存和附件回填逻辑,前端重构侧边栏组件支持折叠和图标导
航,完善文档中心状态筛选和详情提示,报销创建和审批详情
页优化会话管理和费用明细交互,新增助手应用服务和预设动
作工具函数,补充单元测试覆盖。
2026-05-25 13:35:39 +08:00

691 lines
25 KiB
Python

from __future__ import annotations
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.services.settings import SettingsService
STATEFUL_CONTEXT_KEYS = (
"session_type",
"entry_source",
"request_context",
"attachment_names",
"attachment_count",
"ocr_summary",
"ocr_documents",
"review_form_values",
"business_time_context",
)
REVIEW_FLOW_CONTEXT_KEYS = {
"draft_claim_id",
"draft_claim_no",
"draft_status",
"request_context",
"attachment_names",
"attachment_count",
"ocr_summary",
"ocr_documents",
"review_form_values",
"business_time_context",
}
REVIEW_FLOW_CONTINUATION_KEYWORDS = (
"补充",
"继续",
"继续上传",
"当前",
"这张",
"这个",
"该单据",
"现有",
"已有",
"关联",
"合并",
"修改",
"更正",
"改成",
"调整",
"下一步",
"保存草稿",
)
NEW_EXPENSE_PROMPT_KEYWORDS = (
"申请报销",
"我要报销",
"我想报销",
"帮我报销",
"发起报销",
"提交报销",
"生成报销",
"创建报销",
"新建报销",
)
DEFAULT_CONVERSATION_RETENTION_DAYS = 3
class AgentConversationService:
def __init__(self, db: Session) -> None:
self.db = db
def get_or_create_conversation(
self,
*,
conversation_id: str | None,
user_id: str | None,
source: str,
context_json: dict[str, Any],
) -> AgentConversation:
self.prune_expired_conversations()
normalized_id = str(conversation_id or "").strip()
normalized_user_id = str(user_id or "").strip() or None
incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense"
incoming_draft_claim_id = self._resolve_draft_claim_id(context_json)
conversation = self.get_conversation(normalized_id) if normalized_id else None
if conversation is not None and conversation.user_id != normalized_user_id:
normalized_id = ""
conversation = None
if conversation is not None:
existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if existing_session_type != incoming_session_type:
normalized_id = ""
conversation = None
if conversation is not None and self._has_draft_claim_scope_conflict(
conversation,
incoming_draft_claim_id,
):
normalized_id = ""
conversation = None
if conversation is None:
conversation = AgentConversation(
conversation_id=normalized_id or f"conv_{uuid.uuid4().hex[:16]}",
user_id=normalized_user_id,
source=source,
entry_source=str(context_json.get("entry_source") or "").strip() or None,
title=self._resolve_title(context_json),
draft_claim_id=incoming_draft_claim_id or None,
state_json=self._extract_state_json(context_json),
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
if not conversation.user_id and normalized_user_id:
conversation.user_id = normalized_user_id
if not conversation.entry_source:
conversation.entry_source = str(context_json.get("entry_source") or "").strip() or None
if not conversation.title:
conversation.title = self._resolve_title(context_json)
if incoming_draft_claim_id:
conversation.draft_claim_id = incoming_draft_claim_id
conversation.state_json = self._merge_state_json(
conversation.state_json,
self._extract_state_json(context_json),
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
def prune_expired_conversations(
self,
*,
retention_days: int | None = None,
) -> int:
resolved_retention_days = retention_days or self._resolve_retention_days()
cutoff = datetime.now(UTC) - timedelta(days=max(1, resolved_retention_days))
stmt = select(AgentConversation).where(AgentConversation.updated_at < cutoff)
expired_conversations = [
conversation
for conversation in self.db.scalars(stmt).all()
if not self._is_saved_conversation(conversation)
]
if not expired_conversations:
return 0
for conversation in expired_conversations:
self.db.delete(conversation)
self.db.commit()
return len(expired_conversations)
@staticmethod
def _is_saved_conversation(conversation: AgentConversation) -> bool:
if str(conversation.draft_claim_id or "").strip():
return True
state_json = dict(conversation.state_json or {})
return bool(str(state_json.get("draft_claim_id") or "").strip())
def _resolve_retention_days(self) -> int:
try:
settings_row, _ = SettingsService(self.db).ensure_settings_ready()
configured_days = int(
getattr(
settings_row,
"conversation_retention_days",
DEFAULT_CONVERSATION_RETENTION_DAYS,
)
or DEFAULT_CONVERSATION_RETENTION_DAYS
)
return max(1, min(configured_days, 10))
except Exception:
self.db.rollback()
return DEFAULT_CONVERSATION_RETENTION_DAYS
def get_conversation(self, conversation_id: str) -> AgentConversation | None:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return None
stmt = select(AgentConversation).where(AgentConversation.conversation_id == normalized_id)
return self.db.scalar(stmt)
def get_latest_conversation_for_user(
self,
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
prefer_recoverable: bool = False,
) -> AgentConversation | None:
self.prune_expired_conversations()
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
return None
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
if source:
stmt = stmt.where(AgentConversation.source == source)
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if not normalized_session_type:
return conversations[0] if conversations else None
fallback_conversation = None
for conversation in conversations:
current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if current_session_type != normalized_session_type:
continue
if fallback_conversation is None:
fallback_conversation = conversation
if not prefer_recoverable or self._is_recoverable_conversation(conversation):
return conversation
return fallback_conversation
@staticmethod
def _is_recoverable_conversation(conversation: AgentConversation) -> bool:
if str(conversation.draft_claim_id or "").strip():
return True
state_json = dict(conversation.state_json or {})
documents = state_json.get("ocr_documents")
if not isinstance(documents, list):
return False
for item in documents:
if not isinstance(item, dict):
continue
preview_url = str(item.get("preview_url") or "").strip()
preview_data_url = str(item.get("preview_data_url") or "").strip()
preview_kind = str(item.get("preview_kind") or "").strip()
if (preview_url or preview_data_url) and preview_kind in {"image", "pdf"}:
return True
return False
def hydrate_context_json(
self,
*,
conversation: AgentConversation,
context_json: dict[str, Any],
message: str | None = None,
history_limit: int = 8,
) -> dict[str, Any]:
merged = dict(context_json or {})
incoming_draft_claim_id = self._resolve_draft_claim_id(merged)
if self._has_draft_claim_scope_conflict(conversation, incoming_draft_claim_id):
return merged
state_json = dict(conversation.state_json or {})
should_hydrate_review_flow = self._should_hydrate_review_flow_context(
context_json=merged,
message=message,
)
if not should_hydrate_review_flow:
for key in REVIEW_FLOW_CONTEXT_KEYS:
if key == "business_time_context" and not self._is_empty_value(merged.get(key)):
continue
merged.pop(key, None)
merged["conversation_id"] = conversation.conversation_id
merged["conversation_history"] = self.list_message_history(
conversation.conversation_id,
limit=history_limit,
)
if conversation.last_scenario:
merged.setdefault("conversation_scenario", conversation.last_scenario)
if conversation.last_intent:
merged.setdefault("conversation_intent", conversation.last_intent)
if (
should_hydrate_review_flow
and conversation.draft_claim_id
and not str(merged.get("draft_claim_id") or "").strip()
):
merged["draft_claim_id"] = conversation.draft_claim_id
merged["conversation_state"] = state_json
for key in STATEFUL_CONTEXT_KEYS:
if key in REVIEW_FLOW_CONTEXT_KEYS and not should_hydrate_review_flow:
continue
if self._is_empty_value(merged.get(key)) and not self._is_empty_value(state_json.get(key)):
merged[key] = state_json.get(key)
return merged
@staticmethod
def _should_hydrate_review_flow_context(
*,
context_json: dict[str, Any],
message: str | None,
) -> bool:
if isinstance(context_json.get("expense_scene_selection"), dict):
return True
if AgentConversationService._resolve_draft_claim_id(context_json):
compact_message = str(message or "").replace(" ", "")
if compact_message and any(keyword in compact_message for keyword in NEW_EXPENSE_PROMPT_KEYWORDS):
return False
return True
if str(context_json.get("review_action") or "").strip():
return True
if str(context_json.get("entry_source") or "").strip() == "detail":
return True
if not AgentConversationService._is_empty_value(context_json.get("attachment_names")):
return True
if not AgentConversationService._is_empty_value(context_json.get("ocr_documents")):
return True
if str(context_json.get("ocr_summary") or "").strip():
return True
try:
if int(context_json.get("attachment_count") or 0) > 0:
return True
except (TypeError, ValueError):
pass
compact_message = str(message or "").replace(" ", "")
if not compact_message:
return False
if any(keyword in compact_message for keyword in NEW_EXPENSE_PROMPT_KEYWORDS):
return False
return any(keyword in compact_message for keyword in REVIEW_FLOW_CONTINUATION_KEYWORDS)
def append_message(
self,
*,
conversation_id: str,
role: str,
content: str,
run_id: str | None = None,
message_json: dict[str, Any] | None = None,
) -> AgentConversationMessage | None:
normalized_content = str(content or "").strip()
if not normalized_content:
return None
conversation = self.get_conversation(conversation_id)
if conversation is None:
return None
next_sequence = int(conversation.message_count or 0) + 1
normalized_message_json = dict(message_json or {})
normalized_message_json.setdefault("sequence", next_sequence)
message = AgentConversationMessage(
conversation_id=conversation_id,
run_id=run_id,
role=str(role or "user").strip() or "user",
content=normalized_content,
message_json=normalized_message_json,
created_at=datetime.now(UTC),
)
conversation.message_count = next_sequence
if role == "user" and not conversation.title:
conversation.title = normalized_content[:48]
conversation.updated_at = datetime.now(UTC)
self.db.add(message)
self.db.add(conversation)
self.db.commit()
self.db.refresh(message)
return message
def list_message_history(
self,
conversation_id: str,
*,
limit: int = 8,
) -> list[dict[str, Any]]:
normalized_id = str(conversation_id or "").strip()
if not normalized_id or limit <= 0:
return []
messages = self.list_messages(normalized_id)
if limit > 0:
messages = messages[-limit:]
return [
{
"role": item.role,
"content": item.content,
"run_id": item.run_id,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in messages
]
def list_messages(
self,
conversation_id: str,
*,
limit: int | None = None,
) -> list[AgentConversationMessage]:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return []
stmt = (
select(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == normalized_id)
.order_by(AgentConversationMessage.created_at.asc())
)
messages = list(self.db.scalars(stmt).all())
messages.sort(key=self._message_sort_key)
if limit and limit > 0:
return messages[:limit]
return messages
def update_state(
self,
*,
conversation_id: str,
run_id: str | None,
scenario: str | None,
intent: str | None,
context_json: dict[str, Any],
draft_payload: dict[str, Any] | None = None,
) -> AgentConversation | None:
conversation = self.get_conversation(conversation_id)
if conversation is None:
return None
conversation.last_run_id = str(run_id or "").strip() or conversation.last_run_id
conversation.last_scenario = str(scenario or "").strip() or conversation.last_scenario
conversation.last_intent = str(intent or "").strip() or conversation.last_intent
if draft_payload and str(draft_payload.get("claim_id") or "").strip():
conversation.draft_claim_id = str(draft_payload["claim_id"]).strip()
next_state = self._merge_state_json(
conversation.state_json,
self._extract_state_json(context_json),
)
if draft_payload:
if str(draft_payload.get("claim_id") or "").strip():
next_state["draft_claim_id"] = str(draft_payload["claim_id"]).strip()
if str(draft_payload.get("claim_no") or "").strip():
next_state["draft_claim_no"] = str(draft_payload["claim_no"]).strip()
if str(draft_payload.get("status") or "").strip():
next_state["draft_status"] = str(draft_payload["status"]).strip()
conversation.state_json = next_state
conversation.updated_at = datetime.now(UTC)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
def delete_user_conversations(
self,
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
) -> int:
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
return 0
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
if source:
stmt = stmt.where(AgentConversation.source == source)
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if normalized_session_type:
conversations = [
conversation
for conversation in conversations
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
== normalized_session_type
]
if not conversations:
return 0
for conversation in conversations:
self.db.delete(conversation)
self.db.commit()
return len(conversations)
def delete_conversations_for_draft_claim(
self,
*,
claim_id: str | None,
source: str | None = "user_message",
session_type: str | None = "expense",
) -> int:
normalized_claim_id = str(claim_id or "").strip()
if not normalized_claim_id:
return 0
stmt = select(AgentConversation).where(AgentConversation.draft_claim_id == normalized_claim_id)
if source:
stmt = stmt.where(AgentConversation.source == source)
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if normalized_session_type:
conversations = [
conversation
for conversation in conversations
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
== normalized_session_type
]
if not conversations:
return 0
for conversation in conversations:
self.db.delete(conversation)
self.db.commit()
return len(conversations)
def delete_conversation(
self,
*,
conversation_id: str | None,
user_id: str | None = None,
source: str | None = "user_message",
) -> int:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return 0
conversation = self.get_conversation(normalized_id)
if conversation is None:
return 0
normalized_user_id = str(user_id or "").strip()
if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id:
return 0
normalized_source = str(source or "").strip()
if normalized_source and str(conversation.source or "").strip() != normalized_source:
return 0
self.db.delete(conversation)
self.db.commit()
return 1
def serialize_conversation(
self,
conversation: AgentConversation,
*,
include_messages: bool = True,
message_limit: int | None = None,
) -> dict[str, Any]:
payload = {
"conversation_id": conversation.conversation_id,
"user_id": conversation.user_id,
"source": conversation.source,
"entry_source": conversation.entry_source,
"title": conversation.title,
"last_run_id": conversation.last_run_id,
"last_scenario": conversation.last_scenario,
"last_intent": conversation.last_intent,
"draft_claim_id": conversation.draft_claim_id,
"state_json": dict(conversation.state_json or {}),
"message_count": int(conversation.message_count or 0),
"updated_at": conversation.updated_at,
"messages": [],
}
if include_messages:
payload["messages"] = [
self.serialize_message(item)
for item in self.list_messages(conversation.conversation_id, limit=message_limit)
]
return payload
@staticmethod
def serialize_message(message: AgentConversationMessage) -> dict[str, Any]:
return {
"id": message.id,
"role": message.role,
"content": message.content,
"run_id": message.run_id,
"message_json": dict(message.message_json or {}),
"created_at": message.created_at,
}
@staticmethod
def _message_sort_key(message: AgentConversationMessage) -> tuple[int, datetime, str, int, str]:
message_json = dict(message.message_json or {})
sequence = AgentConversationService._coerce_message_sequence(message_json.get("sequence"))
created_at = message.created_at or datetime.min.replace(tzinfo=UTC)
run_id = str(message.run_id or "")
role_priority = 0 if str(message.role or "").strip() == "user" else 1
fallback_sequence = sequence if sequence is not None else 10**9
return (
fallback_sequence,
created_at,
run_id,
role_priority,
str(message.id or ""),
)
@staticmethod
def _coerce_message_sequence(value: Any) -> int | None:
try:
normalized = int(value)
except (TypeError, ValueError):
return None
return normalized if normalized > 0 else None
@staticmethod
def _is_empty_value(value: Any) -> bool:
if value is None:
return True
if isinstance(value, str):
return not value.strip()
if isinstance(value, (list, tuple, set, dict)):
return len(value) == 0
return False
@staticmethod
def _resolve_title(context_json: dict[str, Any]) -> str | None:
request_context = context_json.get("request_context")
if isinstance(request_context, dict):
for key in ("reason", "title", "id"):
value = str(request_context.get(key) or "").strip()
if value:
return value[:200]
return None
@staticmethod
def _extract_state_json(context_json: dict[str, Any]) -> dict[str, Any]:
state_json: dict[str, Any] = {}
for key in STATEFUL_CONTEXT_KEYS:
value = context_json.get(key)
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
if isinstance(value, (list, dict)) and not value:
continue
state_json[key] = value
draft_claim_id = AgentConversationService._resolve_draft_claim_id(context_json)
if draft_claim_id:
state_json["draft_claim_id"] = draft_claim_id
return state_json
@staticmethod
def _resolve_draft_claim_id(context_json: dict[str, Any]) -> str:
draft_claim_id = str((context_json or {}).get("draft_claim_id") or "").strip()
if draft_claim_id:
return draft_claim_id
request_context = (context_json or {}).get("request_context")
if isinstance(request_context, dict):
return str(
request_context.get("claim_id")
or request_context.get("claimId")
or request_context.get("draft_claim_id")
or request_context.get("draftClaimId")
or ""
).strip()
return ""
@staticmethod
def _resolve_conversation_draft_claim_id(conversation: AgentConversation) -> str:
state_json = dict(conversation.state_json or {})
return str(
conversation.draft_claim_id
or state_json.get("draft_claim_id")
or ""
).strip()
@staticmethod
def _has_draft_claim_scope_conflict(
conversation: AgentConversation,
incoming_draft_claim_id: str | None,
) -> bool:
incoming_claim_id = str(incoming_draft_claim_id or "").strip()
if not incoming_claim_id:
return False
existing_claim_id = AgentConversationService._resolve_conversation_draft_claim_id(conversation)
return bool(existing_claim_id and existing_claim_id != incoming_claim_id)
@staticmethod
def _merge_state_json(
current_state: dict[str, Any] | None,
incoming_state: dict[str, Any] | None,
) -> dict[str, Any]:
merged = dict(current_state or {})
for key, value in (incoming_state or {}).items():
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
if isinstance(value, (list, dict)) and not value:
continue
merged[key] = value
return merged