157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import Generator
|
||
|
|
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
from sqlalchemy import create_engine, select
|
||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
||
|
|
from sqlalchemy.pool import StaticPool
|
||
|
|
|
||
|
|
from app.api.deps import get_db
|
||
|
|
from app.core.agent_enums import AgentAssetDomain
|
||
|
|
from app.db.base import Base
|
||
|
|
from app.main import create_app
|
||
|
|
from app.models.agent_asset import AgentAsset, AgentAssetRuleFeedback
|
||
|
|
from app.schemas.agent_asset import (
|
||
|
|
AgentAssetRiskRuleFeedbackCreate,
|
||
|
|
AgentAssetRiskRuleGenerateRequest,
|
||
|
|
)
|
||
|
|
from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager
|
||
|
|
from app.services.agent_assets import AgentAssetService
|
||
|
|
from app.services.risk_rule_generation import RiskRuleGenerationService
|
||
|
|
|
||
|
|
|
||
|
|
class NullRuntimeChatService:
|
||
|
|
def complete(self, *args, **kwargs) -> None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def build_session() -> Session:
|
||
|
|
engine = create_engine(
|
||
|
|
"sqlite+pysqlite:///:memory:",
|
||
|
|
connect_args={"check_same_thread": False},
|
||
|
|
poolclass=StaticPool,
|
||
|
|
)
|
||
|
|
Base.metadata.create_all(bind=engine)
|
||
|
|
return sessionmaker(bind=engine, autoflush=False, autocommit=False)()
|
||
|
|
|
||
|
|
|
||
|
|
def build_client() -> tuple[TestClient, sessionmaker[Session]]:
|
||
|
|
engine = create_engine(
|
||
|
|
"sqlite+pysqlite:///:memory:",
|
||
|
|
connect_args={"check_same_thread": False},
|
||
|
|
poolclass=StaticPool,
|
||
|
|
)
|
||
|
|
Base.metadata.create_all(bind=engine)
|
||
|
|
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||
|
|
app = create_app()
|
||
|
|
|
||
|
|
def override_db() -> Generator[Session, None, None]:
|
||
|
|
db = session_factory()
|
||
|
|
try:
|
||
|
|
yield db
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
app.dependency_overrides[get_db] = override_db
|
||
|
|
return TestClient(app), session_factory
|
||
|
|
|
||
|
|
|
||
|
|
def test_risk_rule_feedback_records_misjudgement_without_modifying_rule(tmp_path) -> None:
|
||
|
|
with build_session() as db:
|
||
|
|
asset_id = _create_rule(db, tmp_path)
|
||
|
|
asset = db.get(AgentAsset, asset_id)
|
||
|
|
assert asset is not None
|
||
|
|
before_config = dict(asset.config_json or {})
|
||
|
|
before_status = asset.status
|
||
|
|
|
||
|
|
feedback = AgentAssetService(db).create_risk_rule_feedback(
|
||
|
|
asset_id,
|
||
|
|
AgentAssetRiskRuleFeedbackCreate(
|
||
|
|
feedback_type="false_positive",
|
||
|
|
subject_type="expense_claim",
|
||
|
|
subject_key="CLAIM-001",
|
||
|
|
subject_label="差旅报销 CLAIM-001",
|
||
|
|
actual_result={"hit": True, "severity": "high"},
|
||
|
|
expected_result={"hit": False},
|
||
|
|
comment="票据城市实际与行程一致,当前规则误判。",
|
||
|
|
payload={"source": "expense_review"},
|
||
|
|
),
|
||
|
|
actor="employee",
|
||
|
|
)
|
||
|
|
|
||
|
|
stored = db.scalar(
|
||
|
|
select(AgentAssetRuleFeedback).where(
|
||
|
|
AgentAssetRuleFeedback.feedback_id == feedback.feedback_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
assert stored is not None
|
||
|
|
assert feedback.feedback_type == "false_positive"
|
||
|
|
assert feedback.version == asset.working_version
|
||
|
|
assert stored.actual_result_json["hit"] is True
|
||
|
|
db.refresh(asset)
|
||
|
|
assert asset.status == before_status
|
||
|
|
assert asset.config_json == before_config
|
||
|
|
|
||
|
|
|
||
|
|
def test_risk_rule_feedback_endpoint_allows_ordinary_user_and_manager_list(tmp_path) -> None:
|
||
|
|
client, session_factory = build_client()
|
||
|
|
with session_factory() as db:
|
||
|
|
asset_id = _create_rule(db, tmp_path)
|
||
|
|
|
||
|
|
response = client.post(
|
||
|
|
f"/api/v1/agent-assets/{asset_id}/risk-rules/feedback",
|
||
|
|
headers=_user_headers(),
|
||
|
|
json={
|
||
|
|
"feedback_type": "false_negative",
|
||
|
|
"subject_type": "expense_claim",
|
||
|
|
"subject_key": "CLAIM-002",
|
||
|
|
"actual_result": {"hit": False},
|
||
|
|
"expected_result": {"hit": True, "severity": "medium"},
|
||
|
|
"comment": "这张票据应命中风险但没有命中。",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 201
|
||
|
|
assert response.json()["created_by"] == "employee"
|
||
|
|
assert response.json()["status"] == "open"
|
||
|
|
|
||
|
|
list_response = client.get(
|
||
|
|
f"/api/v1/agent-assets/{asset_id}/risk-rules/feedback",
|
||
|
|
headers=_manager_headers(),
|
||
|
|
)
|
||
|
|
assert list_response.status_code == 200
|
||
|
|
assert list_response.json()[0]["feedback_type"] == "false_negative"
|
||
|
|
|
||
|
|
|
||
|
|
def _create_rule(db: Session, tmp_path) -> str:
|
||
|
|
return RiskRuleGenerationService(
|
||
|
|
db,
|
||
|
|
rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"),
|
||
|
|
runtime_chat_service=NullRuntimeChatService(),
|
||
|
|
).generate_rule_asset(
|
||
|
|
AgentAssetRiskRuleGenerateRequest(
|
||
|
|
business_domain=AgentAssetDomain.EXPENSE,
|
||
|
|
expense_category="travel",
|
||
|
|
rule_title="差旅票据城市规则",
|
||
|
|
natural_language="差旅票据城市与申报目的地不一致时,提示补充说明。",
|
||
|
|
),
|
||
|
|
actor="pytest",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _user_headers() -> dict[str, str]:
|
||
|
|
return {
|
||
|
|
"x-auth-username": "employee",
|
||
|
|
"x-auth-name": "employee",
|
||
|
|
"x-auth-role-codes": "user",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _manager_headers() -> dict[str, str]:
|
||
|
|
return {
|
||
|
|
"x-auth-username": "manager",
|
||
|
|
"x-auth-name": "manager",
|
||
|
|
"x-auth-role-codes": "manager",
|
||
|
|
}
|