from __future__ import annotations from datetime import UTC, datetime, timedelta from typing import Any from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from app.models.employee import Employee from app.models.financial_record import ( AccountsPayableRecord, AccountsReceivableRecord, ExpenseClaim, ) from app.schemas.ontology import OntologyParseResult PRIVILEGED_EXPENSE_QUERY_ROLE_CODES = {"finance"} SELF_REFERENCE_KEYWORDS = ("我的", "我自己", "本人", "我名下", "给我查", "我提交", "我申请") EXPENSE_QUERY_RECENT_WINDOW_DAYS = 10 EXPENSE_QUERY_PREVIEW_LIMIT = 5 EXPENSE_STATUS_LABELS = { "archived": "归档", "draft": "草稿", "supplement": "待补充", "returned": "已退回", "submitted": "已提交", "review": "审核中", "approved": "已通过", "rejected": "已驳回", "paid": "归档", } EXPENSE_QUERY_STATUS_KEYWORDS = ( (("归档", "已归档", "入账", "已入账", "已付款"), ("archived",)), (("审批通过", "审核通过", "已通过", "已审核"), ("approved",)), (("审批中", "审核中", "进行中", "流程中"), ("submitted", "review")), (("已提交", "提交了"), ("submitted",)), (("草稿", "待报销", "待提交"), ("draft",)), (("待补充", "待完善", "退回", "已退回"), ("supplement", "returned")), (("驳回", "已驳回", "拒绝"), ("rejected",)), ) EXPENSE_STATUS_ALIASES = { "归档": "archived", "已归档": "archived", "入账": "archived", "已入账": "archived", "已付款": "archived", "已通过": "approved", "审批通过": "approved", "审核通过": "approved", "已审核": "approved", "审批中": "review", "审核中": "review", "进行中": "review", "已提交": "submitted", "草稿": "draft", "待报销": "draft", "待提交": "draft", "待补充": "supplement", "待完善": "supplement", "已退回": "returned", "退回": "returned", "驳回": "rejected", "已驳回": "rejected", } EXPENSE_STATUS_GROUP_LABELS = { "draft": "草稿", "in_progress": "审批中", "completed": "审批完成", "other": "其他状态", } EXPENSE_STATUS_GROUP_ORDER = ("draft", "in_progress", "completed", "other") EXPENSE_RISK_LEVEL_LABELS = { "high": "高风险", "medium": "中风险", "warning": "中风险", "low": "低风险", "info": "提示", } EXPENSE_TYPE_LABELS = { "travel": "差旅费", "hotel": "住宿费", "transport": "交通费", "meal": "业务招待费", "meeting": "会务费", "entertainment": "业务招待费", "office": "办公用品费", "training": "培训费", "communication": "通讯费", "welfare": "福利费", "other": "其他费用", } class OrchestratorDatabaseQueryBuilder: def __init__(self, db: Session) -> None: self.db = db def build_database_answer( self, ontology: OntologyParseResult, *, user_id: str | None, context_json: dict[str, Any], message: str, ) -> dict[str, Any]: if ontology.scenario == "expense": return self._build_expense_database_answer( ontology=ontology, user_id=user_id, context_json=context_json, message=message, ) if ontology.scenario == "accounts_receivable": return self._build_accounts_receivable_answer() return self._build_accounts_payable_answer() def _build_expense_database_answer( self, *, ontology: OntologyParseResult, user_id: str | None, context_json: dict[str, Any], message: str, ) -> dict[str, Any]: conditions, scope_label, scoped_to_current_user = self._build_expense_query_scope( ontology=ontology, user_id=user_id, context_json=context_json, message=message, ) count_stmt = select(func.count()).select_from(ExpenseClaim) amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim) for condition in conditions: count_stmt = count_stmt.where(condition) amount_stmt = amount_stmt.where(condition) total_count = int(self.db.scalar(count_stmt) or 0) total_amount = float(self.db.scalar(amount_stmt) or 0) recent_window_applied = self._should_limit_expense_query_to_recent_window(ontology, message) display_count = total_count display_amount = total_amount older_record_count = 0 display_conditions = list(conditions) window_start_date: str | None = None window_end_date: str | None = None if recent_window_applied: reference_now = self._resolve_reference_now(context_json) recent_window_start, recent_window_end = self._resolve_expense_recent_window_bounds(reference_now) recent_condition = self._build_expense_recent_window_condition( recent_window_start, recent_window_end, ) display_conditions.append(recent_condition) window_start_date = recent_window_start.date().isoformat() window_end_date = (recent_window_end - timedelta(microseconds=1)).date().isoformat() recent_count_stmt = select(func.count()).select_from(ExpenseClaim).where(recent_condition) recent_amount_stmt = select(func.coalesce(func.sum(ExpenseClaim.amount), 0)).select_from(ExpenseClaim).where( recent_condition ) for condition in conditions: recent_count_stmt = recent_count_stmt.where(condition) recent_amount_stmt = recent_amount_stmt.where(condition) display_count = int(self.db.scalar(recent_count_stmt) or 0) display_amount = float(self.db.scalar(recent_amount_stmt) or 0) older_record_count = max(0, total_count - display_count) preview_stmt = ( select(ExpenseClaim) .order_by( func.coalesce( ExpenseClaim.submitted_at, ExpenseClaim.created_at, ExpenseClaim.occurred_at, ).desc(), ExpenseClaim.occurred_at.desc(), ) .limit(EXPENSE_QUERY_PREVIEW_LIMIT) ) for condition in display_conditions: preview_stmt = preview_stmt.where(condition) preview_claims = list(self.db.scalars(preview_stmt).all()) status_groups = self._build_expense_status_groups(display_conditions) return { "result_type": "expense_claim_list", "record_count": display_count, "total_amount": round(display_amount, 2), "scope_label": scope_label, "title": f"最近 {len(preview_claims)} 条{scope_label}" if preview_claims else f"{scope_label}筛选结果", "scoped_to_current_user": scoped_to_current_user, "recent_window_applied": recent_window_applied, "window_days": EXPENSE_QUERY_RECENT_WINDOW_DAYS if recent_window_applied else None, "window_start_date": window_start_date, "window_end_date": window_end_date, "preview_count": len(preview_claims), "preview_limit": EXPENSE_QUERY_PREVIEW_LIMIT, "older_record_count": older_record_count, "records": [ self._build_expense_query_record(claim) for claim in preview_claims ], "status_groups": status_groups, "has_more_in_window": display_count > len(preview_claims), "total_matched_count": total_count, } def _build_accounts_receivable_answer(self) -> dict[str, Any]: total_count = int( self.db.scalar( select(func.count()).select_from(AccountsReceivableRecord) ) or 0 ) total_amount = float( self.db.scalar( select(func.coalesce(func.sum(AccountsReceivableRecord.amount_outstanding), 0)) ) or 0 ) return { "record_count": total_count, "outstanding_amount": round(total_amount, 2), } def _build_accounts_payable_answer(self) -> dict[str, Any]: total_count = int( self.db.scalar(select(func.count()).select_from(AccountsPayableRecord)) or 0 ) total_amount = float( self.db.scalar( select(func.coalesce(func.sum(AccountsPayableRecord.amount_outstanding), 0)) ) or 0 ) return { "record_count": total_count, "outstanding_amount": round(total_amount, 2), } @staticmethod def _should_limit_expense_query_to_recent_window( ontology: OntologyParseResult, message: str = "", ) -> bool: has_explicit_claim_no = any( item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip() for item in ontology.entities ) has_explicit_time_range = bool( ontology.time_range.start_date or ontology.time_range.end_date ) compact_message = str(message or "").replace(" ", "") asks_recent_window = any( keyword in compact_message for keyword in ("近", "最近", "本周", "上周", "过去", "前几天", "这几天") ) return asks_recent_window and not has_explicit_claim_no and not has_explicit_time_range @staticmethod def _resolve_reference_now(context_json: dict[str, Any]) -> datetime: raw_value = str(context_json.get("client_now_iso") or "").strip() if raw_value: normalized = raw_value.replace("Z", "+00:00") try: parsed = datetime.fromisoformat(normalized) if parsed.tzinfo is None: return parsed.replace(tzinfo=UTC) return parsed.astimezone(UTC) except ValueError: pass return datetime.now(UTC) @staticmethod def _resolve_expense_recent_window_bounds( reference_now: datetime, ) -> tuple[datetime, datetime]: normalized_now = reference_now.astimezone(UTC) window_end = normalized_now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) window_start = window_end - timedelta(days=EXPENSE_QUERY_RECENT_WINDOW_DAYS) return window_start, window_end @staticmethod def _build_expense_recent_window_condition( window_start: datetime, window_end: datetime, ) -> Any: document_datetime = func.coalesce( ExpenseClaim.submitted_at, ExpenseClaim.created_at, ExpenseClaim.occurred_at, ) return and_(document_datetime >= window_start, document_datetime < window_end) def _build_expense_status_groups( self, conditions: list[Any], ) -> list[dict[str, Any]]: stmt = select(ExpenseClaim.status, func.count()).select_from(ExpenseClaim).group_by(ExpenseClaim.status) for condition in conditions: stmt = stmt.where(condition) grouped_counts = { key: 0 for key in EXPENSE_STATUS_GROUP_ORDER } for status, count in self.db.execute(stmt).all(): group_key, _ = self._resolve_expense_status_group(str(status or "").strip()) grouped_counts[group_key] = grouped_counts.get(group_key, 0) + int(count or 0) return [ { "key": key, "label": EXPENSE_STATUS_GROUP_LABELS[key], "count": grouped_counts.get(key, 0), } for key in EXPENSE_STATUS_GROUP_ORDER if grouped_counts.get(key, 0) > 0 ] @staticmethod def _resolve_expense_status_group(status: str) -> tuple[str, str]: normalized = str(status or "").strip().lower() if normalized == "draft": return "draft", EXPENSE_STATUS_GROUP_LABELS["draft"] if normalized in {"submitted", "review"}: return "in_progress", EXPENSE_STATUS_GROUP_LABELS["in_progress"] if normalized in {"approved", "paid"}: return "completed", EXPENSE_STATUS_GROUP_LABELS["completed"] return "other", EXPENSE_STATUS_GROUP_LABELS["other"] @staticmethod def _resolve_expense_query_document_datetime( claim: ExpenseClaim, ) -> datetime | None: return claim.submitted_at or claim.created_at or claim.occurred_at def _build_expense_query_record( self, claim: ExpenseClaim, ) -> dict[str, Any]: status_group, status_group_label = self._resolve_expense_status_group(claim.status) document_datetime = self._resolve_expense_query_document_datetime(claim) approval_stage = str(claim.approval_stage or "").strip() status_label = ( "已归档" if "归档" in approval_stage else EXPENSE_STATUS_LABELS.get(claim.status, claim.status or "处理中") ) return { "claim_id": claim.id, "claim_no": claim.claim_no, "employee_name": claim.employee_name, "expense_type": claim.expense_type, "expense_type_label": EXPENSE_TYPE_LABELS.get(claim.expense_type, claim.expense_type or "报销"), "amount": round(float(claim.amount), 2), "status": claim.status, "status_label": status_label, "status_group": status_group, "status_group_label": status_group_label, "approval_stage": approval_stage, "document_date": document_datetime.date().isoformat() if document_datetime else "", "occurred_at": claim.occurred_at.date().isoformat() if claim.occurred_at else "", "reason": claim.reason, "location": claim.location, "risk_flags": self._normalize_expense_query_risk_flags(claim.risk_flags_json), } @staticmethod def _normalize_expense_query_risk_flags(raw_flags: Any) -> list[dict[str, str]]: if not isinstance(raw_flags, list): return [] normalized_flags: list[dict[str, str]] = [] for index, raw_flag in enumerate(raw_flags, start=1): if isinstance(raw_flag, dict): raw_level = str(raw_flag.get("severity") or raw_flag.get("level") or "").strip().lower() level = raw_level if raw_level in EXPENSE_RISK_LEVEL_LABELS else "medium" summary = str( raw_flag.get("message") or raw_flag.get("summary") or raw_flag.get("title") or raw_flag.get("label") or "" ).strip() detail = ";".join( str(point or "").strip() for point in list(raw_flag.get("points") or []) if str(point or "").strip() ) title = str(raw_flag.get("label") or EXPENSE_RISK_LEVEL_LABELS[level]).strip() else: raw_text = str(raw_flag or "").strip() if not raw_text: continue level = "high" if any(keyword in raw_text for keyword in ("高风险", "超标", "重复", "异常")) else "medium" summary = raw_text detail = raw_text title = EXPENSE_RISK_LEVEL_LABELS[level] if not summary: continue normalized_flags.append( { "key": f"risk-{index}", "level": level, "level_label": EXPENSE_RISK_LEVEL_LABELS.get(level, "中风险"), "title": title or EXPENSE_RISK_LEVEL_LABELS.get(level, "中风险"), "summary": summary, "detail": detail or summary, } ) return normalized_flags def _build_expense_query_scope( self, *, ontology: OntologyParseResult, user_id: str | None, context_json: dict[str, Any], message: str, ) -> tuple[list[Any], str, bool]: conditions: list[Any] = [] explicit_employee_names = list( dict.fromkeys( str(item.value or "").strip() for item in ontology.entities if item.type == "employee" and str(item.value or "").strip() ) ) expense_claim_nos = list( dict.fromkeys( str(item.normalized_value or item.value or "").strip().upper() for item in ontology.entities if item.type == "expense_claim" and str(item.normalized_value or item.value or "").strip() ) ) expense_types = list( dict.fromkeys( str(item.normalized_value or item.value or "").strip() for item in ontology.entities if item.type == "expense_type" and str(item.normalized_value or item.value or "").strip() ) ) project_values = self._collect_expense_query_filter_values(ontology, "project") location_values = self._collect_expense_query_filter_values(ontology, "location") status_values = self._resolve_expense_query_status_values( [ str(item.value).strip() for item in ontology.constraints if item.field == "status" and item.operator == "=" and str(item.value).strip() ], message, ) amount_constraints = [ item for item in ontology.constraints if item.field == "amount" and item.operator in {">", ">=", "<", "<=", "="} ] scope_label = "报销单" scoped_to_current_user = False if expense_claim_nos: conditions.append(ExpenseClaim.claim_no.in_(expense_claim_nos)) if expense_types: conditions.append(ExpenseClaim.expense_type.in_(expense_types)) direct_status_values = [status for status in status_values if status != "archived"] if "archived" in status_values: conditions.append( or_( ExpenseClaim.approval_stage.ilike("%归档%"), ExpenseClaim.status.in_(["approved", "paid"]), ) ) if direct_status_values: conditions.append(ExpenseClaim.status.in_(direct_status_values)) if project_values: project_conditions = [] for value in project_values: pattern = f"%{value}%" project_conditions.append(ExpenseClaim.project_code.ilike(pattern)) project_conditions.append(ExpenseClaim.reason.ilike(pattern)) conditions.append(or_(*project_conditions)) if location_values: location_conditions = [] for value in location_values: pattern = f"%{value}%" location_conditions.append(ExpenseClaim.location.ilike(pattern)) location_conditions.append(ExpenseClaim.reason.ilike(pattern)) conditions.append(or_(*location_conditions)) for item in amount_constraints: amount_value = float(item.value) if item.operator == ">": conditions.append(ExpenseClaim.amount > amount_value) elif item.operator == ">=": conditions.append(ExpenseClaim.amount >= amount_value) elif item.operator == "<": conditions.append(ExpenseClaim.amount < amount_value) elif item.operator == "<=": conditions.append(ExpenseClaim.amount <= amount_value) else: conditions.append(ExpenseClaim.amount == amount_value) if ontology.time_range.start_date: conditions.append( ExpenseClaim.occurred_at >= datetime.fromisoformat(f"{ontology.time_range.start_date}T00:00:00+00:00") ) if ontology.time_range.end_date: conditions.append( ExpenseClaim.occurred_at <= datetime.fromisoformat(f"{ontology.time_range.end_date}T23:59:59.999999+00:00") ) has_privileged_access = self._has_privileged_expense_query_access(context_json) refers_to_self = self._is_self_expense_query(message) if not has_privileged_access: owner_conditions, owner_label = self._build_current_user_claim_conditions( user_id=user_id, context_json=context_json, ) if owner_conditions: conditions.append(or_(*owner_conditions)) scope_label = owner_label scoped_to_current_user = True else: conditions.append(ExpenseClaim.id == "__no_visible_claim__") scope_label = "你的报销单" scoped_to_current_user = True elif explicit_employee_names: conditions.append(ExpenseClaim.employee_name.in_(explicit_employee_names)) scope_label = f"{'、'.join(explicit_employee_names)}的报销单" elif refers_to_self: owner_conditions, owner_label = self._build_current_user_claim_conditions( user_id=user_id, context_json=context_json, ) if owner_conditions: conditions.append(or_(*owner_conditions)) scope_label = owner_label scoped_to_current_user = True else: conditions.append(ExpenseClaim.id == "__no_visible_claim__") scope_label = "你的报销单" scoped_to_current_user = True else: scope_label = "全部报销单" return conditions, self._compose_expense_scope_label(scope_label, status_values), scoped_to_current_user @staticmethod def _resolve_expense_query_status_values( raw_values: list[str], message: str, ) -> list[str]: values: list[str] = [] for raw_value in raw_values: normalized = str(raw_value or "").strip() if not normalized: continue values.append(EXPENSE_STATUS_ALIASES.get(normalized, normalized)) compact_message = str(message or "").replace(" ", "") for keywords, statuses in EXPENSE_QUERY_STATUS_KEYWORDS: if any(keyword in compact_message for keyword in keywords): values.extend(statuses) return [ status for status in dict.fromkeys(values) if status in EXPENSE_STATUS_LABELS ] @staticmethod def _compose_expense_scope_label(scope_label: str, status_values: list[str]) -> str: normalized_scope = str(scope_label or "").strip() or "报销单" if not status_values: return normalized_scope status_labels = [ EXPENSE_STATUS_LABELS.get(status, status) for status in status_values if status in EXPENSE_STATUS_LABELS ] if not status_labels: return normalized_scope status_text = "或".join(dict.fromkeys(status_labels)) if "报销单" in normalized_scope: return normalized_scope.replace("报销单", f"{status_text}报销单") return f"{normalized_scope}({status_text})" @staticmethod def _collect_expense_query_filter_values( ontology: OntologyParseResult, field_name: str, ) -> list[str]: values: list[str] = [] for entity in ontology.entities: if entity.type != field_name: continue value = str(entity.normalized_value or entity.value or "").strip() if value: values.append(value) for constraint in ontology.constraints: if constraint.field != field_name or constraint.operator != "=": continue value = str(constraint.value or "").strip() if value: values.append(value) return list(dict.fromkeys(values)) def _build_current_user_claim_conditions( self, *, user_id: str | None, context_json: dict[str, Any], ) -> tuple[list[Any], str]: normalized_user_id = str(user_id or "").strip() employee = None if normalized_user_id: employee = self.db.scalar( select(Employee) .where(func.lower(Employee.email) == normalized_user_id.lower()) .limit(1) ) conditions: list[Any] = [] seen: set[tuple[str, str]] = set() def add_condition(field_name: str, value: str | None) -> None: normalized = str(value or "").strip() if not normalized: return marker = (field_name, normalized.lower()) if marker in seen: return seen.add(marker) if field_name == "employee_id": conditions.append(ExpenseClaim.employee_id == normalized) return conditions.append(ExpenseClaim.employee_name == normalized) if employee is not None: add_condition("employee_id", employee.id) add_condition("employee_name", employee.email) if self._employee_name_is_unique(employee): add_condition("employee_name", employee.name) else: add_condition("employee_id", normalized_user_id) add_condition("employee_name", normalized_user_id) subject_name = (employee.name if employee is not None else "") or normalized_user_id if subject_name: return conditions, "你的报销单" return conditions, "当前用户的报销单" def _employee_name_is_unique(self, employee: Employee) -> bool: normalized_name = str(employee.name or "").strip() if not normalized_name: return False same_name_count = int( self.db.scalar( select(func.count()).select_from(Employee).where(Employee.name == normalized_name) ) or 0 ) return same_name_count == 1 @staticmethod def _has_privileged_expense_query_access(context_json: dict[str, Any]) -> bool: role_codes = { str(item).strip().lower() for item in context_json.get("role_codes", []) if str(item).strip() } return bool(role_codes & PRIVILEGED_EXPENSE_QUERY_ROLE_CODES) @staticmethod def _is_self_expense_query(message: str) -> bool: compact_message = "".join(str(message or "").split()) return any(keyword in compact_message for keyword in SELF_REFERENCE_KEYWORDS)