refactor(backend): update services and tests
- services/expense_claims.py: update expense claims service - services/user_agent.py: update user agent service - tests/test_orchestrator_service.py: update orchestrator service tests - tests/test_user_agent_service.py: update user agent service tests
This commit is contained in:
@@ -25,6 +25,7 @@ EXPENSE_TYPE_LABELS = {
|
||||
}
|
||||
|
||||
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
|
||||
MAX_DRAFT_CLAIMS_PER_USER = 3
|
||||
|
||||
|
||||
class ExpenseClaimService:
|
||||
@@ -117,7 +118,7 @@ class ExpenseClaimService:
|
||||
|
||||
before_json = self._serialize_claim(claim)
|
||||
claim.status = "submitted"
|
||||
claim.approval_stage = "审批流转"
|
||||
claim.approval_stage = "AI验审"
|
||||
claim.submitted_at = datetime.now(UTC)
|
||||
|
||||
self.db.commit()
|
||||
@@ -172,7 +173,39 @@ class ExpenseClaimService:
|
||||
is_new_claim = claim is None
|
||||
before_json = self._serialize_claim(claim) if claim is not None else None
|
||||
|
||||
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
|
||||
employee = self._resolve_employee(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=user_id,
|
||||
)
|
||||
draft_owner_name = (
|
||||
employee.name
|
||||
if employee is not None
|
||||
else self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
if is_new_claim:
|
||||
existing_draft_count = self._count_draft_claims_for_owner(
|
||||
employee=employee,
|
||||
employee_name=draft_owner_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER:
|
||||
return {
|
||||
"message": (
|
||||
f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿,"
|
||||
"才能再次新建草稿。"
|
||||
),
|
||||
"draft_limit_reached": True,
|
||||
"draft_only": False,
|
||||
"status": "blocked",
|
||||
"draft_count": existing_draft_count,
|
||||
"max_draft_count": MAX_DRAFT_CLAIMS_PER_USER,
|
||||
}
|
||||
|
||||
amount = self._resolve_amount(ontology.entities, context_json=context_json)
|
||||
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
|
||||
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
|
||||
@@ -202,11 +235,7 @@ class ExpenseClaimService:
|
||||
claim = ExpenseClaim(
|
||||
claim_no=self._generate_claim_no(final_occurred_at),
|
||||
employee_id=employee.id if employee is not None else None,
|
||||
employee_name=employee.name if employee is not None else self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
user_id=user_id,
|
||||
),
|
||||
employee_name=draft_owner_name,
|
||||
department_id=employee.organization_unit_id if employee is not None else None,
|
||||
department_name=self._resolve_department_name(
|
||||
employee=employee,
|
||||
@@ -221,7 +250,7 @@ class ExpenseClaimService:
|
||||
invoice_count=final_attachment_count,
|
||||
occurred_at=final_occurred_at,
|
||||
status="draft",
|
||||
approval_stage="待补充",
|
||||
approval_stage="待提交",
|
||||
risk_flags_json=final_risk_flags,
|
||||
)
|
||||
self.db.add(claim)
|
||||
@@ -251,7 +280,7 @@ class ExpenseClaimService:
|
||||
claim.invoice_count = final_attachment_count
|
||||
claim.occurred_at = final_occurred_at
|
||||
claim.status = "draft"
|
||||
claim.approval_stage = "待补充"
|
||||
claim.approval_stage = "待提交"
|
||||
claim.risk_flags_json = final_risk_flags
|
||||
|
||||
self.db.flush()
|
||||
@@ -355,12 +384,77 @@ class ExpenseClaimService:
|
||||
)
|
||||
return f"{prefix}{existing + 1:03d}"
|
||||
|
||||
def _count_draft_claims_for_owner(
|
||||
self,
|
||||
*,
|
||||
employee: Employee | None,
|
||||
employee_name: str,
|
||||
user_id: str | None,
|
||||
) -> int:
|
||||
owner_filters = self._build_draft_owner_filters(
|
||||
employee=employee,
|
||||
employee_name=employee_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not owner_filters:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(ExpenseClaim)
|
||||
.where(ExpenseClaim.status == "draft")
|
||||
.where(or_(*owner_filters))
|
||||
)
|
||||
return int(self.db.scalar(stmt) or 0)
|
||||
|
||||
@staticmethod
|
||||
def _build_draft_owner_filters(
|
||||
*,
|
||||
employee: Employee | None,
|
||||
employee_name: str,
|
||||
user_id: str | None,
|
||||
) -> list[Any]:
|
||||
conditions: list[Any] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
|
||||
def add_condition(field_name: str, value: str | None) -> None:
|
||||
normalized = str(value or "").strip()
|
||||
if not normalized or normalized == "待补充":
|
||||
return
|
||||
|
||||
marker = (field_name, normalized.lower())
|
||||
if marker in seen:
|
||||
return
|
||||
seen.add(marker)
|
||||
|
||||
if field_name == "employee_id":
|
||||
conditions.append(ExpenseClaim.employee_id == normalized)
|
||||
return
|
||||
conditions.append(ExpenseClaim.employee_name == normalized)
|
||||
|
||||
if employee is not None:
|
||||
add_condition("employee_id", employee.id)
|
||||
add_condition("employee_name", employee.name)
|
||||
add_condition("employee_name", employee.email)
|
||||
|
||||
add_condition("employee_name", employee_name)
|
||||
add_condition("employee_name", user_id)
|
||||
return conditions
|
||||
|
||||
def _resolve_employee(
|
||||
self,
|
||||
*,
|
||||
ontology: OntologyParseResult,
|
||||
context_json: dict[str, Any],
|
||||
user_id: str | None,
|
||||
) -> Employee | None:
|
||||
normalized_user_id = str(user_id or "").strip()
|
||||
if normalized_user_id:
|
||||
stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1)
|
||||
employee = self.db.scalar(stmt)
|
||||
if employee is not None:
|
||||
return employee
|
||||
|
||||
employee_name = self._resolve_employee_name(
|
||||
ontology=ontology,
|
||||
context_json=context_json,
|
||||
|
||||
@@ -207,6 +207,8 @@ class UserAgentService:
|
||||
|
||||
if payload.ontology.intent == "draft":
|
||||
tool_message = str(payload.tool_payload.get("message") or "").strip()
|
||||
if payload.tool_payload.get("draft_limit_reached"):
|
||||
return tool_message or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||
if tool_message and (
|
||||
str(payload.tool_payload.get("claim_id") or "").strip()
|
||||
or str(payload.tool_payload.get("claim_no") or "").strip()
|
||||
@@ -988,6 +990,11 @@ class UserAgentService:
|
||||
return None
|
||||
if payload.ontology.intent not in {"draft", "operate"}:
|
||||
return None
|
||||
if payload.tool_payload.get("draft_limit_reached"):
|
||||
return (
|
||||
str(payload.tool_payload.get("message") or "").strip()
|
||||
or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||
)
|
||||
|
||||
review_action = str(payload.context_json.get("review_action") or "").strip()
|
||||
if review_action == "save_draft":
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
|
||||
from app.models.employee import Employee
|
||||
from app.models.financial_record import ExpenseClaim
|
||||
from app.schemas.settings import SettingsWrite
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
@@ -183,6 +184,153 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||
assert claim.items
|
||||
|
||||
|
||||
def test_orchestrator_blocks_fourth_expense_draft_for_same_user() -> None:
|
||||
client, session_factory = build_client()
|
||||
user_id = "zhangsan@example.com"
|
||||
|
||||
with session_factory() as db:
|
||||
db.add(
|
||||
Employee(
|
||||
employee_no="E1001",
|
||||
name="张三",
|
||||
email=user_id,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
for amount, city in ((120, "上海"), (240, "北京"), (360, "深圳")):
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": user_id,
|
||||
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"name": "张三",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
|
||||
|
||||
blocked_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": user_id,
|
||||
"message": "帮我生成报销草稿,我昨天去杭州出差,交通费480元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"name": "张三",
|
||||
"review_action": "save_draft",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert blocked_response.status_code == 200
|
||||
blocked_payload = blocked_response.json()
|
||||
assert blocked_payload["status"] == "succeeded"
|
||||
assert "你当前已保存 3 个草稿" in blocked_payload["result"]["answer"]
|
||||
assert blocked_payload["result"]["draft_payload"]["claim_id"] is None
|
||||
assert blocked_payload["result"]["draft_payload"]["claim_no"] is None
|
||||
assert blocked_payload["result"]["draft_payload"]["status"] == "blocked"
|
||||
|
||||
with session_factory() as db:
|
||||
draft_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(ExpenseClaim)
|
||||
.where(ExpenseClaim.status == "draft")
|
||||
)
|
||||
assert draft_count == 3
|
||||
|
||||
|
||||
def test_orchestrator_allows_existing_draft_update_when_user_already_has_three_drafts() -> None:
|
||||
client, session_factory = build_client()
|
||||
user_id = "lisi@example.com"
|
||||
|
||||
with session_factory() as db:
|
||||
db.add(
|
||||
Employee(
|
||||
employee_no="E1002",
|
||||
name="李四",
|
||||
email=user_id,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
first_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": user_id,
|
||||
"message": "帮我生成报销草稿,我昨天去上海出差,交通费120元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"name": "李四",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert first_response.status_code == 200
|
||||
first_payload = first_response.json()
|
||||
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
|
||||
conversation_id = first_payload["conversation_id"]
|
||||
assert claim_id
|
||||
assert conversation_id
|
||||
|
||||
for amount, city in ((240, "北京"), (360, "深圳")):
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": user_id,
|
||||
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"name": "李四",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
|
||||
|
||||
update_response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"message": "金额改成888元",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"name": "李四",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert update_response.status_code == 200
|
||||
update_payload = update_response.json()
|
||||
assert update_payload["result"]["draft_payload"]["claim_id"] == claim_id
|
||||
|
||||
with session_factory() as db:
|
||||
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
|
||||
assert claim is not None
|
||||
assert float(claim.amount) == 888.0
|
||||
|
||||
draft_count = db.scalar(
|
||||
select(func.count())
|
||||
.select_from(ExpenseClaim)
|
||||
.where(ExpenseClaim.employee_id == claim.employee_id)
|
||||
.where(ExpenseClaim.status == "draft")
|
||||
)
|
||||
assert draft_count == 3
|
||||
|
||||
|
||||
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
|
||||
@@ -246,6 +246,36 @@ def test_user_agent_draft_returns_structured_payload() -> None:
|
||||
assert response.answer == response.review_payload.body_message
|
||||
|
||||
|
||||
def test_user_agent_returns_draft_limit_message_when_save_is_blocked() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="请按当前识别信息保存报销草稿",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="请按当前识别信息保存报销草稿",
|
||||
ontology=ontology,
|
||||
context_json={"review_action": "save_draft"},
|
||||
tool_payload={
|
||||
"draft_limit_reached": True,
|
||||
"message": "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。",
|
||||
"status": "blocked",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
response.answer
|
||||
== "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||
)
|
||||
|
||||
|
||||
def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
|
||||
Reference in New Issue
Block a user