from __future__ import annotations from collections.abc import Generator from fastapi.testclient import TestClient from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from app.api.deps import get_db from app.core.agent_enums import AgentAssetDomain from app.db.base import Base from app.main import create_app from app.models.agent_asset import AgentAsset, AgentAssetRuleFeedback from app.schemas.agent_asset import ( AgentAssetRiskRuleFeedbackCreate, AgentAssetRiskRuleGenerateRequest, ) from app.services.agent_asset_rule_library import AgentAssetRuleLibraryManager from app.services.agent_assets import AgentAssetService from app.services.risk_rule_generation import RiskRuleGenerationService class NullRuntimeChatService: def complete(self, *args, **kwargs) -> None: return None def build_session() -> Session: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) return sessionmaker(bind=engine, autoflush=False, autocommit=False)() def build_client() -> tuple[TestClient, sessionmaker[Session]]: engine = create_engine( "sqlite+pysqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) Base.metadata.create_all(bind=engine) session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False) app = create_app() def override_db() -> Generator[Session, None, None]: db = session_factory() try: yield db finally: db.close() app.dependency_overrides[get_db] = override_db return TestClient(app), session_factory def test_risk_rule_feedback_records_misjudgement_without_modifying_rule(tmp_path) -> None: with build_session() as db: asset_id = _create_rule(db, tmp_path) asset = db.get(AgentAsset, asset_id) assert asset is not None before_config = dict(asset.config_json or {}) before_status = asset.status feedback = AgentAssetService(db).create_risk_rule_feedback( asset_id, AgentAssetRiskRuleFeedbackCreate( feedback_type="false_positive", subject_type="expense_claim", subject_key="CLAIM-001", subject_label="差旅报销 CLAIM-001", actual_result={"hit": True, "severity": "high"}, expected_result={"hit": False}, comment="票据城市实际与行程一致,当前规则误判。", payload={"source": "expense_review"}, ), actor="employee", ) stored = db.scalar( select(AgentAssetRuleFeedback).where( AgentAssetRuleFeedback.feedback_id == feedback.feedback_id ) ) assert stored is not None assert feedback.feedback_type == "false_positive" assert feedback.version == asset.working_version assert stored.actual_result_json["hit"] is True db.refresh(asset) assert asset.status == before_status assert asset.config_json == before_config def test_risk_rule_feedback_endpoint_allows_ordinary_user_and_manager_list(tmp_path) -> None: client, session_factory = build_client() with session_factory() as db: asset_id = _create_rule(db, tmp_path) response = client.post( f"/api/v1/agent-assets/{asset_id}/risk-rules/feedback", headers=_user_headers(), json={ "feedback_type": "false_negative", "subject_type": "expense_claim", "subject_key": "CLAIM-002", "actual_result": {"hit": False}, "expected_result": {"hit": True, "severity": "medium"}, "comment": "这张票据应命中风险但没有命中。", }, ) assert response.status_code == 201 assert response.json()["created_by"] == "employee" assert response.json()["status"] == "open" list_response = client.get( f"/api/v1/agent-assets/{asset_id}/risk-rules/feedback", headers=_manager_headers(), ) assert list_response.status_code == 200 assert list_response.json()[0]["feedback_type"] == "false_negative" def _create_rule(db: Session, tmp_path) -> str: return RiskRuleGenerationService( db, rule_library_manager=AgentAssetRuleLibraryManager(rule_root=tmp_path / "rules"), runtime_chat_service=NullRuntimeChatService(), ).generate_rule_asset( AgentAssetRiskRuleGenerateRequest( business_domain=AgentAssetDomain.EXPENSE, expense_category="travel", rule_title="差旅票据城市规则", natural_language="差旅票据城市与申报目的地不一致时,提示补充说明。", ), actor="pytest", ) def _user_headers() -> dict[str, str]: return { "x-auth-username": "employee", "x-auth-name": "employee", "x-auth-role-codes": "user", } def _manager_headers() -> dict[str, str]: return { "x-auth-username": "manager", "x-auth-name": "manager", "x-auth-role-codes": "manager", }