refactor(backend): update services and tests
- services/expense_claims.py: update expense claims service - services/user_agent.py: update user agent service - tests/test_orchestrator_service.py: update orchestrator service tests - tests/test_user_agent_service.py: update user agent service tests
This commit is contained in:
@@ -25,6 +25,7 @@ EXPENSE_TYPE_LABELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
|
PRIVILEGED_CLAIM_ROLE_CODES = {"manager", "finance", "approver", "auditor", "executive"}
|
||||||
|
MAX_DRAFT_CLAIMS_PER_USER = 3
|
||||||
|
|
||||||
|
|
||||||
class ExpenseClaimService:
|
class ExpenseClaimService:
|
||||||
@@ -117,7 +118,7 @@ class ExpenseClaimService:
|
|||||||
|
|
||||||
before_json = self._serialize_claim(claim)
|
before_json = self._serialize_claim(claim)
|
||||||
claim.status = "submitted"
|
claim.status = "submitted"
|
||||||
claim.approval_stage = "审批流转"
|
claim.approval_stage = "AI验审"
|
||||||
claim.submitted_at = datetime.now(UTC)
|
claim.submitted_at = datetime.now(UTC)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
@@ -172,7 +173,39 @@ class ExpenseClaimService:
|
|||||||
is_new_claim = claim is None
|
is_new_claim = claim is None
|
||||||
before_json = self._serialize_claim(claim) if claim is not None else None
|
before_json = self._serialize_claim(claim) if claim is not None else None
|
||||||
|
|
||||||
employee = self._resolve_employee(ontology=ontology, context_json=context_json)
|
employee = self._resolve_employee(
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=context_json,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
draft_owner_name = (
|
||||||
|
employee.name
|
||||||
|
if employee is not None
|
||||||
|
else self._resolve_employee_name(
|
||||||
|
ontology=ontology,
|
||||||
|
context_json=context_json,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_new_claim:
|
||||||
|
existing_draft_count = self._count_draft_claims_for_owner(
|
||||||
|
employee=employee,
|
||||||
|
employee_name=draft_owner_name,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if existing_draft_count >= MAX_DRAFT_CLAIMS_PER_USER:
|
||||||
|
return {
|
||||||
|
"message": (
|
||||||
|
f"你当前已保存 {MAX_DRAFT_CLAIMS_PER_USER} 个草稿,请先完成已保存的草稿,"
|
||||||
|
"才能再次新建草稿。"
|
||||||
|
),
|
||||||
|
"draft_limit_reached": True,
|
||||||
|
"draft_only": False,
|
||||||
|
"status": "blocked",
|
||||||
|
"draft_count": existing_draft_count,
|
||||||
|
"max_draft_count": MAX_DRAFT_CLAIMS_PER_USER,
|
||||||
|
}
|
||||||
|
|
||||||
amount = self._resolve_amount(ontology.entities, context_json=context_json)
|
amount = self._resolve_amount(ontology.entities, context_json=context_json)
|
||||||
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
|
occurred_at = self._resolve_occurred_at(ontology, context_json=context_json)
|
||||||
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
|
expense_type = self._resolve_expense_type(ontology.entities, context_json=context_json)
|
||||||
@@ -202,11 +235,7 @@ class ExpenseClaimService:
|
|||||||
claim = ExpenseClaim(
|
claim = ExpenseClaim(
|
||||||
claim_no=self._generate_claim_no(final_occurred_at),
|
claim_no=self._generate_claim_no(final_occurred_at),
|
||||||
employee_id=employee.id if employee is not None else None,
|
employee_id=employee.id if employee is not None else None,
|
||||||
employee_name=employee.name if employee is not None else self._resolve_employee_name(
|
employee_name=draft_owner_name,
|
||||||
ontology=ontology,
|
|
||||||
context_json=context_json,
|
|
||||||
user_id=user_id,
|
|
||||||
),
|
|
||||||
department_id=employee.organization_unit_id if employee is not None else None,
|
department_id=employee.organization_unit_id if employee is not None else None,
|
||||||
department_name=self._resolve_department_name(
|
department_name=self._resolve_department_name(
|
||||||
employee=employee,
|
employee=employee,
|
||||||
@@ -221,7 +250,7 @@ class ExpenseClaimService:
|
|||||||
invoice_count=final_attachment_count,
|
invoice_count=final_attachment_count,
|
||||||
occurred_at=final_occurred_at,
|
occurred_at=final_occurred_at,
|
||||||
status="draft",
|
status="draft",
|
||||||
approval_stage="待补充",
|
approval_stage="待提交",
|
||||||
risk_flags_json=final_risk_flags,
|
risk_flags_json=final_risk_flags,
|
||||||
)
|
)
|
||||||
self.db.add(claim)
|
self.db.add(claim)
|
||||||
@@ -251,7 +280,7 @@ class ExpenseClaimService:
|
|||||||
claim.invoice_count = final_attachment_count
|
claim.invoice_count = final_attachment_count
|
||||||
claim.occurred_at = final_occurred_at
|
claim.occurred_at = final_occurred_at
|
||||||
claim.status = "draft"
|
claim.status = "draft"
|
||||||
claim.approval_stage = "待补充"
|
claim.approval_stage = "待提交"
|
||||||
claim.risk_flags_json = final_risk_flags
|
claim.risk_flags_json = final_risk_flags
|
||||||
|
|
||||||
self.db.flush()
|
self.db.flush()
|
||||||
@@ -355,12 +384,77 @@ class ExpenseClaimService:
|
|||||||
)
|
)
|
||||||
return f"{prefix}{existing + 1:03d}"
|
return f"{prefix}{existing + 1:03d}"
|
||||||
|
|
||||||
|
def _count_draft_claims_for_owner(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
employee: Employee | None,
|
||||||
|
employee_name: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> int:
|
||||||
|
owner_filters = self._build_draft_owner_filters(
|
||||||
|
employee=employee,
|
||||||
|
employee_name=employee_name,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if not owner_filters:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(ExpenseClaim)
|
||||||
|
.where(ExpenseClaim.status == "draft")
|
||||||
|
.where(or_(*owner_filters))
|
||||||
|
)
|
||||||
|
return int(self.db.scalar(stmt) or 0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_draft_owner_filters(
|
||||||
|
*,
|
||||||
|
employee: Employee | None,
|
||||||
|
employee_name: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> list[Any]:
|
||||||
|
conditions: list[Any] = []
|
||||||
|
seen: set[tuple[str, str]] = set()
|
||||||
|
|
||||||
|
def add_condition(field_name: str, value: str | None) -> None:
|
||||||
|
normalized = str(value or "").strip()
|
||||||
|
if not normalized or normalized == "待补充":
|
||||||
|
return
|
||||||
|
|
||||||
|
marker = (field_name, normalized.lower())
|
||||||
|
if marker in seen:
|
||||||
|
return
|
||||||
|
seen.add(marker)
|
||||||
|
|
||||||
|
if field_name == "employee_id":
|
||||||
|
conditions.append(ExpenseClaim.employee_id == normalized)
|
||||||
|
return
|
||||||
|
conditions.append(ExpenseClaim.employee_name == normalized)
|
||||||
|
|
||||||
|
if employee is not None:
|
||||||
|
add_condition("employee_id", employee.id)
|
||||||
|
add_condition("employee_name", employee.name)
|
||||||
|
add_condition("employee_name", employee.email)
|
||||||
|
|
||||||
|
add_condition("employee_name", employee_name)
|
||||||
|
add_condition("employee_name", user_id)
|
||||||
|
return conditions
|
||||||
|
|
||||||
def _resolve_employee(
|
def _resolve_employee(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ontology: OntologyParseResult,
|
ontology: OntologyParseResult,
|
||||||
context_json: dict[str, Any],
|
context_json: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
) -> Employee | None:
|
) -> Employee | None:
|
||||||
|
normalized_user_id = str(user_id or "").strip()
|
||||||
|
if normalized_user_id:
|
||||||
|
stmt = select(Employee).where(func.lower(Employee.email) == normalized_user_id.lower()).limit(1)
|
||||||
|
employee = self.db.scalar(stmt)
|
||||||
|
if employee is not None:
|
||||||
|
return employee
|
||||||
|
|
||||||
employee_name = self._resolve_employee_name(
|
employee_name = self._resolve_employee_name(
|
||||||
ontology=ontology,
|
ontology=ontology,
|
||||||
context_json=context_json,
|
context_json=context_json,
|
||||||
|
|||||||
@@ -207,6 +207,8 @@ class UserAgentService:
|
|||||||
|
|
||||||
if payload.ontology.intent == "draft":
|
if payload.ontology.intent == "draft":
|
||||||
tool_message = str(payload.tool_payload.get("message") or "").strip()
|
tool_message = str(payload.tool_payload.get("message") or "").strip()
|
||||||
|
if payload.tool_payload.get("draft_limit_reached"):
|
||||||
|
return tool_message or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||||
if tool_message and (
|
if tool_message and (
|
||||||
str(payload.tool_payload.get("claim_id") or "").strip()
|
str(payload.tool_payload.get("claim_id") or "").strip()
|
||||||
or str(payload.tool_payload.get("claim_no") or "").strip()
|
or str(payload.tool_payload.get("claim_no") or "").strip()
|
||||||
@@ -988,6 +990,11 @@ class UserAgentService:
|
|||||||
return None
|
return None
|
||||||
if payload.ontology.intent not in {"draft", "operate"}:
|
if payload.ontology.intent not in {"draft", "operate"}:
|
||||||
return None
|
return None
|
||||||
|
if payload.tool_payload.get("draft_limit_reached"):
|
||||||
|
return (
|
||||||
|
str(payload.tool_payload.get("message") or "").strip()
|
||||||
|
or "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||||
|
)
|
||||||
|
|
||||||
review_action = str(payload.context_json.get("review_action") or "").strip()
|
review_action = str(payload.context_json.get("review_action") or "").strip()
|
||||||
if review_action == "save_draft":
|
if review_action == "save_draft":
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from app.api.deps import get_db
|
|||||||
from app.db.base import Base
|
from app.db.base import Base
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
|
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
|
||||||
|
from app.models.employee import Employee
|
||||||
from app.models.financial_record import ExpenseClaim
|
from app.models.financial_record import ExpenseClaim
|
||||||
from app.schemas.settings import SettingsWrite
|
from app.schemas.settings import SettingsWrite
|
||||||
from app.services.agent_assets import AgentAssetService
|
from app.services.agent_assets import AgentAssetService
|
||||||
@@ -183,6 +184,153 @@ def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
|||||||
assert claim.items
|
assert claim.items
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_blocks_fourth_expense_draft_for_same_user() -> None:
|
||||||
|
client, session_factory = build_client()
|
||||||
|
user_id = "zhangsan@example.com"
|
||||||
|
|
||||||
|
with session_factory() as db:
|
||||||
|
db.add(
|
||||||
|
Employee(
|
||||||
|
employee_no="E1001",
|
||||||
|
name="张三",
|
||||||
|
email=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
for amount, city in ((120, "上海"), (240, "北京"), (360, "深圳")):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"name": "张三",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
|
||||||
|
|
||||||
|
blocked_response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": "帮我生成报销草稿,我昨天去杭州出差,交通费480元",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"name": "张三",
|
||||||
|
"review_action": "save_draft",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert blocked_response.status_code == 200
|
||||||
|
blocked_payload = blocked_response.json()
|
||||||
|
assert blocked_payload["status"] == "succeeded"
|
||||||
|
assert "你当前已保存 3 个草稿" in blocked_payload["result"]["answer"]
|
||||||
|
assert blocked_payload["result"]["draft_payload"]["claim_id"] is None
|
||||||
|
assert blocked_payload["result"]["draft_payload"]["claim_no"] is None
|
||||||
|
assert blocked_payload["result"]["draft_payload"]["status"] == "blocked"
|
||||||
|
|
||||||
|
with session_factory() as db:
|
||||||
|
draft_count = db.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(ExpenseClaim)
|
||||||
|
.where(ExpenseClaim.status == "draft")
|
||||||
|
)
|
||||||
|
assert draft_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_allows_existing_draft_update_when_user_already_has_three_drafts() -> None:
|
||||||
|
client, session_factory = build_client()
|
||||||
|
user_id = "lisi@example.com"
|
||||||
|
|
||||||
|
with session_factory() as db:
|
||||||
|
db.add(
|
||||||
|
Employee(
|
||||||
|
employee_no="E1002",
|
||||||
|
name="李四",
|
||||||
|
email=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
first_response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": "帮我生成报销草稿,我昨天去上海出差,交通费120元",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"name": "李四",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert first_response.status_code == 200
|
||||||
|
first_payload = first_response.json()
|
||||||
|
claim_id = first_payload["result"]["draft_payload"]["claim_id"]
|
||||||
|
conversation_id = first_payload["conversation_id"]
|
||||||
|
assert claim_id
|
||||||
|
assert conversation_id
|
||||||
|
|
||||||
|
for amount, city in ((240, "北京"), (360, "深圳")):
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": user_id,
|
||||||
|
"message": f"帮我生成报销草稿,我昨天去{city}出差,交通费{amount}元",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"name": "李四",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["result"]["draft_payload"]["claim_no"].startswith("EXP-")
|
||||||
|
|
||||||
|
update_response = client.post(
|
||||||
|
"/api/v1/orchestrator/run",
|
||||||
|
json={
|
||||||
|
"source": "user_message",
|
||||||
|
"user_id": user_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"message": "金额改成888元",
|
||||||
|
"context_json": {
|
||||||
|
"role_codes": ["finance"],
|
||||||
|
"name": "李四",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_response.status_code == 200
|
||||||
|
update_payload = update_response.json()
|
||||||
|
assert update_payload["result"]["draft_payload"]["claim_id"] == claim_id
|
||||||
|
|
||||||
|
with session_factory() as db:
|
||||||
|
claim = db.scalar(select(ExpenseClaim).where(ExpenseClaim.id == claim_id))
|
||||||
|
assert claim is not None
|
||||||
|
assert float(claim.amount) == 888.0
|
||||||
|
|
||||||
|
draft_count = db.scalar(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(ExpenseClaim)
|
||||||
|
.where(ExpenseClaim.employee_id == claim.employee_id)
|
||||||
|
.where(ExpenseClaim.status == "draft")
|
||||||
|
)
|
||||||
|
assert draft_count == 3
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
|
def test_orchestrator_persists_conversation_and_reuses_expense_draft_context() -> None:
|
||||||
client, session_factory = build_client()
|
client, session_factory = build_client()
|
||||||
|
|
||||||
|
|||||||
@@ -246,6 +246,36 @@ def test_user_agent_draft_returns_structured_payload() -> None:
|
|||||||
assert response.answer == response.review_payload.body_message
|
assert response.answer == response.review_payload.body_message
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_returns_draft_limit_message_when_save_is_blocked() -> None:
|
||||||
|
session_factory = build_session_factory()
|
||||||
|
with session_factory() as db:
|
||||||
|
ontology = SemanticOntologyService(db).parse(
|
||||||
|
OntologyParseRequest(
|
||||||
|
query="请按当前识别信息保存报销草稿",
|
||||||
|
user_id="pytest",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = UserAgentService(db).respond(
|
||||||
|
UserAgentRequest(
|
||||||
|
run_id=ontology.run_id,
|
||||||
|
user_id="pytest",
|
||||||
|
message="请按当前识别信息保存报销草稿",
|
||||||
|
ontology=ontology,
|
||||||
|
context_json={"review_action": "save_draft"},
|
||||||
|
tool_payload={
|
||||||
|
"draft_limit_reached": True,
|
||||||
|
"message": "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。",
|
||||||
|
"status": "blocked",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.answer
|
||||||
|
== "你当前已保存 3 个草稿,请先完成已保存的草稿,才能再次新建草稿。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None:
|
def test_user_agent_builds_review_payload_for_multi_document_expense_flow() -> None:
|
||||||
session_factory = build_session_factory()
|
session_factory = build_session_factory()
|
||||||
with session_factory() as db:
|
with session_factory() as db:
|
||||||
|
|||||||
Reference in New Issue
Block a user