feat(backend): add ontology and orchestrator API endpoints
New endpoints: - server/src/app/api/v1/endpoints/ontology.py: ontology API - server/src/app/api/v1/endpoints/orchestrator.py: orchestrator API New schemas: - server/src/app/schemas/ontology.py: ontology data schemas - server/src/app/schemas/orchestrator.py: orchestrator data schemas - server/src/app/schemas/user_agent.py: user agent data schemas New services: - server/src/app/services/ontology.py: ontology business logic - server/src/app/services/orchestrator.py: orchestrator business logic - server/src/app/services/runtime_chat.py: runtime chat service - server/src/app/services/user_agent.py: user agent service New tests: - server/tests/test_ontology_service.py - server/tests/test_orchestrator_service.py - server/tests/test_user_agent_service.py
This commit is contained in:
397
server/tests/test_ontology_service.py
Normal file
397
server/tests/test_ontology_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.schemas.ontology import OntologyParseRequest
|
||||
from app.services.ontology import LlmOntologyParseResult, SemanticOntologyService
|
||||
|
||||
|
||||
def build_session_factory() -> sessionmaker[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||
session_factory = build_session_factory()
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
EVALUATION_CASES = [
|
||||
pytest.param(
|
||||
"查一下本周报销超标风险",
|
||||
"expense",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="expense-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"张三 4 月差旅报销金额是多少",
|
||||
"expense",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="expense-query-employee-month",
|
||||
),
|
||||
pytest.param(
|
||||
"为什么酒店超标报销不能直接通过",
|
||||
"expense",
|
||||
"explain",
|
||||
"read",
|
||||
{},
|
||||
id="expense-explain-policy",
|
||||
),
|
||||
pytest.param(
|
||||
"列出金额最高的10笔报销",
|
||||
"expense",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="expense-topn-query",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我生成张三4月差旅报销草稿",
|
||||
"expense",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="expense-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"我今天去客户现场,招待了客户,花销了1000元",
|
||||
"expense",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="expense-narrative-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"客户 A 这个月还有多少应收",
|
||||
"accounts_receivable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ar-query-customer-month",
|
||||
),
|
||||
pytest.param(
|
||||
"对比客户A和客户B本月应收差异",
|
||||
"accounts_receivable",
|
||||
"compare",
|
||||
"read",
|
||||
{},
|
||||
id="ar-compare-customers",
|
||||
),
|
||||
pytest.param(
|
||||
"检查客户B逾期应收风险",
|
||||
"accounts_receivable",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="ar-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"生成客户A回款跟进草稿",
|
||||
"accounts_receivable",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="ar-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"查询客户B账龄明细",
|
||||
"accounts_receivable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ar-aging-query",
|
||||
),
|
||||
pytest.param(
|
||||
"供应商 B 明天要付多少钱",
|
||||
"accounts_payable",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="ap-query-vendor-tomorrow",
|
||||
),
|
||||
pytest.param(
|
||||
"对比供应商A和供应商B本月应付差异",
|
||||
"accounts_payable",
|
||||
"compare",
|
||||
"read",
|
||||
{},
|
||||
id="ap-compare-vendors",
|
||||
),
|
||||
pytest.param(
|
||||
"检查供应商B逾期付款风险",
|
||||
"accounts_payable",
|
||||
"risk_check",
|
||||
"read",
|
||||
{},
|
||||
id="ap-risk-check",
|
||||
),
|
||||
pytest.param(
|
||||
"生成供应商A付款沟通草稿",
|
||||
"accounts_payable",
|
||||
"draft",
|
||||
"draft_write",
|
||||
{},
|
||||
id="ap-draft",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我安排付款给供应商B",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"approval_required",
|
||||
{"role_codes": ["finance"]},
|
||||
id="ap-operate-approval-required",
|
||||
),
|
||||
pytest.param(
|
||||
"公司财务制度在哪里看",
|
||||
"knowledge",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-query",
|
||||
),
|
||||
pytest.param(
|
||||
"规则中心的审核依据是什么",
|
||||
"knowledge",
|
||||
"explain",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-explain",
|
||||
),
|
||||
pytest.param(
|
||||
"知识库里有没有双人复核制度",
|
||||
"knowledge",
|
||||
"query",
|
||||
"read",
|
||||
{},
|
||||
id="knowledge-query-library",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我直接付款给供应商B",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-direct-payment",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我上线付款双人复核规则",
|
||||
"accounts_payable",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-activate-rule",
|
||||
),
|
||||
pytest.param(
|
||||
"帮我删除今天的报销记录",
|
||||
"expense",
|
||||
"operate",
|
||||
"forbidden",
|
||||
{"role_codes": ["user"]},
|
||||
id="forbidden-delete-expense",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("query,scenario,intent,permission,context_json", EVALUATION_CASES)
|
||||
def test_semantic_ontology_service_matches_day3_evaluation_set(
|
||||
query: str,
|
||||
scenario: str,
|
||||
intent: str,
|
||||
permission: str,
|
||||
context_json: dict,
|
||||
) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query=query,
|
||||
user_id="pytest",
|
||||
context_json=context_json,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == scenario
|
||||
assert result.intent == intent
|
||||
assert result.permission.level == permission
|
||||
assert result.run_id.startswith("run_")
|
||||
|
||||
|
||||
def test_semantic_ontology_service_extracts_entities_time_and_constraints() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 2026年4月差旅报销金额超过5000元的明细",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "query"
|
||||
assert result.time_range.start_date == "2026-04-01"
|
||||
assert result.time_range.end_date == "2026-04-30"
|
||||
assert any(
|
||||
item.type == "employee" and item.normalized_value == "张三"
|
||||
for item in result.entities
|
||||
)
|
||||
assert any(
|
||||
item.type == "expense_type" and item.normalized_value == "travel"
|
||||
for item in result.entities
|
||||
)
|
||||
assert any(
|
||||
item.field == "amount" and item.operator == ">" and item.value == 5000
|
||||
for item in result.constraints
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_ontology_service_prefers_expense_for_customer_entertainment_narrative() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
result = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "draft"
|
||||
assert result.permission.level == "draft_write"
|
||||
assert result.time_range.raw == "今天"
|
||||
assert result.clarification_required is True
|
||||
assert "customer_name" in result.missing_slots
|
||||
assert "participants" in result.missing_slots
|
||||
assert any(
|
||||
item.type == "expense_type" and item.normalized_value == "entertainment"
|
||||
for item in result.entities
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_ontology_service_uses_model_parse_when_available(monkeypatch) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
service = SemanticOntologyService(db)
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_parse_with_model",
|
||||
lambda **kwargs: LlmOntologyParseResult(
|
||||
scenario="expense",
|
||||
intent="draft",
|
||||
confidence=0.91,
|
||||
clarification_required=True,
|
||||
clarification_question="请补充费用类型、金额和票据附件。",
|
||||
missing_slots=["expense_type", "amount", "attachments"],
|
||||
ambiguity=[],
|
||||
entity_hints=[],
|
||||
),
|
||||
)
|
||||
|
||||
result = service.parse(
|
||||
OntologyParseRequest(
|
||||
query="我要报销",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.scenario == "expense"
|
||||
assert result.intent == "draft"
|
||||
assert result.parse_strategy == "llm_primary"
|
||||
assert result.clarification_required is True
|
||||
assert "expense_type" in result.missing_slots
|
||||
assert result.clarification_question == "请补充费用类型、金额和票据附件。"
|
||||
|
||||
|
||||
def test_parse_ontology_endpoint_returns_eight_fields_and_writes_trace() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ontology/parse",
|
||||
json={
|
||||
"query": "查一下本周报销超标风险",
|
||||
"user_id": "pytest",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["scenario"] == "expense"
|
||||
assert payload["intent"] == "risk_check"
|
||||
assert payload["permission"]["level"] == "read"
|
||||
assert payload["run_id"].startswith("run_")
|
||||
assert set(payload) >= {
|
||||
"scenario",
|
||||
"intent",
|
||||
"entities",
|
||||
"time_range",
|
||||
"metrics",
|
||||
"constraints",
|
||||
"risk_flags",
|
||||
"permission",
|
||||
"confidence",
|
||||
"missing_slots",
|
||||
"ambiguity",
|
||||
"parse_strategy",
|
||||
"clarification_required",
|
||||
"clarification_question",
|
||||
"run_id",
|
||||
"field_errors",
|
||||
}
|
||||
|
||||
run_response = client.get(f"/api/v1/agent-runs/{payload['run_id']}")
|
||||
|
||||
assert run_response.status_code == 200
|
||||
run_payload = run_response.json()
|
||||
assert run_payload["ontology_json"]["scenario"] == "expense"
|
||||
assert run_payload["ontology_json"]["intent"] == "risk_check"
|
||||
assert run_payload["semantic_parse"]["scenario"] == "expense"
|
||||
assert run_payload["semantic_parse"]["intent"] == "risk_check"
|
||||
|
||||
|
||||
def test_parse_ontology_endpoint_returns_forbidden_for_unprivileged_payment_request() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ontology/parse",
|
||||
json={
|
||||
"query": "帮我直接付款给供应商B",
|
||||
"user_id": "pytest",
|
||||
"context_json": {"role_codes": ["user"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["scenario"] == "accounts_payable"
|
||||
assert payload["intent"] == "operate"
|
||||
assert payload["permission"]["level"] == "forbidden"
|
||||
assert payload["clarification_required"] is True
|
||||
assert payload["field_errors"]
|
||||
241
server/tests/test_orchestrator_service.py
Normal file
241
server/tests/test_orchestrator_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.db.base import Base
|
||||
from app.main import create_app
|
||||
from app.services.agent_assets import AgentAssetService
|
||||
|
||||
|
||||
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
app = create_app()
|
||||
|
||||
def override_db() -> Generator[Session, None, None]:
|
||||
db = session_factory()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_db
|
||||
return TestClient(app), session_factory
|
||||
|
||||
|
||||
def test_orchestrator_routes_user_query_to_user_agent() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "客户A这个月还有多少应收",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "read"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["result"]["answer"]
|
||||
assert payload["result"]["suggested_actions"]
|
||||
assert payload["trace_summary"]["tool_count"] >= 1
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "user_agent"
|
||||
assert run_detail["route_json"]["selected_agent"] == "user_agent"
|
||||
assert run_detail["semantic_parse"]["scenario"] == "accounts_receivable"
|
||||
assert run_detail["tool_calls"][0]["tool_type"] == "database"
|
||||
|
||||
|
||||
def test_orchestrator_routes_schedule_to_hermes() -> None:
|
||||
client, session_factory = build_client()
|
||||
|
||||
with session_factory() as db:
|
||||
task = next(
|
||||
item
|
||||
for item in AgentAssetService(db).list_assets(asset_type="task", status="active")
|
||||
if item.code == "task.hermes.daily_risk_scan"
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "schedule",
|
||||
"task_id": task.id,
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "hermes"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["trace_summary"]["tool_count"] == 2
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "hermes"
|
||||
assert run_detail["route_json"]["selected_agent"] == "hermes"
|
||||
assert len(run_detail["tool_calls"]) == 2
|
||||
|
||||
|
||||
def test_orchestrator_forbidden_request_does_not_call_downstream_agent() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我直接付款给供应商B",
|
||||
"context_json": {"role_codes": ["user"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] is None
|
||||
assert payload["permission_level"] == "forbidden"
|
||||
assert payload["status"] == "blocked"
|
||||
assert payload["trace_summary"]["tool_count"] == 0
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["agent"] == "orchestrator"
|
||||
assert run_detail["tool_calls"] == []
|
||||
|
||||
|
||||
def test_orchestrator_approval_required_returns_confirmation_result() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我安排付款给供应商B",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "approval_required"
|
||||
assert payload["requires_confirmation"] is True
|
||||
assert payload["status"] == "blocked"
|
||||
assert "确认" in payload["result"]["message"]
|
||||
|
||||
|
||||
def test_orchestrator_user_agent_draft_returns_structured_payload() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "帮我生成张三4月差旅报销草稿",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["result"]["draft_payload"]["confirmation_required"] is True
|
||||
assert payload["result"]["suggested_actions"]
|
||||
|
||||
|
||||
def test_orchestrator_treats_expense_narrative_as_draft_instead_of_ar_query() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "我今天去客户现场,招待了客户,花销了1000元",
|
||||
"context_json": {"role_codes": ["finance"]},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["permission_level"] == "draft_write"
|
||||
assert payload["status"] == "blocked"
|
||||
assert payload["route_reason"] == "clarification_required"
|
||||
assert payload["trace_summary"]["scenario"] == "expense"
|
||||
assert payload["trace_summary"]["intent"] == "draft"
|
||||
assert payload["trace_summary"]["tool_count"] == 0
|
||||
assert "应收场景数据" not in payload["result"]["message"]
|
||||
assert "请补充" in payload["result"]["message"]
|
||||
|
||||
|
||||
def test_orchestrator_tool_failure_is_logged_and_degraded() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"simulate_tool_failure": "database",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["selected_agent"] == "user_agent"
|
||||
assert payload["status"] == "succeeded"
|
||||
assert payload["trace_summary"]["failed_tool_count"] == 1
|
||||
assert payload["trace_summary"]["degraded"] is True
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["tool_calls"][0]["status"] == "failed"
|
||||
assert "simulated database failure" in run_detail["tool_calls"][0]["error_message"]
|
||||
|
||||
|
||||
def test_orchestrator_exception_is_written_to_agent_run() -> None:
|
||||
client, _ = build_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/orchestrator/run",
|
||||
json={
|
||||
"source": "user_message",
|
||||
"user_id": "pytest",
|
||||
"message": "查一下本周报销金额",
|
||||
"context_json": {
|
||||
"role_codes": ["finance"],
|
||||
"simulate_orchestrator_exception": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["status"] == "failed"
|
||||
|
||||
run_detail = client.get(f"/api/v1/agent-runs/{payload['run_id']}").json()
|
||||
assert run_detail["status"] == "failed"
|
||||
assert "simulated orchestrator exception" in run_detail["error_message"]
|
||||
179
server/tests/test_user_agent_service.py
Normal file
179
server/tests/test_user_agent_service.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.db.base import Base
|
||||
from app.schemas.ontology import OntologyParseRequest
|
||||
from app.schemas.user_agent import UserAgentRequest
|
||||
from app.services.ontology import SemanticOntologyService
|
||||
from app.services.user_agent import UserAgentService
|
||||
|
||||
|
||||
def build_session_factory() -> sessionmaker[Session]:
|
||||
engine = create_engine(
|
||||
"sqlite+pysqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def test_user_agent_query_returns_readable_answer_and_actions() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 4 月差旅报销金额是多少",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="张三 4 月差旅报销金额是多少",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert "8800.00" in response.answer
|
||||
assert len(response.suggested_actions) >= 1
|
||||
|
||||
|
||||
def test_user_agent_prefers_runtime_model_answer_when_available(monkeypatch) -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="张三 4 月差旅报销金额是多少",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
service = UserAgentService(db)
|
||||
monkeypatch.setattr(
|
||||
service,
|
||||
"_generate_answer_with_model",
|
||||
lambda *args, **kwargs: "这是模型回答",
|
||||
)
|
||||
|
||||
response = service.respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="张三 4 月差旅报销金额是多少",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 2, "total_amount": 8800.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.answer == "这是模型回答"
|
||||
|
||||
|
||||
def test_user_agent_sanitizes_model_thinking_blocks() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
service = UserAgentService(db)
|
||||
|
||||
assert (
|
||||
service._sanitize_model_answer("<think>内部推理</think>\n最终答复")
|
||||
== "最终答复"
|
||||
)
|
||||
|
||||
|
||||
def test_user_agent_guides_generic_expense_request() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我要报销",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="我要报销",
|
||||
ontology=ontology,
|
||||
tool_payload={"record_count": 9, "total_amount": 12345.0},
|
||||
)
|
||||
)
|
||||
|
||||
assert "补充费用类型" in response.answer
|
||||
assert "上传票据" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_guides_implicit_expense_draft_request() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="我今天去客户现场,招待了客户,花销了1000元",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="我今天去客户现场,招待了客户,花销了1000元",
|
||||
ontology=ontology,
|
||||
tool_payload={"draft_only": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "1000元" in response.answer
|
||||
assert "票据附件" in response.answer
|
||||
assert "报销草稿" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_risk_response_includes_rule_citations() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="检查重复报销风险",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="检查重复报销风险",
|
||||
ontology=ontology,
|
||||
tool_payload={"risk_flags": ["duplicate_expense"]},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.risk_flags == ["duplicate_expense"]
|
||||
assert any(item.source_type == "rule" for item in response.citations)
|
||||
assert "duplicate_expense" in response.answer
|
||||
|
||||
|
||||
def test_user_agent_draft_returns_structured_payload() -> None:
|
||||
session_factory = build_session_factory()
|
||||
with session_factory() as db:
|
||||
ontology = SemanticOntologyService(db).parse(
|
||||
OntologyParseRequest(
|
||||
query="帮我生成张三4月差旅报销草稿",
|
||||
user_id="pytest",
|
||||
)
|
||||
)
|
||||
response = UserAgentService(db).respond(
|
||||
UserAgentRequest(
|
||||
run_id=ontology.run_id,
|
||||
user_id="pytest",
|
||||
message="帮我生成张三4月差旅报销草稿",
|
||||
ontology=ontology,
|
||||
tool_payload={"draft_only": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.draft_payload is not None
|
||||
assert response.draft_payload.confirmation_required is True
|
||||
assert "待人工确认" in response.answer
|
||||
Reference in New Issue
Block a user