refactor(backend): update orchestrator endpoint and services

- endpoints/orchestrator.py: update orchestrator API endpoint
- services/agent_conversations.py: update agent conversations service
- services/orchestrator.py: update orchestrator service
- services/user_agent.py: update user agent service
This commit is contained in:
caoxiaozhu
2026-05-13 13:06:52 +00:00
parent 0f7bd43ce3
commit 70cff69b7f
4 changed files with 359 additions and 20 deletions

View File

@@ -48,9 +48,14 @@ def run_orchestrator(payload: OrchestratorRequest, db: DbSession) -> Orchestrato
def get_latest_conversation(
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
db: DbSession,
session_type: Annotated[str | None, Query(description="会话类型,例如 expense / knowledge。")] = None,
) -> ConversationLookupResponse:
service = AgentConversationService(db)
conversation = service.get_latest_conversation_for_user(user_id=user_id, source="user_message")
conversation = service.get_latest_conversation_for_user(
user_id=user_id,
source="user_message",
session_type=session_type,
)
if conversation is None:
return ConversationLookupResponse(found=False, conversation=None)
@@ -60,6 +65,25 @@ def get_latest_conversation(
)
@router.delete(
"/conversations/{conversation_id}",
response_model=ConversationDeleteResponse,
summary="删除当前用户单个会话",
description="删除当前用户在智能体工作台中的单个会话,用于清空当前 session 内容。",
)
def delete_single_conversation(
conversation_id: str,
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
db: DbSession,
) -> ConversationDeleteResponse:
deleted_count = AgentConversationService(db).delete_conversation(
conversation_id=conversation_id,
user_id=user_id,
source="user_message",
)
return ConversationDeleteResponse(deleted_count=deleted_count)
@router.delete(
"/conversations",
response_model=ConversationDeleteResponse,
@@ -69,9 +93,11 @@ def get_latest_conversation(
def delete_user_conversations(
user_id: Annotated[str, Query(min_length=1, description="当前用户 ID。")],
db: DbSession,
session_type: Annotated[str | None, Query(description="可选,会话类型,例如 expense / knowledge。")] = None,
) -> ConversationDeleteResponse:
deleted_count = AgentConversationService(db).delete_user_conversations(
user_id=user_id,
source="user_message",
session_type=session_type,
)
return ConversationDeleteResponse(deleted_count=deleted_count)

View File

@@ -11,6 +11,7 @@ from app.models.agent_conversation import AgentConversation, AgentConversationMe
from app.services.settings import SettingsService
STATEFUL_CONTEXT_KEYS = (
"session_type",
"entry_source",
"request_context",
"attachment_names",
@@ -37,10 +38,16 @@ class AgentConversationService:
normalized_id = str(conversation_id or "").strip()
normalized_user_id = str(user_id or "").strip() or None
incoming_session_type = str(context_json.get("session_type") or "").strip() or "expense"
conversation = self.get_conversation(normalized_id) if normalized_id else None
if conversation is not None and conversation.user_id != normalized_user_id:
normalized_id = ""
conversation = None
if conversation is not None:
existing_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if existing_session_type != incoming_session_type:
normalized_id = ""
conversation = None
if conversation is None:
conversation = AgentConversation(
@@ -117,6 +124,7 @@ class AgentConversationService:
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
) -> AgentConversation | None:
self.prune_expired_conversations()
@@ -128,7 +136,16 @@ class AgentConversationService:
if source:
stmt = stmt.where(AgentConversation.source == source)
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
return self.db.scalar(stmt.limit(1))
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if not normalized_session_type:
return conversations[0] if conversations else None
for conversation in conversations:
current_session_type = str((conversation.state_json or {}).get("session_type") or "").strip() or "expense"
if current_session_type == normalized_session_type:
return conversation
return None
def hydrate_context_json(
self,
@@ -285,6 +302,7 @@ class AgentConversationService:
*,
user_id: str | None,
source: str | None = "user_message",
session_type: str | None = None,
) -> int:
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
@@ -294,6 +312,14 @@ class AgentConversationService:
if source:
stmt = stmt.where(AgentConversation.source == source)
conversations = list(self.db.scalars(stmt).all())
normalized_session_type = str(session_type or "").strip()
if normalized_session_type:
conversations = [
conversation
for conversation in conversations
if (str((conversation.state_json or {}).get("session_type") or "").strip() or "expense")
== normalized_session_type
]
if not conversations:
return 0
@@ -303,6 +329,33 @@ class AgentConversationService:
self.db.commit()
return len(conversations)
def delete_conversation(
self,
*,
conversation_id: str | None,
user_id: str | None = None,
source: str | None = "user_message",
) -> int:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return 0
conversation = self.get_conversation(normalized_id)
if conversation is None:
return 0
normalized_user_id = str(user_id or "").strip()
if normalized_user_id and str(conversation.user_id or "").strip() != normalized_user_id:
return 0
normalized_source = str(source or "").strip()
if normalized_source and str(conversation.source or "").strip() != normalized_source:
return 0
self.db.delete(conversation)
self.db.commit()
return 1
def serialize_conversation(
self,
conversation: AgentConversation,

View File

@@ -5,7 +5,7 @@ from datetime import UTC, datetime
from time import perf_counter
from typing import Any
from sqlalchemy import func, select
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.core.agent_enums import (
@@ -18,6 +18,7 @@ from app.core.agent_enums import (
AgentToolType,
)
from app.core.logging import get_logger
from app.models.employee import Employee
from app.models.financial_record import (
AccountsPayableRecord,
AccountsReceivableRecord,
@@ -59,6 +60,10 @@ class ExecutionOutcome:
failed_tool_count: int
PRIVILEGED_EXPENSE_QUERY_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
SELF_REFERENCE_KEYWORDS = ("我的", "我自己", "本人", "我名下", "给我查", "我提交", "我申请")
class OrchestratorService:
def __init__(self, db: Session) -> None:
self.db = db
@@ -497,7 +502,12 @@ class OrchestratorService:
tool_name=self._database_tool_name(ontology.scenario),
request_json=self._build_ontology_json(ontology),
context_json=context_json,
executor=lambda: self._build_database_answer(ontology),
executor=lambda: self._build_database_answer(
ontology,
user_id=payload.user_id,
context_json=context_json,
message=payload.message or "",
),
fallback_factory=lambda exc: {
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
"degraded": True,
@@ -831,25 +841,56 @@ class OrchestratorService:
if expected == tool_type.lower():
raise RuntimeError(f"simulated {tool_type} failure")
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
def _build_database_answer(
self,
ontology: OntologyParseResult,
*,
user_id: str | None,
context_json: dict[str, Any],
message: str,
) -> dict[str, Any]:
if ontology.scenario == "expense":
count_stmt = select(func.count()).select_from(ExpenseClaim)
amount_stmt = select(
func.coalesce(func.sum(ExpenseClaim.amount), 0)
).select_from(ExpenseClaim)
employee_names = [
item.normalized_value
for item in ontology.entities
if item.type == "employee"
]
if employee_names:
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim)
preview_stmt = (
select(ExpenseClaim)
.order_by(ExpenseClaim.occurred_at.desc(), ExpenseClaim.created_at.desc())
.limit(5)
)
conditions, scope_label, scoped_to_current_user = self._build_expense_query_scope(
ontology=ontology,
user_id=user_id,
context_json=context_json,
message=message,
)
for condition in conditions:
count_stmt = count_stmt.where(condition)
amount_stmt = amount_stmt.where(condition)
preview_stmt = preview_stmt.where(condition)
total_count = int(self.db.scalar(count_stmt) or 0)
total_amount = float(self.db.scalar(amount_stmt) or 0)
preview_claims = list(self.db.scalars(preview_stmt).all())
return {
"record_count": total_count,
"total_amount": round(total_amount, 2),
"scope_label": scope_label,
"scoped_to_current_user": scoped_to_current_user,
"records": [
{
"claim_id": claim.id,
"claim_no": claim.claim_no,
"employee_name": claim.employee_name,
"expense_type": claim.expense_type,
"amount": round(float(claim.amount), 2),
"status": claim.status,
"approval_stage": claim.approval_stage,
"occurred_at": claim.occurred_at.date().isoformat() if claim.occurred_at else "",
"reason": claim.reason,
"location": claim.location,
}
for claim in preview_claims
],
"has_more": total_count > len(preview_claims),
}
if ontology.scenario == "accounts_receivable":
@@ -885,6 +926,183 @@ class OrchestratorService:
"outstanding_amount": round(total_amount, 2),
}
def _build_expense_query_scope(
self,
*,
ontology: OntologyParseResult,
user_id: str | None,
context_json: dict[str, Any],
message: str,
) -> tuple[list[Any], str, bool]:
conditions: list[Any] = []
explicit_employee_names = list(
dict.fromkeys(
str(item.value or "").strip()
for item in ontology.entities
if item.type == "employee" and str(item.value or "").strip()
)
)
expense_claim_nos = list(
dict.fromkeys(
str(item.normalized_value or item.value or "").strip().upper()
for item in ontology.entities
if item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip()
)
)
expense_types = list(
dict.fromkeys(
str(item.normalized_value or item.value or "").strip()
for item in ontology.entities
if item.type == "expense_type" and str(item.normalized_value or item.value or "").strip()
)
)
status_values = list(
dict.fromkeys(
str(item.value).strip()
for item in ontology.constraints
if item.field == "status" and item.operator == "=" and str(item.value).strip()
)
)
amount_constraints = [
item
for item in ontology.constraints
if item.field == "amount" and item.operator in {">", ">=", "<", "<=", "="}
]
scope_label = "报销单"
scoped_to_current_user = False
if expense_claim_nos:
conditions.append(ExpenseClaim.claim_no.in_(expense_claim_nos))
if expense_types:
conditions.append(ExpenseClaim.expense_type.in_(expense_types))
if status_values:
conditions.append(ExpenseClaim.status.in_(status_values))
for item in amount_constraints:
amount_value = float(item.value)
if item.operator == ">":
conditions.append(ExpenseClaim.amount > amount_value)
elif item.operator == ">=":
conditions.append(ExpenseClaim.amount >= amount_value)
elif item.operator == "<":
conditions.append(ExpenseClaim.amount < amount_value)
elif item.operator == "<=":
conditions.append(ExpenseClaim.amount <= amount_value)
else:
conditions.append(ExpenseClaim.amount == amount_value)
if ontology.time_range.start_date:
conditions.append(
ExpenseClaim.occurred_at
>= datetime.fromisoformat(f"{ontology.time_range.start_date}T00:00:00+00:00")
)
if ontology.time_range.end_date:
conditions.append(
ExpenseClaim.occurred_at
<= datetime.fromisoformat(f"{ontology.time_range.end_date}T23:59:59.999999+00:00")
)
has_privileged_access = self._has_privileged_expense_query_access(context_json)
refers_to_self = self._is_self_expense_query(message)
if not has_privileged_access:
owner_conditions, owner_label = self._build_current_user_claim_conditions(
user_id=user_id,
context_json=context_json,
)
if owner_conditions:
conditions.append(or_(*owner_conditions))
scope_label = owner_label
scoped_to_current_user = True
else:
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
scope_label = "你的报销单"
scoped_to_current_user = True
elif explicit_employee_names:
conditions.append(ExpenseClaim.employee_name.in_(explicit_employee_names))
scope_label = f"{''.join(explicit_employee_names)}的报销单"
elif refers_to_self:
owner_conditions, owner_label = self._build_current_user_claim_conditions(
user_id=user_id,
context_json=context_json,
)
if owner_conditions:
conditions.append(or_(*owner_conditions))
scope_label = owner_label
scoped_to_current_user = True
else:
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
scope_label = "你的报销单"
scoped_to_current_user = True
else:
scope_label = "全部报销单"
return conditions, scope_label, scoped_to_current_user
def _build_current_user_claim_conditions(
self,
*,
user_id: str | None,
context_json: dict[str, Any],
) -> tuple[list[Any], str]:
normalized_user_id = str(user_id or "").strip()
display_name = str(context_json.get("name") or "").strip()
employee = None
if normalized_user_id:
employee = self.db.scalar(
select(Employee)
.where(func.lower(Employee.email) == normalized_user_id.lower())
.limit(1)
)
conditions: list[Any] = []
seen: set[tuple[str, str]] = set()
def add_condition(field_name: str, value: str | None) -> None:
normalized = str(value or "").strip()
if not normalized:
return
marker = (field_name, normalized.lower())
if marker in seen:
return
seen.add(marker)
if field_name == "employee_id":
conditions.append(ExpenseClaim.employee_id == normalized)
return
conditions.append(ExpenseClaim.employee_name == normalized)
if employee is not None:
add_condition("employee_id", employee.id)
add_condition("employee_name", employee.name)
add_condition("employee_name", employee.email)
if not display_name:
display_name = employee.name
add_condition("employee_name", display_name)
add_condition("employee_name", normalized_user_id)
subject_name = display_name or (employee.name if employee is not None else "") or normalized_user_id
if subject_name:
return conditions, "你的报销单"
return conditions, "当前用户的报销单"
@staticmethod
def _has_privileged_expense_query_access(context_json: dict[str, Any]) -> bool:
if bool(context_json.get("is_admin")):
return True
role_codes = {
str(item).strip().lower()
for item in context_json.get("role_codes", [])
if str(item).strip()
}
return bool(role_codes & PRIVILEGED_EXPENSE_QUERY_ROLE_CODES)
@staticmethod
def _is_self_expense_query(message: str) -> bool:
compact_message = "".join(str(message or "").split())
return any(keyword in compact_message for keyword in SELF_REFERENCE_KEYWORDS)
@staticmethod
def _build_user_query_result(
ontology: OntologyParseResult,

View File

@@ -85,6 +85,15 @@ GROUP_SCENE_LABELS = {
"other": "其他费用",
}
EXPENSE_STATUS_LABELS = {
"draft": "草稿",
"submitted": "已提交",
"review": "审核中",
"approved": "已通过",
"rejected": "已驳回",
"paid": "已付款",
}
SLOT_LABELS = {
"expense_type": "报销类型",
"customer_name": "客户名称",
@@ -389,10 +398,41 @@ class UserAgentService:
if scenario == "expense":
record_count = int(data.get("record_count") or 0)
total_amount = float(data.get("total_amount") or 0)
return (
f"{subject}共命中 {record_count} 笔报销,金额合计 {total_amount:.2f} 元。"
"如需继续处理,可以查看明细或生成处理意见草稿。"
)
scope_label = str(data.get("scope_label") or subject).strip() or subject
preview_records = data.get("records")
if record_count <= 0:
return f"当前没有查到{scope_label}。你可以补充时间范围、单号或状态继续筛选。"
summary = f"查到{scope_label}{record_count} 笔,金额合计 {total_amount:.2f} 元。"
if not isinstance(preview_records, list) or not preview_records:
return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。"
preview_text: list[str] = []
for item in preview_records[:3]:
if not isinstance(item, dict):
continue
claim_no = str(item.get("claim_no") or "未编号").strip() or "未编号"
occurred_at = str(item.get("occurred_at") or "").strip()
expense_type = EXPENSE_TYPE_LABELS.get(
str(item.get("expense_type") or "").strip(),
str(item.get("expense_type") or "报销").strip() or "报销",
)
amount = float(item.get("amount") or 0)
status = EXPENSE_STATUS_LABELS.get(
str(item.get("status") or "").strip(),
str(item.get("status") or "处理中").strip() or "处理中",
)
date_prefix = f"{occurred_at}" if occurred_at else ""
preview_text.append(
f"{claim_no}{date_prefix}{expense_type}{amount:.2f} 元,{status}"
)
if not preview_text:
return f"{summary} 如需继续处理,可以查看明细或生成处理意见草稿。"
has_more = bool(data.get("has_more")) or record_count > len(preview_records)
more_hint = " 当前先展示最近几笔,可继续查看明细。" if has_more else ""
return f"{summary} 其中包括:{''.join(preview_text)}{more_hint}".strip()
if scenario == "accounts_receivable":
record_count = int(data.get("record_count") or 0)
@@ -1249,6 +1289,8 @@ class UserAgentService:
payload: UserAgentRequest,
review_payload: UserAgentReviewPayload | None,
) -> bool:
if payload.ontology.scenario == "expense" and payload.ontology.intent in {"query", "compare"}:
return True
if review_payload is None:
return False
return payload.ontology.scenario == "expense" and (