Files
X-Financial/server/tests/test_ontology_service.py
caoxiaozhu 88ff04bef8 feat: 新增归档中心页面并完善知识库与报销查询能力
新增前端归档中心视图及相关工具函数,扩充知识库文档分类和
提取器支持多种格式,增强编排器报销查询的多维度检索,优
化本体规则和用户代理审核消息,前端完善报销创建和审批详
情交互细节,补充单元测试覆盖。
2026-05-22 16:00:19 +08:00

659 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.ontology import OntologyParseRequest
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
EVALUATION_CASES = [
pytest.param(
"查一下本周报销超标风险",
"expense",
"risk_check",
"read",
{},
id="expense-risk-check",
),
pytest.param(
"张三 4 月差旅报销金额是多少",
"expense",
"query",
"read",
{},
id="expense-query-employee-month",
),
pytest.param(
"为什么酒店超标报销不能直接通过",
"expense",
"explain",
"read",
{},
id="expense-explain-policy",
),
pytest.param(
"列出金额最高的10笔报销",
"expense",
"query",
"read",
{},
id="expense-topn-query",
),
pytest.param(
"帮我生成张三4月差旅报销草稿",
"expense",
"draft",
"draft_write",
{},
id="expense-draft",
),
pytest.param(
"我今天去客户现场招待了客户花销了1000元",
"expense",
"draft",
"draft_write",
{},
id="expense-narrative-draft",
),
pytest.param(
"客户 A 这个月还有多少应收",
"accounts_receivable",
"query",
"read",
{},
id="ar-query-customer-month",
),
pytest.param(
"对比客户A和客户B本月应收差异",
"accounts_receivable",
"compare",
"read",
{},
id="ar-compare-customers",
),
pytest.param(
"检查客户B逾期应收风险",
"accounts_receivable",
"risk_check",
"read",
{},
id="ar-risk-check",
),
pytest.param(
"生成客户A回款跟进草稿",
"accounts_receivable",
"draft",
"draft_write",
{},
id="ar-draft",
),
pytest.param(
"查询客户B账龄明细",
"accounts_receivable",
"query",
"read",
{},
id="ar-aging-query",
),
pytest.param(
"供应商 B 明天要付多少钱",
"accounts_payable",
"query",
"read",
{},
id="ap-query-vendor-tomorrow",
),
pytest.param(
"对比供应商A和供应商B本月应付差异",
"accounts_payable",
"compare",
"read",
{},
id="ap-compare-vendors",
),
pytest.param(
"检查供应商B逾期付款风险",
"accounts_payable",
"risk_check",
"read",
{},
id="ap-risk-check",
),
pytest.param(
"生成供应商A付款沟通草稿",
"accounts_payable",
"draft",
"draft_write",
{},
id="ap-draft",
),
pytest.param(
"帮我安排付款给供应商B",
"accounts_payable",
"operate",
"approval_required",
{"role_codes": ["finance"]},
id="ap-operate-approval-required",
),
pytest.param(
"公司财务制度在哪里看",
"knowledge",
"query",
"read",
{},
id="knowledge-query",
),
pytest.param(
"规则中心的审核依据是什么",
"knowledge",
"explain",
"read",
{},
id="knowledge-explain",
),
pytest.param(
"知识库里有没有双人复核制度",
"knowledge",
"query",
"read",
{},
id="knowledge-query-library",
),
pytest.param(
"帮我直接付款给供应商B",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-direct-payment",
),
pytest.param(
"帮我上线付款双人复核规则",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-activate-rule",
),
pytest.param(
"帮我删除今天的报销记录",
"expense",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-delete-expense",
),
]
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
def test_semantic_ontology_service_matches_day3_evaluation_set(
query: str,
scenario: str,
intent: str,
permission: str,
context_json: dict,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query=query,
user_id="pytest",
context_json=context_json,
)
)
assert result.scenario == scenario
assert result.intent == intent
assert result.permission.level == permission
assert result.run_id.startswith("run_")
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 2026年4月差旅报销金额超过5000元的明细",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.start_date == "2026-04-01"
assert result.time_range.end_date == "2026-04-30"
def test_semantic_ontology_service_treats_travel_amount_question_as_knowledge_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我要去武汉出差3天请问我一共可以报销多少费用",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
assert result.clarification_question is None
assert result.missing_slots == []
def test_semantic_ontology_service_keeps_travel_amount_follow_up_in_knowledge_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="那P4员工可以报销多少钱",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
"conversation_history": [
{
"role": "user",
"content": "我要去武汉出差3天请问我一共可以报销多少费用",
}
],
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
def test_semantic_ontology_service_rejects_draft_intent_inside_knowledge_session(
monkeypatch,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充招待对象和票据附件。",
missing_slots=["participants", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要去北京出差3天一共可以报销多少钱",
user_id="pytest",
context_json={
"role_codes": ["employee"],
"name": "曹笑竹",
"grade": "P3",
"session_type": "knowledge",
},
)
)
assert result.scenario == "knowledge"
assert result.intent == "query"
assert result.clarification_required is False
assert result.clarification_question is None
def test_review_next_step_context_inherits_expense_draft_flow() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我已核对右侧识别结果,请进入下一步。",
user_id="pytest",
context_json={
"review_action": "next_step",
"draft_claim_id": "claim-1",
"attachment_count": 1,
},
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.clarification_required is False
assert result.clarification_question is None
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我今天去客户现场招待了客户花销了1000元",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.time_range.raw == "今天"
assert result.clarification_required is True
assert "customer_name" in result.missing_slots
assert "participants" in result.missing_slots
assert any(
item.type == "expense_type" and item.normalized_value == "entertainment"
for item in result.entities
)
def test_semantic_ontology_service_uses_client_local_date_for_relative_time() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我昨天请客户吃饭花了200元",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-12T16:30:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.time_range.raw == "昨天"
assert result.time_range.start_date == "2026-05-12"
assert result.time_range.end_date == "2026-05-12"
def test_semantic_ontology_service_extracts_day_before_yesterday_from_client_local_date() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我前天请客户吃饭花了200元",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-12T16:30:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.time_range.raw == "前天"
assert result.time_range.start_date == "2026-05-11"
assert result.time_range.end_date == "2026-05-11"
def test_semantic_ontology_service_treats_status_document_text_as_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="查询草稿的单据",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.permission.level == "read"
assert any(
item.field == "status" and item.value == "draft"
for item in result.constraints
)
def test_semantic_ontology_service_extracts_history_query_time_and_location() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我去年去北京报销的单据",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-21T04:00:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.raw == "去年"
assert result.time_range.start_date == "2025-01-01"
assert result.time_range.end_date == "2025-12-31"
assert any(
item.type == "location" and item.normalized_value == "北京"
for item in result.entities
)
def test_semantic_ontology_service_understands_last_week_claim_progress_query() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我上周提交的单据报销了么?",
user_id="pytest",
context_json={
"client_now_iso": "2026-05-21T04:00:00.000Z",
"client_timezone_offset_minutes": -480,
},
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.raw == "上周"
assert result.time_range.start_date == "2026-05-11"
assert result.time_range.end_date == "2026-05-17"
def test_semantic_ontology_service_maps_office_supplies_to_office_expense_type() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我买了办公用品和文具花了88元帮我报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "office"
for item in result.entities
)
def test_semantic_ontology_service_maps_riding_fare_to_transport_expense_type() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="业务发生时间:2026-03-04送客户去林萃小区办事请报销乘车费用",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "transport"
for item in result.entities
)
def test_semantic_ontology_service_maps_taxi_ticket_reimbursement_to_transport_draft() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="送客户去机场,报销的士票",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert any(
item.type == "expense_type" and item.normalized_value == "transport"
for item in result.entities
)
assert not any(
item.type == "expense_type" and item.normalized_value == "entertainment"
for item in result.entities
)
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充费用类型、金额和票据附件。",
missing_slots=["expense_type", "amount", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.parse_strategy == "llm_primary"
assert result.clarification_required is True
assert "expense_type" in result.missing_slots
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "查一下本周报销超标风险",
"user_id": "pytest",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "expense"
assert payload["intent"] == "risk_check"
assert payload["permission"]["level"] == "read"
assert payload["run_id"].startswith("run_")
assert set(payload) >= {
"scenario",
"intent",
"entities",
"time_range",
"metrics",
"constraints",
"risk_flags",
"permission",
"confidence",
"missing_slots",
"ambiguity",
"parse_strategy",
"clarification_required",
"clarification_question",
"run_id",
"field_errors",
}
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
assert run_response.status_code == 200
run_payload = run_response.json()
assert run_payload["ontology_json"]["scenario"] == "expense"
assert run_payload["ontology_json"]["intent"] == "risk_check"
assert run_payload["semantic_parse"]["scenario"] == "expense"
assert run_payload["semantic_parse"]["intent"] == "risk_check"
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "帮我直接付款给供应商B",
"user_id": "pytest",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "accounts_payable"
assert payload["intent"] == "operate"
assert payload["permission"]["level"] == "forbidden"
assert payload["clarification_required"] is True
assert payload["field_errors"]