refactor(backend): update services and tests

- services/expense_claims.py: update expense claims service
- services/user_agent.py: update user agent service
- tests/test_orchestrator_service.py: update orchestrator service tests
- tests/test_user_agent_service.py: update user agent service tests
This commit is contained in:
caoxiaozhu
2026-05-13 03:39:41 +00:00
parent fae9966a11
commit 4db5e8ec16
4 changed files with 288 additions and 9 deletions

View File

@@ -25,6 +25,7 @@ EXPENSE_TYPE_LABELS = {
}
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
MAX_DRAFT_CLAIMS_PER_USER = 3
class ExpenseClaimService:
@@ -117,7 +118,7 @@ class ExpenseClaimService:
before_json = self._serialize_claim(claim)
claim.status = "submitted"
claim.approval_stage = "批流转"
claim.approval_stage = "AI验"
claim.submitted_at = datetime.now(UTC)
self.db.commit()
@@ -172,7 +173,39 @@ class ExpenseClaimService:
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
employee = self._resolve_employee(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
draft_owner_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
)
if is_new_claim:
existing_draft_count = self._count_draft_claims_for_owner(
employee=employee,
employee_name=draft_owner_name,
user_id=user_id,
)
if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER:
return {
"message": (
f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿,"
"才能再次新建草稿。"
),
"draft_limit_reached": True,
"draft_only": False,
"status": "blocked",
"draft_count": existing_draft_count,
"max_draft_count": MAX_DRAFT_CLAIMS_PER_USER,
}
amount = self._resolve_amount(ontology.entities, context_json=context_json)
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
@@ -202,11 +235,7 @@ class ExpenseClaimService:
claim = ExpenseClaim(
claim_no=self._generate_claim_no(final_occurred_at),
employee_id=employee.id if employee is not None else None,
employee_name=employee.name if employee is not None else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
),
employee_name=draft_owner_name,
department_id=employee.organization_unit_id if employee is not None else None,
department_name=self._resolve_department_name(
employee=employee,
@@ -221,7 +250,7 @@ class ExpenseClaimService:
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="补充",
approval_stage="提交",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
@@ -251,7 +280,7 @@ class ExpenseClaimService:
claim.invoice_count = final_attachment_count
claim.occurred_at = final_occurred_at
claim.status = "draft"
claim.approval_stage = "补充"
claim.approval_stage = "提交"
claim.risk_flags_json = final_risk_flags
self.db.flush()
@@ -355,12 +384,77 @@ class ExpenseClaimService:
)
return f"{prefix}{existing + 1:03d}"
def _count_draft_claims_for_owner(
self,
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> int:
owner_filters = self._build_draft_owner_filters(
employee=employee,
employee_name=employee_name,
user_id=user_id,
)
if not owner_filters:
return 0
stmt = (
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
.where(or_(*owner_filters))
)
return int(self.db.scalar(stmt) or 0)
@staticmethod
def _build_draft_owner_filters(
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> list[Any]:
conditions: list[Any] = []
seen: set[tuple[str, str]] = set()
def add_condition(field_name: str, value: str | None) -> None:
normalized = str(value or "").strip()
if not normalized or normalized == "待补充":
return
marker = (field_name, normalized.lower())
if marker in seen:
return
seen.add(marker)
if field_name == "employee_id":
conditions.append(ExpenseClaim.employee_id == normalized)
return
conditions.append(ExpenseClaim.employee_name == normalized)
if employee is not None:
add_condition("employee_id", employee.id)
add_condition("employee_name", employee.name)
add_condition("employee_name", employee.email)
add_condition("employee_name", employee_name)
add_condition("employee_name", user_id)
return conditions
def _resolve_employee(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
) -> Employee | None:
normalized_user_id = str(user_id or "").strip()
if normalized_user_id:
stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1)
employee = self.db.scalar(stmt)
if employee is not None:
return employee
employee_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,