refactor(backend): update orchestrator endpoint and services

- endpoints/orchestrator.py: update orchestrator API endpoint
- services/agent_conversations.py: update agent conversations service
- services/orchestrator.py: update orchestrator service
- services/user_agent.py: update user agent service
This commit is contained in:
caoxiaozhu
2026-05-13 13:06:52 +00:00
parent 0f7bd43ce3
commit 70cff69b7f
4 changed files with 359 additions and 20 deletions

View File

@@ -5,7 +5,7 @@ from datetime import UTC, datetime
from time import perf_counter
from typing import Any
from sqlalchemy import func, select
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.core.agent_enums import (
@@ -18,6 +18,7 @@ from app.core.agent_enums import (
AgentToolType,
)
from app.core.logging import get_logger
from app.models.employee import Employee
from app.models.financial_record import (
AccountsPayableRecord,
AccountsReceivableRecord,
@@ -59,6 +60,10 @@ class ExecutionOutcome:
failed_tool_count: int
PRIVILEGED_EXPENSE_QUERY_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
SELF_REFERENCE_KEYWORDS = ("我的", "我自己", "本人", "我名下", "给我查", "我提交", "我申请")
class OrchestratorService:
def __init__(self, db: Session) -> None:
self.db = db
@@ -497,7 +502,12 @@ class OrchestratorService:
tool_name=self._database_tool_name(ontology.scenario),
request_json=self._build_ontology_json(ontology),
context_json=context_json,
executor=lambda: self._build_database_answer(ontology),
executor=lambda: self._build_database_answer(
ontology,
user_id=payload.user_id,
context_json=context_json,
message=payload.message or "",
),
fallback_factory=lambda exc: {
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
"degraded": True,
@@ -831,25 +841,56 @@ class OrchestratorService:
if expected == tool_type.lower():
raise RuntimeError(f"simulated {tool_type} failure")
def _build_database_answer(self, ontology: OntologyParseResult) -> dict[str, Any]:
def _build_database_answer(
self,
ontology: OntologyParseResult,
*,
user_id: str | None,
context_json: dict[str, Any],
message: str,
) -> dict[str, Any]:
if ontology.scenario == "expense":
count_stmt = select(func.count()).select_from(ExpenseClaim)
amount_stmt = select(
func.coalesce(func.sum(ExpenseClaim.amount), 0)
).select_from(ExpenseClaim)
employee_names = [
item.normalized_value
for item in ontology.entities
if item.type == "employee"
]
if employee_names:
count_stmt = count_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
amount_stmt = amount_stmt.where(ExpenseClaim.employee_name.in_(employee_names))
amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim)
preview_stmt = (
select(ExpenseClaim)
.order_by(ExpenseClaim.occurred_at.desc(), ExpenseClaim.created_at.desc())
.limit(5)
)
conditions, scope_label, scoped_to_current_user = self._build_expense_query_scope(
ontology=ontology,
user_id=user_id,
context_json=context_json,
message=message,
)
for condition in conditions:
count_stmt = count_stmt.where(condition)
amount_stmt = amount_stmt.where(condition)
preview_stmt = preview_stmt.where(condition)
total_count = int(self.db.scalar(count_stmt) or 0)
total_amount = float(self.db.scalar(amount_stmt) or 0)
preview_claims = list(self.db.scalars(preview_stmt).all())
return {
"record_count": total_count,
"total_amount": round(total_amount, 2),
"scope_label": scope_label,
"scoped_to_current_user": scoped_to_current_user,
"records": [
{
"claim_id": claim.id,
"claim_no": claim.claim_no,
"employee_name": claim.employee_name,
"expense_type": claim.expense_type,
"amount": round(float(claim.amount), 2),
"status": claim.status,
"approval_stage": claim.approval_stage,
"occurred_at": claim.occurred_at.date().isoformat() if claim.occurred_at else "",
"reason": claim.reason,
"location": claim.location,
}
for claim in preview_claims
],
"has_more": total_count > len(preview_claims),
}
if ontology.scenario == "accounts_receivable":
@@ -885,6 +926,183 @@ class OrchestratorService:
"outstanding_amount": round(total_amount, 2),
}
def _build_expense_query_scope(
self,
*,
ontology: OntologyParseResult,
user_id: str | None,
context_json: dict[str, Any],
message: str,
) -> tuple[list[Any], str, bool]:
conditions: list[Any] = []
explicit_employee_names = list(
dict.fromkeys(
str(item.value or "").strip()
for item in ontology.entities
if item.type == "employee" and str(item.value or "").strip()
)
)
expense_claim_nos = list(
dict.fromkeys(
str(item.normalized_value or item.value or "").strip().upper()
for item in ontology.entities
if item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip()
)
)
expense_types = list(
dict.fromkeys(
str(item.normalized_value or item.value or "").strip()
for item in ontology.entities
if item.type == "expense_type" and str(item.normalized_value or item.value or "").strip()
)
)
status_values = list(
dict.fromkeys(
str(item.value).strip()
for item in ontology.constraints
if item.field == "status" and item.operator == "=" and str(item.value).strip()
)
)
amount_constraints = [
item
for item in ontology.constraints
if item.field == "amount" and item.operator in {">", ">=", "<", "<=", "="}
]
scope_label = "报销单"
scoped_to_current_user = False
if expense_claim_nos:
conditions.append(ExpenseClaim.claim_no.in_(expense_claim_nos))
if expense_types:
conditions.append(ExpenseClaim.expense_type.in_(expense_types))
if status_values:
conditions.append(ExpenseClaim.status.in_(status_values))
for item in amount_constraints:
amount_value = float(item.value)
if item.operator == ">":
conditions.append(ExpenseClaim.amount > amount_value)
elif item.operator == ">=":
conditions.append(ExpenseClaim.amount >= amount_value)
elif item.operator == "<":
conditions.append(ExpenseClaim.amount < amount_value)
elif item.operator == "<=":
conditions.append(ExpenseClaim.amount <= amount_value)
else:
conditions.append(ExpenseClaim.amount == amount_value)
if ontology.time_range.start_date:
conditions.append(
ExpenseClaim.occurred_at
>= datetime.fromisoformat(f"{ontology.time_range.start_date}T00:00:00+00:00")
)
if ontology.time_range.end_date:
conditions.append(
ExpenseClaim.occurred_at
<= datetime.fromisoformat(f"{ontology.time_range.end_date}T23:59:59.999999+00:00")
)
has_privileged_access = self._has_privileged_expense_query_access(context_json)
refers_to_self = self._is_self_expense_query(message)
if not has_privileged_access:
owner_conditions, owner_label = self._build_current_user_claim_conditions(
user_id=user_id,
context_json=context_json,
)
if owner_conditions:
conditions.append(or_(*owner_conditions))
scope_label = owner_label
scoped_to_current_user = True
else:
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
scope_label = "你的报销单"
scoped_to_current_user = True
elif explicit_employee_names:
conditions.append(ExpenseClaim.employee_name.in_(explicit_employee_names))
scope_label = f"{''.join(explicit_employee_names)}的报销单"
elif refers_to_self:
owner_conditions, owner_label = self._build_current_user_claim_conditions(
user_id=user_id,
context_json=context_json,
)
if owner_conditions:
conditions.append(or_(*owner_conditions))
scope_label = owner_label
scoped_to_current_user = True
else:
conditions.append(ExpenseClaim.id == "__no_visible_claim__")
scope_label = "你的报销单"
scoped_to_current_user = True
else:
scope_label = "全部报销单"
return conditions, scope_label, scoped_to_current_user
def _build_current_user_claim_conditions(
self,
*,
user_id: str | None,
context_json: dict[str, Any],
) -> tuple[list[Any], str]:
normalized_user_id = str(user_id or "").strip()
display_name = str(context_json.get("name") or "").strip()
employee = None
if normalized_user_id:
employee = self.db.scalar(
select(Employee)
.where(func.lower(Employee.email) == normalized_user_id.lower())
.limit(1)
)
conditions: list[Any] = []
seen: set[tuple[str, str]] = set()
def add_condition(field_name: str, value: str | None) -> None:
normalized = str(value or "").strip()
if not normalized:
return
marker = (field_name, normalized.lower())
if marker in seen:
return
seen.add(marker)
if field_name == "employee_id":
conditions.append(ExpenseClaim.employee_id == normalized)
return
conditions.append(ExpenseClaim.employee_name == normalized)
if employee is not None:
add_condition("employee_id", employee.id)
add_condition("employee_name", employee.name)
add_condition("employee_name", employee.email)
if not display_name:
display_name = employee.name
add_condition("employee_name", display_name)
add_condition("employee_name", normalized_user_id)
subject_name = display_name or (employee.name if employee is not None else "") or normalized_user_id
if subject_name:
return conditions, "你的报销单"
return conditions, "当前用户的报销单"
@staticmethod
def _has_privileged_expense_query_access(context_json: dict[str, Any]) -> bool:
if bool(context_json.get("is_admin")):
return True
role_codes = {
str(item).strip().lower()
for item in context_json.get("role_codes", [])
if str(item).strip()
}
return bool(role_codes & PRIVILEGED_EXPENSE_QUERY_ROLE_CODES)
@staticmethod
def _is_self_expense_query(message: str) -> bool:
compact_message = "".join(str(message or "").split())
return any(keyword in compact_message for keyword in SELF_REFERENCE_KEYWORDS)
@staticmethod
def _build_user_query_result(
ontology: OntologyParseResult,