from __future__ import annotations from collections.abc import Generator import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.db.base import Base from app.main import create_app from app.schemas.ontology import OntologyParseRequest from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService def build_session_factory() -> sessionmaker[Session]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) return sessionmaker(bind=engine, autoflush=False, autocommit=False) def build_client() -> tuple[TestClient, sessionmaker[Session]]: session_factory = build_session_factory() app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory EVALUATION_CASES = [ pytest.param( "查一下本周报销超标风险", "expense", "risk_check", "read", {}, id="expense-risk-check", ), pytest.param( "张三 4 月差旅报销金额是多少", "expense", "query", "read", {}, id="expense-query-employee-month", ), pytest.param( "为什么酒店超标报销不能直接通过", "expense", "explain", "read", {}, id="expense-explain-policy", ), pytest.param( "列出金额最高的10笔报销", "expense", "query", "read", {}, id="expense-topn-query", ), pytest.param( "帮我生成张三4月差旅报销草稿", "expense", "draft", "draft_write", {}, id="expense-draft", ), pytest.param( "我今天去客户现场,招待了客户,花销了1000元", "expense", "draft", "draft_write", {}, id="expense-narrative-draft", ), pytest.param( "客户 A 这个月还有多少应收", "accounts_receivable", "query", "read", {}, id="ar-query-customer-month", ), pytest.param( "对比客户A和客户B本月应收差异", "accounts_receivable", "compare", "read", {}, id="ar-compare-customers", ), pytest.param( "检查客户B逾期应收风险", "accounts_receivable", "risk_check", "read", {}, id="ar-risk-check", ), pytest.param( "生成客户A回款跟进草稿", "accounts_receivable", "draft", "draft_write", {}, id="ar-draft", ), pytest.param( "查询客户B账龄明细", "accounts_receivable", "query", "read", {}, id="ar-aging-query", ), pytest.param( "供应商 B 明天要付多少钱", "accounts_payable", "query", "read", {}, id="ap-query-vendor-tomorrow", ), pytest.param( "对比供应商A和供应商B本月应付差异", "accounts_payable", "compare", "read", {}, id="ap-compare-vendors", ), pytest.param( "检查供应商B逾期付款风险", "accounts_payable", "risk_check", "read", {}, id="ap-risk-check", ), pytest.param( "生成供应商A付款沟通草稿", "accounts_payable", "draft", "draft_write", {}, id="ap-draft", ), pytest.param( "帮我安排付款给供应商B", "accounts_payable", "operate", "approval_required", {"role_codes": ["finance"]}, id="ap-operate-approval-required", ), pytest.param( "公司财务制度在哪里看", "knowledge", "query", "read", {}, id="knowledge-query", ), pytest.param( "规则中心的审核依据是什么", "knowledge", "explain", "read", {}, id="knowledge-explain", ), pytest.param( "知识库里有没有双人复核制度", "knowledge", "query", "read", {}, id="knowledge-query-library", ), pytest.param( "帮我直接付款给供应商B", "accounts_payable", "operate", "forbidden", {"role_codes": ["user"]}, id="forbidden-direct-payment", ), pytest.param( "帮我上线付款双人复核规则", "accounts_payable", "operate", "forbidden", {"role_codes": ["user"]}, id="forbidden-activate-rule", ), pytest.param( "帮我删除今天的报销记录", "expense", "operate", "forbidden", {"role_codes": ["user"]}, id="forbidden-delete-expense", ), ] @pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES) def test_semantic_ontology_service_matches_day3_evaluation_set( query: str, scenario: str, intent: str, permission: str, context_json: dict, ) -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query=query, user_id="pytest", context_json=context_json, ) ) assert result.scenario == scenario assert result.intent == intent assert result.permission.level == permission assert result.run_id.startswith("run_") def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="张三 2026年4月差旅报销金额超过5000元的明细", user_id="pytest", ) ) assert result.scenario == "expense" assert result.intent == "query" assert result.time_range.start_date == "2026-04-01" assert result.time_range.end_date == "2026-04-30" def test_semantic_ontology_service_treats_travel_amount_question_as_knowledge_query() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我要去武汉出差3天,请问我一共可以报销多少费用?", user_id="pytest", context_json={ "role_codes": ["employee"], "name": "曹笑竹", "grade": "P3", "session_type": "knowledge", }, ) ) assert result.scenario == "knowledge" assert result.intent == "query" assert result.clarification_required is False assert result.clarification_question is None assert result.missing_slots == [] def test_semantic_ontology_service_keeps_travel_amount_follow_up_in_knowledge_query() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="那P4员工可以报销多少钱?", user_id="pytest", context_json={ "role_codes": ["employee"], "name": "曹笑竹", "grade": "P3", "session_type": "knowledge", "conversation_history": [ { "role": "user", "content": "我要去武汉出差3天,请问我一共可以报销多少费用?", } ], }, ) ) assert result.scenario == "knowledge" assert result.intent == "query" assert result.clarification_required is False def test_semantic_ontology_service_rejects_draft_intent_inside_knowledge_session( monkeypatch, ) -> None: session_factory = build_session_factory() with session_factory() as db: service = SemanticOntologyService(db) monkeypatch.setattr( service, "_parse_with_model", lambda **kwargs: LlmOntologyParseResult( scenario="expense", intent="draft", confidence=0.91, clarification_required=True, clarification_question="请补充招待对象和票据附件。", missing_slots=["participants", "attachments"], ambiguity=[], entity_hints=[], ), ) result = service.parse( OntologyParseRequest( query="我要去北京出差3天,一共可以报销多少钱?", user_id="pytest", context_json={ "role_codes": ["employee"], "name": "曹笑竹", "grade": "P3", "session_type": "knowledge", }, ) ) assert result.scenario == "knowledge" assert result.intent == "query" assert result.clarification_required is False assert result.clarification_question is None def test_review_next_step_context_inherits_expense_draft_flow() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我已核对右侧识别结果,请进入下一步。", user_id="pytest", context_json={ "review_action": "next_step", "draft_claim_id": "claim-1", "attachment_count": 1, }, ) ) assert result.scenario == "expense" assert result.intent == "draft" assert result.permission.level == "draft_write" assert result.clarification_required is False assert result.clarification_question is None def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我今天去客户现场,招待了客户,花销了1000元", user_id="pytest", ) ) assert result.scenario == "expense" assert result.intent == "draft" assert result.permission.level == "draft_write" assert result.time_range.raw == "今天" assert result.clarification_required is True assert "customer_name" in result.missing_slots assert "participants" in result.missing_slots assert any( item.type == "expense_type" and item.normalized_value == "entertainment" for item in result.entities ) def test_semantic_ontology_service_uses_client_local_date_for_relative_time() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我昨天请客户吃饭花了200元", user_id="pytest", context_json={ "client_now_iso": "2026-05-12T16:30:00.000Z", "client_timezone_offset_minutes": -480, }, ) ) assert result.time_range.raw == "昨天" assert result.time_range.start_date == "2026-05-12" assert result.time_range.end_date == "2026-05-12" def test_semantic_ontology_service_extracts_day_before_yesterday_from_client_local_date() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我前天请客户吃饭花了200元", user_id="pytest", context_json={ "client_now_iso": "2026-05-12T16:30:00.000Z", "client_timezone_offset_minutes": -480, }, ) ) assert result.time_range.raw == "前天" assert result.time_range.start_date == "2026-05-11" assert result.time_range.end_date == "2026-05-11" def test_semantic_ontology_service_maps_office_supplies_to_office_expense_type() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="我买了办公用品和文具,花了88元,帮我报销", user_id="pytest", ) ) assert result.scenario == "expense" assert result.intent == "draft" assert any( item.type == "expense_type" and item.normalized_value == "office" for item in result.entities ) def test_semantic_ontology_service_maps_riding_fare_to_transport_expense_type() -> None: session_factory = build_session_factory() with session_factory() as db: result = SemanticOntologyService(db).parse( OntologyParseRequest( query="业务发生时间:2026-03-04,送客户去林萃小区办事,请报销乘车费用", user_id="pytest", ) ) assert result.scenario == "expense" assert result.intent == "draft" assert any( item.type == "expense_type" and item.normalized_value == "transport" for item in result.entities ) def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None: session_factory = build_session_factory() with session_factory() as db: service = SemanticOntologyService(db) monkeypatch.setattr( service, "_parse_with_model", lambda **kwargs: LlmOntologyParseResult( scenario="expense", intent="draft", confidence=0.91, clarification_required=True, clarification_question="请补充费用类型、金额和票据附件。", missing_slots=["expense_type", "amount", "attachments"], ambiguity=[], entity_hints=[], ), ) result = service.parse( OntologyParseRequest( query="我要报销", user_id="pytest", ) ) assert result.scenario == "expense" assert result.intent == "draft" assert result.parse_strategy == "llm_primary" assert result.clarification_required is True assert "expense_type" in result.missing_slots assert result.clarification_question == "请补充费用类型、金额和票据附件。" def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None: client, _ = build_client() response = client.post( "/api/v1/ontology/parse", json={ "query": "查一下本周报销超标风险", "user_id": "pytest", "context_json": {"role_codes": ["finance"]}, }, ) assert response.status_code == 200 payload = response.json() assert payload["scenario"] == "expense" assert payload["intent"] == "risk_check" assert payload["permission"]["level"] == "read" assert payload["run_id"].startswith("run_") assert set(payload) >= { "scenario", "intent", "entities", "time_range", "metrics", "constraints", "risk_flags", "permission", "confidence", "missing_slots", "ambiguity", "parse_strategy", "clarification_required", "clarification_question", "run_id", "field_errors", } run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}") assert run_response.status_code == 200 run_payload = run_response.json() assert run_payload["ontology_json"]["scenario"] == "expense" assert run_payload["ontology_json"]["intent"] == "risk_check" assert run_payload["semantic_parse"]["scenario"] == "expense" assert run_payload["semantic_parse"]["intent"] == "risk_check" def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None: client, _ = build_client() response = client.post( "/api/v1/ontology/parse", json={ "query": "帮我直接付款给供应商B", "user_id": "pytest", "context_json": {"role_codes": ["user"]}, }, ) assert response.status_code == 200 payload = response.json() assert payload["scenario"] == "accounts_payable" assert payload["intent"] == "operate" assert payload["permission"]["level"] == "forbidden" assert payload["clarification_required"] is True assert payload["field_errors"]