refactor(backend): update services and tests

- services/expense_claims.py: update expense claims service
- services/user_agent.py: update user agent service
- tests/test_orchestrator_service.py: update orchestrator service tests
- tests/test_user_agent_service.py: update user agent service tests
This commit is contained in:
caoxiaozhu
2026-05-13 03:39:41 +00:00
parent fae9966a11
commit 4db5e8ec16
4 changed files with 288 additions and 9 deletions

View File

@@ -25,6 +25,7 @@ EXPENSE_TYPE_LABELS = {
}
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
MAX_DRAFT_CLAIMS_PER_USER = 3
class ExpenseClaimService:
@@ -117,7 +118,7 @@ class ExpenseClaimService:
before_json = self._serialize_claim(claim)
claim.status = "submitted"
claim.approval_stage = "批流转"
claim.approval_stage = "AI验"
claim.submitted_at = datetime.now(UTC)
self.db.commit()
@@ -172,7 +173,39 @@ class ExpenseClaimService:
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
employee = self._resolve_employee(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
draft_owner_name = (
employee.name
if employee is not None
else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
)
)
if is_new_claim:
existing_draft_count = self._count_draft_claims_for_owner(
employee=employee,
employee_name=draft_owner_name,
user_id=user_id,
)
if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER:
return {
"message": (
f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿,"
"才能再次新建草稿。"
),
"draft_limit_reached": True,
"draft_only": False,
"status": "blocked",
"draft_count": existing_draft_count,
"max_draft_count": MAX_DRAFT_CLAIMS_PER_USER,
}
amount = self._resolve_amount(ontology.entities, context_json=context_json)
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
@@ -202,11 +235,7 @@ class ExpenseClaimService:
claim = ExpenseClaim(
claim_no=self._generate_claim_no(final_occurred_at),
employee_id=employee.id if employee is not None else None,
employee_name=employee.name if employee is not None else self._resolve_employee_name(
ontology=ontology,
context_json=context_json,
user_id=user_id,
),
employee_name=draft_owner_name,
department_id=employee.organization_unit_id if employee is not None else None,
department_name=self._resolve_department_name(
employee=employee,
@@ -221,7 +250,7 @@ class ExpenseClaimService:
invoice_count=final_attachment_count,
occurred_at=final_occurred_at,
status="draft",
approval_stage="补充",
approval_stage="提交",
risk_flags_json=final_risk_flags,
)
self.db.add(claim)
@@ -251,7 +280,7 @@ class ExpenseClaimService:
claim.invoice_count = final_attachment_count
claim.occurred_at = final_occurred_at
claim.status = "draft"
claim.approval_stage = "补充"
claim.approval_stage = "提交"
claim.risk_flags_json = final_risk_flags
self.db.flush()
@@ -355,12 +384,77 @@ class ExpenseClaimService:
)
return f"{prefix}{existing + 1:03d}"
def _count_draft_claims_for_owner(
self,
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> int:
owner_filters = self._build_draft_owner_filters(
employee=employee,
employee_name=employee_name,
user_id=user_id,
)
if not owner_filters:
return 0
stmt = (
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
.where(or_(*owner_filters))
)
return int(self.db.scalar(stmt) or 0)
@staticmethod
def _build_draft_owner_filters(
*,
employee: Employee | None,
employee_name: str,
user_id: str | None,
) -> list[Any]:
conditions: list[Any] = []
seen: set[tuple[str, str]] = set()
def add_condition(field_name: str, value: str | None) -> None:
normalized = str(value or "").strip()
if not normalized or normalized == "待补充":
return
marker = (field_name, normalized.lower())
if marker in seen:
return
seen.add(marker)
if field_name == "employee_id":
conditions.append(ExpenseClaim.employee_id == normalized)
return
conditions.append(ExpenseClaim.employee_name == normalized)
if employee is not None:
add_condition("employee_id", employee.id)
add_condition("employee_name", employee.name)
add_condition("employee_name", employee.email)
add_condition("employee_name", employee_name)
add_condition("employee_name", user_id)
return conditions
def _resolve_employee(
self,
*,
ontology: OntologyParseResult,
context_json: dict[str, Any],
user_id: str | None,
) -> Employee | None:
normalized_user_id = str(user_id or "").strip()
if normalized_user_id:
stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1)
employee = self.db.scalar(stmt)
if employee is not None:
return employee
employee_name = self._resolve_employee_name(
ontology=ontology,
context_json=context_json,

View File

@@ -207,6 +207,8 @@ class UserAgentService:
if payload.ontology.intent == "draft":
tool_message = str(payload.tool_payload.get("message") or "").strip()
if payload.tool_payload.get("draft_limit_reached"):
return tool_message or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
if tool_message and (
str(payload.tool_payload.get("claim_id") or "").strip()
or str(payload.tool_payload.get("claim_no") or "").strip()
@@ -988,6 +990,11 @@ class UserAgentService:
return None
if payload.ontology.intent not in {"draft", "operate"}:
return None
if payload.tool_payload.get("draft_limit_reached"):
return (
str(payload.tool_payload.get("message") or "").strip()
or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
)
review_action = str(payload.context_json.get("review_action") or "").strip()
if review_action == "save_draft":

View File

@@ -12,6 +12,7 @@ from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.schemas.settings import SettingsWrite
from app.services.agent_assets import AgentAssetService
@@ -183,6 +184,153 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
assert claim.items
def test_orchestrator_blocks_fourth_expense_draft_for_same_user() -> None:
client, session_factory = build_client()
user_id = "zhangsan@example.com"
with session_factory() as db:
db.add(
Employee(
employee_no="E1001",
name="张三",
email=user_id,
)
)
db.commit()
for amount, city in ((120, "上海"), (240, "北京"), (360, "深圳")):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}",
"context_json": {
"role_codes": ["finance"],
"name": "张三",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
blocked_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": "帮我生成报销草稿我昨天去杭州出差交通费480元",
"context_json": {
"role_codes": ["finance"],
"name": "张三",
"review_action": "save_draft",
},
},
)
assert blocked_response.status_code == 200
blocked_payload = blocked_response.json()
assert blocked_payload["status"] == "succeeded"
assert "你当前已保存 3 个草稿" in blocked_payload["result"]["answer"]
assert blocked_payload["result"]["draft_payload"]["claim_id"] is None
assert blocked_payload["result"]["draft_payload"]["claim_no"] is None
assert blocked_payload["result"]["draft_payload"]["status"] == "blocked"
with session_factory() as db:
draft_count = db.scalar(
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.status == "draft")
)
assert draft_count == 3
def test_orchestrator_allows_existing_draft_update_when_user_already_has_three_drafts() -> None:
client, session_factory = build_client()
user_id = "lisi@example.com"
with session_factory() as db:
db.add(
Employee(
employee_no="E1002",
name="李四",
email=user_id,
)
)
db.commit()
first_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": "帮我生成报销草稿我昨天去上海出差交通费120元",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert first_response.status_code == 200
first_payload = first_response.json()
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
conversation_id = first_payload["conversation_id"]
assert claim_id
assert conversation_id
for amount, city in ((240, "北京"), (360, "深圳")):
response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
update_response = client.post(
"/api/v1/orchestrator/run",
json={
"source": "user_message",
"user_id": user_id,
"conversation_id": conversation_id,
"message": "金额改成888元",
"context_json": {
"role_codes": ["finance"],
"name": "李四",
},
},
)
assert update_response.status_code == 200
update_payload = update_response.json()
assert update_payload["result"]["draft_payload"]["claim_id"] == claim_id
with session_factory() as db:
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
assert claim is not None
assert float(claim.amount) == 888.0
draft_count = db.scalar(
select(func.count())
.select_from(ExpenseClaim)
.where(ExpenseClaim.employee_id == claim.employee_id)
.where(ExpenseClaim.status == "draft")
)
assert draft_count == 3
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
client, session_factory = build_client()

View File

@@ -246,6 +246,36 @@ def test_user_agent_draft_returns_structured_payload() -> None:
assert response.answer == response.review_payload.body_message
def test_user_agent_returns_draft_limit_message_when_save_is_blocked() -> None:
session_factory = build_session_factory()
with session_factory() as db:
ontology = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="请按当前识别信息保存报销草稿",
user_id="pytest",
)
)
response = UserAgentService(db).respond(
UserAgentRequest(
run_id=ontology.run_id,
user_id="pytest",
message="请按当前识别信息保存报销草稿",
ontology=ontology,
context_json={"review_action": "save_draft"},
tool_payload={
"draft_limit_reached": True,
"message": "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。",
"status": "blocked",
},
)
)
assert (
response.answer
== "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
)
def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None:
session_factory = build_session_factory()
with session_factory() as db: