Files
X-Financial/server/src/app/services/ontology_detection.py

558 lines
21 KiB
Python
Raw Normal View History

from __future__ import annotations
import json
import re
from typing import Any
from pydantic import ValidationError
from app.core.logging import get_logger
from app.schemas.ontology import (
OntologyConstraint,
OntologyEntity,
OntologyMetric,
OntologyParseRequest,
OntologyTimeRange,
)
from app.services.ontology_rules import (
AR_CORE_KEYWORDS,
AP_CORE_KEYWORDS,
COMPARE_KEYWORDS,
DRAFT_FOLLOW_UP_KEYWORDS,
DRAFT_KEYWORDS,
EXPENSE_APPLICATION_CONTEXT_TYPES,
EXPENSE_APPLICATION_KEYWORDS,
EXPENSE_NARRATIVE_KEYWORDS,
EXPENSE_REVIEW_ACTIONS,
EXPLAIN_KEYWORDS,
GENERIC_EXPENSE_PROMPTS,
KNOWLEDGE_INTENTS,
LlmOntologyEntityHint,
LlmOntologyParseResult,
OPERATE_KEYWORDS,
QUERY_KEYWORDS,
RISK_KEYWORDS,
SCENARIO_KEYWORDS,
STATUS_KEYWORDS,
)
logger = get_logger("app.services.ontology")
TRANSPORT_EXPENSE_OVERRIDE_KEYWORDS = (
"打车",
"网约车",
"出租车票",
"出租车",
"的士票",
"的士",
"滴滴",
"市内交通",
"乘车",
"乘车费",
"用车",
"叫车",
"车费",
"车资",
"机场",
)
EXPLICIT_ENTERTAINMENT_KEYWORDS = (
"业务招待",
"招待费",
"招待",
"宴请",
"请客",
"请客户吃饭",
"客户吃饭",
"客户用餐",
"客户餐",
"商务接待",
"商务宴请",
"接待餐",
)
class OntologyDetectionMixin:
@staticmethod
def _is_expense_application_context(context_json: dict[str, Any]) -> bool:
document_type = str(context_json.get("document_type") or "").strip()
application_stage = str(context_json.get("application_stage") or "").strip()
entry_source = str(context_json.get("entry_source") or "").strip()
session_type = str(context_json.get("session_type") or "").strip()
return (
document_type in EXPENSE_APPLICATION_CONTEXT_TYPES
or application_stage in EXPENSE_APPLICATION_CONTEXT_TYPES
or session_type in EXPENSE_APPLICATION_CONTEXT_TYPES
or entry_source in {"application", "documents_application", "expense_application"}
)
@staticmethod
def _looks_like_expense_application(compact_query: str) -> bool:
return any(keyword in compact_query for keyword in EXPENSE_APPLICATION_KEYWORDS)
def _detect_scenario(self, compact_query: str) -> tuple[str, float]:
scores = {key: 0.0 for key in SCENARIO_KEYWORDS}
for scenario, keywords in SCENARIO_KEYWORDS.items():
for keyword, weight in keywords:
if keyword in compact_query:
scores[scenario] += weight
best_scenario = max(scores, key=scores.get)
best_score = scores[best_scenario]
if best_score <= 0:
if "单据" in compact_query and any(
keyword in compact_query for keyword in STATUS_KEYWORDS
):
return "expense", 0.14
return "unknown", 0.0
if best_scenario == "knowledge":
business_scores = [
scores["expense"],
scores["accounts_receivable"],
scores["accounts_payable"],
]
if max(business_scores) > 0:
best_scenario = ("expense", "accounts_receivable", "accounts_payable")[
business_scores.index(max(business_scores))
]
best_score = max(business_scores)
return best_scenario, round(min(best_score, 0.34), 2)
def _detect_intent(
self,
compact_query: str,
*,
scenario: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
) -> tuple[str, float]:
if any(keyword in compact_query for keyword in OPERATE_KEYWORDS):
return "operate", 0.30
status_document_query = (
"单据" in compact_query
and any(keyword in compact_query for keyword in STATUS_KEYWORDS)
and not any(keyword in compact_query for keyword in DRAFT_KEYWORDS if keyword != "草稿")
)
historical_document_query = any(
keyword in compact_query
for keyword in ("报销的单据", "报销单据", "报销过的单据", "报销记录")
)
if scenario == "expense" and any(
keyword in compact_query
for keyword in (
"报销了吗",
"报销了么",
"报销了没",
"报销了没有",
"报销没",
"单据状态",
"审批状态",
"报销进度",
"到哪了",
"到了哪",
"有没有报销",
"是否报销",
"进行中的单据",
"草稿单据",
"草稿的单据",
"待补充单据",
"审批中的单据",
"已提交单据",
"已入账单据",
)
) or (scenario == "expense" and (status_document_query or historical_document_query)):
return "query", 0.24
if any(keyword in compact_query for keyword in DRAFT_KEYWORDS):
return "draft", 0.26
if scenario == "expense" and "报销" in compact_query and any(
item.type == "expense_type"
and str(item.normalized_value or item.value or "").strip()
for item in entities
) and not any(
keyword in compact_query
for keyword in (
*QUERY_KEYWORDS,
*COMPARE_KEYWORDS,
*EXPLAIN_KEYWORDS,
*RISK_KEYWORDS,
)
):
return "draft", 0.25
if scenario == "expense" and self._is_generic_expense_prompt(compact_query):
return "draft", 0.24
if any(keyword in compact_query for keyword in COMPARE_KEYWORDS):
return "compare", 0.24
if any(keyword in compact_query for keyword in EXPLAIN_KEYWORDS):
return "explain", 0.22
if any(keyword in compact_query for keyword in RISK_KEYWORDS):
return "risk_check", 0.24
if any(keyword in compact_query for keyword in QUERY_KEYWORDS):
return "query", 0.20
if self._looks_like_expense_narrative(
compact_query,
scenario=scenario,
entities=entities,
time_range=time_range,
):
return "draft", 0.22
return "query", 0.10
@staticmethod
def _looks_like_follow_up_message(compact_query: str) -> bool:
if not compact_query:
return False
if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS):
return True
if compact_query.startswith(("", "", "", "这个", "那个")):
return True
has_domain_keyword = any(
keyword in compact_query
for keyword, _weight in (
*SCENARIO_KEYWORDS["expense"],
*SCENARIO_KEYWORDS["accounts_receivable"],
*SCENARIO_KEYWORDS["accounts_payable"],
*SCENARIO_KEYWORDS["knowledge"],
)
)
return len(compact_query) <= 12 and not has_domain_keyword
def _should_inherit_expense_draft(
self,
compact_query: str,
*,
scenario: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
context_json: dict[str, Any],
) -> bool:
context_scenario = self._resolve_context_scenario(context_json)
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
review_action = str(context_json.get("review_action") or "").strip()
if review_action in EXPENSE_REVIEW_ACTIONS:
return True
if context_scenario != "expense" and not draft_claim_id:
return False
if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS):
return True
if self._looks_like_expense_narrative(
compact_query,
scenario="expense",
entities=entities,
time_range=time_range,
):
return True
if self._looks_like_follow_up_message(compact_query):
return True
if any(keyword in compact_query for keyword in OPERATE_KEYWORDS):
return False
if any(keyword in compact_query for keyword in COMPARE_KEYWORDS + RISK_KEYWORDS):
return False
if any(keyword in compact_query for keyword in QUERY_KEYWORDS):
return False
return bool(
draft_claim_id
and any(
item.type
in {"amount", "customer", "employee", "expense_type", "project", "invoice"}
for item in entities
)
)
@staticmethod
def _is_generic_expense_prompt(compact_query: str) -> bool:
return compact_query in GENERIC_EXPENSE_PROMPTS
@staticmethod
def _looks_like_expense_narrative(
compact_query: str,
*,
scenario: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
) -> bool:
if scenario not in {"expense", "accounts_receivable", "accounts_payable", "unknown"}:
return False
if any(keyword in compact_query for keyword in AR_CORE_KEYWORDS + AP_CORE_KEYWORDS):
return False
entity_types = {item.type for item in entities}
has_expense_signal = any(
keyword in compact_query for keyword in EXPENSE_NARRATIVE_KEYWORDS
) or "expense_type" in entity_types
has_context_signal = (
bool(time_range.start_date)
or "amount" in entity_types
or ("报销" in compact_query and "expense_type" in entity_types)
)
return has_expense_signal and has_context_signal
def _parse_with_model(
self,
*,
payload: OntologyParseRequest,
query: str,
compact_query: str,
fallback_scenario: str,
fallback_intent: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
metrics: list[OntologyMetric],
constraints: list[OntologyConstraint],
) -> LlmOntologyParseResult | None:
messages = self._build_model_messages(
payload=payload,
query=query,
compact_query=compact_query,
fallback_scenario=fallback_scenario,
fallback_intent=fallback_intent,
entities=entities,
time_range=time_range,
metrics=metrics,
constraints=constraints,
)
response_text = self.runtime_chat_service.complete(
messages,
max_tokens=600,
temperature=0.0,
)
payload_json = self._extract_json_payload(response_text)
if payload_json is None:
return None
try:
return LlmOntologyParseResult.model_validate(payload_json)
except ValidationError as exc:
logger.warning("Semantic model output validation failed: %s", exc)
return None
@staticmethod
def _build_model_messages(
*,
payload: OntologyParseRequest,
query: str,
compact_query: str,
fallback_scenario: str,
fallback_intent: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
metrics: list[OntologyMetric],
constraints: list[OntologyConstraint],
) -> list[dict[str, str]]:
facts = {
"query": query,
"compact_query": compact_query,
"context": {
"entry_source": payload.context_json.get("entry_source"),
"attachment_names": payload.context_json.get("attachment_names", []),
"attachment_count": payload.context_json.get("attachment_count", 0),
"ocr_summary": payload.context_json.get("ocr_summary", ""),
"ocr_documents": payload.context_json.get("ocr_documents", []),
"request_context": payload.context_json.get("request_context"),
"role_codes": payload.context_json.get("role_codes", []),
"conversation_id": payload.context_json.get("conversation_id"),
"conversation_scenario": payload.context_json.get("conversation_scenario"),
"conversation_intent": payload.context_json.get("conversation_intent"),
"document_type": payload.context_json.get("document_type"),
"application_stage": payload.context_json.get("application_stage"),
"application_fields": payload.context_json.get("application_fields"),
"draft_claim_id": payload.context_json.get("draft_claim_id"),
"review_action": payload.context_json.get("review_action"),
"review_form_values": payload.context_json.get("review_form_values"),
"conversation_history": payload.context_json.get("conversation_history", []),
},
"rule_candidates": {
"scenario": fallback_scenario,
"intent": fallback_intent,
"entities": [item.model_dump(mode="json") for item in entities],
"time_range": time_range.model_dump(mode="json"),
"metrics": [item.model_dump(mode="json") for item in metrics],
"constraints": [item.model_dump(mode="json") for item in constraints],
},
}
system_prompt = (
"你是企业财务共享平台的语义解析器。"
"你的任务是把用户输入解析为固定 JSON用于后续路由、追问和权限判断。"
"只输出 JSON 对象,不要输出 Markdown、代码块、解释、标题或 <think>。"
"场景 scenario 只能是expense, accounts_receivable, "
"accounts_payable, knowledge, unknown。"
"意图 intent 只能是query, explain, compare, risk_check, draft, operate。"
"如果用户是在描述一笔待处理费用、待报销事项、上传票据或希望整理报销,"
"即使没有明确说“生成草稿”,也优先使用 expense + draft。"
"如果提供了 conversation_history必须把最近轮次作为当前追问的上下文"
"正确理解“这个”“那笔”“改成 800”“继续补充”这类省略表达。"
"出现“客户”不等于应收,出现“供应商”不等于应付,必须结合动作词和业务目标判断。"
"只有明确查询、统计、列出、多少、明细、对比时才优先使用 query 或 compare。"
"附件名称和 OCR 摘要只作为辅助证据,不能编造未出现的事实。"
"如果用户明确提到打车、的士票、出租车票、网约车、乘车费、车费等交通票据,"
"即使句子里出现“客户”,也必须优先识别为 transport不要推断为 entertainment。"
"不要输出用户原文未出现、且与规则候选冲突的费用类型。"
"信息不足时 clarification_required=true并给出一句简短中文追问。"
"missing_slots 使用简短 snake_case例如 expense_type, amount, "
"customer_name, participants, attachments。"
"entity_hints 只填写你比较确定的业务对象;如果不确定,可以返回空数组。"
)
user_prompt = (
"请根据以下事实输出 JSON\n"
f"{json.dumps(facts, ensure_ascii=False, indent=2, default=str)}\n\n"
"输出格式:\n"
"{\n"
' "scenario": "expense",\n'
' "intent": "draft",\n'
' "confidence": 0.88,\n'
' "clarification_required": true,\n'
' "clarification_question": "请补充发生时间、金额和票据附件。",\n'
' "missing_slots": ["time_range", "amount", "attachments"],\n'
' "ambiguity": [],\n'
' "entity_hints": [\n'
' {"type": "expense_type", "value": "交通费", '
'"normalized_value": "transport", "role": "filter", '
'"confidence": 0.86}\n'
" ]\n"
"}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
@staticmethod
def _extract_json_payload(response_text: str | None) -> dict[str, Any] | None:
if not response_text:
return None
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
cleaned = cleaned.strip()
if not cleaned:
return None
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", cleaned, flags=re.DOTALL)
candidates = [fenced_match.group(1)] if fenced_match else []
candidates.extend([cleaned])
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(cleaned[start : end + 1])
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
@staticmethod
def _resolve_scenario(
fallback_scenario: str,
model_parse: LlmOntologyParseResult | None,
) -> str:
if model_parse is None:
return fallback_scenario
if model_parse.scenario == "unknown" and fallback_scenario != "unknown":
return fallback_scenario
return model_parse.scenario
def _resolve_intent(
self,
compact_query: str,
*,
fallback_intent: str,
scenario: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
model_parse: LlmOntologyParseResult | None,
) -> str:
candidate = model_parse.intent if model_parse is not None else fallback_intent
if scenario == "knowledge":
if candidate in KNOWLEDGE_INTENTS:
return candidate
if fallback_intent in KNOWLEDGE_INTENTS:
return fallback_intent
return "query"
if candidate == "query" and scenario == "expense":
if self._is_generic_expense_prompt(compact_query) or fallback_intent == "draft":
return "draft"
return candidate
@staticmethod
def _merge_entities(
base_entities: list[OntologyEntity],
entity_hints: list[LlmOntologyEntityHint],
compact_query: str = "",
) -> list[OntologyEntity]:
merged: dict[tuple[str, str], OntologyEntity] = {
(item.type, item.normalized_value): item for item in base_entities
}
for hint in entity_hints:
value = str(hint.value or "").strip()
if not value:
continue
normalized_value = str(hint.normalized_value or value).strip()
key = (str(hint.type).strip(), normalized_value)
candidate = OntologyEntity(
type=str(hint.type).strip(),
value=value,
normalized_value=normalized_value,
role=str(hint.role or "target").strip() or "target",
confidence=float(hint.confidence),
)
existing = merged.get(key)
if existing is None or existing.confidence < candidate.confidence:
merged[key] = candidate
items = list(merged.values())
if OntologyDetectionMixin._should_transport_override_entertainment(
compact_query,
items,
):
items = [
item
for item in items
if not (
item.type == "expense_type"
and item.normalized_value == "entertainment"
)
]
return items
@staticmethod
def _should_transport_override_entertainment(
compact_query: str,
entities: list[OntologyEntity],
) -> bool:
expense_types = {
str(item.normalized_value or item.value or "").strip()
for item in entities
if item.type == "expense_type"
}
if not {"transport", "entertainment"}.issubset(expense_types):
return False
if not any(keyword in compact_query for keyword in TRANSPORT_EXPENSE_OVERRIDE_KEYWORDS):
return False
return not any(keyword in compact_query for keyword in EXPLICIT_ENTERTAINMENT_KEYWORDS)
@staticmethod
def _normalize_short_text_list(values: list[str]) -> list[str]:
normalized: list[str] = []
seen: set[str] = set()
for value in values:
cleaned = str(value or "").strip()
if not cleaned or cleaned in seen:
continue
normalized.append(cleaned)
seen.add(cleaned)
return normalized[:6]