Files
X-Financial/server/tests/test_ontology_service.py
caoxiaozhu 50b1c3f9a9 feat: 增强规则资产管理与审计页面运行时调试
后端新增规则资产版本管理和规则文件 CRUD 接口,优化风险
规则生成模板执行和员工数据模型字段,知识库 RAG 增强本
地回退和文档提取能力,清理旧风险规则文件统一由生成引擎
管理,前端审计页面增加运行时调试面板和规则资产编辑交互,
补充单元测试覆盖。
2026-05-24 21:44:17 +08:00

752 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.ontology import OntologyParseRequest
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
EVALUATION_CASES = [
pytest.param(
"查一下本周报销超标风险",
"expense",
"risk_check",
"read",
{},
id="expense-risk-check",
),
pytest.param(
"张三 4 月差旅报销金额是多少",
"expense",
"query",
"read",
{},
id="expense-query-employee-month",
),
pytest.param(
"为什么酒店超标报销不能直接通过",
"expense",
"explain",
"read",
{},
id="expense-explain-policy",
),
pytest.param(
"列出金额最高的10笔报销",
"expense",
"query",
"read",
{},
id="expense-topn-query",
),
pytest.param(
"帮我生成张三4月差旅报销草稿",
"expense",
"draft",
"draft_write",
{},
id="expense-draft",
),
pytest.param(
"我今天去客户现场招待了客户花销了1000元",
"expense",
"draft",
"draft_write",
{},
id="expense-narrative-draft",
),
pytest.param(
"客户 A 这个月还有多少应收",
"accounts_receivable",
"query",
"read",
{},
id="ar-query-customer-month",
),
pytest.param(
"对比客户A和客户B本月应收差异",
"accounts_receivable",
"compare",
"read",
{},
id="ar-compare-customers",
),
pytest.param(
"检查客户B逾期应收风险",
"accounts_receivable",
"risk_check",
"read",
{},
id="ar-risk-check",
),
pytest.param(
"生成客户A回款跟进草稿",
"accounts_receivable",
"draft",
"draft_write",
{},
id="ar-draft",
),
pytest.param(
"查询客户B账龄明细",
"accounts_receivable",
"query",
"read",
{},
id="ar-aging-query",
),
pytest.param(
"供应商 B 明天要付多少钱",
"accounts_payable",
"query",
"read",
{},
id="ap-query-vendor-tomorrow",
),
pytest.param(
"对比供应商A和供应商B本月应付差异",
"accounts_payable",
"compare",
"read",
{},
id="ap-compare-vendors",
),
pytest.param(
"检查供应商B逾期付款风险",
"accounts_payable",
"risk_check",
"read",
{},
id="ap-risk-check",
),
pytest.param(
"生成供应商A付款沟通草稿",
"accounts_payable",
"draft",
"draft_write",
{},
id="ap-draft",
),
pytest.param(
"帮我安排付款给供应商B",
"accounts_payable",
"operate",
"approval_required",
{"role_codes": ["finance"]},
id="ap-operate-approval-required",
),
pytest.param(
"公司财务制度在哪里看",
"knowledge",
"query",
"read",
{},
id="knowledge-query",
),
pytest.param(
"规则中心的审核依据是什么",
"knowledge",
"explain",
"read",
{},
id="knowledge-explain",
),
pytest.param(
"知识库里有没有双人复核制度",
"knowledge",
"query",
"read",
{},
id="knowledge-query-library",
),
pytest.param(
"帮我直接付款给供应商B",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-direct-payment",
),
pytest.param(
"帮我上线付款双人复核规则",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-activate-rule",
),
pytest.param(
"帮我删除今天的报销记录",
"expense",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-delete-expense",
),
]
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
def test_semantic_ontology_service_matches_day3_evaluation_set(
query: str,
scenario: str,
intent: str,
permission: str,
context_json: dict,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query=query,
user_id="pytest",
context_json=context_json,
)
)
assert result.scenario == scenario
assert result.intent == intent
assert result.permission.level == permission
assert result.run_id.startswith("run_")
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 2026年4月差旅报销金额超过5000元的明细",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.start_date == "2026-04-01"
assert result.time_range.end_date == "2026-04-30"
def test_semantic_ontology_service_treats_travel_amount_question_as_knowledge_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我要去武汉出差3天请问我一共可以报销多少费用",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
assert result.clarification_question is None
assert result.missing_slots == []
def test_semantic_ontology_service_keeps_travel_amount_follow_up_in_knowledge_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="那P4员工可以报销多少钱",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
"conversation_history": [
{
"role": "user",
"content": "我要去武汉出差3天请问我一共可以报销多少费用",
}
],
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
def test_semantic_ontology_service_rejects_draft_intent_inside_knowledge_session(
monkeypatch,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充招待对象和票据附件。",
missing_slots=["participants", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要去北京出差3天一共可以报销多少钱",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
assert result.clarification_question is None
def test_review_next_step_context_inherits_expense_draft_flow() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我已核对右侧识别结果,请进入下一步。",
user_id="pytest",
context_json={
"review_action": "next_step",
"draft_claim_id": "claim-1",
"attachment_count": 1,
},
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.clarification_required is False
assert result.clarification_question is None
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我今天去客户现场招待了客户花销了1000元",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.time_range.raw == "今天"
assert result.clarification_required is True
assert "customer_name" in result.missing_slots
assert "participants" in result.missing_slots
assert any(
item.type == "expense_type" and item.normalized_value == "meal"
for item in result.entities
)
def test_semantic_ontology_service_uses_client_local_date_for_relative_time() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我昨天请客户吃饭花了200元",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-12T16:30:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.time_range.raw == "昨天"
assert result.time_range.start_date == "2026-05-12"
assert result.time_range.end_date == "2026-05-12"
def test_semantic_ontology_service_extracts_day_before_yesterday_from_client_local_date() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我前天请客户吃饭花了200元",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-12T16:30:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.time_range.raw == "前天"
assert result.time_range.start_date == "2026-05-11"
assert result.time_range.end_date == "2026-05-11"
def test_semantic_ontology_service_treats_status_document_text_as_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="查询草稿的单据",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.permission.level == "read"
assert any(
item.field == "status" and item.value == "draft"
for item in result.constraints
)
def test_semantic_ontology_service_extracts_history_query_time_and_location() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我去年去北京报销的单据",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-21T04:00:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.raw == "去年"
assert result.time_range.start_date == "2025-01-01"
assert result.time_range.end_date == "2025-12-31"
assert any(
item.type == "location" and item.normalized_value == "北京"
for item in result.entities
)
def test_semantic_ontology_service_understands_last_week_claim_progress_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上周提交的单据报销了么?",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-21T04:00:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.raw == "上周"
assert result.time_range.start_date == "2026-05-11"
assert result.time_range.end_date == "2026-05-17"
def test_semantic_ontology_service_maps_office_supplies_to_office_expense_type() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我买了办公用品和文具花了88元帮我报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "office"
for item in result.entities
)
def test_semantic_ontology_service_maps_riding_fare_to_transport_expense_type() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="业务发生时间:2026-03-04送客户去林萃小区办事请报销乘车费用",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "transport"
for item in result.entities
)
def test_semantic_ontology_service_maps_taxi_ticket_reimbursement_to_transport_draft() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="送客户去机场,报销的士票",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "transport"
for item in result.entities
)
assert not any(
item.type == "expense_type" and item.normalized_value == "entertainment"
for item in result.entities
)
@pytest.mark.parametrize(
"query,expected_type",
[
("报销飞机票和行程单", "travel"),
("报销酒店发票和房费", "hotel"),
("报销滴滴打车票", "transport"),
("报销工作餐餐费", "meal"),
("报销会议场地费", "meeting"),
("报销客户接待餐", "meal"),
("报销打印纸和硒鼓", "office"),
("报销培训课程费", "training"),
("报销手机话费和流量费", "communication"),
("报销员工体检费", "welfare"),
],
)
def test_semantic_ontology_service_covers_common_expense_scene_keywords(
query: str,
expected_type: str,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(query=query, user_id="pytest")
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == expected_type
for item in result.entities
)
def test_semantic_ontology_service_connects_expense_application_to_ontology() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="申请2026-06-01 ~ 2026-06-03去北京做客户现场验收差旅预算18000元",
user_id="pytest",
context_json={
"document_type": "expense_application",
"application_stage": "pre_approval",
"entry_source": "documents_application",
},
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "document_type" and item.normalized_value == "expense_application"
for item in result.entities
)
assert any(
item.type == "workflow_stage" and item.normalized_value == "pre_approval"
for item in result.entities
)
assert any(
item.field == "document_type" and item.value == "expense_application"
for item in result.constraints
)
assert any(
item.type == "expense_type" and item.normalized_value == "travel"
for item in result.entities
)
def test_semantic_ontology_service_requires_attachment_for_meeting_application() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="发起会务申请2026-06-01 ~ 2026-06-02上海产品发布会预算32000元",
user_id="pytest",
context_json={
"document_type": "expense_application",
"application_stage": "pre_approval",
"entry_source": "documents_application",
"attachment_count": 0,
},
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "meeting"
for item in result.entities
)
assert "attachments" in result.missing_slots
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充费用类型、金额和票据附件。",
missing_slots=["expense_type", "amount", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.parse_strategy == "llm_primary"
assert result.clarification_required is True
assert "expense_type" in result.missing_slots
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "查一下本周报销超标风险",
"user_id": "pytest",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "expense"
assert payload["intent"] == "risk_check"
assert payload["permission"]["level"] == "read"
assert payload["run_id"].startswith("run_")
assert set(payload) >= {
"scenario",
"intent",
"entities",
"time_range",
"metrics",
"constraints",
"risk_flags",
"permission",
"confidence",
"missing_slots",
"ambiguity",
"parse_strategy",
"clarification_required",
"clarification_question",
"run_id",
"field_errors",
}
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
assert run_response.status_code == 200
run_payload = run_response.json()
assert run_payload["ontology_json"]["scenario"] == "expense"
assert run_payload["ontology_json"]["intent"] == "risk_check"
assert run_payload["semantic_parse"]["scenario"] == "expense"
assert run_payload["semantic_parse"]["intent"] == "risk_check"
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "帮我直接付款给供应商B",
"user_id": "pytest",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "accounts_payable"
assert payload["intent"] == "operate"
assert payload["permission"]["level"] == "forbidden"
assert payload["clarification_required"] is True
assert payload["field_errors"]