refactor(backend): update orchestrator endpoint and services
- endpoints/orchestrator.py: update orchestrator API endpoint - services/agent_conversations.py: update agent conversations service - services/orchestrator.py: update orchestrator service - services/user_agent.py: update user agent service
This commit is contained in:
@@ -48,9 +48,14 @@ def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> Orchestrato
|
||||
def get_latest_conversation(
|
||||
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
|
||||
db: DbSession,
|
||||
session_type: Annotated[str | None, Query(description="会话类型,例如 expense / knowledge。")] = None,
|
||||
) -> ConversationLookupResponse:
|
||||
service = AgentConversationService(db)
|
||||
conversation = service.get_latest_conversation_for_user(user_id=user_id, source="user_message")
|
||||
conversation = service.get_latest_conversation_for_user(
|
||||
user_id=user_id,
|
||||
source="user_message",
|
||||
session_type=session_type,
|
||||
)
|
||||
if conversation is None:
|
||||
return ConversationLookupResponse(found=False, conversation=None)
|
||||
|
||||
@@ -60,6 +65,25 @@ def get_latest_conversation(
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/conversations/{conversation_id}",
|
||||
response_model=ConversationDeleteResponse,
|
||||
summary="删除当前用户单个会话",
|
||||
description="删除当前用户在智能体工作台中的单个会话,用于清空当前 session 内容。",
|
||||
)
|
||||
def delete_single_conversation(
|
||||
conversation_id: str,
|
||||
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
|
||||
db: DbSession,
|
||||
) -> ConversationDeleteResponse:
|
||||
deleted_count = AgentConversationService(db).delete_conversation(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
source="user_message",
|
||||
)
|
||||
return ConversationDeleteResponse(deleted_count=deleted_count)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/conversations",
|
||||
response_model=ConversationDeleteResponse,
|
||||
@@ -69,9 +93,11 @@ def get_latest_conversation(
|
||||
def delete_user_conversations(
|
||||
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
|
||||
db: DbSession,
|
||||
session_type: Annotated[str | None, Query(description="可选,会话类型,例如 expense / knowledge。")] = None,
|
||||
) -> ConversationDeleteResponse:
|
||||
deleted_count = AgentConversationService(db).delete_user_conversations(
|
||||
user_id=user_id,
|
||||
source="user_message",
|
||||
session_type=session_type,
|
||||
)
|
||||
return ConversationDeleteResponse(deleted_count=deleted_count)
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.models.agent_conversation import AgentConversation, AgentConversationMe
|
||||
from app.services.settings import SettingsService
|
||||
|
||||
STATEFUL_CONTEXT_KEYS = (
|
||||
"session_type",
|
||||
"entry_source",
|
||||
"request_context",
|
||||
"attachment_names",
|
||||
@@ -37,10 +38,16 @@ class AgentConversationService:
|
||||
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
normalized_user_id = str(user_id or "").strip() or None
|
||||
incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense"
|
||||
conversation = self.get_conversation(normalized_id) if normalized_id else None
|
||||
if conversation is not None and conversation.user_id != normalized_user_id:
|
||||
normalized_id = ""
|
||||
conversation = None
|
||||
if conversation is not None:
|
||||
existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
|
||||
if existing_session_type != incoming_session_type:
|
||||
normalized_id = ""
|
||||
conversation = None
|
||||
|
||||
if conversation is None:
|
||||
conversation = AgentConversation(
|
||||
@@ -117,6 +124,7 @@ class AgentConversationService:
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
session_type: str | None = None,
|
||||
) -> AgentConversation | None:
|
||||
self.prune_expired_conversations()
|
||||
|
||||
@@ -128,7 +136,16 @@ class AgentConversationService:
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
|
||||
return self.db.scalar(stmt.limit(1))
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
normalized_session_type = str(session_type or "").strip()
|
||||
if not normalized_session_type:
|
||||
return conversations[0] if conversations else None
|
||||
|
||||
for conversation in conversations:
|
||||
current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
|
||||
if current_session_type == normalized_session_type:
|
||||
return conversation
|
||||
return None
|
||||
|
||||
def hydrate_context_json(
|
||||
self,
|
||||
@@ -285,6 +302,7 @@ class AgentConversationService:
|
||||
*,
|
||||
user_id: str | None,
|
||||
source: str | None = "user_message",
|
||||
session_type: str | None = None,
|
||||
) -> int:
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if not normalized_user_id:
|
||||
@@ -294,6 +312,14 @@ class AgentConversationService:
|
||||
if source:
|
||||
stmt = stmt.where(AgentConversation.source == source)
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
normalized_session_type = str(session_type or "").strip()
|
||||
if normalized_session_type:
|
||||
conversations = [
|
||||
conversation
|
||||
for conversation in conversations
|
||||
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
|
||||
== normalized_session_type
|
||||
]
|
||||
if not conversations:
|
||||
return 0
|
||||
|
||||
@@ -303,6 +329,33 @@ class AgentConversationService:
|
||||
self.db.commit()
|
||||
return len(conversations)
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
user_id: str | None = None,
|
||||
source: str | None = "user_message",
|
||||
) -> int:
|
||||
normalized_id = str(conversation_id or "").strip()
|
||||
if not normalized_id:
|
||||
return 0
|
||||
|
||||
conversation = self.get_conversation(normalized_id)
|
||||
if conversation is None:
|
||||
return 0
|
||||
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id:
|
||||
return 0
|
||||
|
||||
normalized_source = str(source or "").strip()
|
||||
if normalized_source and str(conversation.source or "").strip() != normalized_source:
|
||||
return 0
|
||||
|
||||
self.db.delete(conversation)
|
||||
self.db.commit()
|
||||
return 1
|
||||
|
||||
def serialize_conversation(
|
||||
self,
|
||||
conversation: AgentConversation,
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import UTC, datetime
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent_enums import (
|
||||
@@ -18,6 +18,7 @@ from app.core.agent_enums import (
|
||||
AgentToolType,
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import (
|
||||
AccountsPayableRecord,
|
||||
AccountsReceivableRecord,
|
||||
@@ -59,6 +60,10 @@ class ExecutionOutcome:
|
||||
failed_tool_count: int
|
||||
|
||||
|
||||
PRIVILEGED_EXPENSE_QUERY_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
|
||||
SELF_REFERENCE_KEYWORDS = ("我的", "我自己", "本人", "我名下", "给我查", "我提交", "我申请")
|
||||
|
||||
|
||||
class OrchestratorService:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
@@ -497,7 +502,12 @@ class OrchestratorService:
|
||||
tool_name=self._database_tool_name(ontology.scenario),
|
||||
request_json=self._build_ontology_json(ontology),
|
||||
context_json=context_json,
|
||||
executor=lambda: self._build_database_answer(ontology),
|
||||
executor=lambda: self._build_database_answer(
|
||||
ontology,
|
||||
user_id=payload.user_id,
|
||||
context_json=context_json,
|
||||
message=payload.message or "",
|
||||
),
|
||||
fallback_factory=lambda exc: {
|
||||
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
|
||||
"degraded": True,
|
||||
@@ -831,25 +841,56 @@ class OrchestratorService:
|
||||
if expected == tool_type.lower():
|
||||
raise RuntimeError(f"simulated {tool_type} failure")
|
||||
|
||||
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
|
||||
def _build_database_answer(
|
||||
self,
|
||||
ontology: OntologyParseResult,
|
||||
*,
|
||||
user_id: str | None,
|
||||
context_json: dict[str, Any],
|
||||
message: str,
|
||||
) -> dict[str, Any]:
|
||||
if ontology.scenario == "expense":
|
||||
count_stmt = select(func.count()).select_from(ExpenseClaim)
|
||||
amount_stmt = select(
|
||||
func.coalesce(func.sum(ExpenseClaim.amount), 0)
|
||||
).select_from(ExpenseClaim)
|
||||
employee_names = [
|
||||
item.normalized_value
|
||||
for item in ontology.entities
|
||||
if item.type == "employee"
|
||||
]
|
||||
if employee_names:
|
||||
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
|
||||
amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim)
|
||||
preview_stmt = (
|
||||
select(ExpenseClaim)
|
||||
.order_by(ExpenseClaim.occurred_at.desc(), ExpenseClaim.created_at.desc())
|
||||
.limit(5)
|
||||
)
|
||||
conditions, scope_label, scoped_to_current_user = self._build_expense_query_scope(
|
||||
ontology=ontology,
|
||||
user_id=user_id,
|
||||
context_json=context_json,
|
||||
message=message,
|
||||
)
|
||||
for condition in conditions:
|
||||
count_stmt = count_stmt.where(condition)
|
||||
amount_stmt = amount_stmt.where(condition)
|
||||
preview_stmt = preview_stmt.where(condition)
|
||||
total_count = int(self.db.scalar(count_stmt) or 0)
|
||||
total_amount = float(self.db.scalar(amount_stmt) or 0)
|
||||
preview_claims = list(self.db.scalars(preview_stmt).all())
|
||||
return {
|
||||
"record_count": total_count,
|
||||
"total_amount": round(total_amount, 2),
|
||||
"scope_label": scope_label,
|
||||
"scoped_to_current_user": scoped_to_current_user,
|
||||
"records": [
|
||||
{
|
||||
"claim_id": claim.id,
|
||||
"claim_no": claim.claim_no,
|
||||
"employee_name": claim.employee_name,
|
||||
"expense_type": claim.expense_type,
|
||||
"amount": round(float(claim.amount), 2),
|
||||
"status": claim.status,
|
||||
"approval_stage": claim.approval_stage,
|
||||
"occurred_at": claim.occurred_at.date().isoformat() if claim.occurred_at else "",
|
||||
"reason": claim.reason,
|
||||
"location": claim.location,
|
||||
}
|
||||
for claim in preview_claims
|
||||
],
|
||||
"has_more": total_count > len(preview_claims),
|
||||
}
|
||||
|
||||
if ontology.scenario == "accounts_receivable":
|
||||
@@ -885,6 +926,183 @@ class OrchestratorService:
|
||||
"outstanding_amount": round(total_amount, 2),
|
||||
}
|
||||
|
||||
def _build_expense_query_scope(
|
||||
self,
|
||||
*,
|
||||
ontology: OntologyParseResult,
|
||||
user_id: str | None,
|
||||
context_json: dict[str, Any],
|
||||
message: str,
|
||||
) -> tuple[list[Any], str, bool]:
|
||||
conditions: list[Any] = []
|
||||
explicit_employee_names = list(
|
||||
dict.fromkeys(
|
||||
str(item.value or "").strip()
|
||||
for item in ontology.entities
|
||||
if item.type == "employee" and str(item.value or "").strip()
|
||||
)
|
||||
)
|
||||
expense_claim_nos = list(
|
||||
dict.fromkeys(
|
||||
str(item.normalized_value or item.value or "").strip().upper()
|
||||
for item in ontology.entities
|
||||
if item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip()
|
||||
)
|
||||
)
|
||||
expense_types = list(
|
||||
dict.fromkeys(
|
||||
str(item.normalized_value or item.value or "").strip()
|
||||
for item in ontology.entities
|
||||
if item.type == "expense_type" and str(item.normalized_value or item.value or "").strip()
|
||||
)
|
||||
)
|
||||
status_values = list(
|
||||
dict.fromkeys(
|
||||
str(item.value).strip()
|
||||
for item in ontology.constraints
|
||||
if item.field == "status" and item.operator == "=" and str(item.value).strip()
|
||||
)
|
||||
)
|
||||
amount_constraints = [
|
||||
item
|
||||
for item in ontology.constraints
|
||||
if item.field == "amount" and item.operator in {">", ">=", "<", "<=", "="}
|
||||
]
|
||||
scope_label = "报销单"
|
||||
scoped_to_current_user = False
|
||||
|
||||
if expense_claim_nos:
|
||||
conditions.append(ExpenseClaim.claim_no.in_(expense_claim_nos))
|
||||
if expense_types:
|
||||
conditions.append(ExpenseClaim.expense_type.in_(expense_types))
|
||||
if status_values:
|
||||
conditions.append(ExpenseClaim.status.in_(status_values))
|
||||
|
||||
for item in amount_constraints:
|
||||
amount_value = float(item.value)
|
||||
if item.operator == ">":
|
||||
conditions.append(ExpenseClaim.amount > amount_value)
|
||||
elif item.operator == ">=":
|
||||
conditions.append(ExpenseClaim.amount >= amount_value)
|
||||
elif item.operator == "<":
|
||||
conditions.append(ExpenseClaim.amount < amount_value)
|
||||
elif item.operator == "<=":
|
||||
conditions.append(ExpenseClaim.amount <= amount_value)
|
||||
else:
|
||||
conditions.append(ExpenseClaim.amount == amount_value)
|
||||
|
||||
if ontology.time_range.start_date:
|
||||
conditions.append(
|
||||
ExpenseClaim.occurred_at
|
||||
>= datetime.fromisoformat(f"{ontology.time_range.start_date}T00:00:00+00:00")
|
||||
)
|
||||
if ontology.time_range.end_date:
|
||||
conditions.append(
|
||||
ExpenseClaim.occurred_at
|
||||
<= datetime.fromisoformat(f"{ontology.time_range.end_date}T23:59:59.999999+00:00")
|
||||
)
|
||||
|
||||
has_privileged_access = self._has_privileged_expense_query_access(context_json)
|
||||
refers_to_self = self._is_self_expense_query(message)
|
||||
if not has_privileged_access:
|
||||
owner_conditions, owner_label = self._build_current_user_claim_conditions(
|
||||
user_id=user_id,
|
||||
context_json=context_json,
|
||||
)
|
||||
if owner_conditions:
|
||||
conditions.append(or_(*owner_conditions))
|
||||
scope_label = owner_label
|
||||
scoped_to_current_user = True
|
||||
else:
|
||||
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
|
||||
scope_label = "你的报销单"
|
||||
scoped_to_current_user = True
|
||||
elif explicit_employee_names:
|
||||
conditions.append(ExpenseClaim.employee_name.in_(explicit_employee_names))
|
||||
scope_label = f"{'、'.join(explicit_employee_names)}的报销单"
|
||||
elif refers_to_self:
|
||||
owner_conditions, owner_label = self._build_current_user_claim_conditions(
|
||||
user_id=user_id,
|
||||
context_json=context_json,
|
||||
)
|
||||
if owner_conditions:
|
||||
conditions.append(or_(*owner_conditions))
|
||||
scope_label = owner_label
|
||||
scoped_to_current_user = True
|
||||
else:
|
||||
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
|
||||
scope_label = "你的报销单"
|
||||
scoped_to_current_user = True
|
||||
else:
|
||||
scope_label = "全部报销单"
|
||||
|
||||
return conditions, scope_label, scoped_to_current_user
|
||||
|
||||
def _build_current_user_claim_conditions(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None,
|
||||
context_json: dict[str, Any],
|
||||
) -> tuple[list[Any], str]:
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
display_name = str(context_json.get("name") or "").strip()
|
||||
employee = None
|
||||
if normalized_user_id:
|
||||
employee = self.db.scalar(
|
||||
select(Employee)
|
||||
.where(func.lower(Employee.email) == normalized_user_id.lower())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
conditions: list[Any] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
|
||||
def add_condition(field_name: str, value: str | None) -> None:
|
||||
normalized = str(value or "").strip()
|
||||
if not normalized:
|
||||
return
|
||||
|
||||
marker = (field_name, normalized.lower())
|
||||
if marker in seen:
|
||||
return
|
||||
seen.add(marker)
|
||||
|
||||
if field_name == "employee_id":
|
||||
conditions.append(ExpenseClaim.employee_id == normalized)
|
||||
return
|
||||
conditions.append(ExpenseClaim.employee_name == normalized)
|
||||
|
||||
if employee is not None:
|
||||
add_condition("employee_id", employee.id)
|
||||
add_condition("employee_name", employee.name)
|
||||
add_condition("employee_name", employee.email)
|
||||
if not display_name:
|
||||
display_name = employee.name
|
||||
|
||||
add_condition("employee_name", display_name)
|
||||
add_condition("employee_name", normalized_user_id)
|
||||
|
||||
subject_name = display_name or (employee.name if employee is not None else "") or normalized_user_id
|
||||
if subject_name:
|
||||
return conditions, "你的报销单"
|
||||
return conditions, "当前用户的报销单"
|
||||
|
||||
@staticmethod
|
||||
def _has_privileged_expense_query_access(context_json: dict[str, Any]) -> bool:
|
||||
if bool(context_json.get("is_admin")):
|
||||
return True
|
||||
role_codes = {
|
||||
str(item).strip().lower()
|
||||
for item in context_json.get("role_codes", [])
|
||||
if str(item).strip()
|
||||
}
|
||||
return bool(role_codes & PRIVILEGED_EXPENSE_QUERY_ROLE_CODES)
|
||||
|
||||
@staticmethod
|
||||
def _is_self_expense_query(message: str) -> bool:
|
||||
compact_message = "".join(str(message or "").split())
|
||||
return any(keyword in compact_message for keyword in SELF_REFERENCE_KEYWORDS)
|
||||
|
||||
@staticmethod
|
||||
def _build_user_query_result(
|
||||
ontology: OntologyParseResult,
|
||||
|
||||
@@ -85,6 +85,15 @@ GROUP_SCENE_LABELS = {
|
||||
"other": "其他费用",
|
||||
}
|
||||
|
||||
EXPENSE_STATUS_LABELS = {
|
||||
"draft": "草稿",
|
||||
"submitted": "已提交",
|
||||
"review": "审核中",
|
||||
"approved": "已通过",
|
||||
"rejected": "已驳回",
|
||||
"paid": "已付款",
|
||||
}
|
||||
|
||||
SLOT_LABELS = {
|
||||
"expense_type": "报销类型",
|
||||
"customer_name": "客户名称",
|
||||
@@ -389,10 +398,41 @@ class UserAgentService:
|
||||
if scenario == "expense":
|
||||
record_count = int(data.get("record_count") or 0)
|
||||
total_amount = float(data.get("total_amount") or 0)
|
||||
return (
|
||||
f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。"
|
||||
"如需继续处理,可以查看明细或生成处理意见草稿。"
|
||||
)
|
||||
scope_label = str(data.get("scope_label") or subject).strip() or subject
|
||||
preview_records = data.get("records")
|
||||
if record_count <= 0:
|
||||
return f"当前没有查到{scope_label}。你可以补充时间范围、单号或状态继续筛选。"
|
||||
|
||||
summary = f"查到{scope_label}共 {record_count} 笔,金额合计 {total_amount:.2f} 元。"
|
||||
if not isinstance(preview_records, list) or not preview_records:
|
||||
return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。"
|
||||
|
||||
preview_text: list[str] = []
|
||||
for item in preview_records[:3]:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
claim_no = str(item.get("claim_no") or "未编号").strip() or "未编号"
|
||||
occurred_at = str(item.get("occurred_at") or "").strip()
|
||||
expense_type = EXPENSE_TYPE_LABELS.get(
|
||||
str(item.get("expense_type") or "").strip(),
|
||||
str(item.get("expense_type") or "报销").strip() or "报销",
|
||||
)
|
||||
amount = float(item.get("amount") or 0)
|
||||
status = EXPENSE_STATUS_LABELS.get(
|
||||
str(item.get("status") or "").strip(),
|
||||
str(item.get("status") or "处理中").strip() or "处理中",
|
||||
)
|
||||
date_prefix = f"{occurred_at}," if occurred_at else ""
|
||||
preview_text.append(
|
||||
f"{claim_no}({date_prefix}{expense_type},{amount:.2f} 元,{status})"
|
||||
)
|
||||
|
||||
if not preview_text:
|
||||
return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。"
|
||||
|
||||
has_more = bool(data.get("has_more")) or record_count > len(preview_records)
|
||||
more_hint = " 当前先展示最近几笔,可继续查看明细。" if has_more else ""
|
||||
return f"{summary} 其中包括:{';'.join(preview_text)}。{more_hint}".strip()
|
||||
|
||||
if scenario == "accounts_receivable":
|
||||
record_count = int(data.get("record_count") or 0)
|
||||
@@ -1249,6 +1289,8 @@ class UserAgentService:
|
||||
payload: UserAgentRequest,
|
||||
review_payload: UserAgentReviewPayload | None,
|
||||
) -> bool:
|
||||
if payload.ontology.scenario == "expense" and payload.ontology.intent in {"query", "compare"}:
|
||||
return True
|
||||
if review_payload is None:
|
||||
return False
|
||||
return payload.ontology.scenario == "expense" and (
|
||||
|
||||
Reference in New Issue
Block a user