Files
X-Financial/server/tests/test_risk_observations_service.py

365 lines
15 KiB
Python
Raw Normal View History

from __future__ import annotations
from collections.abc import Generator
from datetime import UTC, datetime
from decimal import Decimal
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.api.v1.endpoints.risk_observations import router as risk_observations_router
from app.db.base import Base
from app.models.employee import Employee
from app.models.financial_record import ExpenseClaim
from app.models.risk_observation import RiskObservation
from app.schemas.risk_observation import RiskObservationFeedbackCreate
from app.algorithem.risk_graph.replay import AlgorithmReplaySetBuilder
from app.services.risk_observations import RiskObservationService
def test_risk_observation_service_upserts_and_summarizes_dashboard() -> None:
with _build_session() as db:
db.add(_employee_orm())
db.add_all([_claim_orm("c1", "BX-001"), _claim_orm("c2", "BX-002")])
db.flush()
service = RiskObservationService(db)
service.upsert_observation(_observation_payload("risk:c1:duplicate_invoice"))
service.upsert_observation(
{
**_observation_payload("risk:c2:preapproval_absent"),
"claim_id": "c2",
"claim_no": "BX-002",
"risk_signal": "preapproval_absent",
"risk_type": "preapproval_absent",
"risk_score": 72,
"risk_level": "high",
}
)
db.commit()
feedback = service.create_feedback(
"risk:c1:duplicate_invoice",
RiskObservationFeedbackCreate(feedback_type="confirm", actor="auditor"),
)
dashboard = service.summarize_dashboard(window_days=30)
history = service.build_history_stats(risk_signals={"duplicate_invoice"})
refreshed = service.get_observation("risk:c1:duplicate_invoice")
assert feedback.feedback_type == "confirm"
assert refreshed is not None
assert refreshed.status == "confirmed"
assert refreshed.source == "financial_risk_graph"
assert refreshed.algorithm_version == "financial_risk_graph.v1"
assert refreshed.sampling_strategy["strategy"] == "focused_review"
assert refreshed.evaluation_case_id == "case-duplicate-invoice"
assert refreshed.ontology_parse_id == "parse-1"
assert refreshed.ontology_version == "ontology.v1"
assert refreshed.domain == "expense"
assert refreshed.scenario == "reimbursement"
assert refreshed.intent == "risk_check"
assert refreshed.ontology_entities_json == [{"type": "claim", "value": "c1"}]
assert refreshed.risk_signals_json == [{"code": "duplicate_invoice"}]
assert refreshed.canonical_subject_key == "claim:c1"
assert dashboard.total_observations == 2
assert dashboard.high_or_above_count == 2
assert dashboard.risk_clue_count == 1
assert dashboard.confirmed_count == 1
assert dashboard.feedback_sample_count == 1
assert dashboard.total_amount == 2400.0
assert dashboard.level_distribution["high"] == 2
assert dashboard.signal_distribution["duplicate_invoice"] == 1
assert dashboard.department_distribution["风控部"] == 2
assert dashboard.expense_type_distribution["travel"] == 2
assert dashboard.employee_grade_distribution["P6"] == 2
assert dashboard.supplier_distribution["上海差旅供应商"] == 2
assert dashboard.top_departments[0]["name"] == "风控部"
assert dashboard.top_departments[0]["amount"] == 2400.0
assert dashboard.top_employees[0]["name"] == "风险员工"
assert dashboard.top_suppliers[0]["name"] == "上海差旅供应商"
assert dashboard.top_expense_types[0]["name"] == "travel"
assert dashboard.top_rules[0]["name"] == "policy.duplicate_invoice"
assert dashboard.top_risk_signals[0]["name"] in {
"duplicate_invoice",
"preapproval_absent",
}
assert dashboard.daily_trend
assert history[0].risk_signal == "duplicate_invoice"
assert history[0].confirmed_count == 1
def test_platform_rule_flags_are_persisted_as_risk_observations() -> None:
with _build_session() as db:
claim = _claim_orm("c-platform", "BX-PLATFORM")
db.add(claim)
db.flush()
observations = RiskObservationService(db).upsert_platform_risk_flags(
claim,
[
{
"hit_source": "rule_center",
"rule_type": "risk",
"rule_code": "risk.invoice.duplicate_invoice",
"rule_version": "v1.2.0",
"severity": "critical",
"action": "block",
"label": "重复发票校验",
"message": "票据号码已在其他报销单中出现。",
"evidence": {"invoice_no": "INV-001"},
}
],
)
db.commit()
assert len(observations) == 1
persisted = db.query(RiskObservation).filter_by(claim_id="c-platform").one()
assert persisted.risk_signal == "duplicate_invoice"
assert persisted.risk_level == "critical"
assert persisted.source == "rule_center"
assert persisted.algorithm_version == "v1.2.0"
assert persisted.contribution_scores_json == {"S_rule": 100}
def test_risk_observation_storage_ready_is_cached_per_bind(monkeypatch: pytest.MonkeyPatch) -> None:
with _build_session() as db:
RiskObservationService._storage_ready_cache.clear()
create_all_calls = []
original_create_all = Base.metadata.create_all
def spy_create_all(*args, **kwargs):
create_all_calls.append(kwargs.get("bind"))
return original_create_all(*args, **kwargs)
monkeypatch.setattr(Base.metadata, "create_all", spy_create_all)
service = RiskObservationService(db)
service.ensure_storage_ready()
service.ensure_storage_ready()
RiskObservationService(db).ensure_storage_ready()
assert len(create_all_calls) == 1
RiskObservationService._storage_ready_cache.clear()
def test_risk_observation_endpoints_return_list_detail_dashboard_and_feedback() -> None:
client, session_factory = _build_client()
with session_factory() as db:
service = RiskObservationService(db)
service.upsert_observation(
_observation_payload("risk:c1:duplicate_invoice"),
execution_log_id="exec-1",
)
db.commit()
list_response = client.get("/api/v1/risk-observations", params={"risk_level": "high"})
execution_log_response = client.get("/api/v1/risk-observations/execution-log/exec-1")
detail_response = client.get("/api/v1/risk-observations/risk:c1:duplicate_invoice")
dashboard_response = client.get("/api/v1/risk-observations/dashboard")
feedback_response = client.post(
"/api/v1/risk-observations/risk:c1:duplicate_invoice/feedback",
json={"feedback_type": "false_positive", "actor": "auditor", "comment": "误报"},
)
assert list_response.status_code == 200
assert list_response.json()["total"] == 1
assert execution_log_response.status_code == 200
assert len(execution_log_response.json()) == 1
assert detail_response.status_code == 200
assert detail_response.json()["risk_signal"] == "duplicate_invoice"
assert dashboard_response.status_code == 200
assert dashboard_response.json()["total_observations"] == 1
assert dashboard_response.json()["risk_clue_count"] == 1
assert dashboard_response.json()["feedback_sample_count"] == 0
assert "top_departments" in dashboard_response.json()
assert feedback_response.status_code == 200
assert feedback_response.json()["feedback_type"] == "false_positive"
updated_detail_response = client.get("/api/v1/risk-observations/risk:c1:duplicate_invoice")
assert updated_detail_response.status_code == 200
assert updated_detail_response.json()["feedback_items"][0]["feedback_type"] == "false_positive"
with session_factory() as db:
observation = db.query(RiskObservation).filter_by(
observation_key="risk:c1:duplicate_invoice"
).one()
assert observation.status == "false_positive"
assert observation.feedback_status == "false_positive"
def test_risk_observation_feedback_pool_fields_and_replay_set_contract() -> None:
with _build_session() as db:
service = RiskObservationService(db)
service.upsert_observation(_observation_payload("risk:c1:duplicate_invoice"))
db.commit()
feedback = service.create_feedback(
"risk:c1:duplicate_invoice",
RiskObservationFeedbackCreate(
feedback_type="comment",
action="rewrite",
actor="auditor",
comment="建议生成候选规则",
payload_json={
"decision": "candidate_rule_rewrite",
"candidate_rule_source": "risk_observation_feedback",
"confidence_score": 0.76,
"escalation_target": "finance_manager",
"supplement_required": True,
},
),
)
observation = service.get_observation("risk:c1:duplicate_invoice")
assert observation is not None
replay_set = AlgorithmReplaySetBuilder().build_from_observations(
"replay-set-1",
[
{
"observation_key": observation.observation_key,
"claim_id": observation.claim_id,
"risk_signal": observation.risk_signal,
"risk_score": observation.risk_score,
"risk_level": observation.risk_level,
"algorithm_version": observation.algorithm_version,
"feedback_status": observation.feedback_status,
"ontology_json": observation.ontology_json,
"decision_trace": observation.decision_trace_json,
}
],
created_at=datetime(2026, 5, 30, tzinfo=UTC),
)
assert feedback.decision == "candidate_rule_rewrite"
assert feedback.candidate_rule_source == "risk_observation_feedback"
assert feedback.confidence_score == 0.76
assert feedback.escalation_target == "finance_manager"
assert feedback.supplement_required is True
assert replay_set.replay_set_id == "replay-set-1"
assert replay_set.cases[0].claim_id == "c1"
assert replay_set.cases[0].ontology_version == "ontology.v1"
assert replay_set.cases[0].algorithm_version == "financial_risk_graph.v1"
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def _build_client() -> tuple[TestClient, sessionmaker[Session]]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
app = FastAPI()
app.include_router(risk_observations_router, prefix="/api/v1")
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
def _observation_payload(observation_key: str) -> dict:
return {
"observation_key": observation_key,
"subject_type": "expense_claim",
"subject_key": "claim:c1",
"subject_label": "BX-001",
"claim_id": "c1",
"claim_no": "BX-001",
"risk_type": "duplicate_invoice",
"risk_signal": "duplicate_invoice",
"title": "Duplicate invoice risk",
"description": "Same invoice appears in multiple claims.",
"risk_score": 86,
"risk_level": "high",
"confidence_score": "0.81",
"control_stage": "reimbursement",
"control_mode": "risk_observation",
"automation_mode": "semi_auto_review",
"source": "financial_risk_graph",
"algorithm_version": "financial_risk_graph.v1",
"contribution_scores": {"S_rule": 82, "S_graph": 95},
"baseline": {"scope": "expense_type", "sample_size": 4},
"evidence": [
{
"code": "duplicate_invoice_graph",
"source": "graph",
"metadata": {"vendor_name": "上海差旅供应商"},
}
],
"graph_node_keys": ["claim:c1", "vendor:上海差旅供应商"],
"graph_edge_keys": [],
"policy_refs": ["policy.duplicate_invoice"],
"similar_case_claim_ids": ["c2"],
"ontology_json": {
"gate": "review",
"ontology_parse_id": "parse-1",
"ontology_version": "ontology.v1",
"domain": "expense",
"scenario": "reimbursement",
"intent": "risk_check",
"ontology_entities_json": [{"type": "claim", "value": "c1"}],
"risk_signals_json": [{"code": "duplicate_invoice"}],
"canonical_subject_key": "claim:c1",
},
"decision_trace": {
"formula": "weighted",
"sampling_strategy": {"strategy": "focused_review", "threshold": 70},
"evaluation_case_id": "case-duplicate-invoice",
},
}
def _employee_orm() -> Employee:
return Employee(
id="emp-risk",
employee_no="E-RISK",
name="风险员工",
email="risk.employee@example.com",
position="高级专员",
grade="P6",
)
def _claim_orm(claim_id: str, claim_no: str) -> ExpenseClaim:
now = datetime(2026, 5, 20, tzinfo=UTC)
return ExpenseClaim(
id=claim_id,
claim_no=claim_no,
employee_id="emp-risk",
employee_name="风险员工",
department_id="dept-risk",
department_name="风控部",
expense_type="travel",
reason="客户拜访",
location="上海",
amount=Decimal("1200"),
currency="CNY",
invoice_count=1,
occurred_at=now,
submitted_at=now,
status="submitted",
approval_stage="manager_review",
risk_flags_json=[],
)