from __future__ import annotations import calendar import re from datetime import UTC, date, datetime, timedelta from typing import Any from app.core.agent_enums import AgentPermissionLevel from app.schemas.ontology import ( OntologyConstraint, OntologyEntity, OntologyMetric, OntologyPermission, OntologyTimeRange, ) from app.services.document_numbering import DOCUMENT_NUMBER_EXTRACT_PATTERN from app.services.ontology_rules import ( AMOUNT_PATTERN, DATE_RANGE_PATTERN, EXPLICIT_DATE_PATTERN, EXPLICIT_MONTH_PATTERN, EXPENSE_APPLICATION_ATTACHMENT_REQUIRED_TYPES, EXPENSE_APPLICATION_CONTEXT_TYPES, EXPENSE_APPLICATION_KEYWORDS, EXPENSE_APPLICATION_REQUIRED_SLOT_KEYS, EXPENSE_TYPE_KEYWORDS, GENERIC_EXPENSE_APPLICATION_PROMPTS, GENERIC_EXPENSE_PROMPTS, LOCATION_KEYWORDS, MONTH_DAY_PATTERN, MONTH_DAY_RANGE_PATTERN, ReferenceCatalog, STATUS_KEYWORDS, TOP_N_PATTERN, ) class OntologyExtractionMixin: @staticmethod def _is_expense_application_context_value(context_json: dict[str, Any]) -> bool: document_type = str(context_json.get("document_type") or "").strip() application_stage = str(context_json.get("application_stage") or "").strip() entry_source = str(context_json.get("entry_source") or "").strip() session_type = str(context_json.get("session_type") or "").strip() return ( document_type in EXPENSE_APPLICATION_CONTEXT_TYPES or application_stage in EXPENSE_APPLICATION_CONTEXT_TYPES or session_type in EXPENSE_APPLICATION_CONTEXT_TYPES or entry_source in {"application", "documents_application", "expense_application"} ) @staticmethod def _has_expense_application_signal(compact_query: str) -> bool: return any(keyword in compact_query for keyword in EXPENSE_APPLICATION_KEYWORDS) def _infer_default_missing_slots( self, compact_query: str, *, scenario: str, intent: str, entities: list[OntologyEntity], time_range: OntologyTimeRange, context_json: dict[str, Any], ) -> list[str]: if scenario != "expense" or intent != "draft": return [] entity_types = {item.type for item in entities} attachment_count = int(context_json.get("attachment_count") or 0) missing_slots: list[str] = [] application_mode = ( self._is_expense_application_context_value(context_json) or self._has_expense_application_signal(compact_query) or any( item.type == "document_type" and item.normalized_value == "expense_application" for item in entities ) ) if application_mode: form_values = context_json.get("review_form_values") if not isinstance(form_values, dict): form_values = {} expense_type_codes = { str(item.normalized_value or item.value or "").strip() for item in entities if item.type == "expense_type" } if "expense_type" not in entity_types and not str(form_values.get("expense_type") or "").strip(): missing_slots.append("expense_type") if "amount" not in entity_types and not str(form_values.get("amount") or "").strip(): missing_slots.append("amount") if not time_range.start_date and not ( str(form_values.get("time_range") or form_values.get("business_time") or "").strip() ): missing_slots.append("time_range") reason_value = str( form_values.get("reason") or form_values.get("business_reason") or form_values.get("reason_value") or "" ).strip() if not reason_value and compact_query in GENERIC_EXPENSE_APPLICATION_PROMPTS: missing_slots.append("reason") if attachment_count <= 0 and expense_type_codes & EXPENSE_APPLICATION_ATTACHMENT_REQUIRED_TYPES: missing_slots.append("attachments") ordered_keys = [*EXPENSE_APPLICATION_REQUIRED_SLOT_KEYS, "attachments"] return [item for item in ordered_keys if item in missing_slots] if self._is_generic_expense_prompt(compact_query): if "expense_type" not in entity_types: missing_slots.append("expense_type") if "amount" not in entity_types: missing_slots.append("amount") if not time_range.start_date: missing_slots.append("time_range") missing_slots.append("reason") if attachment_count <= 0: missing_slots.append("attachments") return missing_slots has_entertainment_type = any( item.normalized_value == "entertainment" for item in entities if item.type == "expense_type" ) has_explicit_entertainment_text = "客户" in compact_query and any( keyword in compact_query for keyword in ("招待", "接待", "吃饭", "用餐", "宴请", "请客", "客户餐") ) if has_entertainment_type or has_explicit_entertainment_text: if "customer" not in entity_types: missing_slots.append("customer_name") missing_slots.append("participants") if attachment_count <= 0: missing_slots.append("attachments") return missing_slots @staticmethod def _resolve_confidence( *, model_confidence: float | None, fallback_confidence: float, clarification_required: bool, permission: OntologyPermission, ) -> float: confidence = fallback_confidence if model_confidence is None else float(model_confidence) confidence = max(0.0, min(confidence, 0.98)) if permission.level == AgentPermissionLevel.FORBIDDEN.value: confidence = max(confidence, 0.86) if clarification_required and permission.level != AgentPermissionLevel.FORBIDDEN.value: confidence = min(confidence, 0.58) return round(confidence, 2) def _extract_entities( self, query: str, compact_query: str, reference: ReferenceCatalog, *, context_json: dict[str, Any] | None = None, ) -> list[OntologyEntity]: entities: dict[tuple[str, str], OntologyEntity] = {} context_json = context_json or {} def upsert(entity: OntologyEntity) -> None: key = (entity.type, entity.normalized_value) if key not in entities: entities[key] = entity if ( self._is_expense_application_context_value(context_json) or self._has_expense_application_signal(compact_query) ): upsert( self._make_entity( "document_type", "费用申请", "expense_application", role="target", confidence=0.94, ) ) upsert( self._make_entity( "workflow_stage", "前置申请", "pre_approval", role="target", confidence=0.9, ) ) for match in re.finditer(r"客户\s*([A-Za-z0-9一二三四五六七八九十]+)", query): suffix = match.group(1).strip() normalized = f"客户{suffix}".replace(" ", "") upsert(self._make_entity("customer", match.group(0).strip(), normalized, role="filter")) labeled_customer_match = re.search(r"客户名称[::]\s*(?P[^\n,。;]+)", query) if labeled_customer_match: customer_name = labeled_customer_match.group("name").strip() upsert(self._make_entity("customer", customer_name, customer_name, role="filter")) for match in re.finditer(r"供应商\s*([A-Za-z0-9一二三四五六七八九十]+)", query): suffix = match.group(1).strip() normalized = f"供应商{suffix}".replace(" ", "") upsert(self._make_entity("vendor", match.group(0).strip(), normalized, role="filter")) employee_match = re.search( r"(?P[赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦许何吕施张孔曹严华金魏陶姜" r"戚谢邹喻柏水窦章云苏潘葛范彭郎鲁韦昌马苗凤花方俞任袁柳鲍史唐费廉岑" r"薛雷贺倪汤滕殷罗毕郝邬安常乐于时傅卞康伍余元卜顾孟平黄和穆萧尹姚邵" r"湛汪祁毛禹狄米贝明臧计成戴宋庞熊纪舒屈项祝董梁杜阮蓝闵席季强贾路江" r"童颜郭梅盛林钟徐邱骆高夏蔡田樊胡凌霍虞万支柯管卢莫房裘缪解应宗丁宣" r"邓洪包左石崔吉龚程嵇邢裴陆荣翁荀羊惠甄曲家封芮储靳汲邴糜松井段富巫" r"乌焦巴弓牧隗山谷车侯伊宫宁仇栾刘景詹束龙叶司黎薄印白怀蒲邰从鄂索咸" r"籍卓蔺屠蒙池乔阴胥能苍双闻莘党翟谭贡姬申扶堵冉宰郦雍桑桂牛寿通边扈" r"燕冀浦尚农温别庄晏柴瞿阎连茹习艾容向古易慎戈廖庾终暨居衡步都耿满弘" r"匡国文寇广禄阙东欧殳沃利蔚越夔隆师巩聂晁勾敖融冷辛阚那简饶曾关蒯相" r"查后荆游竺权盖益桓公][\u4e00-\u9fa5]{1,2})(?=\s*(?:\d{4}年|\d{1,2}月|本月|" r"上月|本周|报销|差旅|费用|申请))", query, ) if employee_match: name = employee_match.group("name") upsert(self._make_entity("employee", name, name, role="filter")) for name in reference.employees: if self._compact(name) in compact_query: upsert(self._make_entity("employee", name, name, role="filter")) for name in reference.departments: if self._compact(name) in compact_query: upsert(self._make_entity("department", name, name, role="filter")) for name in reference.customers: if self._compact(name) in compact_query: upsert(self._make_entity("customer", name, name, role="filter")) for name in reference.vendors: if self._compact(name) in compact_query: upsert(self._make_entity("vendor", name, name, role="filter")) for code in reference.projects: if self._compact(code) in compact_query: upsert(self._make_entity("project", code, code, role="filter")) for code in re.findall(r"PRJ-[A-Z]+-\d+", query, flags=re.IGNORECASE): upsert(self._make_entity("project", code, code.upper(), role="filter")) for match in DOCUMENT_NUMBER_EXTRACT_PATTERN.finditer(query): code = match.group(0) upsert(self._make_entity("expense_claim", code, code.upper())) for code in re.findall(r"AR-\d{6}-\d{3}", query, flags=re.IGNORECASE): upsert(self._make_entity("receivable", code, code.upper())) for code in re.findall(r"AP-\d{6}-\d{3}", query, flags=re.IGNORECASE): upsert(self._make_entity("payable", code, code.upper())) for code in re.findall(r"INV-[A-Z]+-\d+", query, flags=re.IGNORECASE): upsert(self._make_entity("invoice", code, code.upper())) for code in re.findall(r"CTR-[A-Z]+-\d+", query, flags=re.IGNORECASE): upsert(self._make_entity("contract", code, code.upper())) for location in LOCATION_KEYWORDS: if location in query: upsert(self._make_entity("location", location, location, role="filter", confidence=0.86)) for label, normalized in EXPENSE_TYPE_KEYWORDS.items(): if label in query: upsert(self._make_entity("expense_type", label, normalized, role="filter")) has_customer_entertainment_signal = "客户" in query and any( keyword in query for keyword in ("吃饭", "用餐", "餐饮", "宴请", "请客", "招待", "接待") ) if has_customer_entertainment_signal: upsert( self._make_entity( "expense_type", "业务招待费", "meal", role="filter", confidence=0.96, ) ) if any( keyword in query for keyword in ( "打车", "网约车", "出租车票", "出租车", "车费", "乘车", "用车", "叫车", "车资", "的士票", "的士", "滴滴", "市内交通", "地铁", "公交", "停车费", "过路费", "通行费", "高速费", ) ): upsert(self._make_entity("expense_type", "交通", "transport", role="filter", confidence=0.9)) if any(keyword in query for keyword in ("出差", "机票", "飞机票", "航班", "火车票", "火车", "高铁票", "高铁", "动车", "行程单")): upsert(self._make_entity("expense_type", "差旅", "travel", role="filter", confidence=0.88)) if any(keyword in query for keyword in ("酒店", "酒店发票", "住宿", "住宿费", "宾馆", "民宿", "房费", "客房")): upsert(self._make_entity("expense_type", "住宿", "hotel", role="filter", confidence=0.86)) if ( not has_customer_entertainment_signal and any(keyword in query for keyword in ("餐费", "用餐", "午餐", "晚餐", "早餐", "餐饮")) ): upsert(self._make_entity("expense_type", "业务招待费", "meal", role="filter", confidence=0.84)) if any( keyword in query for keyword in ("办公用品", "文具", "耗材", "办公耗材", "打印纸", "办公设备", "键盘", "鼠标", "白板", "硒鼓", "墨盒") ): upsert(self._make_entity("expense_type", "办公用品费", "office", role="filter", confidence=0.87)) if any(keyword in query for keyword in ("培训", "讲师费", "课时费", "课程费", "教材", "认证费", "考试费")): upsert(self._make_entity("expense_type", "培训费", "training", role="filter", confidence=0.84)) if any(keyword in query for keyword in ("通讯费", "话费", "电话费", "手机费", "流量费", "宽带费", "网络费")): upsert(self._make_entity("expense_type", "通讯费", "communication", role="filter", confidence=0.84)) if any(keyword in query for keyword in ("福利费", "团建", "慰问", "节日福利", "体检费", "员工关怀")): upsert(self._make_entity("expense_type", "福利费", "welfare", role="filter", confidence=0.84)) for amount in self._extract_amount_entities(query): upsert(amount) return list(entities.values()) def _extract_amount_entities(self, query: str) -> list[OntologyEntity]: entities: list[OntologyEntity] = [] for match in AMOUNT_PATTERN.finditer(query): raw_value = match.group("value") unit = match.group("unit") prefix = match.group("prefix") if raw_value is None: continue if prefix is None and unit is None: continue amount_value = self._normalize_amount(raw_value, unit) display_value = f"{raw_value}{unit or ''}" role = "threshold" if prefix else "target" entities.append( self._make_entity( "amount", display_value, str(amount_value), role=role, confidence=0.9, ) ) return entities @staticmethod def _make_entity( entity_type: str, value: str, normalized_value: str, *, role: str = "target", confidence: float = 0.92, ) -> OntologyEntity: return OntologyEntity( type=entity_type, value=value, normalized_value=normalized_value, role=role, confidence=confidence, ) @staticmethod def _infer_scenario_from_entities(entities: list[OntologyEntity]) -> str | None: entity_types = {item.type for item in entities} if entity_types & {"vendor", "payable"}: return "accounts_payable" if entity_types & {"customer", "receivable", "contract"}: return "accounts_receivable" if entity_types & {"employee", "expense_claim", "expense_type"}: return "expense" return None def _extract_time_range( self, query: str, compact_query: str, *, context_json: dict[str, Any], ) -> tuple[OntologyTimeRange, float]: today = self._resolve_reference_today(context_json) direct_mappings = [ ("大前天", self._single_day_range(today - timedelta(days=3), "大前天", "day")), ("前天", self._single_day_range(today - timedelta(days=2), "前天", "day")), ("昨日", self._single_day_range(today - timedelta(days=1), "昨日", "day")), ("昨天", self._single_day_range(today - timedelta(days=1), "昨天", "day")), ("今天", self._single_day_range(today, "今天", "day")), ("明天", self._single_day_range(today + timedelta(days=1), "明天", "day")), ("后天", self._single_day_range(today + timedelta(days=2), "后天", "day")), ("大后天", self._single_day_range(today + timedelta(days=3), "大后天", "day")), ] for keyword, value in direct_mappings: if keyword in query: return value, 0.10 if "本周" in query or "这周" in query or "本星期" in query: start = today - timedelta(days=today.weekday()) end = start + timedelta(days=6) return self._range(start, end, "本周", "week"), 0.10 if "上周" in query: end = today - timedelta(days=today.weekday() + 1) start = end - timedelta(days=6) return self._range(start, end, "上周", "week"), 0.10 if "本月" in query or "这个月" in query: start = date(today.year, today.month, 1) end = date(today.year, today.month, calendar.monthrange(today.year, today.month)[1]) return self._range(start, end, "本月", "month"), 0.10 if "上月" in query: year = today.year if today.month > 1 else today.year - 1 month = today.month - 1 if today.month > 1 else 12 start = date(year, month, 1) end = date(year, month, calendar.monthrange(year, month)[1]) return self._range(start, end, "上月", "month"), 0.10 if "本季度" in query or "这个季度" in query: quarter = (today.month - 1) // 3 start_month = quarter * 3 + 1 end_month = start_month + 2 start = date(today.year, start_month, 1) end = date(today.year, end_month, calendar.monthrange(today.year, end_month)[1]) return self._range(start, end, "本季度", "quarter"), 0.10 if "今年" in query: return ( self._range(date(today.year, 1, 1), date(today.year, 12, 31), "今年", "year"), 0.10, ) if "去年" in query or "上一年" in query: year = today.year - 1 return ( self._range(date(year, 1, 1), date(year, 12, 31), "去年", "year"), 0.10, ) match = DATE_RANGE_PATTERN.search(query) if match: start = self._parse_iso_date(match.group("start")) end = self._parse_iso_date(match.group("end")) if start and end: return self._range(start, end, match.group(0), "custom"), 0.10 match = EXPLICIT_DATE_PATTERN.search(query) if match: explicit = date( int(match.group("year")), int(match.group("month")), int(match.group("day")), ) return self._single_day_range(explicit, match.group(0), "day"), 0.10 match = EXPLICIT_MONTH_PATTERN.search(query) if match: year = int(match.group("year")) month = int(match.group("month")) start = date(year, month, 1) end = date(year, month, calendar.monthrange(year, month)[1]) return self._range(start, end, match.group(0), "month"), 0.10 match = MONTH_DAY_RANGE_PATTERN.search(query) if match: start = date(today.year, int(match.group("start_month")), int(match.group("start_day"))) end = date(today.year, int(match.group("end_month")), int(match.group("end_day"))) return self._range(start, end, match.group(0), "custom"), 0.10 match = MONTH_DAY_PATTERN.search(compact_query) if match: explicit = date(today.year, int(match.group("month")), int(match.group("day"))) return self._single_day_range(explicit, match.group(0), "day"), 0.08 month_match = re.search(r"(?P\d{1,2})月", compact_query) if month_match: month = int(month_match.group("month")) start = date(today.year, month, 1) end = date(today.year, month, calendar.monthrange(today.year, month)[1]) return self._range(start, end, month_match.group(0), "month"), 0.08 return OntologyTimeRange(), 0.0 @staticmethod def _resolve_reference_today(context_json: dict[str, Any]) -> date: client_now_iso = str(context_json.get("client_now_iso") or "").strip() if not client_now_iso: return datetime.now(UTC).date() normalized = client_now_iso.replace("Z", "+00:00") try: client_now = datetime.fromisoformat(normalized) except ValueError: return datetime.now(UTC).date() if client_now.tzinfo is None: client_now = client_now.replace(tzinfo=UTC) try: offset_minutes = int(context_json.get("client_timezone_offset_minutes") or 0) except (TypeError, ValueError): offset_minutes = 0 local_now = client_now - timedelta(minutes=offset_minutes) return local_now.date() @staticmethod def _single_day_range(target: date, raw: str, granularity: str) -> OntologyTimeRange: return OntologyTimeRange( raw=raw, start_date=target.isoformat(), end_date=target.isoformat(), granularity=granularity, ) @staticmethod def _range(start: date, end: date, raw: str, granularity: str) -> OntologyTimeRange: return OntologyTimeRange( raw=raw, start_date=start.isoformat(), end_date=end.isoformat(), granularity=granularity, ) @staticmethod def _parse_iso_date(value: str) -> date | None: try: return date.fromisoformat(value) except ValueError: return None def _extract_metrics(self, compact_query: str) -> list[OntologyMetric]: metrics: dict[str, OntologyMetric] = {} def upsert(metric: OntologyMetric) -> None: metrics[metric.name] = metric if any( keyword in compact_query for keyword in ("多少钱", "金额", "总额", "支出", "回款", "应收", "应付") ): upsert(OntologyMetric(name="amount", aggregation="sum", unit="CNY")) if any(keyword in compact_query for keyword in ("多少笔", "几笔", "数量", "条数", "单数")): upsert(OntologyMetric(name="count", aggregation="count", unit="records")) if "超标" in compact_query or "超预算" in compact_query: upsert(OntologyMetric(name="amount_over_limit")) if "逾期" in compact_query or "账龄" in compact_query: upsert(OntologyMetric(name="overdue")) if "重复" in compact_query: upsert(OntologyMetric(name="duplicate_expense")) top_match = TOP_N_PATTERN.search(compact_query) if top_match: metrics["amount"] = OntologyMetric( name="amount", aggregation="sum", unit="CNY", sort="desc" if "最低" not in compact_query else "asc", top_n=int(top_match.group("top")), ) return list(metrics.values()) def _extract_constraints( self, compact_query: str, entities: list[OntologyEntity], ) -> list[OntologyConstraint]: constraints: dict[tuple[str, str, str, str | None], OntologyConstraint] = {} def upsert(constraint: OntologyConstraint) -> None: key = ( constraint.field, constraint.operator, str(constraint.value), constraint.currency, ) if key not in constraints: constraints[key] = constraint for entity in entities: if entity.type in { "employee", "department", "customer", "vendor", "project", "location", "expense_type", "document_type", "workflow_stage", }: upsert( OntologyConstraint( field=entity.type, operator="=", value=entity.normalized_value, ) ) for keyword, normalized in STATUS_KEYWORDS.items(): if keyword in compact_query: upsert(OntologyConstraint(field="status", operator="=", value=normalized)) for amount_match in AMOUNT_PATTERN.finditer(compact_query): if not amount_match.group("prefix"): continue operator = self._normalize_operator(amount_match.group("prefix")) value = self._normalize_amount(amount_match.group("value"), amount_match.group("unit")) upsert( OntologyConstraint( field="amount", operator=operator, value=value, currency="CNY", ) ) break top_match = TOP_N_PATTERN.search(compact_query) if top_match: top_n = int(top_match.group("top")) upsert(OntologyConstraint(field="top_n", operator="=", value=top_n)) upsert( OntologyConstraint( field="sort_by", operator="desc" if "最低" not in compact_query else "asc", value="amount", ) ) return list(constraints.values())