diff --git a/server/src/app/services/agent_conversations.py b/server/src/app/services/agent_conversations.py new file mode 100644 index 0000000..ea87a12 --- /dev/null +++ b/server/src/app/services/agent_conversations.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.models.agent_conversation import AgentConversation, AgentConversationMessage +from app.services.settings import SettingsService + +STATEFUL_CONTEXT_KEYS = ( + "entry_source", + "request_context", + "attachment_names", + "attachment_count", + "ocr_summary", + "ocr_documents", +) +DEFAULT_CONVERSATION_RETENTION_DAYS = 3 + + +class AgentConversationService: + def __init__(self, db: Session) -> None: + self.db = db + + def get_or_create_conversation( + self, + *, + conversation_id: str | None, + user_id: str | None, + source: str, + context_json: dict[str, Any], + ) -> AgentConversation: + self.prune_expired_conversations() + + normalized_id = str(conversation_id or "").strip() + normalized_user_id = str(user_id or "").strip() or None + conversation = self.get_conversation(normalized_id) if normalized_id else None + if conversation is not None and conversation.user_id != normalized_user_id: + normalized_id = "" + conversation = None + + if conversation is None: + conversation = AgentConversation( + conversation_id=normalized_id or f"conv_{uuid.uuid4().hex[:16]}", + user_id=normalized_user_id, + source=source, + entry_source=str(context_json.get("entry_source") or "").strip() or None, + title=self._resolve_title(context_json), + state_json=self._extract_state_json(context_json), + ) + self.db.add(conversation) + self.db.commit() + self.db.refresh(conversation) + return conversation + + if not conversation.user_id and normalized_user_id: + conversation.user_id = normalized_user_id + if not conversation.entry_source: + conversation.entry_source = str(context_json.get("entry_source") or "").strip() or None + if not conversation.title: + conversation.title = self._resolve_title(context_json) + conversation.state_json = self._merge_state_json( + conversation.state_json, + self._extract_state_json(context_json), + ) + self.db.add(conversation) + self.db.commit() + self.db.refresh(conversation) + return conversation + + def prune_expired_conversations( + self, + *, + retention_days: int | None = None, + ) -> int: + resolved_retention_days = retention_days or self._resolve_retention_days() + cutoff = datetime.now(UTC) - timedelta(days=max(1, resolved_retention_days)) + stmt = select(AgentConversation).where(AgentConversation.updated_at < cutoff) + expired_conversations = list(self.db.scalars(stmt).all()) + if not expired_conversations: + return 0 + + for conversation in expired_conversations: + self.db.delete(conversation) + + self.db.commit() + return len(expired_conversations) + + def _resolve_retention_days(self) -> int: + try: + settings_row, _ = SettingsService(self.db).ensure_settings_ready() + configured_days = int( + getattr( + settings_row, + "conversation_retention_days", + DEFAULT_CONVERSATION_RETENTION_DAYS, + ) + or DEFAULT_CONVERSATION_RETENTION_DAYS + ) + return max(1, min(configured_days, 10)) + except Exception: + self.db.rollback() + return DEFAULT_CONVERSATION_RETENTION_DAYS + + def get_conversation(self, conversation_id: str) -> AgentConversation | None: + normalized_id = str(conversation_id or "").strip() + if not normalized_id: + return None + stmt = select(AgentConversation).where(AgentConversation.conversation_id == normalized_id) + return self.db.scalar(stmt) + + def get_latest_conversation_for_user( + self, + *, + user_id: str | None, + source: str | None = "user_message", + ) -> AgentConversation | None: + self.prune_expired_conversations() + + normalized_user_id = str(user_id or "").strip() + if not normalized_user_id: + return None + + stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id) + if source: + stmt = stmt.where(AgentConversation.source == source) + stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc()) + return self.db.scalar(stmt.limit(1)) + + def hydrate_context_json( + self, + *, + conversation: AgentConversation, + context_json: dict[str, Any], + history_limit: int = 8, + ) -> dict[str, Any]: + merged = dict(context_json or {}) + state_json = dict(conversation.state_json or {}) + + merged["conversation_id"] = conversation.conversation_id + merged["conversation_history"] = self.list_message_history( + conversation.conversation_id, + limit=history_limit, + ) + if conversation.last_scenario: + merged.setdefault("conversation_scenario", conversation.last_scenario) + if conversation.last_intent: + merged.setdefault("conversation_intent", conversation.last_intent) + if conversation.draft_claim_id and not str(merged.get("draft_claim_id") or "").strip(): + merged["draft_claim_id"] = conversation.draft_claim_id + merged["conversation_state"] = state_json + + for key in STATEFUL_CONTEXT_KEYS: + if self._is_empty_value(merged.get(key)) and not self._is_empty_value(state_json.get(key)): + merged[key] = state_json.get(key) + + return merged + + def append_message( + self, + *, + conversation_id: str, + role: str, + content: str, + run_id: str | None = None, + message_json: dict[str, Any] | None = None, + ) -> AgentConversationMessage | None: + normalized_content = str(content or "").strip() + if not normalized_content: + return None + + conversation = self.get_conversation(conversation_id) + if conversation is None: + return None + + message = AgentConversationMessage( + conversation_id=conversation_id, + run_id=run_id, + role=str(role or "user").strip() or "user", + content=normalized_content, + message_json=message_json or {}, + created_at=datetime.now(UTC), + ) + conversation.message_count = int(conversation.message_count or 0) + 1 + if role == "user" and not conversation.title: + conversation.title = normalized_content[:48] + conversation.updated_at = datetime.now(UTC) + self.db.add(message) + self.db.add(conversation) + self.db.commit() + self.db.refresh(message) + return message + + def list_message_history( + self, + conversation_id: str, + *, + limit: int = 8, + ) -> list[dict[str, Any]]: + normalized_id = str(conversation_id or "").strip() + if not normalized_id or limit <= 0: + return [] + + stmt = ( + select(AgentConversationMessage) + .where(AgentConversationMessage.conversation_id == normalized_id) + .order_by(AgentConversationMessage.created_at.desc()) + .limit(limit) + ) + messages = list(self.db.scalars(stmt).all()) + messages.reverse() + return [ + { + "role": item.role, + "content": item.content, + "run_id": item.run_id, + "created_at": item.created_at.isoformat() if item.created_at else None, + } + for item in messages + ] + + def list_messages( + self, + conversation_id: str, + *, + limit: int | None = None, + ) -> list[AgentConversationMessage]: + normalized_id = str(conversation_id or "").strip() + if not normalized_id: + return [] + + stmt = ( + select(AgentConversationMessage) + .where(AgentConversationMessage.conversation_id == normalized_id) + .order_by(AgentConversationMessage.created_at.asc(), AgentConversationMessage.id.asc()) + ) + if limit and limit > 0: + stmt = stmt.limit(limit) + return list(self.db.scalars(stmt).all()) + + def update_state( + self, + *, + conversation_id: str, + run_id: str | None, + scenario: str | None, + intent: str | None, + context_json: dict[str, Any], + draft_payload: dict[str, Any] | None = None, + ) -> AgentConversation | None: + conversation = self.get_conversation(conversation_id) + if conversation is None: + return None + + conversation.last_run_id = str(run_id or "").strip() or conversation.last_run_id + conversation.last_scenario = str(scenario or "").strip() or conversation.last_scenario + conversation.last_intent = str(intent or "").strip() or conversation.last_intent + if draft_payload and str(draft_payload.get("claim_id") or "").strip(): + conversation.draft_claim_id = str(draft_payload["claim_id"]).strip() + + next_state = self._merge_state_json( + conversation.state_json, + self._extract_state_json(context_json), + ) + if draft_payload: + if str(draft_payload.get("claim_id") or "").strip(): + next_state["draft_claim_id"] = str(draft_payload["claim_id"]).strip() + if str(draft_payload.get("claim_no") or "").strip(): + next_state["draft_claim_no"] = str(draft_payload["claim_no"]).strip() + if str(draft_payload.get("status") or "").strip(): + next_state["draft_status"] = str(draft_payload["status"]).strip() + conversation.state_json = next_state + conversation.updated_at = datetime.now(UTC) + + self.db.add(conversation) + self.db.commit() + self.db.refresh(conversation) + return conversation + + def delete_user_conversations( + self, + *, + user_id: str | None, + source: str | None = "user_message", + ) -> int: + normalized_user_id = str(user_id or "").strip() + if not normalized_user_id: + return 0 + + stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id) + if source: + stmt = stmt.where(AgentConversation.source == source) + conversations = list(self.db.scalars(stmt).all()) + if not conversations: + return 0 + + for conversation in conversations: + self.db.delete(conversation) + + self.db.commit() + return len(conversations) + + def serialize_conversation( + self, + conversation: AgentConversation, + *, + include_messages: bool = True, + message_limit: int | None = None, + ) -> dict[str, Any]: + payload = { + "conversation_id": conversation.conversation_id, + "user_id": conversation.user_id, + "source": conversation.source, + "entry_source": conversation.entry_source, + "title": conversation.title, + "last_run_id": conversation.last_run_id, + "last_scenario": conversation.last_scenario, + "last_intent": conversation.last_intent, + "draft_claim_id": conversation.draft_claim_id, + "state_json": dict(conversation.state_json or {}), + "message_count": int(conversation.message_count or 0), + "updated_at": conversation.updated_at, + "messages": [], + } + if include_messages: + payload["messages"] = [ + self.serialize_message(item) + for item in self.list_messages(conversation.conversation_id, limit=message_limit) + ] + return payload + + @staticmethod + def serialize_message(message: AgentConversationMessage) -> dict[str, Any]: + return { + "id": message.id, + "role": message.role, + "content": message.content, + "run_id": message.run_id, + "message_json": dict(message.message_json or {}), + "created_at": message.created_at, + } + + @staticmethod + def _is_empty_value(value: Any) -> bool: + if value is None: + return True + if isinstance(value, str): + return not value.strip() + if isinstance(value, (list, tuple, set, dict)): + return len(value) == 0 + return False + + @staticmethod + def _resolve_title(context_json: dict[str, Any]) -> str | None: + request_context = context_json.get("request_context") + if isinstance(request_context, dict): + for key in ("reason", "title", "id"): + value = str(request_context.get(key) or "").strip() + if value: + return value[:200] + return None + + @staticmethod + def _extract_state_json(context_json: dict[str, Any]) -> dict[str, Any]: + state_json: dict[str, Any] = {} + for key in STATEFUL_CONTEXT_KEYS: + value = context_json.get(key) + if value is None: + continue + if isinstance(value, str) and not value.strip(): + continue + if isinstance(value, (list, dict)) and not value: + continue + state_json[key] = value + + draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() + if draft_claim_id: + state_json["draft_claim_id"] = draft_claim_id + return state_json + + @staticmethod + def _merge_state_json( + current_state: dict[str, Any] | None, + incoming_state: dict[str, Any] | None, + ) -> dict[str, Any]: + merged = dict(current_state or {}) + for key, value in (incoming_state or {}).items(): + if value is None: + continue + if isinstance(value, str) and not value.strip(): + continue + if isinstance(value, (list, dict)) and not value: + continue + merged[key] = value + return merged diff --git a/server/src/app/services/expense_claims.py b/server/src/app/services/expense_claims.py index 288a1c3..197aa31 100644 --- a/server/src/app/services/expense_claims.py +++ b/server/src/app/services/expense_claims.py @@ -40,6 +40,7 @@ class ExpenseClaimService: self._ensure_ready() claim = self._find_target_claim(ontology=ontology, context_json=context_json) + is_new_claim = claim is None before_json = self._serialize_claim(claim) if claim is not None else None employee = self._resolve_employee(ontology=ontology, context_json=context_json) @@ -47,12 +48,30 @@ class ExpenseClaimService: occurred_at = self._resolve_occurred_at(ontology) expense_type = self._resolve_expense_type(ontology.entities) location = self._resolve_location(message=message, context_json=context_json) - reason = self._resolve_reason(message=message, context_json=context_json) + reason = self._resolve_reason( + message=message, + context_json=context_json, + allow_message_fallback=is_new_claim, + ) attachment_count = self._resolve_attachment_count(context_json) + final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00")) + final_occurred_at = ( + occurred_at if occurred_at is not None else (claim.occurred_at if claim is not None else datetime.now(UTC)) + ) + final_expense_type = expense_type or (claim.expense_type if claim is not None else "other") + final_location = location or (claim.location if claim is not None else "待补充") + final_reason = reason or (claim.reason if claim is not None else "待补充") + final_attachment_count = ( + attachment_count if attachment_count > 0 else int(claim.invoice_count or 0) if claim is not None else 0 + ) + final_risk_flags = list(ontology.risk_flags) or ( + list(claim.risk_flags_json or []) if claim is not None else [] + ) + if claim is None: claim = ExpenseClaim( - claim_no=self._generate_claim_no(occurred_at), + claim_no=self._generate_claim_no(final_occurred_at), employee_id=employee.id if employee is not None else None, employee_name=employee.name if employee is not None else self._resolve_employee_name( ontology=ontology, @@ -65,16 +84,16 @@ class ExpenseClaimService: context_json=context_json, ), project_code=self._resolve_project_code(ontology.entities), - expense_type=expense_type, - reason=reason, - location=location, - amount=amount, + expense_type=final_expense_type, + reason=final_reason, + location=final_location, + amount=final_amount, currency="CNY", - invoice_count=attachment_count, - occurred_at=occurred_at, + invoice_count=final_attachment_count, + occurred_at=final_occurred_at, status="draft", approval_stage="待补充", - risk_flags_json=list(ontology.risk_flags), + risk_flags_json=final_risk_flags, ) self.db.add(claim) else: @@ -86,6 +105,7 @@ class ExpenseClaimService: ontology=ontology, context_json=context_json, user_id=user_id, + fallback=claim.employee_name, ) ) claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id @@ -95,24 +115,24 @@ class ExpenseClaimService: fallback=claim.department_name, ) claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code - claim.expense_type = expense_type or claim.expense_type - claim.reason = reason - claim.location = location - claim.amount = amount - claim.invoice_count = attachment_count - claim.occurred_at = occurred_at + claim.expense_type = final_expense_type + claim.reason = final_reason + claim.location = final_location + claim.amount = final_amount + claim.invoice_count = final_attachment_count + claim.occurred_at = final_occurred_at claim.status = "draft" claim.approval_stage = "待补充" - claim.risk_flags_json = list(ontology.risk_flags) + claim.risk_flags_json = final_risk_flags self.db.flush() self._upsert_primary_item( claim=claim, - occurred_at=occurred_at, - expense_type=expense_type, - amount=amount, - reason=reason, - location=location, + occurred_at=final_occurred_at, + expense_type=final_expense_type, + amount=final_amount, + reason=final_reason, + location=final_location, attachment_names=self._resolve_attachment_names(context_json), ) self.db.commit() @@ -130,7 +150,7 @@ class ExpenseClaimService: return { "message": ( - f"已创建报销草稿 {claim.claim_no},当前状态为 draft。" + f"已{'创建' if is_new_claim else '更新'}报销草稿 {claim.claim_no},当前状态为 draft。" "你可以继续补充费用明细、客户单位和票据附件。" ), "draft_only": True, @@ -229,6 +249,7 @@ class ExpenseClaimService: ontology: OntologyParseResult, context_json: dict[str, Any], user_id: str | None, + fallback: str = "待补充", ) -> str: for item in ontology.entities: if item.type == "employee" and item.value.strip(): @@ -237,7 +258,7 @@ class ExpenseClaimService: value = str(context_json.get(key) or "").strip() if value: return value - return str(user_id or "待补充").strip() or "待补充" + return str(user_id or fallback).strip() or fallback @staticmethod def _resolve_department_name( @@ -270,26 +291,33 @@ class ExpenseClaimService: return None @staticmethod - def _resolve_expense_type(entities: list[OntologyEntity]) -> str: + def _resolve_expense_type(entities: list[OntologyEntity]) -> str | None: for item in entities: if item.type == "expense_type": normalized = item.normalized_value.strip() if normalized: return normalized - return "other" + return None @staticmethod - def _resolve_reason(*, message: str, context_json: dict[str, Any]) -> str: + def _resolve_reason( + *, + message: str, + context_json: dict[str, Any], + allow_message_fallback: bool, + ) -> str | None: request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("reason", "title"): value = str(request_context.get(key) or "").strip() if value: return value - return str(message or "").strip()[:500] or "待补充" + if not allow_message_fallback: + return None + return str(message or "").strip()[:500] or None @staticmethod - def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str: + def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None: request_context = context_json.get("request_context") if isinstance(request_context, dict): for key in ("city", "location"): @@ -299,10 +327,10 @@ class ExpenseClaimService: compact = str(message or "").replace(" ", "") if "客户现场" in compact: return "客户现场" - return "待补充" + return None @staticmethod - def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime: + def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime | None: start_date = ontology.time_range.start_date if start_date: try: @@ -310,10 +338,10 @@ class ExpenseClaimService: return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC) except ValueError: pass - return datetime.now(UTC) + return None @staticmethod - def _resolve_amount(entities: list[OntologyEntity]) -> Decimal: + def _resolve_amount(entities: list[OntologyEntity]) -> Decimal | None: for item in entities: if item.type != "amount" or item.role == "threshold": continue @@ -321,7 +349,7 @@ class ExpenseClaimService: return Decimal(item.normalized_value).quantize(Decimal("0.01")) except (InvalidOperation, ValueError): continue - return Decimal("0.00") + return None @staticmethod def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]: diff --git a/server/src/app/services/ontology.py b/server/src/app/services/ontology.py index 86e412e..2b54e40 100644 --- a/server/src/app/services/ontology.py +++ b/server/src/app/services/ontology.py @@ -120,6 +120,24 @@ EXPLAIN_KEYWORDS = ("为什么", "依据", "原因", "怎么处理", "是否可 COMPARE_KEYWORDS = ("对比", "比较", "相比", "差异", "变化") RISK_KEYWORDS = ("风险", "异常", "重复", "超标", "超预算", "逾期", "验真", "巡检") DRAFT_KEYWORDS = ("生成", "草稿", "起草", "拟一份", "创建", "发起", "准备") +DRAFT_FOLLOW_UP_KEYWORDS = ( + "继续", + "补充", + "补一下", + "修改", + "改成", + "改为", + "换成", + "更新", + "确认", + "提交", + "保存", + "客户是", + "地点是", + "金额是", + "日期是", + "时间是", +) OPERATE_KEYWORDS = ( "直接付款", "帮我付款", @@ -200,6 +218,7 @@ STATUS_KEYWORDS = { } PRIVILEGED_ROLE_CODES = {"manager", "finance", "approver", "executive"} +CONTEXTUAL_SCENARIOS = {"expense", "accounts_receivable", "accounts_payable", "knowledge"} @dataclass(slots=True) @@ -289,12 +308,17 @@ class SemanticOntologyService: raise ValueError("query 不能为空。") AgentFoundationService(self.db).ensure_foundation_ready() + context_json = payload.context_json or {} reference = self._load_reference_catalog() compact_query = self._compact(query) entities = self._extract_entities(query, compact_query, reference) rule_scenario, scenario_score = self._detect_scenario(compact_query) time_range, _time_score = self._extract_time_range(query, compact_query) + context_scenario = self._resolve_context_scenario(context_json) + if rule_scenario == "unknown" and context_scenario is not None: + rule_scenario = context_scenario + scenario_score = max(scenario_score, 0.14) if rule_scenario == "unknown": inferred_scenario = self._infer_scenario_from_entities(entities) if inferred_scenario is not None: @@ -316,6 +340,17 @@ class SemanticOntologyService: entities=entities, time_range=time_range, ) + if self._should_inherit_expense_draft( + compact_query, + scenario=rule_scenario, + entities=entities, + time_range=time_range, + context_json=context_json, + ): + rule_scenario = "expense" + rule_intent = "draft" + scenario_score = max(scenario_score, 0.18) + intent_score = max(intent_score, 0.18) metrics = self._extract_metrics(compact_query) constraints = self._extract_constraints(compact_query, entities) model_parse = self._parse_with_model( @@ -353,7 +388,7 @@ class SemanticOntologyService: intent=intent, entities=entities, time_range=time_range, - context_json=payload.context_json or {}, + context_json=context_json, ) ) ambiguity = self._normalize_short_text_list( @@ -362,7 +397,7 @@ class SemanticOntologyService: risk_flags = self._extract_risk_flags(compact_query, scenario) permission = self._resolve_permission( compact_query, - payload.context_json or {}, + context_json, intent, ) @@ -524,6 +559,13 @@ class SemanticOntologyService: def _compact(text: str) -> str: return re.sub(r"\s+", "", text).lower() + @staticmethod + def _resolve_context_scenario(context_json: dict[str, Any]) -> str | None: + value = str(context_json.get("conversation_scenario") or "").strip() + if value in CONTEXTUAL_SCENARIOS: + return value + return None + def _detect_scenario(self, compact_query: str) -> tuple[str, float]: scores = {key: 0.0 for key in SCENARIO_KEYWORDS} for scenario, keywords in SCENARIO_KEYWORDS.items(): @@ -581,6 +623,68 @@ class SemanticOntologyService: return "draft", 0.22 return "query", 0.10 + @staticmethod + def _looks_like_follow_up_message(compact_query: str) -> bool: + if not compact_query: + return False + if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS): + return True + if compact_query.startswith(("那", "这", "它", "这个", "那个")): + return True + + has_domain_keyword = any( + keyword in compact_query + for keyword, _weight in ( + *SCENARIO_KEYWORDS["expense"], + *SCENARIO_KEYWORDS["accounts_receivable"], + *SCENARIO_KEYWORDS["accounts_payable"], + *SCENARIO_KEYWORDS["knowledge"], + ) + ) + return len(compact_query) <= 12 and not has_domain_keyword + + def _should_inherit_expense_draft( + self, + compact_query: str, + *, + scenario: str, + entities: list[OntologyEntity], + time_range: OntologyTimeRange, + context_json: dict[str, Any], + ) -> bool: + context_scenario = self._resolve_context_scenario(context_json) + draft_claim_id = str(context_json.get("draft_claim_id") or "").strip() + if context_scenario != "expense" and not draft_claim_id: + return False + + if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS): + return True + if self._looks_like_expense_narrative( + compact_query, + scenario="expense", + entities=entities, + time_range=time_range, + ): + return True + if self._looks_like_follow_up_message(compact_query): + return True + + if any(keyword in compact_query for keyword in OPERATE_KEYWORDS): + return False + if any(keyword in compact_query for keyword in COMPARE_KEYWORDS + RISK_KEYWORDS): + return False + if any(keyword in compact_query for keyword in QUERY_KEYWORDS): + return False + + return bool( + draft_claim_id + and any( + item.type + in {"amount", "customer", "employee", "expense_type", "project", "invoice"} + for item in entities + ) + ) + @staticmethod def _is_generic_expense_prompt(compact_query: str) -> bool: return compact_query in GENERIC_EXPENSE_PROMPTS @@ -670,6 +774,11 @@ class SemanticOntologyService: "ocr_documents": payload.context_json.get("ocr_documents", []), "request_context": payload.context_json.get("request_context"), "role_codes": payload.context_json.get("role_codes", []), + "conversation_id": payload.context_json.get("conversation_id"), + "conversation_scenario": payload.context_json.get("conversation_scenario"), + "conversation_intent": payload.context_json.get("conversation_intent"), + "draft_claim_id": payload.context_json.get("draft_claim_id"), + "conversation_history": payload.context_json.get("conversation_history", []), }, "rule_candidates": { "scenario": fallback_scenario, @@ -690,6 +799,8 @@ class SemanticOntologyService: "意图 intent 只能是:query, explain, compare, risk_check, draft, operate。" "如果用户是在描述一笔待处理费用、待报销事项、上传票据或希望整理报销," "即使没有明确说“生成草稿”,也优先使用 expense + draft。" + "如果提供了 conversation_history,必须把最近轮次作为当前追问的上下文," + "正确理解“这个”“那笔”“改成 800”“继续补充”这类省略表达。" "出现“客户”不等于应收,出现“供应商”不等于应付,必须结合动作词和业务目标判断。" "只有明确查询、统计、列出、多少、明细、对比时才优先使用 query 或 compare。" "附件名称和 OCR 摘要只作为辅助证据,不能编造未出现的事实。" diff --git a/server/src/app/services/orchestrator.py b/server/src/app/services/orchestrator.py index f1c7f67..71c4afa 100644 --- a/server/src/app/services/orchestrator.py +++ b/server/src/app/services/orchestrator.py @@ -32,6 +32,7 @@ from app.schemas.orchestrator import ( ) from app.schemas.user_agent import UserAgentRequest, UserAgentResponse from app.services.agent_assets import AgentAssetService +from app.services.agent_conversations import AgentConversationService from app.services.expense_claims import ExpenseClaimService from app.services.agent_foundation import AgentFoundationService from app.services.agent_runs import AgentRunService @@ -62,6 +63,7 @@ class OrchestratorService: def __init__(self, db: Session) -> None: self.db = db self.asset_service = AgentAssetService(db) + self.conversation_service = AgentConversationService(db) self.expense_claim_service = ExpenseClaimService(db) self.run_service = AgentRunService(db) self.ontology_service = SemanticOntologyService(db) @@ -69,10 +71,28 @@ class OrchestratorService: def run(self, payload: OrchestratorRequest) -> OrchestratorResponse: AgentFoundationService(self.db).ensure_foundation_ready() + context_json = dict(payload.context_json or {}) + conversation_id = str(payload.conversation_id or "").strip() or None + conversation = None + if payload.source == AgentRunSource.USER_MESSAGE.value: + conversation = self.conversation_service.get_or_create_conversation( + conversation_id=conversation_id, + user_id=payload.user_id, + source=payload.source, + context_json=context_json, + ) + conversation_id = conversation.conversation_id + context_json = self.conversation_service.hydrate_context_json( + conversation=conversation, + context_json=context_json, + ) + route_json: dict[str, Any] = { "orchestrated_by": AgentName.ORCHESTRATOR.value, "stage": "created", } + if conversation_id: + route_json["conversation_id"] = conversation_id run = self.run_service.create_run( agent=AgentName.ORCHESTRATOR.value, source=payload.source, @@ -87,15 +107,27 @@ class OrchestratorService: try: message, task_asset = self._resolve_message(payload) + if conversation is not None: + self.conversation_service.append_message( + conversation_id=conversation.conversation_id, + role="user", + content=message, + run_id=run.run_id, + message_json={ + "attachment_names": context_json.get("attachment_names", []), + "attachment_count": context_json.get("attachment_count", 0), + "ocr_summary": context_json.get("ocr_summary", ""), + }, + ) ontology = self.ontology_service.parse_for_run( OntologyParseRequest( query=message, user_id=payload.user_id, - context_json=payload.context_json, + context_json=context_json, ), run_id=run.run_id, ) - if payload.context_json.get("simulate_orchestrator_exception"): + if context_json.get("simulate_orchestrator_exception"): raise RuntimeError("simulated orchestrator exception") selected_agent, route_reason = self._select_agent(payload, ontology) capabilities = self._select_capabilities( @@ -159,6 +191,7 @@ class OrchestratorService: capabilities=capabilities, requires_confirmation=requires_confirmation, task_asset=task_asset, + context_json=context_json, ) else: outcome = self._execute_user_agent( @@ -167,6 +200,7 @@ class OrchestratorService: ontology=ontology, capabilities=capabilities, requires_confirmation=requires_confirmation, + context_json=context_json, ) final_status = ( @@ -176,10 +210,19 @@ class OrchestratorService: and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value else outcome.status ) + response_status = self._normalize_response_status(final_status) result_message = ( str(outcome.result.get("message", "")).strip() or "Orchestrator 执行完成。" ) + trace_summary = OrchestratorTraceSummary( + scenario=ontology.scenario, + intent=ontology.intent, + tool_count=outcome.tool_count, + failed_tool_count=outcome.failed_tool_count, + selected_capability_codes=selected_capability_codes, + degraded=outcome.degraded, + ) self.run_service.update_run( run.run_id, agent=selected_agent or AgentName.ORCHESTRATOR.value, @@ -195,22 +238,51 @@ class OrchestratorService: error_message=None, finished_at=datetime.now(UTC), ) + if conversation is not None and conversation_id: + draft_payload = outcome.result.get("draft_payload") + self.conversation_service.update_state( + conversation_id=conversation_id, + run_id=run.run_id, + scenario=ontology.scenario, + intent=ontology.intent, + context_json=context_json, + draft_payload=draft_payload if isinstance(draft_payload, dict) else None, + ) + self.conversation_service.append_message( + conversation_id=conversation_id, + role="assistant", + content=result_message, + run_id=run.run_id, + message_json={ + "status": final_status, + "scenario": ontology.scenario, + "intent": ontology.intent, + "attachment_names": context_json.get("attachment_names", []), + "attachment_count": context_json.get("attachment_count", 0), + "draft_payload": draft_payload if isinstance(draft_payload, dict) else None, + "orchestrator_payload": { + "run_id": run.run_id, + "conversation_id": conversation_id, + "selected_agent": selected_agent, + "route_reason": route_reason, + "permission_level": ontology.permission.level, + "status": response_status, + "requires_confirmation": requires_confirmation, + "trace_summary": trace_summary.model_dump(), + "result": outcome.result, + }, + }, + ) return OrchestratorResponse( run_id=run.run_id, + conversation_id=conversation_id, selected_agent=selected_agent, route_reason=route_reason, permission_level=ontology.permission.level, - status=self._normalize_response_status(final_status), + status=response_status, result=outcome.result, requires_confirmation=requires_confirmation, - trace_summary=OrchestratorTraceSummary( - scenario=ontology.scenario, - intent=ontology.intent, - tool_count=outcome.tool_count, - failed_tool_count=outcome.failed_tool_count, - selected_capability_codes=selected_capability_codes, - degraded=outcome.degraded, - ), + trace_summary=trace_summary, ) except Exception as exc: logger.exception("Orchestrator run failed run_id=%s", run.run_id) @@ -223,8 +295,25 @@ class OrchestratorService: error_message=str(exc), finished_at=datetime.now(UTC), ) + if conversation is not None and conversation_id: + self.conversation_service.update_state( + conversation_id=conversation_id, + run_id=run.run_id, + scenario=None, + intent=None, + context_json=context_json, + draft_payload=None, + ) + self.conversation_service.append_message( + conversation_id=conversation_id, + role="assistant", + content=f"Orchestrator 执行失败:{exc}", + run_id=run.run_id, + message_json={"status": AgentRunStatus.FAILED.value}, + ) return OrchestratorResponse( run_id=run.run_id, + conversation_id=conversation_id, selected_agent=None, route_reason="orchestrator_exception", permission_level=AgentPermissionLevel.READ.value, @@ -336,6 +425,7 @@ class OrchestratorService: ontology: OntologyParseResult, capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], requires_confirmation: bool, + context_json: dict[str, Any], ) -> ExecutionOutcome: selected_capability_codes = self._flatten_capability_codes(capabilities) if requires_confirmation: @@ -347,7 +437,7 @@ class OrchestratorService: "message": payload.message, "permission_level": ontology.permission.level, }, - context_json=payload.context_json, + context_json=context_json, executor=lambda: { "confirmation_title": "操作需要确认", "message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。", @@ -372,7 +462,7 @@ class OrchestratorService: tool_type=AgentToolType.DATABASE.value, tool_name=self._database_tool_name(ontology.scenario), request_json=self._build_ontology_json(ontology), - context_json=payload.context_json, + context_json=context_json, executor=lambda: self._build_database_answer(ontology), fallback_factory=lambda exc: { "message": f"数据库查询暂时不可用,已返回降级说明:{exc}", @@ -386,7 +476,7 @@ class OrchestratorService: user_id=payload.user_id, message=payload.message or "", ontology=ontology, - context_json=payload.context_json, + context_json=context_json, tool_payload=tool_payload, selected_capability_codes=selected_capability_codes, degraded=degraded, @@ -409,7 +499,7 @@ class OrchestratorService: tool_type=AgentToolType.DATABASE.value, tool_name="knowledge.search", request_json=self._build_ontology_json(ontology), - context_json=payload.context_json, + context_json=context_json, executor=lambda: self._build_knowledge_answer(ontology, capabilities), fallback_factory=lambda exc: { "message": f"知识检索暂时不可用,建议稍后重试:{exc}", @@ -423,7 +513,7 @@ class OrchestratorService: user_id=payload.user_id, message=payload.message or "", ontology=ontology, - context_json=payload.context_json, + context_json=context_json, tool_payload=tool_payload, selected_capability_codes=selected_capability_codes, degraded=degraded, @@ -446,7 +536,7 @@ class OrchestratorService: tool_type=AgentToolType.RULE_ENGINE.value, tool_name=self._rule_tool_name(capabilities), request_json=self._build_ontology_json(ontology), - context_json=payload.context_json, + context_json=context_json, executor=lambda: self._build_rule_answer(ontology), fallback_factory=lambda exc: { "message": f"规则检查暂时不可用,已返回人工复核建议:{exc}", @@ -460,7 +550,7 @@ class OrchestratorService: user_id=payload.user_id, message=payload.message or "", ontology=ontology, - context_json=payload.context_json, + context_json=context_json, tool_payload=tool_payload, selected_capability_codes=selected_capability_codes, degraded=degraded, @@ -499,7 +589,7 @@ class OrchestratorService: user_id=payload.user_id, message=payload.message or "", ontology=ontology, - context_json=payload.context_json, + context_json=context_json, ) fallback_factory = lambda exc: { "message": f"报销草稿落库失败,请稍后再试:{exc}", @@ -511,7 +601,7 @@ class OrchestratorService: tool_type=tool_type, tool_name=tool_name, request_json=self._build_ontology_json(ontology), - context_json=payload.context_json, + context_json=context_json, executor=executor, fallback_factory=fallback_factory, ) @@ -522,7 +612,7 @@ class OrchestratorService: user_id=payload.user_id, message=payload.message or "", ontology=ontology, - context_json=payload.context_json, + context_json=context_json, tool_payload=tool_payload, selected_capability_codes=selected_capability_codes, degraded=degraded, @@ -548,6 +638,7 @@ class OrchestratorService: capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], requires_confirmation: bool, task_asset: AgentAssetRead | None, + context_json: dict[str, Any], ) -> ExecutionOutcome: if requires_confirmation: return ExecutionOutcome( @@ -566,7 +657,7 @@ class OrchestratorService: tool_type=AgentToolType.RULE_ENGINE.value, tool_name=self._rule_tool_name(capabilities), request_json=self._build_ontology_json(ontology), - context_json=payload.context_json, + context_json=context_json, executor=lambda: self._build_rule_answer(ontology), fallback_factory=lambda exc: { "message": f"规则巡检失败,已降级为待人工复核:{exc}", @@ -581,7 +672,7 @@ class OrchestratorService: "task_code": task_asset.code if task_asset is not None else "", "scenario": ontology.scenario, }, - context_json=payload.context_json, + context_json=context_json, executor=lambda: self._build_mcp_answer(task_asset, ontology), fallback_factory=lambda exc: { "message": f"MCP 调用失败,已使用缓存快照降级:{exc}", @@ -806,6 +897,8 @@ class OrchestratorService: } if response.draft_payload is not None: result["draft_payload"] = response.draft_payload.model_dump() + if response.review_payload is not None: + result["review_payload"] = response.review_payload.model_dump() return result @staticmethod diff --git a/server/src/app/services/settings.py b/server/src/app/services/settings.py index bf4237d..8d47ca5 100644 --- a/server/src/app/services/settings.py +++ b/server/src/app/services/settings.py @@ -203,7 +203,8 @@ class SettingsService: settings_row.admin_account = payload.adminForm.adminAccount settings_row.admin_email = payload.adminForm.adminEmail - settings_row.session_timeout = payload.adminForm.sessionTimeout + settings_row.session_timeout = payload.adminForm.sessionTimeout + settings_row.conversation_retention_days = payload.sessionForm.conversationRetentionDays settings_row.notice_email = payload.adminForm.noticeEmail settings_row.mfa_enabled = payload.adminForm.mfaEnabled settings_row.strong_password = payload.adminForm.strongPassword @@ -428,8 +429,9 @@ class SettingsService: copyright_text=f"Copyright © 2024-{current_year} {company_name}. All Rights Reserved.", admin_account=admin_account, admin_email=admin_email, - session_timeout=30, - notice_email=admin_email, + session_timeout=30, + conversation_retention_days=3, + notice_email=admin_email, mfa_enabled=True, strong_password=True, login_alert_enabled=True, @@ -520,6 +522,10 @@ class SettingsService: if "system_settings" in table_names: settings_columns = {column["name"] for column in inspector.get_columns("system_settings")} + if "conversation_retention_days" not in settings_columns: + migration_statements.append( + "ALTER TABLE system_settings ADD COLUMN conversation_retention_days INTEGER DEFAULT 3" + ) if "onlyoffice_enabled" not in settings_columns: migration_statements.append( "ALTER TABLE system_settings ADD COLUMN onlyoffice_enabled BOOLEAN DEFAULT FALSE" @@ -588,20 +594,23 @@ class SettingsService: "recordNumber": settings_row.record_number, "copyright": settings_row.copyright_text, }, - adminForm={ - "adminAccount": settings_row.admin_account, - "adminEmail": settings_row.admin_email, - "newPassword": "", - "confirmPassword": "", + adminForm={ + "adminAccount": settings_row.admin_account, + "adminEmail": settings_row.admin_email, + "newPassword": "", + "confirmPassword": "", "sessionTimeout": settings_row.session_timeout, "noticeEmail": settings_row.notice_email, "mfaEnabled": settings_row.mfa_enabled, "strongPassword": settings_row.strong_password, - "loginAlertEnabled": settings_row.login_alert_enabled, - "adminPasswordConfigured": bool(secrets_row.admin_password_hash), - }, - llmForm={ - "mainProvider": main_model.provider, + "loginAlertEnabled": settings_row.login_alert_enabled, + "adminPasswordConfigured": bool(secrets_row.admin_password_hash), + }, + sessionForm={ + "conversationRetentionDays": settings_row.conversation_retention_days, + }, + llmForm={ + "mainProvider": main_model.provider, "mainModel": main_model.model_name, "mainEndpoint": main_model.endpoint, "mainApiKey": "", diff --git a/server/src/app/services/user_agent.py b/server/src/app/services/user_agent.py index d06f7e2..f9371ff 100644 --- a/server/src/app/services/user_agent.py +++ b/server/src/app/services/user_agent.py @@ -2,14 +2,24 @@ from __future__ import annotations import json import re +from datetime import UTC, datetime, timedelta +from sqlalchemy import select from sqlalchemy.orm import Session from app.core.agent_enums import AgentAssetStatus, AgentAssetType +from app.models.financial_record import ExpenseClaim from app.schemas.agent_asset import AgentAssetListItem from app.schemas.user_agent import ( UserAgentCitation, UserAgentDraftPayload, + UserAgentReviewAction, + UserAgentReviewClaimGroup, + UserAgentReviewDocumentCard, + UserAgentReviewDocumentField, + UserAgentReviewPayload, + UserAgentReviewRiskBrief, + UserAgentReviewSlotCard, UserAgentRequest, UserAgentResponse, UserAgentSuggestedAction, @@ -53,8 +63,32 @@ EXPENSE_TYPE_LABELS = { "meal": "餐费", "meeting": "会务", "entertainment": "招待", + "other": "其他", } +GROUP_SCENE_LABELS = { + "travel": "差旅费", + "entertainment": "业务招待费", + "meal": "伙食费", + "transport": "交通费", + "hotel": "住宿费", + "other": "其他费用", +} + +SLOT_LABELS = { + "expense_type": "报销类型", + "customer_name": "客户名称", + "time_range": "发生时间", + "location": "地点", + "merchant_name": "酒店/商户", + "amount": "金额", + "participants": "参与人员", + "attachments": "票据附件", +} + +DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)") +AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)") + class UserAgentService: def __init__(self, db: Session) -> None: @@ -72,23 +106,32 @@ class UserAgentService: if payload.ontology.intent == "draft" else None ) + review_payload = self._build_review_payload( + payload, + citations=citations, + draft_payload=draft_payload, + ) if payload.degraded and payload.tool_payload.get("message"): return UserAgentResponse( answer=str(payload.tool_payload["message"]), citations=citations, suggested_actions=suggested_actions, + review_payload=review_payload, risk_flags=risk_flags, requires_confirmation=payload.requires_confirmation, ) - guided_answer = self._build_guided_answer(payload) + guided_answer = None + if draft_payload is None or draft_payload.claim_id is None: + guided_answer = self._build_guided_answer(payload) if guided_answer: return UserAgentResponse( answer=guided_answer, citations=citations, suggested_actions=suggested_actions, draft_payload=draft_payload, + review_payload=review_payload, risk_flags=risk_flags, requires_confirmation=payload.requires_confirmation, ) @@ -98,20 +141,23 @@ class UserAgentService: citations=citations, draft_payload=draft_payload, ) - answer = self._generate_answer_with_model( - payload, - citations=citations, - suggested_actions=suggested_actions, - risk_flags=risk_flags, - draft_payload=draft_payload, - fallback_answer=fallback_answer, - ) + answer = None + if not self._should_skip_model_answer(payload, review_payload): + answer = self._generate_answer_with_model( + payload, + citations=citations, + suggested_actions=suggested_actions, + risk_flags=risk_flags, + draft_payload=draft_payload, + fallback_answer=fallback_answer, + ) return UserAgentResponse( answer=answer or fallback_answer, citations=citations, suggested_actions=suggested_actions, draft_payload=draft_payload, + review_payload=review_payload, risk_flags=risk_flags, requires_confirmation=payload.requires_confirmation, ) @@ -129,6 +175,13 @@ class UserAgentService: if payload.ontology.intent == "risk_check": return self._build_risk_answer(payload, citations) + if payload.ontology.intent == "draft": + tool_message = str(payload.tool_payload.get("message") or "").strip() + if tool_message and ( + str(payload.tool_payload.get("claim_id") or "").strip() + or str(payload.tool_payload.get("claim_no") or "").strip() + ): + return tool_message if payload.ontology.intent == "draft" and draft_payload is not None: return ( f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容," @@ -243,6 +296,11 @@ class UserAgentService: "attachment_names": self._resolve_attachment_names(payload), "ocr_summary": payload.context_json.get("ocr_summary", ""), "ocr_documents": payload.context_json.get("ocr_documents", []), + "conversation_id": payload.context_json.get("conversation_id"), + "conversation_scenario": payload.context_json.get("conversation_scenario"), + "conversation_intent": payload.context_json.get("conversation_intent"), + "draft_claim_id": payload.context_json.get("draft_claim_id"), + "conversation_history": self._resolve_conversation_history(payload), }, "tool_payload": payload.tool_payload, "citations": [item.model_dump(mode="json") for item in citations], @@ -267,6 +325,7 @@ class UserAgentService: "并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。" "如果上下文里只有附件名称,必须明确说明你只拿到了附件名称," "不能假装已看过图片、PDF 或发票内容。" + "如果提供了 conversation_history,必须结合最近轮次理解追问、代词、省略字段和补充信息。" "不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。" "若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。" "只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、" @@ -447,6 +506,424 @@ class UserAgentService: ), ] + def _build_review_payload( + self, + payload: UserAgentRequest, + *, + citations: list[UserAgentCitation], + draft_payload: UserAgentDraftPayload | None, + ) -> UserAgentReviewPayload | None: + attachment_count = self._resolve_attachment_count(payload) + ocr_documents = self._resolve_ocr_documents(payload) + if payload.ontology.scenario != "expense": + return None + if payload.ontology.intent not in {"draft", "operate"} and attachment_count <= 0 and not ocr_documents: + return None + + slot_cards = self._build_review_slot_cards(payload, ocr_documents=ocr_documents) + document_cards = self._build_review_document_cards(payload, ocr_documents=ocr_documents) + claim_groups = self._build_review_claim_groups( + payload, + document_cards=document_cards, + ) + risk_briefs = self._build_review_risk_briefs( + payload, + citations=citations, + document_cards=document_cards, + claim_groups=claim_groups, + ) + confirmation_actions = self._build_review_confirmation_actions( + payload, + claim_groups=claim_groups, + draft_payload=draft_payload, + ) + intent_summary = self._build_review_intent_summary( + payload, + slot_cards=slot_cards, + claim_groups=claim_groups, + ) + + return UserAgentReviewPayload( + intent_summary=intent_summary, + scenario=payload.ontology.scenario, + intent=payload.ontology.intent, + missing_slots=list(payload.ontology.missing_slots), + risk_briefs=risk_briefs, + slot_cards=slot_cards, + document_cards=document_cards, + claim_groups=claim_groups, + confirmation_actions=confirmation_actions, + ) + + def _build_review_slot_cards( + self, + payload: UserAgentRequest, + *, + ocr_documents: list[dict[str, object]], + ) -> list[UserAgentReviewSlotCard]: + first_doc_fields = self._extract_document_fields(ocr_documents[0]) if ocr_documents else {} + missing_slots = set(payload.ontology.missing_slots) + entity_map = self._collect_entity_values(payload) + + time_value = self._format_time_range(payload) + location_value = self._resolve_location_value(payload) + merchant_value = self._extract_document_merchant_name(ocr_documents[0]) if ocr_documents else "" + customer_value = entity_map.get("customer", "") + participants_value = entity_map.get("participants", "") + amount_value = entity_map.get("amount") + if not amount_value: + ocr_total_amount = self._sum_ocr_amounts(ocr_documents) + amount_value = f"{ocr_total_amount:.2f}元" if ocr_total_amount > 0 else "" + expense_type_code = entity_map.get("expense_type_code", "") + expense_type_value = EXPENSE_TYPE_LABELS.get(expense_type_code, entity_map.get("expense_type", "")) + if not expense_type_value and ocr_documents: + expense_type_value = self._infer_expense_type_from_documents(payload, ocr_documents) + attachment_value = ( + f"{self._resolve_attachment_count(payload)} 份附件" + if self._resolve_attachment_count(payload) + else "" + ) + + cards = [ + self._make_slot_card( + key="expense_type", + value=expense_type_value, + source="user_text" if expense_type_value else "system", + confidence=0.9 if expense_type_value else 0.0, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="customer_name", + value=customer_value, + source="user_text" if customer_value else "system", + confidence=0.88 if customer_value else 0.0, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="time_range", + value=time_value, + source="user_text" if time_value else "system", + confidence=0.9 if time_value else 0.0, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="location", + value=location_value, + source="page_context" if location_value and location_value != "客户现场" else "user_text", + confidence=0.82 if location_value else 0.0, + required=False, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="merchant_name", + value=merchant_value, + source="ocr" if merchant_value else "system", + confidence=0.72 if merchant_value else 0.0, + required=False, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="amount", + value=amount_value, + source="user_text" if entity_map.get("amount") else "ocr" if amount_value else "system", + confidence=0.92 if amount_value else 0.0, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="participants", + value=participants_value, + source="user_text" if participants_value else "system", + confidence=0.8 if participants_value else 0.0, + missing_slots=missing_slots, + ), + self._make_slot_card( + key="attachments", + value=attachment_value, + source="upload" if attachment_value else "system", + confidence=1.0 if attachment_value else 0.0, + missing_slots=missing_slots, + ), + ] + return cards + + def _build_review_document_cards( + self, + payload: UserAgentRequest, + *, + ocr_documents: list[dict[str, object]], + ) -> list[UserAgentReviewDocumentCard]: + cards: list[UserAgentReviewDocumentCard] = [] + for index, item in enumerate(ocr_documents, start=1): + classified = self._classify_document(item, payload) + fields = self._extract_document_fields(item) + cards.append( + UserAgentReviewDocumentCard( + index=index, + filename=str(item.get("filename") or f"document-{index}"), + document_type=classified["document_type"], + suggested_expense_type=classified["expense_type"], + scene_label=GROUP_SCENE_LABELS.get( + classified["group_code"], + classified["scene_label"], + ), + summary=str(item.get("summary") or item.get("text") or "").strip(), + avg_score=float(item.get("avg_score") or 0.0), + warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()], + fields=[ + UserAgentReviewDocumentField( + label=label, + value=value, + source="ocr", + ) + for label, value in fields.items() + if str(value).strip() + ], + ) + ) + return cards + + def _build_review_claim_groups( + self, + payload: UserAgentRequest, + *, + document_cards: list[UserAgentReviewDocumentCard], + ) -> list[UserAgentReviewClaimGroup]: + groups: dict[str, dict[str, object]] = {} + for card in document_cards: + group_code = self._normalize_group_code(card.suggested_expense_type) + bucket = groups.setdefault( + group_code, + { + "document_indexes": [], + "amount_total": 0.0, + "expense_type": group_code, + "scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"), + "reasons": [], + }, + ) + bucket["document_indexes"].append(card.index) + bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card) + bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}") + + if not groups: + expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other") + group_code = self._normalize_group_code(expense_type_code) + groups[group_code] = { + "document_indexes": [], + "amount_total": self._resolve_amount_value(payload), + "expense_type": expense_type_code or "other", + "scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"), + "reasons": ["当前主要依据用户文本和页面上下文进行分单建议。"], + } + + claim_groups: list[UserAgentReviewClaimGroup] = [] + for index, (group_code, bucket) in enumerate(groups.items(), start=1): + title = f"建议报销单 {index}:{bucket['scene_label']}" + rationale = ( + ";".join(dict.fromkeys(str(item) for item in bucket["reasons"])) + if bucket["reasons"] + else "当前仅有单一场景,无需拆单。" + ) + claim_groups.append( + UserAgentReviewClaimGroup( + group_code=group_code, + title=title, + expense_type=str(bucket["expense_type"]), + scene_label=str(bucket["scene_label"]), + document_indexes=list(bucket["document_indexes"]), + amount_total=round(float(bucket["amount_total"]), 2), + rationale=rationale, + ) + ) + return claim_groups + + def _build_review_risk_briefs( + self, + payload: UserAgentRequest, + *, + citations: list[UserAgentCitation], + document_cards: list[UserAgentReviewDocumentCard], + claim_groups: list[UserAgentReviewClaimGroup], + ) -> list[UserAgentReviewRiskBrief]: + briefs: list[UserAgentReviewRiskBrief] = [] + employee_name = self._collect_entity_values(payload).get("employee_name") or str( + payload.context_json.get("name") or "" + ).strip() + if employee_name: + since = datetime.now(UTC) - timedelta(days=90) + stmt = select(ExpenseClaim).where( + ExpenseClaim.employee_name == employee_name, + ExpenseClaim.occurred_at >= since, + ) + recent_claims = list(self.db.scalars(stmt).all()) + if recent_claims: + risky_count = sum(1 for item in recent_claims if item.risk_flags_json) + draft_count = sum(1 for item in recent_claims if item.status == "draft") + briefs.append( + UserAgentReviewRiskBrief( + title="历史报销画像", + level="info", + content=( + f"{employee_name} 最近 90 天共有 {len(recent_claims)} 笔报销," + f"其中 {risky_count} 笔带风险标记,{draft_count} 笔仍处于草稿态。" + ), + ) + ) + current_amount = self._resolve_amount_value(payload) + if current_amount > 0: + duplicate_count = sum( + 1 + for item in recent_claims + if abs(float(item.amount) - current_amount) < 0.01 + ) + if duplicate_count: + briefs.append( + UserAgentReviewRiskBrief( + title="金额重复预警", + level="warning", + content=( + f"近 90 天发现 {duplicate_count} 笔金额相同的报销记录," + "提交前建议核对是否为重复报销或拆分不当。" + ), + ) + ) + + if citations: + briefs.append( + UserAgentReviewRiskBrief( + title="制度注意事项", + level="info", + content=citations[0].excerpt or f"请先核对 {citations[0].title} 的制度要求。", + ) + ) + + warning_count = sum(len(item.warnings) for item in document_cards) + if warning_count: + briefs.append( + UserAgentReviewRiskBrief( + title="票据识别提醒", + level="warning", + content=f"当前共有 {warning_count} 条票据识别提示,建议逐张确认 OCR 识别字段。", + ) + ) + + if len(claim_groups) > 1: + briefs.append( + UserAgentReviewRiskBrief( + title="建议拆单", + level="high", + content=f"系统检测到 {len(claim_groups)} 类费用场景,建议拆成多张报销单后再提交。", + ) + ) + + return briefs[:4] + + def _build_review_confirmation_actions( + self, + payload: UserAgentRequest, + *, + claim_groups: list[UserAgentReviewClaimGroup], + draft_payload: UserAgentDraftPayload | None, + ) -> list[UserAgentReviewAction]: + actions: list[UserAgentReviewAction] = [] + + if claim_groups: + if len(claim_groups) > 1: + actions.append( + UserAgentReviewAction( + label=f"按 {len(claim_groups)} 张报销单生成", + action_type="split_claims", + description="保留当前识别结果,并按费用场景拆分生成多张报销草稿。", + emphasis="primary", + ) + ) + else: + actions.append( + UserAgentReviewAction( + label="确认并继续生成草稿", + action_type="confirm_review", + description="确认当前识别字段无误后,继续生成或覆盖当前报销草稿。", + emphasis="primary", + ) + ) + + for slot in payload.ontology.missing_slots[:3]: + label = SLOT_LABELS.get(slot, slot) + actions.append( + UserAgentReviewAction( + label=f"补充{label}", + action_type="fill_slot", + description=f"当前还缺少 {label},补充后可提升分单和建单准确度。", + emphasis="secondary", + ) + ) + + if self._resolve_attachment_count(payload) <= 0: + actions.append( + UserAgentReviewAction( + label="继续上传票据", + action_type="upload_more", + description="上传发票、行程单或电子票据后,系统会重新识别并完善报销分组。", + emphasis="secondary", + ) + ) + + if draft_payload is not None and draft_payload.claim_no: + actions.append( + UserAgentReviewAction( + label=f"查看草稿 {draft_payload.claim_no}", + action_type="open_claim", + description="查看当前已创建的报销草稿,并继续补充字段或附件。", + emphasis="secondary", + ) + ) + + return actions[:5] + + def _build_review_intent_summary( + self, + payload: UserAgentRequest, + *, + slot_cards: list[UserAgentReviewSlotCard], + claim_groups: list[UserAgentReviewClaimGroup], + ) -> str: + slots = {item.key: item for item in slot_cards} + expense_type = slots.get("expense_type") + amount = slots.get("amount") + time_range = slots.get("time_range") + location = slots.get("location") + customer = slots.get("customer_name") + + summary = "系统识别出您想要发起一笔报销。" + if expense_type and expense_type.value: + summary = f"系统识别出您想要报销{expense_type.value}。" + details: list[str] = [] + if customer and customer.value: + details.append(f"客户名称:{customer.value}") + if time_range and time_range.value: + details.append(f"时间:{time_range.value}") + if location and location.value: + details.append(f"地点:{location.value}") + if amount and amount.value: + details.append(f"金额:{amount.value}") + if claim_groups and len(claim_groups) > 1: + details.append(f"建议拆分为 {len(claim_groups)} 张报销单") + if details: + return f"{summary} {';'.join(details)}。" + return summary + + @staticmethod + def _should_skip_model_answer( + payload: UserAgentRequest, + review_payload: UserAgentReviewPayload | None, + ) -> bool: + if review_payload is None: + return False + return payload.ontology.scenario == "expense" and ( + payload.ontology.intent == "draft" + or int(payload.context_json.get("attachment_count") or 0) > 0 + ) + def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]: domain = self._resolve_domain(payload.ontology.scenario) items = self.asset_service.list_assets( @@ -516,6 +993,45 @@ class UserAgentService: return [] return [str(name) for name in names if str(name).strip()] + @staticmethod + def _resolve_attachment_count(payload: UserAgentRequest) -> int: + names = UserAgentService._resolve_attachment_names(payload) + if names: + return len(names) + try: + return max(0, int(payload.context_json.get("attachment_count") or 0)) + except (TypeError, ValueError): + return 0 + + @staticmethod + def _resolve_ocr_documents(payload: UserAgentRequest) -> list[dict[str, object]]: + documents = payload.context_json.get("ocr_documents") + if not isinstance(documents, list): + return [] + normalized: list[dict[str, object]] = [] + for item in documents[:8]: + if not isinstance(item, dict): + continue + normalized.append(item) + return normalized + + @staticmethod + def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]: + history = payload.context_json.get("conversation_history") + if not isinstance(history, list): + return [] + + normalized: list[dict[str, object]] = [] + for item in history[-8:]: + if not isinstance(item, dict): + continue + role = str(item.get("role") or "").strip() + content = str(item.get("content") or "").strip() + if not role or not content: + continue + normalized.append({"role": role, "content": content}) + return normalized + @staticmethod def _resolve_domain(scenario: str) -> str | None: if scenario == "expense": @@ -557,3 +1073,210 @@ class UserAgentService: if len(cleaned) >= 2: break return ";".join(cleaned[:2]) + + def _collect_entity_values(self, payload: UserAgentRequest) -> dict[str, str]: + values = { + "employee_name": "", + "customer": "", + "participants": "", + "amount": "", + "expense_type": "", + "expense_type_code": "", + } + participants: list[str] = [] + for item in payload.ontology.entities: + if item.type == "employee" and not values["employee_name"]: + values["employee_name"] = item.value + elif item.type == "customer" and not values["customer"]: + values["customer"] = item.value + elif item.type == "amount" and item.role != "threshold" and not values["amount"]: + values["amount"] = f"{item.value}元" if "元" not in item.value else item.value + elif item.type == "expense_type" and not values["expense_type_code"]: + values["expense_type_code"] = item.normalized_value + values["expense_type"] = EXPENSE_TYPE_LABELS.get( + item.normalized_value, + item.value, + ) + elif item.type in {"participant", "person"} and item.value.strip(): + participants.append(item.value.strip()) + if participants: + values["participants"] = "、".join(dict.fromkeys(participants)) + return values + + def _format_time_range(self, payload: UserAgentRequest) -> str: + time_range = payload.ontology.time_range + if time_range.raw: + return time_range.raw + if time_range.start_date and time_range.end_date: + if time_range.start_date == time_range.end_date: + return time_range.start_date + return f"{time_range.start_date} 至 {time_range.end_date}" + return "" + + def _resolve_location_value(self, payload: UserAgentRequest) -> str: + request_context = payload.context_json.get("request_context") + if isinstance(request_context, dict): + for key in ("city", "location"): + value = str(request_context.get(key) or "").strip() + if value: + return value + city_match = re.search(r"去(?P[\u4e00-\u9fa5]{2,8})(?:出差|拜访|参会|见客户|客户现场)", payload.message) + if city_match: + return city_match.group("city").strip() + if "客户现场" in payload.message.replace(" ", ""): + return "客户现场" + return "" + + def _make_slot_card( + self, + *, + key: str, + value: str, + source: str, + confidence: float, + missing_slots: set[str], + required: bool = True, + ) -> UserAgentReviewSlotCard: + is_missing = key in missing_slots or not str(value).strip() + return UserAgentReviewSlotCard( + key=key, + label=SLOT_LABELS.get(key, key), + value=str(value or "").strip(), + source=source, + confidence=confidence, + required=required, + confirmed=not is_missing and source in {"user_text", "page_context", "upload"}, + status="missing" if is_missing else "identified" if source == "user_text" else "inferred", + hint=f"建议补充 {SLOT_LABELS.get(key, key)}。" + if is_missing and required + else "", + ) + + def _classify_document( + self, + item: dict[str, object], + payload: UserAgentRequest, + ) -> dict[str, str]: + text = " ".join( + [ + str(item.get("filename") or ""), + str(item.get("summary") or ""), + str(item.get("text") or ""), + ] + ).lower() + compact = text.replace(" ", "") + expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "") + has_customer = bool(self._collect_entity_values(payload).get("customer")) + + if any(keyword in compact for keyword in ("机票", "航班", "火车", "高铁", "行程单")): + return { + "document_type": "travel_ticket", + "expense_type": "travel", + "group_code": "travel", + "scene_label": "差旅票据", + } + if any(keyword in compact for keyword in ("酒店", "住宿", "宾馆")): + return { + "document_type": "hotel_invoice", + "expense_type": "hotel", + "group_code": "travel", + "scene_label": "住宿票据", + } + if any(keyword in compact for keyword in ("打车", "出租车", "滴滴", "网约车", "过路费", "停车")): + return { + "document_type": "transport_receipt", + "expense_type": "transport", + "group_code": "travel", + "scene_label": "交通票据", + } + if any(keyword in compact for keyword in ("餐", "饭店", "酒楼", "酒家", "餐饮", "meal")): + group_code = "entertainment" if expense_type_code == "entertainment" or has_customer else "meal" + return { + "document_type": "meal_receipt", + "expense_type": group_code, + "group_code": group_code, + "scene_label": "餐饮票据", + } + return { + "document_type": "other", + "expense_type": expense_type_code or "other", + "group_code": self._normalize_group_code(expense_type_code or "other"), + "scene_label": "其他票据", + } + + @staticmethod + def _normalize_group_code(expense_type_code: str) -> str: + if expense_type_code in {"travel", "hotel", "transport"}: + return "travel" + if expense_type_code in {"entertainment", "meal"}: + return expense_type_code + return "other" + + def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]: + text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() + fields: dict[str, str] = {} + amount_match = AMOUNT_TEXT_PATTERN.search(text) + if amount_match: + fields["金额"] = f"{amount_match.group(1)}元" + date_match = DATE_TEXT_PATTERN.search(text) + if date_match: + fields["时间"] = date_match.group(1) + + merchant = self._extract_document_merchant_name(item) + if merchant: + fields["商户/酒店"] = merchant + return fields + + @staticmethod + def _extract_document_merchant_name(item: dict[str, object]) -> str: + text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip() + for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"): + if keyword in text: + return keyword + return "" + + @staticmethod + def _extract_amount_from_card(card: UserAgentReviewDocumentCard) -> float: + for item in card.fields: + if item.label != "金额": + continue + try: + return float(str(item.value).replace("元", "").strip()) + except ValueError: + return 0.0 + return 0.0 + + def _resolve_amount_value(self, payload: UserAgentRequest) -> float: + for item in payload.ontology.entities: + if item.type == "amount" and item.role != "threshold": + try: + return float(item.normalized_value) + except ValueError: + return 0.0 + return 0.0 + + def _sum_ocr_amounts(self, ocr_documents: list[dict[str, object]]) -> float: + total = 0.0 + for item in ocr_documents: + fields = self._extract_document_fields(item) + amount_text = str(fields.get("金额") or "").replace("元", "").strip() + if not amount_text: + continue + try: + total += float(amount_text) + except ValueError: + continue + return total + + def _infer_expense_type_from_documents( + self, + payload: UserAgentRequest, + ocr_documents: list[dict[str, object]], + ) -> str: + labels: list[str] = [] + for item in ocr_documents: + classified = self._classify_document(item, payload) + label = GROUP_SCENE_LABELS.get(classified["group_code"], "") + if label and label not in labels: + labels.append(label) + return " + ".join(labels[:3])