Files
X-Financial/server/tests/test_ontology_service.py
caoxiaozhu 22d47cbf2b feat(backend): add ontology and orchestrator API endpoints
New endpoints:
- server/src/app/api/v1/endpoints/ontology.py: ontology API
- server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API

New schemas:
- server/src/app/schemas/ontology.py: ontology data schemas
- server/src/app/schemas/orchestrator.py: orchestrator data schemas
- server/src/app/schemas/user_agent.py: user agent data schemas

New services:
- server/src/app/services/ontology.py: ontology business logic
- server/src/app/services/orchestrator.py: orchestrator business logic
- server/src/app/services/runtime_chat.py: runtime chat service
- server/src/app/services/user_agent.py: user agent service

New tests:
- server/tests/test_ontology_service.py
- server/tests/test_orchestrator_service.py
- server/tests/test_user_agent_service.py
2026-05-12 01:24:39 +00:00

398 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from collections.abc import Generator
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.api.deps import get_db
from app.db.base import Base
from app.main import create_app
from app.schemas.ontology import OntologyParseRequest
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
session_factory = build_session_factory()
app = create_app()
def override_db() -> Generator[Session, None, None]:
db = session_factory()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_db
return TestClient(app), session_factory
EVALUATION_CASES = [
pytest.param(
"查一下本周报销超标风险",
"expense",
"risk_check",
"read",
{},
id="expense-risk-check",
),
pytest.param(
"张三 4 月差旅报销金额是多少",
"expense",
"query",
"read",
{},
id="expense-query-employee-month",
),
pytest.param(
"为什么酒店超标报销不能直接通过",
"expense",
"explain",
"read",
{},
id="expense-explain-policy",
),
pytest.param(
"列出金额最高的10笔报销",
"expense",
"query",
"read",
{},
id="expense-topn-query",
),
pytest.param(
"帮我生成张三4月差旅报销草稿",
"expense",
"draft",
"draft_write",
{},
id="expense-draft",
),
pytest.param(
"我今天去客户现场招待了客户花销了1000元",
"expense",
"draft",
"draft_write",
{},
id="expense-narrative-draft",
),
pytest.param(
"客户 A 这个月还有多少应收",
"accounts_receivable",
"query",
"read",
{},
id="ar-query-customer-month",
),
pytest.param(
"对比客户A和客户B本月应收差异",
"accounts_receivable",
"compare",
"read",
{},
id="ar-compare-customers",
),
pytest.param(
"检查客户B逾期应收风险",
"accounts_receivable",
"risk_check",
"read",
{},
id="ar-risk-check",
),
pytest.param(
"生成客户A回款跟进草稿",
"accounts_receivable",
"draft",
"draft_write",
{},
id="ar-draft",
),
pytest.param(
"查询客户B账龄明细",
"accounts_receivable",
"query",
"read",
{},
id="ar-aging-query",
),
pytest.param(
"供应商 B 明天要付多少钱",
"accounts_payable",
"query",
"read",
{},
id="ap-query-vendor-tomorrow",
),
pytest.param(
"对比供应商A和供应商B本月应付差异",
"accounts_payable",
"compare",
"read",
{},
id="ap-compare-vendors",
),
pytest.param(
"检查供应商B逾期付款风险",
"accounts_payable",
"risk_check",
"read",
{},
id="ap-risk-check",
),
pytest.param(
"生成供应商A付款沟通草稿",
"accounts_payable",
"draft",
"draft_write",
{},
id="ap-draft",
),
pytest.param(
"帮我安排付款给供应商B",
"accounts_payable",
"operate",
"approval_required",
{"role_codes": ["finance"]},
id="ap-operate-approval-required",
),
pytest.param(
"公司财务制度在哪里看",
"knowledge",
"query",
"read",
{},
id="knowledge-query",
),
pytest.param(
"规则中心的审核依据是什么",
"knowledge",
"explain",
"read",
{},
id="knowledge-explain",
),
pytest.param(
"知识库里有没有双人复核制度",
"knowledge",
"query",
"read",
{},
id="knowledge-query-library",
),
pytest.param(
"帮我直接付款给供应商B",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-direct-payment",
),
pytest.param(
"帮我上线付款双人复核规则",
"accounts_payable",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-activate-rule",
),
pytest.param(
"帮我删除今天的报销记录",
"expense",
"operate",
"forbidden",
{"role_codes": ["user"]},
id="forbidden-delete-expense",
),
]
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
def test_semantic_ontology_service_matches_day3_evaluation_set(
query: str,
scenario: str,
intent: str,
permission: str,
context_json: dict,
) -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query=query,
user_id="pytest",
context_json=context_json,
)
)
assert result.scenario == scenario
assert result.intent == intent
assert result.permission.level == permission
assert result.run_id.startswith("run_")
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="张三 2026年4月差旅报销金额超过5000元的明细",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "query"
assert result.time_range.start_date == "2026-04-01"
assert result.time_range.end_date == "2026-04-30"
assert any(
item.type == "employee" and item.normalized_value == "张三"
for item in result.entities
)
assert any(
item.type == "expense_type" and item.normalized_value == "travel"
for item in result.entities
)
assert any(
item.field == "amount" and item.operator == ">" and item.value == 5000
for item in result.constraints
)
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
session_factory = build_session_factory()
with session_factory() as db:
result = SemanticOntologyService(db).parse(
OntologyParseRequest(
query="我今天去客户现场招待了客户花销了1000元",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.permission.level == "draft_write"
assert result.time_range.raw == "今天"
assert result.clarification_required is True
assert "customer_name" in result.missing_slots
assert "participants" in result.missing_slots
assert any(
item.type == "expense_type" and item.normalized_value == "entertainment"
for item in result.entities
)
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
session_factory = build_session_factory()
with session_factory() as db:
service = SemanticOntologyService(db)
monkeypatch.setattr(
service,
"_parse_with_model",
lambda **kwargs: LlmOntologyParseResult(
scenario="expense",
intent="draft",
confidence=0.91,
clarification_required=True,
clarification_question="请补充费用类型、金额和票据附件。",
missing_slots=["expense_type", "amount", "attachments"],
ambiguity=[],
entity_hints=[],
),
)
result = service.parse(
OntologyParseRequest(
query="我要报销",
user_id="pytest",
)
)
assert result.scenario == "expense"
assert result.intent == "draft"
assert result.parse_strategy == "llm_primary"
assert result.clarification_required is True
assert "expense_type" in result.missing_slots
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "查一下本周报销超标风险",
"user_id": "pytest",
"context_json": {"role_codes": ["finance"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "expense"
assert payload["intent"] == "risk_check"
assert payload["permission"]["level"] == "read"
assert payload["run_id"].startswith("run_")
assert set(payload) >= {
"scenario",
"intent",
"entities",
"time_range",
"metrics",
"constraints",
"risk_flags",
"permission",
"confidence",
"missing_slots",
"ambiguity",
"parse_strategy",
"clarification_required",
"clarification_question",
"run_id",
"field_errors",
}
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
assert run_response.status_code == 200
run_payload = run_response.json()
assert run_payload["ontology_json"]["scenario"] == "expense"
assert run_payload["ontology_json"]["intent"] == "risk_check"
assert run_payload["semantic_parse"]["scenario"] == "expense"
assert run_payload["semantic_parse"]["intent"] == "risk_check"
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
client, _ = build_client()
response = client.post(
"/api/v1/ontology/parse",
json={
"query": "帮我直接付款给供应商B",
"user_id": "pytest",
"context_json": {"role_codes": ["user"]},
},
)
assert response.status_code == 200
payload = response.json()
assert payload["scenario"] == "accounts_payable"
assert payload["intent"] == "operate"
assert payload["permission"]["level"] == "forbidden"
assert payload["clarification_required"] is True
assert payload["field_errors"]