from __future__ import annotations import json from unittest.mock import MagicMock, patch import pytest from app.services.few_shot_retrieval import FewShotRetriever from app.services.few_shot_store import FewShotStore from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict: return { "sample_id": "s1", "score": score, "label": label, "domain": "expense", "risk_type": risk_type, "conclusion_text": conclusion, "payload_json": { "risk_signal": risk_type, "risk_level": "high", "ontology": {"scenario": "reimbursement"}, "feedback_comment": "", }, } def test_retrieve_returns_injection_blocks_with_token_budget() -> None: store = MagicMock(spec=FewShotStore) store.search.return_value = [ _hit(0.9, "confirmed", "确认重复发票需拦截"), _hit(0.8, "false_positive", "此情形属于正常拆单不拦截"), _hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重 ] retriever = FewShotRetriever(store) blocks = retriever.retrieve_for_risk_rule_generation( domain="expense", natural_language="同一发票重复报销" ) assert len(blocks) == 2 assert blocks[0]["score"] == 0.9 assert blocks[0]["label"] == "confirmed" assert blocks[0]["source"] == "historical_confirmed" assert blocks[1]["label"] == "false_positive" # 去重:第三条结论与第一条相同,应被过滤 conclusions = [b["conclusion"] for b in blocks] assert len(set(conclusions)) == len(conclusions) def test_retrieve_empty_case_text_returns_empty() -> None: store = MagicMock(spec=FewShotStore) retriever = FewShotRetriever(store) assert retriever.retrieve_for_risk_rule_generation(natural_language="") == [] store.search.assert_not_called() def test_retrieve_truncates_overlong_conclusion() -> None: store = MagicMock(spec=FewShotStore) long_text = "长结论" * 500 store.search.return_value = [ _hit(0.9, "confirmed", long_text), ] retriever = FewShotRetriever(store) blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x") assert len(blocks) == 1 # 超长结论应被截断到单条上限 from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS def test_build_prompt_merges_few_shot_into_examples() -> None: samples = [ { "source": "historical_confirmed", "label": "confirmed", "domain": "expense", "risk_type": "duplicate_invoice", "conclusion": "确认重复发票", "context": {"risk_signal": "duplicate_invoice"}, } ] messages = build_risk_rule_compiler_messages( domain="expense", domain_label="报销", business_stage="reimbursement", business_stage_label="报销", expense_category=None, expense_category_label="", natural_language="重复发票规则", available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}], few_shot_samples=samples, ) assert len(messages) == 2 payload = json.loads(messages[1]["content"]) examples = payload["examples"] # 前两条是历史样本,后面是内置 examples assert examples[0]["source"] == "historical_confirmed" assert examples[0]["conclusion"] == "确认重复发票" # 内置 example 仍存在(无 source 字段) assert any("user_rule" in ex for ex in examples) def test_build_prompt_without_few_shot_is_backward_compatible() -> None: messages = build_risk_rule_compiler_messages( domain="expense", domain_label="报销", business_stage="reimbursement", business_stage_label="报销", expense_category=None, expense_category_label="", natural_language="重复发票规则", available_fields=[], ) payload = json.loads(messages[1]["content"]) examples = payload["examples"] # 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源 assert all(ex.get("source") != "historical_confirmed" for ex in examples)