120 lines
4.3 KiB
Python
120 lines
4.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.services.few_shot_retrieval import FewShotRetriever
|
||
|
|
from app.services.few_shot_store import FewShotStore
|
||
|
|
from app.services.risk_rule_generation_prompt import build_risk_rule_compiler_messages
|
||
|
|
|
||
|
|
|
||
|
|
def _hit(score: float, label: str, conclusion: str, risk_type: str = "duplicate_invoice") -> dict:
|
||
|
|
return {
|
||
|
|
"sample_id": "s1",
|
||
|
|
"score": score,
|
||
|
|
"label": label,
|
||
|
|
"domain": "expense",
|
||
|
|
"risk_type": risk_type,
|
||
|
|
"conclusion_text": conclusion,
|
||
|
|
"payload_json": {
|
||
|
|
"risk_signal": risk_type,
|
||
|
|
"risk_level": "high",
|
||
|
|
"ontology": {"scenario": "reimbursement"},
|
||
|
|
"feedback_comment": "",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def test_retrieve_returns_injection_blocks_with_token_budget() -> None:
|
||
|
|
store = MagicMock(spec=FewShotStore)
|
||
|
|
store.search.return_value = [
|
||
|
|
_hit(0.9, "confirmed", "确认重复发票需拦截"),
|
||
|
|
_hit(0.8, "false_positive", "此情形属于正常拆单不拦截"),
|
||
|
|
_hit(0.7, "confirmed", "确认重复发票需拦截"), # 重复结论应被去重
|
||
|
|
]
|
||
|
|
retriever = FewShotRetriever(store)
|
||
|
|
blocks = retriever.retrieve_for_risk_rule_generation(
|
||
|
|
domain="expense", natural_language="同一发票重复报销"
|
||
|
|
)
|
||
|
|
assert len(blocks) == 2
|
||
|
|
assert blocks[0]["score"] == 0.9
|
||
|
|
assert blocks[0]["label"] == "confirmed"
|
||
|
|
assert blocks[0]["source"] == "historical_confirmed"
|
||
|
|
assert blocks[1]["label"] == "false_positive"
|
||
|
|
# 去重:第三条结论与第一条相同,应被过滤
|
||
|
|
conclusions = [b["conclusion"] for b in blocks]
|
||
|
|
assert len(set(conclusions)) == len(conclusions)
|
||
|
|
|
||
|
|
|
||
|
|
def test_retrieve_empty_case_text_returns_empty() -> None:
|
||
|
|
store = MagicMock(spec=FewShotStore)
|
||
|
|
retriever = FewShotRetriever(store)
|
||
|
|
assert retriever.retrieve_for_risk_rule_generation(natural_language="") == []
|
||
|
|
store.search.assert_not_called()
|
||
|
|
|
||
|
|
|
||
|
|
def test_retrieve_truncates_overlong_conclusion() -> None:
|
||
|
|
store = MagicMock(spec=FewShotStore)
|
||
|
|
long_text = "长结论" * 500
|
||
|
|
store.search.return_value = [
|
||
|
|
_hit(0.9, "confirmed", long_text),
|
||
|
|
]
|
||
|
|
retriever = FewShotRetriever(store)
|
||
|
|
blocks = retriever.retrieve_for_risk_rule_generation(natural_language="x")
|
||
|
|
assert len(blocks) == 1
|
||
|
|
# 超长结论应被截断到单条上限
|
||
|
|
from app.services.few_shot_retrieval import SINGLE_SAMPLE_MAX_CHARS
|
||
|
|
|
||
|
|
assert len(blocks[0]["conclusion"]) <= SINGLE_SAMPLE_MAX_CHARS
|
||
|
|
|
||
|
|
|
||
|
|
def test_build_prompt_merges_few_shot_into_examples() -> None:
|
||
|
|
samples = [
|
||
|
|
{
|
||
|
|
"source": "historical_confirmed",
|
||
|
|
"label": "confirmed",
|
||
|
|
"domain": "expense",
|
||
|
|
"risk_type": "duplicate_invoice",
|
||
|
|
"conclusion": "确认重复发票",
|
||
|
|
"context": {"risk_signal": "duplicate_invoice"},
|
||
|
|
}
|
||
|
|
]
|
||
|
|
messages = build_risk_rule_compiler_messages(
|
||
|
|
domain="expense",
|
||
|
|
domain_label="报销",
|
||
|
|
business_stage="reimbursement",
|
||
|
|
business_stage_label="报销",
|
||
|
|
expense_category=None,
|
||
|
|
expense_category_label="",
|
||
|
|
natural_language="重复发票规则",
|
||
|
|
available_fields=[{"key": "attachment.invoice_no", "label": "发票号", "type": "string", "source": "attachment"}],
|
||
|
|
few_shot_samples=samples,
|
||
|
|
)
|
||
|
|
assert len(messages) == 2
|
||
|
|
payload = json.loads(messages[1]["content"])
|
||
|
|
examples = payload["examples"]
|
||
|
|
# 前两条是历史样本,后面是内置 examples
|
||
|
|
assert examples[0]["source"] == "historical_confirmed"
|
||
|
|
assert examples[0]["conclusion"] == "确认重复发票"
|
||
|
|
# 内置 example 仍存在(无 source 字段)
|
||
|
|
assert any("user_rule" in ex for ex in examples)
|
||
|
|
|
||
|
|
|
||
|
|
def test_build_prompt_without_few_shot_is_backward_compatible() -> None:
|
||
|
|
messages = build_risk_rule_compiler_messages(
|
||
|
|
domain="expense",
|
||
|
|
domain_label="报销",
|
||
|
|
business_stage="reimbursement",
|
||
|
|
business_stage_label="报销",
|
||
|
|
expense_category=None,
|
||
|
|
expense_category_label="",
|
||
|
|
natural_language="重复发票规则",
|
||
|
|
available_fields=[],
|
||
|
|
)
|
||
|
|
payload = json.loads(messages[1]["content"])
|
||
|
|
examples = payload["examples"]
|
||
|
|
# 无 few_shot_samples 时 examples 里不应有 historical_confirmed 来源
|
||
|
|
assert all(ex.get("source") != "historical_confirmed" for ex in examples)
|