Files
X-Financial/server/tests/test_financial_risk_graph_algorithm.py

819 lines
30 KiB
Python
Raw Normal View History

from __future__ import annotations
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.algorithem.risk_graph import (
RiskGraphClaimItemSnapshot,
RiskGraphClaimSnapshot,
RiskGraphEvaluationContext,
RiskHistoryStats,
evaluate_financial_risk_graph,
map_ontology_to_risk_graph,
)
from app.algorithem.risk_graph.anomaly_models import AnomalyPoint, MultiModelAnomalyDetector
from app.algorithem.risk_graph.control_effect import ControlEffectAnalyzer
from app.algorithem.risk_graph.counterfactual import CounterfactualRiskAdvisor
from app.algorithem.risk_graph.engine import _apply_evidence_source_gate
from app.algorithem.risk_graph.entity_resolution import (
CanonicalEntityRegistry,
FinancialEntityResolver,
)
from app.algorithem.risk_graph.evaluation_cases import default_risk_evaluation_cases
from app.algorithem.risk_graph.features import HeterogeneousRiskGraphFeatureBuilder
from app.algorithem.risk_graph.lineage import RiskDataLineageBuilder
from app.algorithem.risk_graph.models import RiskEvidence, RiskGraphEdge
from app.algorithem.risk_graph.policy_knowledge_contract import (
PolicyKnowledgeItem,
PolicyKnowledgeOrganizingReport,
PolicySourceRef,
build_policy_ref,
)
from app.algorithem.risk_graph.process_mining import (
ConformanceRiskDetector,
ObjectCentricProcessMiner,
)
from app.algorithem.risk_graph.rule_discovery import CandidateRiskRuleDiscovery
from app.algorithem.risk_graph.temporal import TemporalRiskGraphMonitor
from app.db.base import Base
from app.models.financial_record import ExpenseClaim, ExpenseClaimItem
from app.models.hermes_config import HermesTaskConfig, HermesTaskExecutionLog
from app.models.hermes_report import HermesRiskReport
from app.models.risk_observation import RiskObservation
from app.schemas.risk_observation import RiskObservationFeedbackCreate
from app.services.hermes_risk_scanner import HermesRiskScannerService
from app.services.risk_observations import RiskObservationService
def test_risk_graph_engine_combines_rule_anomaly_graph_policy_and_history() -> None:
target = _snapshot(
"c-risk",
"BX-001",
amount="12000",
risk_flags=[{"risk_signal": "preapproval_absent", "severity": "high"}],
items=[_item("i-risk", "hotel", "12000", invoice_id="INV-001")],
)
duplicate = _snapshot(
"c-dup",
"BX-002",
amount="900",
employee_name="李四",
items=[_item("i-dup", "hotel", "900", invoice_id="INV-001")],
)
peers = [
_snapshot(
f"c-peer-{index}",
f"BX-10{index}",
amount=str(amount),
employee_name=f"同事{index}",
)
for index, amount in enumerate([700, 800, 900, 1000], start=1)
]
result = evaluate_financial_risk_graph(
RiskGraphEvaluationContext(
claims=[target, duplicate, *peers],
target_claim_ids={"c-risk"},
history_stats=[
RiskHistoryStats(
risk_signal="duplicate_invoice",
expense_type="travel",
similar_case_count=10,
confirmed_count=7,
false_positive_count=1,
returned_count=2,
)
],
)
)
assert len(result.observations) == 1
observation = result.observations[0]
assert observation.risk_signal == "duplicate_invoice"
assert observation.risk_level == "high"
assert observation.risk_score >= 80
assert observation.automation_mode == "semi_auto_review"
assert observation.contribution_scores["S_rule"] == 82
assert observation.contribution_scores["S_anomaly"] >= 90
assert observation.contribution_scores["S_graph"] >= 95
assert observation.contribution_scores["S_policy"] > 0
assert observation.contribution_scores["S_history"] > 0
evidence_sources = {item.source for item in observation.evidence}
assert len(evidence_sources) >= 2
assert observation.decision_trace["raw_risk_score"] >= observation.risk_score
assert observation.decision_trace["evidence_source_count"] >= 2
assert observation.decision_trace["evidence_source_gate"] == "passed"
assert observation.decision_trace["algorithm_version"] == "financial_risk_graph.v1"
assert observation.decision_trace["decision_row"] == "high:70<=score<90"
assert observation.decision_trace["feature_contributions_json"][0]["feature"] == "S_rule"
assert observation.decision_trace["explanation_template_key"] == "risk.duplicate_invoice.high"
assert observation.decision_trace["sampling_strategy"]["strategy"] == "focused_review"
assert observation.decision_trace["sampling_strategy"]["replay_bucket"] == "high_risk"
assert "c-dup" in observation.similar_case_claim_ids or any(
"c-dup" in key["target_key"] for key in observation.graph_edge_keys
)
node_payloads = [node.as_dict() for node in result.nodes]
assert all("canonical_id" in item for item in node_payloads)
assert all("ontology_parse_id" in item for item in node_payloads)
assert all("ontology_version" in item for item in node_payloads)
edge_payloads = [edge.as_dict() for edge in result.edges]
assert all(item["source"] for item in edge_payloads)
def test_high_risk_score_is_capped_when_only_one_evidence_source() -> None:
score, gate = _apply_evidence_source_gate(
92,
[
RiskEvidence(
code="rule_signal",
title="Rule signal",
detail="Only one evidence source is present.",
source="rule",
score=92,
)
],
)
assert score == 69
assert gate == "capped_high_risk_single_source"
def test_heterogeneous_graph_feature_builder_outputs_core_features() -> None:
result = evaluate_financial_risk_graph(
RiskGraphEvaluationContext(
claims=[
_snapshot(
"c-risk",
"BX-001",
amount="12000",
items=[_item("i-risk", "hotel", "12000", invoice_id="INV-001")],
),
_snapshot(
"c-dup",
"BX-002",
amount="900",
employee_name="李四",
items=[_item("i-dup", "hotel", "900", invoice_id="INV-001")],
),
],
target_claim_ids={"c-risk"},
)
)
features = HeterogeneousRiskGraphFeatureBuilder().build(
result.nodes,
result.edges,
risk_node_keys={"claim:c-risk"},
)
assert features.node_type_counts["claim"] == 2
assert features.edge_type_counts["claim_duplicate_invoice"] >= 1
assert features.meta_path_counts
assert features.clusters[0]["size"] >= 2
assert features.neighbor_risk_density["claim:c-dup"] > 0
def test_temporal_risk_graph_monitor_detects_edge_changes() -> None:
previous_edges = [
RiskGraphEdge(
source_key="employee:e1",
target_key="claim:c1",
edge_type="employee_submits_claim",
),
]
current_edges = [
RiskGraphEdge(
source_key="employee:e1",
target_key=f"claim:c{index}",
edge_type="employee_submits_claim",
)
for index in range(2, 5)
]
diff = TemporalRiskGraphMonitor().monitor(
previous_edges,
current_edges,
risk_node_keys={"claim:c2"},
)
change_types = {item.change_type for item in diff.changes}
assert "relationship_added" in change_types
assert "relationship_removed" in change_types
assert "relationship_surge" in change_types
assert "target_migration" in change_types
assert "risk_propagation" in change_types
assert diff.edge_type_delta["employee_submits_claim"] == 2
def test_financial_entity_resolver_and_registry_merge_aliases() -> None:
resolver = FinancialEntityResolver()
registry = CanonicalEntityRegistry()
first = resolver.resolve("supplier", " 上海 差旅-供应商 ", source="invoice")
second = resolver.resolve("merchant", "上海差旅供应商", source="receipt")
assert first is not None
assert second is not None
assert first.canonical_id == second.canonical_id
saved = registry.upsert(first)
saved = registry.upsert(second)
confirmed = registry.confirm(saved.canonical_id, actor="auditor")
assert len(registry.all()) == 1
assert confirmed is not None
assert confirmed.confirmed_by == "auditor"
assert set(confirmed.aliases) == {"上海 差旅-供应商", "上海差旅供应商"}
def test_multi_model_anomaly_detector_combines_deterministic_signals() -> None:
points = [
AnomalyPoint(
key=f"peer-{index}",
amount=Decimal(str(amount)),
occurred_at=occurred_at,
segment="travel",
)
for index, (amount, occurred_at) in enumerate(
[
(800, datetime(2026, 5, 4, tzinfo=UTC)),
(820, datetime(2026, 5, 11, tzinfo=UTC)),
(790, datetime(2026, 5, 12, tzinfo=UTC)),
(810, datetime(2026, 5, 13, tzinfo=UTC)),
(830, datetime(2026, 5, 14, tzinfo=UTC)),
],
start=1,
)
]
points.append(
AnomalyPoint(
key="target",
amount=Decimal("3200"),
occurred_at=datetime(2026, 5, 18, tzinfo=UTC),
segment="travel",
)
)
signals = MultiModelAnomalyDetector().detect(points, target_key="target")
methods = {item.method for item in signals}
assert "robust_statistics" in methods
assert "isolation_forest_proxy" in methods
assert "local_outlier_factor_proxy" in methods
assert "temporal_jump" in methods
assert "periodic_deviation" in methods
assert max(item.score for item in signals) >= 90
def test_object_centric_process_miner_builds_replayable_events() -> None:
claim = _snapshot(
"c-process",
"BX-PROCESS",
amount="1200",
risk_flags=[{"risk_signal": "preapproval_absent"}],
items=[_item("i-process", "hotel", "1200", invoice_id="INV-PROCESS")],
)
events = ObjectCentricProcessMiner().build_from_claims([claim])
event_types = {item.event_type for item in events}
invoice_event = next(item for item in events if item.event_type == "invoice_attached")
assert {"expense_occurred", "claim_submitted", "expense_item_recorded"} <= event_types
assert "invoice_attached" in event_types
assert "risk_flagged" in event_types
assert invoice_event.object_refs["claim"] == ["c-process"]
assert invoice_event.object_refs["invoice"] == ["INV-PROCESS"]
def test_conformance_risk_detector_finds_process_violations() -> None:
rows = [
_event_row("e-payment", "payment_completed", "2026-05-01T09:00:00+00:00", "c-flow"),
_event_row("e-submit-1", "claim_submitted", "2026-05-02T09:00:00+00:00", "c-flow"),
_event_row("e-approve", "approval_approved", "2026-05-03T09:00:00+00:00", "c-flow"),
_event_row("e-return-1", "claim_returned", "2026-05-04T09:00:00+00:00", "c-flow"),
_event_row("e-submit-2", "claim_submitted", "2026-05-05T09:00:00+00:00", "c-flow"),
_event_row("e-return-2", "claim_returned", "2026-05-06T09:00:00+00:00", "c-flow"),
_event_row("e-approval-only", "approval_approved", "2026-05-01T09:00:00+00:00", "c-bypass"),
_event_row("e-invoice-only", "invoice_attached", "2026-05-01T09:00:00+00:00", "c-invoice"),
]
events = ObjectCentricProcessMiner().build_from_dicts(rows)
risks = ConformanceRiskDetector().detect(events)
risk_codes = {item.risk_code for item in risks}
assert "payment_before_approval" in risk_codes
assert "rework_loop" in risk_codes
assert "approval_bypass" in risk_codes
assert "process_bypass" in risk_codes
def test_risk_data_lineage_builder_collects_source_assets() -> None:
lineage = RiskDataLineageBuilder().build_from_observation(
{
"observation_key": "risk:c1:duplicate_invoice",
"claim_id": "c1",
"run_id": "agent-run-1",
"algorithm_version": "financial_risk_graph.v1",
"ontology_json": {"ontology_version": "ontology.v1"},
"evidence": [
{
"source": "ocr",
"metadata": {
"document_id": "doc-1",
"ocr_job_id": "ocr-1",
"tool_call_id": "tool-1",
},
},
{
"source": "rule_center",
"metadata": {"rule_version": "rule.v2"},
},
],
"decision_trace": {
"evidence_source_gate": "passed",
"data_quality_gate": "capped_missing_required_fields",
"sampling_strategy": {"strategy": "uncertainty_sample"},
},
},
source_event_ids=["event-1"],
)
assert {"risk_observations", "expense_claims", "expense_claim_items"} <= set(
lineage.data_tables
)
assert lineage.document_ids == ["doc-1"]
assert lineage.ocr_job_ids == ["ocr-1"]
assert lineage.agent_run_ids == ["agent-run-1"]
assert lineage.tool_call_ids == ["tool-1"]
assert lineage.rule_versions == ["rule.v2"]
assert lineage.ontology_version == "ontology.v1"
assert lineage.algorithm_version == "financial_risk_graph.v1"
assert lineage.source_event_ids == ["event-1"]
assert lineage.quality_gates == ["capped_missing_required_fields", "uncertainty_sample"]
def test_policy_knowledge_organizing_report_exposes_risk_policy_refs() -> None:
source = PolicySourceRef(
source_id="doc-travel-policy",
title="差旅报销风险管控制度",
location="第三章",
)
report = PolicyKnowledgeOrganizingReport(
summary="整理差旅预审批制度。",
categories=["差旅", "事前申请"],
knowledge_items=[
PolicyKnowledgeItem(
policy_ref=build_policy_ref("travel", "preapproval_absent"),
title="差旅事前申请",
summary="差旅报销需保留事前审批依据。",
expense_type="travel",
control_stage="reimbursement",
trigger_conditions=["preapproval_absent"],
source_refs=[source],
review_status="confirmed",
)
],
source_refs=[source],
)
payload = report.as_dict()
assert payload["risk_policy_refs"] == ["policy.travel.preapproval_absent"]
assert payload["knowledge_items"][0]["source_refs"][0]["source_id"] == "doc-travel-policy"
def test_counterfactual_risk_advisor_returns_actionable_reductions() -> None:
actions = CounterfactualRiskAdvisor().advise(
{
"contribution_scores": {"S_rule": 82, "S_anomaly": 90, "S_graph": 95},
"evidence": [{"code": "duplicate_invoice_graph"}],
"decision_trace": {"data_quality_gate": "capped_missing_required_fields"},
}
)
action_keys = {item.action_key for item in actions}
assert "complete_preapproval_or_required_attachment" in action_keys
assert "align_amount_with_peer_baseline" in action_keys
assert "replace_duplicate_or_conflicting_invoice" in action_keys
assert "supplement_missing_risk_data" in action_keys
assert all(item.expected_score_delta < 0 for item in actions)
def test_candidate_risk_rule_discovery_outputs_review_only_candidates() -> None:
candidates = CandidateRiskRuleDiscovery().discover_from_feedback(
observations=[
{
"observation_key": "risk:c1:duplicate_invoice",
"risk_signal": "duplicate_invoice",
"confidence_score": 0.82,
"evidence": [{"code": "duplicate_invoice_graph", "source": "graph"}],
}
],
feedback_items=[
{
"observation_key": "risk:c1:duplicate_invoice",
"feedback_type": "comment",
"action": "rewrite",
"decision": "candidate_rule_rewrite",
"candidate_rule_source": "risk_observation_feedback",
"confidence_score": 0.77,
"comment": "建议沉淀重复票据候选规则。",
}
],
)
assert len(candidates) == 1
candidate = candidates[0]
assert candidate.rule_code == "candidate.risk.duplicate_invoice"
assert candidate.status == "candidate_review"
assert candidate.source == "risk_observation_feedback"
assert candidate.confidence_score == 0.77
assert any(item["source"] == "graph" for item in candidate.evidence)
assert any(item["source"] == "risk_observation_feedback" for item in candidate.evidence)
def test_control_effect_analyzer_compares_before_and_after_windows() -> None:
summary = ControlEffectAnalyzer().compare(
before=[
{"risk_score": 90, "risk_level": "critical", "feedback_status": "false_positive"},
{"risk_score": 80, "risk_level": "high", "feedback_status": "confirmed"},
],
after=[
{"risk_score": 62, "risk_level": "medium", "feedback_status": "confirmed"},
{"risk_score": 55, "risk_level": "medium", "feedback_status": "confirmed"},
],
)
assert summary.before_count == 2
assert summary.after_count == 2
assert summary.average_score_delta < 0
assert summary.high_rate_delta < 0
assert summary.confirmation_rate_delta > 0
assert summary.false_positive_rate_delta < 0
def test_risk_data_quality_gate_caps_strong_conclusion_for_low_quality_claim() -> None:
target = _snapshot(
"c-low-quality",
"BX-005",
amount="12000",
employee_name="",
risk_flags=[{"risk_signal": "preapproval_absent", "severity": "critical"}],
items=[_item("i-low-quality", "hotel", "900", invoice_id="INV-LOW")],
)
peers = [
_snapshot(
f"c-peer-quality-{index}",
f"BX-30{index}",
amount=str(amount),
employee_name=f"同事{index}",
)
for index, amount in enumerate([700, 800, 900, 1000], start=1)
]
result = evaluate_financial_risk_graph(
RiskGraphEvaluationContext(
claims=[target, *peers],
target_claim_ids={"c-low-quality"},
)
)
assert len(result.observations) == 1
observation = result.observations[0]
assert observation.decision_trace["data_quality_gate"] == "capped_missing_required_fields"
assert observation.decision_trace["data_quality"]["passed"] is False
assert "employee" in observation.decision_trace["data_quality"]["missing_fields"]
assert observation.decision_trace["sampling_strategy"]["strategy"] == "uncertainty_sample"
assert observation.decision_trace["sampling_strategy"]["replay_bucket"] == "data_quality_gate"
assert "score_capped_by_gate" in observation.decision_trace["uncertainty_reasons_json"]
assert "data_quality_gate_not_passed" in observation.decision_trace["uncertainty_reasons_json"]
assert observation.risk_score == 69
assert observation.risk_level == "medium"
def test_default_risk_evaluation_cases_cover_required_categories() -> None:
cases = default_risk_evaluation_cases()
categories = {item.category for item in cases}
assert {
"positive",
"negative",
"counterfactual",
"noise",
"historical_false_positive",
} <= categories
assert all(item.case_id and item.description for item in cases)
def test_risk_graph_engine_avoids_false_risk_when_baseline_and_signals_are_missing() -> None:
result = evaluate_financial_risk_graph(
RiskGraphEvaluationContext(
claims=[_snapshot("c-clean", "BX-003", amount="300")],
target_claim_ids={"c-clean"},
)
)
assert result.observations == []
assert any(node.key == "claim:c-clean" for node in result.nodes)
def test_risk_graph_engine_detects_multi_evidence_and_spatiotemporal_mismatch() -> None:
target = _snapshot(
"c-mismatch",
"BX-004",
amount="8000",
invoice_count=2,
items=[
_item(
"i-mismatch",
"hotel",
"900",
invoice_id="INV-MISMATCH",
item_location="北京",
item_date=date(2026, 4, 1),
)
],
)
peers = [
_snapshot(
f"c-peer-mismatch-{index}",
f"BX-20{index}",
amount=str(amount),
employee_name=f"同事{index}",
)
for index, amount in enumerate([700, 800, 900, 1000], start=1)
]
result = evaluate_financial_risk_graph(
RiskGraphEvaluationContext(
claims=[target, *peers],
target_claim_ids={"c-mismatch"},
)
)
assert len(result.observations) == 1
observation = result.observations[0]
evidence_codes = {item.code for item in observation.evidence}
evidence_sources = {item.source for item in observation.evidence}
assert "document_amount_mismatch" in evidence_codes
assert "invoice_count_mismatch" in evidence_codes
assert "date_outside_claim_window" in evidence_codes
assert "location_mismatch_graph" in evidence_codes
assert {"multi_evidence", "spatiotemporal"} <= evidence_sources
assert observation.risk_signal in {"date_outside_trip", "document_expense_mismatch"}
def test_ontology_mapping_normalizes_signals_and_uses_confidence_gate() -> None:
mapping = map_ontology_to_risk_graph(
{
"run_id": "run-ontology-1",
"scenario": "expense",
"intent": "risk_check",
"confidence": 0.49,
"entities": [
{
"type": "employee",
"value": "张三",
"normalized_value": "E001",
"role": "target",
"confidence": 0.8,
},
{
"type": "expense_type",
"value": "差旅费",
"normalized_value": "travel",
"role": "filter",
"confidence": 0.9,
},
],
"constraints": [{"field": "amount", "operator": ">", "value": 5000}],
"risk_flags": ["city_mismatch"],
},
ontology_version="ontology.test",
)
assert mapping.gate == "candidate_only"
assert mapping.canonical_subject_key == "employee:e001"
assert [item.code for item in mapping.risk_signals] == ["location_mismatch"]
node_payloads = [node.as_dict() for node in mapping.nodes]
assert all(item["canonical_id"] for item in node_payloads)
assert {item["ontology_parse_id"] for item in node_payloads} == {"run-ontology-1"}
assert {item["ontology_version"] for item in node_payloads} == {"ontology.test"}
assert {edge.edge_type for edge in mapping.edges} <= {
"ontology_extracts",
"ontology_constrains",
"ontology_signals",
}
assert {edge.as_dict()["source"] for edge in mapping.edges} == {"ontology"}
def test_hermes_risk_scanner_persists_algorithm_reports() -> None:
with _build_session() as db:
config = HermesTaskConfig(
task_type="global_risk_scan",
cron_expression="0 0 * * *",
is_enabled=True,
)
db.add(config)
db.flush()
log = HermesTaskExecutionLog(config_id=config.id, status="running")
db.add(log)
target = _claim_orm(
"c-risk",
"BX-001",
amount=Decimal("12000"),
risk_flags=[{"risk_signal": "preapproval_absent", "severity": "high"}],
)
target.items.append(
_claim_item_orm("item-risk", "c-risk", Decimal("12000"), invoice_id="INV-001")
)
duplicate = _claim_orm("c-dup", "BX-002", amount=Decimal("900"), employee_name="李四")
duplicate.items.append(
_claim_item_orm("item-dup", "c-dup", Decimal("900"), invoice_id="INV-001")
)
peers = [
_claim_orm(
f"c-peer-{index}",
f"BX-10{index}",
amount=Decimal(str(amount)),
employee_name=f"同事{index}",
)
for index, amount in enumerate([700, 800, 900, 1000], start=1)
]
historical = _claim_orm("c-history", "BX-HIST", amount=Decimal("1000"))
historical.status = "approved"
db.add_all([target, duplicate, *peers, historical])
db.flush()
observation_service = RiskObservationService(db)
observation_service.upsert_observation(
{
"observation_key": "risk:c-history:duplicate_invoice",
"subject_type": "expense_claim",
"subject_key": "claim:c-history",
"subject_label": "BX-HIST",
"claim_id": "c-history",
"claim_no": "BX-HIST",
"risk_type": "duplicate_invoice",
"risk_signal": "duplicate_invoice",
"title": "Historical duplicate invoice risk",
"description": "Confirmed historical duplicate invoice risk.",
"risk_score": 82,
"risk_level": "high",
"source": "financial_risk_graph",
"algorithm_version": "financial_risk_graph.test",
"contribution_scores": {"S_rule": 82},
}
)
observation_service.create_feedback(
"risk:c-history:duplicate_invoice",
RiskObservationFeedbackCreate(feedback_type="confirm", actor="auditor"),
)
summary = HermesRiskScannerService(db).scan_global_risks(log_id=log.id)
reports = list(db.scalars(select(HermesRiskReport)).all())
observations = list(db.scalars(select(RiskObservation)).all())
target_observation = next(item for item in observations if item.claim_id == "c-risk")
refreshed_target = db.get(ExpenseClaim, "c-risk")
assert summary["risk_observation_count"] >= 1
assert any(report.risk_type == "duplicate_invoice" for report in reports)
assert any(item.risk_signal == "duplicate_invoice" for item in observations)
assert target_observation.execution_log_id == log.id
assert target_observation.contribution_scores_json.get("S_history", 0) > 0
assert refreshed_target is not None
assert refreshed_target.hermes_risk_flag is True
assert any(
isinstance(flag, dict) and flag.get("source") == "financial_risk_graph"
for flag in refreshed_target.risk_flags_json
)
def _snapshot(
claim_id: str,
claim_no: str,
*,
amount: str,
employee_name: str = "张三",
department_name: str = "销售部",
employee_grade: str = "P7",
expense_type: str = "travel",
location: str = "上海",
invoice_count: int = 0,
occurred_at: datetime | None = None,
submitted_at: datetime | None = None,
risk_flags: list | None = None,
items: list[RiskGraphClaimItemSnapshot] | None = None,
) -> RiskGraphClaimSnapshot:
occurred = occurred_at or datetime(2026, 5, 20, tzinfo=UTC)
return RiskGraphClaimSnapshot(
claim_id=claim_id,
claim_no=claim_no,
employee_id=employee_name,
employee_name=employee_name,
department_id=department_name,
department_name=department_name,
employee_grade=employee_grade,
expense_type=expense_type,
amount=Decimal(amount),
invoice_count=invoice_count,
occurred_at=occurred,
submitted_at=submitted_at or occurred + timedelta(hours=1),
status="submitted",
location=location,
risk_flags=risk_flags or [],
items=items or [],
)
def _item(
item_id: str,
item_type: str,
amount: str,
*,
invoice_id: str | None = None,
item_location: str = "上海",
item_date: date | None = None,
) -> RiskGraphClaimItemSnapshot:
return RiskGraphClaimItemSnapshot(
item_id=item_id,
item_type=item_type,
item_amount=Decimal(amount),
item_location=item_location,
item_date=item_date or date(2026, 5, 20),
invoice_id=invoice_id,
)
def _event_row(event_id: str, event_type: str, occurred_at: str, claim_id: str) -> dict:
return {
"event_id": event_id,
"event_type": event_type,
"occurred_at": occurred_at,
"object_refs": {"claim": [claim_id]},
"source": "test",
}
def _build_session() -> Session:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
return session_factory()
def _claim_orm(
claim_id: str,
claim_no: str,
*,
amount: Decimal,
employee_name: str = "张三",
risk_flags: list | None = None,
) -> ExpenseClaim:
now = datetime(2026, 5, 20, tzinfo=UTC)
return ExpenseClaim(
id=claim_id,
claim_no=claim_no,
employee_id=employee_name,
employee_name=employee_name,
department_id="sales",
department_name="销售部",
expense_type="travel",
reason="客户拜访",
location="上海",
amount=amount,
currency="CNY",
invoice_count=1,
occurred_at=now,
submitted_at=now + timedelta(hours=1),
status="submitted",
approval_stage="manager_review",
risk_flags_json=risk_flags or [],
)
def _claim_item_orm(
item_id: str,
claim_id: str,
amount: Decimal,
*,
invoice_id: str,
) -> ExpenseClaimItem:
return ExpenseClaimItem(
id=item_id,
claim_id=claim_id,
item_date=date(2026, 5, 20),
item_type="hotel",
item_reason="客户拜访住宿",
item_location="上海",
item_amount=amount,
invoice_id=invoice_id,
)