refactor(backend): update and add service layers

- services/ontology.py: update ontology service
- services/orchestrator.py: update orchestrator service
- services/user_agent.py: update user agent service
- services/settings.py: update settings service
- services/expense_claims.py: update expense claims service
- services/agent_conversations.py: add new agent conversations service
This commit is contained in:
caoxiaozhu
2026-05-12 06:36:09 +00:00
parent a6a28ba865
commit 01df3452fd
6 changed files with 1442 additions and 80 deletions

View File

@@ -0,0 +1,398 @@
from __future__ import annotations
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.agent_conversation import AgentConversation, AgentConversationMessage
from app.services.settings import SettingsService
STATEFUL_CONTEXT_KEYS = (
"entry_source",
"request_context",
"attachment_names",
"attachment_count",
"ocr_summary",
"ocr_documents",
)
DEFAULT_CONVERSATION_RETENTION_DAYS = 3
class AgentConversationService:
def __init__(self, db: Session) -> None:
self.db = db
def get_or_create_conversation(
self,
*,
conversation_id: str | None,
user_id: str | None,
source: str,
context_json: dict[str, Any],
) -> AgentConversation:
self.prune_expired_conversations()
normalized_id = str(conversation_id or "").strip()
normalized_user_id = str(user_id or "").strip() or None
conversation = self.get_conversation(normalized_id) if normalized_id else None
if conversation is not None and conversation.user_id != normalized_user_id:
normalized_id = ""
conversation = None
if conversation is None:
conversation = AgentConversation(
conversation_id=normalized_id or f"conv_{uuid.uuid4().hex[:16]}",
user_id=normalized_user_id,
source=source,
entry_source=str(context_json.get("entry_source") or "").strip() or None,
title=self._resolve_title(context_json),
state_json=self._extract_state_json(context_json),
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
if not conversation.user_id and normalized_user_id:
conversation.user_id = normalized_user_id
if not conversation.entry_source:
conversation.entry_source = str(context_json.get("entry_source") or "").strip() or None
if not conversation.title:
conversation.title = self._resolve_title(context_json)
conversation.state_json = self._merge_state_json(
conversation.state_json,
self._extract_state_json(context_json),
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
def prune_expired_conversations(
self,
*,
retention_days: int | None = None,
) -> int:
resolved_retention_days = retention_days or self._resolve_retention_days()
cutoff = datetime.now(UTC) - timedelta(days=max(1, resolved_retention_days))
stmt = select(AgentConversation).where(AgentConversation.updated_at < cutoff)
expired_conversations = list(self.db.scalars(stmt).all())
if not expired_conversations:
return 0
for conversation in expired_conversations:
self.db.delete(conversation)
self.db.commit()
return len(expired_conversations)
def _resolve_retention_days(self) -> int:
try:
settings_row, _ = SettingsService(self.db).ensure_settings_ready()
configured_days = int(
getattr(
settings_row,
"conversation_retention_days",
DEFAULT_CONVERSATION_RETENTION_DAYS,
)
or DEFAULT_CONVERSATION_RETENTION_DAYS
)
return max(1, min(configured_days, 10))
except Exception:
self.db.rollback()
return DEFAULT_CONVERSATION_RETENTION_DAYS
def get_conversation(self, conversation_id: str) -> AgentConversation | None:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return None
stmt = select(AgentConversation).where(AgentConversation.conversation_id == normalized_id)
return self.db.scalar(stmt)
def get_latest_conversation_for_user(
self,
*,
user_id: str | None,
source: str | None = "user_message",
) -> AgentConversation | None:
self.prune_expired_conversations()
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
return None
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
if source:
stmt = stmt.where(AgentConversation.source == source)
stmt = stmt.order_by(AgentConversation.updated_at.desc(), AgentConversation.created_at.desc())
return self.db.scalar(stmt.limit(1))
def hydrate_context_json(
self,
*,
conversation: AgentConversation,
context_json: dict[str, Any],
history_limit: int = 8,
) -> dict[str, Any]:
merged = dict(context_json or {})
state_json = dict(conversation.state_json or {})
merged["conversation_id"] = conversation.conversation_id
merged["conversation_history"] = self.list_message_history(
conversation.conversation_id,
limit=history_limit,
)
if conversation.last_scenario:
merged.setdefault("conversation_scenario", conversation.last_scenario)
if conversation.last_intent:
merged.setdefault("conversation_intent", conversation.last_intent)
if conversation.draft_claim_id and not str(merged.get("draft_claim_id") or "").strip():
merged["draft_claim_id"] = conversation.draft_claim_id
merged["conversation_state"] = state_json
for key in STATEFUL_CONTEXT_KEYS:
if self._is_empty_value(merged.get(key)) and not self._is_empty_value(state_json.get(key)):
merged[key] = state_json.get(key)
return merged
def append_message(
self,
*,
conversation_id: str,
role: str,
content: str,
run_id: str | None = None,
message_json: dict[str, Any] | None = None,
) -> AgentConversationMessage | None:
normalized_content = str(content or "").strip()
if not normalized_content:
return None
conversation = self.get_conversation(conversation_id)
if conversation is None:
return None
message = AgentConversationMessage(
conversation_id=conversation_id,
run_id=run_id,
role=str(role or "user").strip() or "user",
content=normalized_content,
message_json=message_json or {},
created_at=datetime.now(UTC),
)
conversation.message_count = int(conversation.message_count or 0) + 1
if role == "user" and not conversation.title:
conversation.title = normalized_content[:48]
conversation.updated_at = datetime.now(UTC)
self.db.add(message)
self.db.add(conversation)
self.db.commit()
self.db.refresh(message)
return message
def list_message_history(
self,
conversation_id: str,
*,
limit: int = 8,
) -> list[dict[str, Any]]:
normalized_id = str(conversation_id or "").strip()
if not normalized_id or limit <= 0:
return []
stmt = (
select(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == normalized_id)
.order_by(AgentConversationMessage.created_at.desc())
.limit(limit)
)
messages = list(self.db.scalars(stmt).all())
messages.reverse()
return [
{
"role": item.role,
"content": item.content,
"run_id": item.run_id,
"created_at": item.created_at.isoformat() if item.created_at else None,
}
for item in messages
]
def list_messages(
self,
conversation_id: str,
*,
limit: int | None = None,
) -> list[AgentConversationMessage]:
normalized_id = str(conversation_id or "").strip()
if not normalized_id:
return []
stmt = (
select(AgentConversationMessage)
.where(AgentConversationMessage.conversation_id == normalized_id)
.order_by(AgentConversationMessage.created_at.asc(), AgentConversationMessage.id.asc())
)
if limit and limit > 0:
stmt = stmt.limit(limit)
return list(self.db.scalars(stmt).all())
def update_state(
self,
*,
conversation_id: str,
run_id: str | None,
scenario: str | None,
intent: str | None,
context_json: dict[str, Any],
draft_payload: dict[str, Any] | None = None,
) -> AgentConversation | None:
conversation = self.get_conversation(conversation_id)
if conversation is None:
return None
conversation.last_run_id = str(run_id or "").strip() or conversation.last_run_id
conversation.last_scenario = str(scenario or "").strip() or conversation.last_scenario
conversation.last_intent = str(intent or "").strip() or conversation.last_intent
if draft_payload and str(draft_payload.get("claim_id") or "").strip():
conversation.draft_claim_id = str(draft_payload["claim_id"]).strip()
next_state = self._merge_state_json(
conversation.state_json,
self._extract_state_json(context_json),
)
if draft_payload:
if str(draft_payload.get("claim_id") or "").strip():
next_state["draft_claim_id"] = str(draft_payload["claim_id"]).strip()
if str(draft_payload.get("claim_no") or "").strip():
next_state["draft_claim_no"] = str(draft_payload["claim_no"]).strip()
if str(draft_payload.get("status") or "").strip():
next_state["draft_status"] = str(draft_payload["status"]).strip()
conversation.state_json = next_state
conversation.updated_at = datetime.now(UTC)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
def delete_user_conversations(
self,
*,
user_id: str | None,
source: str | None = "user_message",
) -> int:
normalized_user_id = str(user_id or "").strip()
if not normalized_user_id:
return 0
stmt = select(AgentConversation).where(AgentConversation.user_id == normalized_user_id)
if source:
stmt = stmt.where(AgentConversation.source == source)
conversations = list(self.db.scalars(stmt).all())
if not conversations:
return 0
for conversation in conversations:
self.db.delete(conversation)
self.db.commit()
return len(conversations)
def serialize_conversation(
self,
conversation: AgentConversation,
*,
include_messages: bool = True,
message_limit: int | None = None,
) -> dict[str, Any]:
payload = {
"conversation_id": conversation.conversation_id,
"user_id": conversation.user_id,
"source": conversation.source,
"entry_source": conversation.entry_source,
"title": conversation.title,
"last_run_id": conversation.last_run_id,
"last_scenario": conversation.last_scenario,
"last_intent": conversation.last_intent,
"draft_claim_id": conversation.draft_claim_id,
"state_json": dict(conversation.state_json or {}),
"message_count": int(conversation.message_count or 0),
"updated_at": conversation.updated_at,
"messages": [],
}
if include_messages:
payload["messages"] = [
self.serialize_message(item)
for item in self.list_messages(conversation.conversation_id, limit=message_limit)
]
return payload
@staticmethod
def serialize_message(message: AgentConversationMessage) -> dict[str, Any]:
return {
"id": message.id,
"role": message.role,
"content": message.content,
"run_id": message.run_id,
"message_json": dict(message.message_json or {}),
"created_at": message.created_at,
}
@staticmethod
def _is_empty_value(value: Any) -> bool:
if value is None:
return True
if isinstance(value, str):
return not value.strip()
if isinstance(value, (list, tuple, set, dict)):
return len(value) == 0
return False
@staticmethod
def _resolve_title(context_json: dict[str, Any]) -> str | None:
request_context = context_json.get("request_context")
if isinstance(request_context, dict):
for key in ("reason", "title", "id"):
value = str(request_context.get(key) or "").strip()
if value:
return value[:200]
return None
@staticmethod
def _extract_state_json(context_json: dict[str, Any]) -> dict[str, Any]:
state_json: dict[str, Any] = {}
for key in STATEFUL_CONTEXT_KEYS:
value = context_json.get(key)
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
if isinstance(value, (list, dict)) and not value:
continue
state_json[key] = value
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if draft_claim_id:
state_json["draft_claim_id"] = draft_claim_id
return state_json
@staticmethod
def _merge_state_json(
current_state: dict[str, Any] | None,
incoming_state: dict[str, Any] | None,
) -> dict[str, Any]:
merged = dict(current_state or {})
for key, value in (incoming_state or {}).items():
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
if isinstance(value, (list, dict)) and not value:
continue
merged[key] = value
return merged

View File

@@ -40,6 +40,7 @@ class ExpenseClaimService:
self._ensure_ready() self._ensure_ready()
claim = self._find_target_claim(ontology=ontology, context_json=context_json) claim = self._find_target_claim(ontology=ontology, context_json=context_json)
is_new_claim = claim is None
before_json = self._serialize_claim(claim) if claim is not None else None before_json = self._serialize_claim(claim) if claim is not None else None
employee = self._resolve_employee(ontology=ontology, context_json=context_json) employee = self._resolve_employee(ontology=ontology, context_json=context_json)
@@ -47,12 +48,30 @@ class ExpenseClaimService:
occurred_at = self._resolve_occurred_at(ontology) occurred_at = self._resolve_occurred_at(ontology)
expense_type = self._resolve_expense_type(ontology.entities) expense_type = self._resolve_expense_type(ontology.entities)
location = self._resolve_location(message=message, context_json=context_json) location = self._resolve_location(message=message, context_json=context_json)
reason = self._resolve_reason(message=message, context_json=context_json) reason = self._resolve_reason(
message=message,
context_json=context_json,
allow_message_fallback=is_new_claim,
)
attachment_count = self._resolve_attachment_count(context_json) attachment_count = self._resolve_attachment_count(context_json)
final_amount = amount if amount is not None else (claim.amount if claim is not None else Decimal("0.00"))
final_occurred_at = (
occurred_at if occurred_at is not None else (claim.occurred_at if claim is not None else datetime.now(UTC))
)
final_expense_type = expense_type or (claim.expense_type if claim is not None else "other")
final_location = location or (claim.location if claim is not None else "待补充")
final_reason = reason or (claim.reason if claim is not None else "待补充")
final_attachment_count = (
attachment_count if attachment_count > 0 else int(claim.invoice_count or 0) if claim is not None else 0
)
final_risk_flags = list(ontology.risk_flags) or (
list(claim.risk_flags_json or []) if claim is not None else []
)
if claim is None: if claim is None:
claim = ExpenseClaim( claim = ExpenseClaim(
claim_no=self._generate_claim_no(occurred_at), claim_no=self._generate_claim_no(final_occurred_at),
employee_id=employee.id if employee is not None else None, employee_id=employee.id if employee is not None else None,
employee_name=employee.name if employee is not None else self._resolve_employee_name( employee_name=employee.name if employee is not None else self._resolve_employee_name(
ontology=ontology, ontology=ontology,
@@ -65,16 +84,16 @@ class ExpenseClaimService:
context_json=context_json, context_json=context_json,
), ),
project_code=self._resolve_project_code(ontology.entities), project_code=self._resolve_project_code(ontology.entities),
expense_type=expense_type, expense_type=final_expense_type,
reason=reason, reason=final_reason,
location=location, location=final_location,
amount=amount, amount=final_amount,
currency="CNY", currency="CNY",
invoice_count=attachment_count, invoice_count=final_attachment_count,
occurred_at=occurred_at, occurred_at=final_occurred_at,
status="draft", status="draft",
approval_stage="待补充", approval_stage="待补充",
risk_flags_json=list(ontology.risk_flags), risk_flags_json=final_risk_flags,
) )
self.db.add(claim) self.db.add(claim)
else: else:
@@ -86,6 +105,7 @@ class ExpenseClaimService:
ontology=ontology, ontology=ontology,
context_json=context_json, context_json=context_json,
user_id=user_id, user_id=user_id,
fallback=claim.employee_name,
) )
) )
claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id claim.department_id = employee.organization_unit_id if employee is not None else claim.department_id
@@ -95,24 +115,24 @@ class ExpenseClaimService:
fallback=claim.department_name, fallback=claim.department_name,
) )
claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code claim.project_code = self._resolve_project_code(ontology.entities) or claim.project_code
claim.expense_type = expense_type or claim.expense_type claim.expense_type = final_expense_type
claim.reason = reason claim.reason = final_reason
claim.location = location claim.location = final_location
claim.amount = amount claim.amount = final_amount
claim.invoice_count = attachment_count claim.invoice_count = final_attachment_count
claim.occurred_at = occurred_at claim.occurred_at = final_occurred_at
claim.status = "draft" claim.status = "draft"
claim.approval_stage = "待补充" claim.approval_stage = "待补充"
claim.risk_flags_json = list(ontology.risk_flags) claim.risk_flags_json = final_risk_flags
self.db.flush() self.db.flush()
self._upsert_primary_item( self._upsert_primary_item(
claim=claim, claim=claim,
occurred_at=occurred_at, occurred_at=final_occurred_at,
expense_type=expense_type, expense_type=final_expense_type,
amount=amount, amount=final_amount,
reason=reason, reason=final_reason,
location=location, location=final_location,
attachment_names=self._resolve_attachment_names(context_json), attachment_names=self._resolve_attachment_names(context_json),
) )
self.db.commit() self.db.commit()
@@ -130,7 +150,7 @@ class ExpenseClaimService:
return { return {
"message": ( "message": (
f"创建报销草稿 {claim.claim_no},当前状态为 draft。" f"{'创建' if is_new_claim else '更新'}报销草稿 {claim.claim_no},当前状态为 draft。"
"你可以继续补充费用明细、客户单位和票据附件。" "你可以继续补充费用明细、客户单位和票据附件。"
), ),
"draft_only": True, "draft_only": True,
@@ -229,6 +249,7 @@ class ExpenseClaimService:
ontology: OntologyParseResult, ontology: OntologyParseResult,
context_json: dict[str, Any], context_json: dict[str, Any],
user_id: str | None, user_id: str | None,
fallback: str = "待补充",
) -> str: ) -> str:
for item in ontology.entities: for item in ontology.entities:
if item.type == "employee" and item.value.strip(): if item.type == "employee" and item.value.strip():
@@ -237,7 +258,7 @@ class ExpenseClaimService:
value = str(context_json.get(key) or "").strip() value = str(context_json.get(key) or "").strip()
if value: if value:
return value return value
return str(user_id or "待补充").strip() or "待补充" return str(user_id or fallback).strip() or fallback
@staticmethod @staticmethod
def _resolve_department_name( def _resolve_department_name(
@@ -270,26 +291,33 @@ class ExpenseClaimService:
return None return None
@staticmethod @staticmethod
def _resolve_expense_type(entities: list[OntologyEntity]) -> str: def _resolve_expense_type(entities: list[OntologyEntity]) -> str | None:
for item in entities: for item in entities:
if item.type == "expense_type": if item.type == "expense_type":
normalized = item.normalized_value.strip() normalized = item.normalized_value.strip()
if normalized: if normalized:
return normalized return normalized
return "other" return None
@staticmethod @staticmethod
def _resolve_reason(*, message: str, context_json: dict[str, Any]) -> str: def _resolve_reason(
*,
message: str,
context_json: dict[str, Any],
allow_message_fallback: bool,
) -> str | None:
request_context = context_json.get("request_context") request_context = context_json.get("request_context")
if isinstance(request_context, dict): if isinstance(request_context, dict):
for key in ("reason", "title"): for key in ("reason", "title"):
value = str(request_context.get(key) or "").strip() value = str(request_context.get(key) or "").strip()
if value: if value:
return value return value
return str(message or "").strip()[:500] or "待补充" if not allow_message_fallback:
return None
return str(message or "").strip()[:500] or None
@staticmethod @staticmethod
def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str: def _resolve_location(*, message: str, context_json: dict[str, Any]) -> str | None:
request_context = context_json.get("request_context") request_context = context_json.get("request_context")
if isinstance(request_context, dict): if isinstance(request_context, dict):
for key in ("city", "location"): for key in ("city", "location"):
@@ -299,10 +327,10 @@ class ExpenseClaimService:
compact = str(message or "").replace(" ", "") compact = str(message or "").replace(" ", "")
if "客户现场" in compact: if "客户现场" in compact:
return "客户现场" return "客户现场"
return "待补充" return None
@staticmethod @staticmethod
def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime: def _resolve_occurred_at(ontology: OntologyParseResult) -> datetime | None:
start_date = ontology.time_range.start_date start_date = ontology.time_range.start_date
if start_date: if start_date:
try: try:
@@ -310,10 +338,10 @@ class ExpenseClaimService:
return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC) return datetime(parsed.year, parsed.month, parsed.day, tzinfo=UTC)
except ValueError: except ValueError:
pass pass
return datetime.now(UTC) return None
@staticmethod @staticmethod
def _resolve_amount(entities: list[OntologyEntity]) -> Decimal: def _resolve_amount(entities: list[OntologyEntity]) -> Decimal | None:
for item in entities: for item in entities:
if item.type != "amount" or item.role == "threshold": if item.type != "amount" or item.role == "threshold":
continue continue
@@ -321,7 +349,7 @@ class ExpenseClaimService:
return Decimal(item.normalized_value).quantize(Decimal("0.01")) return Decimal(item.normalized_value).quantize(Decimal("0.01"))
except (InvalidOperation, ValueError): except (InvalidOperation, ValueError):
continue continue
return Decimal("0.00") return None
@staticmethod @staticmethod
def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]: def _resolve_attachment_names(context_json: dict[str, Any]) -> list[str]:

View File

@@ -120,6 +120,24 @@ EXPLAIN_KEYWORDS = ("为什么", "依据", "原因", "怎么处理", "是否可
COMPARE_KEYWORDS = ("对比", "比较", "相比", "差异", "变化") COMPARE_KEYWORDS = ("对比", "比较", "相比", "差异", "变化")
RISK_KEYWORDS = ("风险", "异常", "重复", "超标", "超预算", "逾期", "验真", "巡检") RISK_KEYWORDS = ("风险", "异常", "重复", "超标", "超预算", "逾期", "验真", "巡检")
DRAFT_KEYWORDS = ("生成", "草稿", "起草", "拟一份", "创建", "发起", "准备") DRAFT_KEYWORDS = ("生成", "草稿", "起草", "拟一份", "创建", "发起", "准备")
DRAFT_FOLLOW_UP_KEYWORDS = (
"继续",
"补充",
"补一下",
"修改",
"改成",
"改为",
"换成",
"更新",
"确认",
"提交",
"保存",
"客户是",
"地点是",
"金额是",
"日期是",
"时间是",
)
OPERATE_KEYWORDS = ( OPERATE_KEYWORDS = (
"直接付款", "直接付款",
"帮我付款", "帮我付款",
@@ -200,6 +218,7 @@ STATUS_KEYWORDS = {
} }
PRIVILEGED_ROLE_CODES = {"manager", "finance", "approver", "executive"} PRIVILEGED_ROLE_CODES = {"manager", "finance", "approver", "executive"}
CONTEXTUAL_SCENARIOS = {"expense", "accounts_receivable", "accounts_payable", "knowledge"}
@dataclass(slots=True) @dataclass(slots=True)
@@ -289,12 +308,17 @@ class SemanticOntologyService:
raise ValueError("query 不能为空。") raise ValueError("query 不能为空。")
AgentFoundationService(self.db).ensure_foundation_ready() AgentFoundationService(self.db).ensure_foundation_ready()
context_json = payload.context_json or {}
reference = self._load_reference_catalog() reference = self._load_reference_catalog()
compact_query = self._compact(query) compact_query = self._compact(query)
entities = self._extract_entities(query, compact_query, reference) entities = self._extract_entities(query, compact_query, reference)
rule_scenario, scenario_score = self._detect_scenario(compact_query) rule_scenario, scenario_score = self._detect_scenario(compact_query)
time_range, _time_score = self._extract_time_range(query, compact_query) time_range, _time_score = self._extract_time_range(query, compact_query)
context_scenario = self._resolve_context_scenario(context_json)
if rule_scenario == "unknown" and context_scenario is not None:
rule_scenario = context_scenario
scenario_score = max(scenario_score, 0.14)
if rule_scenario == "unknown": if rule_scenario == "unknown":
inferred_scenario = self._infer_scenario_from_entities(entities) inferred_scenario = self._infer_scenario_from_entities(entities)
if inferred_scenario is not None: if inferred_scenario is not None:
@@ -316,6 +340,17 @@ class SemanticOntologyService:
entities=entities, entities=entities,
time_range=time_range, time_range=time_range,
) )
if self._should_inherit_expense_draft(
compact_query,
scenario=rule_scenario,
entities=entities,
time_range=time_range,
context_json=context_json,
):
rule_scenario = "expense"
rule_intent = "draft"
scenario_score = max(scenario_score, 0.18)
intent_score = max(intent_score, 0.18)
metrics = self._extract_metrics(compact_query) metrics = self._extract_metrics(compact_query)
constraints = self._extract_constraints(compact_query, entities) constraints = self._extract_constraints(compact_query, entities)
model_parse = self._parse_with_model( model_parse = self._parse_with_model(
@@ -353,7 +388,7 @@ class SemanticOntologyService:
intent=intent, intent=intent,
entities=entities, entities=entities,
time_range=time_range, time_range=time_range,
context_json=payload.context_json or {}, context_json=context_json,
) )
) )
ambiguity = self._normalize_short_text_list( ambiguity = self._normalize_short_text_list(
@@ -362,7 +397,7 @@ class SemanticOntologyService:
risk_flags = self._extract_risk_flags(compact_query, scenario) risk_flags = self._extract_risk_flags(compact_query, scenario)
permission = self._resolve_permission( permission = self._resolve_permission(
compact_query, compact_query,
payload.context_json or {}, context_json,
intent, intent,
) )
@@ -524,6 +559,13 @@ class SemanticOntologyService:
def _compact(text: str) -> str: def _compact(text: str) -> str:
return re.sub(r"\s+", "", text).lower() return re.sub(r"\s+", "", text).lower()
@staticmethod
def _resolve_context_scenario(context_json: dict[str, Any]) -> str | None:
value = str(context_json.get("conversation_scenario") or "").strip()
if value in CONTEXTUAL_SCENARIOS:
return value
return None
def _detect_scenario(self, compact_query: str) -> tuple[str, float]: def _detect_scenario(self, compact_query: str) -> tuple[str, float]:
scores = {key: 0.0 for key in SCENARIO_KEYWORDS} scores = {key: 0.0 for key in SCENARIO_KEYWORDS}
for scenario, keywords in SCENARIO_KEYWORDS.items(): for scenario, keywords in SCENARIO_KEYWORDS.items():
@@ -581,6 +623,68 @@ class SemanticOntologyService:
return "draft", 0.22 return "draft", 0.22
return "query", 0.10 return "query", 0.10
@staticmethod
def _looks_like_follow_up_message(compact_query: str) -> bool:
if not compact_query:
return False
if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS):
return True
if compact_query.startswith(("", "", "", "这个", "那个")):
return True
has_domain_keyword = any(
keyword in compact_query
for keyword, _weight in (
*SCENARIO_KEYWORDS["expense"],
*SCENARIO_KEYWORDS["accounts_receivable"],
*SCENARIO_KEYWORDS["accounts_payable"],
*SCENARIO_KEYWORDS["knowledge"],
)
)
return len(compact_query) <= 12 and not has_domain_keyword
def _should_inherit_expense_draft(
self,
compact_query: str,
*,
scenario: str,
entities: list[OntologyEntity],
time_range: OntologyTimeRange,
context_json: dict[str, Any],
) -> bool:
context_scenario = self._resolve_context_scenario(context_json)
draft_claim_id = str(context_json.get("draft_claim_id") or "").strip()
if context_scenario != "expense" and not draft_claim_id:
return False
if any(keyword in compact_query for keyword in DRAFT_FOLLOW_UP_KEYWORDS):
return True
if self._looks_like_expense_narrative(
compact_query,
scenario="expense",
entities=entities,
time_range=time_range,
):
return True
if self._looks_like_follow_up_message(compact_query):
return True
if any(keyword in compact_query for keyword in OPERATE_KEYWORDS):
return False
if any(keyword in compact_query for keyword in COMPARE_KEYWORDS + RISK_KEYWORDS):
return False
if any(keyword in compact_query for keyword in QUERY_KEYWORDS):
return False
return bool(
draft_claim_id
and any(
item.type
in {"amount", "customer", "employee", "expense_type", "project", "invoice"}
for item in entities
)
)
@staticmethod @staticmethod
def _is_generic_expense_prompt(compact_query: str) -> bool: def _is_generic_expense_prompt(compact_query: str) -> bool:
return compact_query in GENERIC_EXPENSE_PROMPTS return compact_query in GENERIC_EXPENSE_PROMPTS
@@ -670,6 +774,11 @@ class SemanticOntologyService:
"ocr_documents": payload.context_json.get("ocr_documents", []), "ocr_documents": payload.context_json.get("ocr_documents", []),
"request_context": payload.context_json.get("request_context"), "request_context": payload.context_json.get("request_context"),
"role_codes": payload.context_json.get("role_codes", []), "role_codes": payload.context_json.get("role_codes", []),
"conversation_id": payload.context_json.get("conversation_id"),
"conversation_scenario": payload.context_json.get("conversation_scenario"),
"conversation_intent": payload.context_json.get("conversation_intent"),
"draft_claim_id": payload.context_json.get("draft_claim_id"),
"conversation_history": payload.context_json.get("conversation_history", []),
}, },
"rule_candidates": { "rule_candidates": {
"scenario": fallback_scenario, "scenario": fallback_scenario,
@@ -690,6 +799,8 @@ class SemanticOntologyService:
"意图 intent 只能是query, explain, compare, risk_check, draft, operate。" "意图 intent 只能是query, explain, compare, risk_check, draft, operate。"
"如果用户是在描述一笔待处理费用、待报销事项、上传票据或希望整理报销," "如果用户是在描述一笔待处理费用、待报销事项、上传票据或希望整理报销,"
"即使没有明确说“生成草稿”,也优先使用 expense + draft。" "即使没有明确说“生成草稿”,也优先使用 expense + draft。"
"如果提供了 conversation_history必须把最近轮次作为当前追问的上下文"
"正确理解“这个”“那笔”“改成 800”“继续补充”这类省略表达。"
"出现“客户”不等于应收,出现“供应商”不等于应付,必须结合动作词和业务目标判断。" "出现“客户”不等于应收,出现“供应商”不等于应付,必须结合动作词和业务目标判断。"
"只有明确查询、统计、列出、多少、明细、对比时才优先使用 query 或 compare。" "只有明确查询、统计、列出、多少、明细、对比时才优先使用 query 或 compare。"
"附件名称和 OCR 摘要只作为辅助证据,不能编造未出现的事实。" "附件名称和 OCR 摘要只作为辅助证据,不能编造未出现的事实。"

View File

@@ -32,6 +32,7 @@ from app.schemas.orchestrator import (
) )
from app.schemas.user_agent import UserAgentRequest, UserAgentResponse from app.schemas.user_agent import UserAgentRequest, UserAgentResponse
from app.services.agent_assets import AgentAssetService from app.services.agent_assets import AgentAssetService
from app.services.agent_conversations import AgentConversationService
from app.services.expense_claims import ExpenseClaimService from app.services.expense_claims import ExpenseClaimService
from app.services.agent_foundation import AgentFoundationService from app.services.agent_foundation import AgentFoundationService
from app.services.agent_runs import AgentRunService from app.services.agent_runs import AgentRunService
@@ -62,6 +63,7 @@ class OrchestratorService:
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
self.db = db self.db = db
self.asset_service = AgentAssetService(db) self.asset_service = AgentAssetService(db)
self.conversation_service = AgentConversationService(db)
self.expense_claim_service = ExpenseClaimService(db) self.expense_claim_service = ExpenseClaimService(db)
self.run_service = AgentRunService(db) self.run_service = AgentRunService(db)
self.ontology_service = SemanticOntologyService(db) self.ontology_service = SemanticOntologyService(db)
@@ -69,10 +71,28 @@ class OrchestratorService:
def run(self, payload: OrchestratorRequest) -> OrchestratorResponse: def run(self, payload: OrchestratorRequest) -> OrchestratorResponse:
AgentFoundationService(self.db).ensure_foundation_ready() AgentFoundationService(self.db).ensure_foundation_ready()
context_json = dict(payload.context_json or {})
conversation_id = str(payload.conversation_id or "").strip() or None
conversation = None
if payload.source == AgentRunSource.USER_MESSAGE.value:
conversation = self.conversation_service.get_or_create_conversation(
conversation_id=conversation_id,
user_id=payload.user_id,
source=payload.source,
context_json=context_json,
)
conversation_id = conversation.conversation_id
context_json = self.conversation_service.hydrate_context_json(
conversation=conversation,
context_json=context_json,
)
route_json: dict[str, Any] = { route_json: dict[str, Any] = {
"orchestrated_by": AgentName.ORCHESTRATOR.value, "orchestrated_by": AgentName.ORCHESTRATOR.value,
"stage": "created", "stage": "created",
} }
if conversation_id:
route_json["conversation_id"] = conversation_id
run = self.run_service.create_run( run = self.run_service.create_run(
agent=AgentName.ORCHESTRATOR.value, agent=AgentName.ORCHESTRATOR.value,
source=payload.source, source=payload.source,
@@ -87,15 +107,27 @@ class OrchestratorService:
try: try:
message, task_asset = self._resolve_message(payload) message, task_asset = self._resolve_message(payload)
if conversation is not None:
self.conversation_service.append_message(
conversation_id=conversation.conversation_id,
role="user",
content=message,
run_id=run.run_id,
message_json={
"attachment_names": context_json.get("attachment_names", []),
"attachment_count": context_json.get("attachment_count", 0),
"ocr_summary": context_json.get("ocr_summary", ""),
},
)
ontology = self.ontology_service.parse_for_run( ontology = self.ontology_service.parse_for_run(
OntologyParseRequest( OntologyParseRequest(
query=message, query=message,
user_id=payload.user_id, user_id=payload.user_id,
context_json=payload.context_json, context_json=context_json,
), ),
run_id=run.run_id, run_id=run.run_id,
) )
if payload.context_json.get("simulate_orchestrator_exception"): if context_json.get("simulate_orchestrator_exception"):
raise RuntimeError("simulated orchestrator exception") raise RuntimeError("simulated orchestrator exception")
selected_agent, route_reason = self._select_agent(payload, ontology) selected_agent, route_reason = self._select_agent(payload, ontology)
capabilities = self._select_capabilities( capabilities = self._select_capabilities(
@@ -159,6 +191,7 @@ class OrchestratorService:
capabilities=capabilities, capabilities=capabilities,
requires_confirmation=requires_confirmation, requires_confirmation=requires_confirmation,
task_asset=task_asset, task_asset=task_asset,
context_json=context_json,
) )
else: else:
outcome = self._execute_user_agent( outcome = self._execute_user_agent(
@@ -167,6 +200,7 @@ class OrchestratorService:
ontology=ontology, ontology=ontology,
capabilities=capabilities, capabilities=capabilities,
requires_confirmation=requires_confirmation, requires_confirmation=requires_confirmation,
context_json=context_json,
) )
final_status = ( final_status = (
@@ -176,10 +210,19 @@ class OrchestratorService:
and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value and ontology.permission.level == AgentPermissionLevel.APPROVAL_REQUIRED.value
else outcome.status else outcome.status
) )
response_status = self._normalize_response_status(final_status)
result_message = ( result_message = (
str(outcome.result.get("message", "")).strip() str(outcome.result.get("message", "")).strip()
or "Orchestrator 执行完成。" or "Orchestrator 执行完成。"
) )
trace_summary = OrchestratorTraceSummary(
scenario=ontology.scenario,
intent=ontology.intent,
tool_count=outcome.tool_count,
failed_tool_count=outcome.failed_tool_count,
selected_capability_codes=selected_capability_codes,
degraded=outcome.degraded,
)
self.run_service.update_run( self.run_service.update_run(
run.run_id, run.run_id,
agent=selected_agent or AgentName.ORCHESTRATOR.value, agent=selected_agent or AgentName.ORCHESTRATOR.value,
@@ -195,22 +238,51 @@ class OrchestratorService:
error_message=None, error_message=None,
finished_at=datetime.now(UTC), finished_at=datetime.now(UTC),
) )
if conversation is not None and conversation_id:
draft_payload = outcome.result.get("draft_payload")
self.conversation_service.update_state(
conversation_id=conversation_id,
run_id=run.run_id,
scenario=ontology.scenario,
intent=ontology.intent,
context_json=context_json,
draft_payload=draft_payload if isinstance(draft_payload, dict) else None,
)
self.conversation_service.append_message(
conversation_id=conversation_id,
role="assistant",
content=result_message,
run_id=run.run_id,
message_json={
"status": final_status,
"scenario": ontology.scenario,
"intent": ontology.intent,
"attachment_names": context_json.get("attachment_names", []),
"attachment_count": context_json.get("attachment_count", 0),
"draft_payload": draft_payload if isinstance(draft_payload, dict) else None,
"orchestrator_payload": {
"run_id": run.run_id,
"conversation_id": conversation_id,
"selected_agent": selected_agent,
"route_reason": route_reason,
"permission_level": ontology.permission.level,
"status": response_status,
"requires_confirmation": requires_confirmation,
"trace_summary": trace_summary.model_dump(),
"result": outcome.result,
},
},
)
return OrchestratorResponse( return OrchestratorResponse(
run_id=run.run_id, run_id=run.run_id,
conversation_id=conversation_id,
selected_agent=selected_agent, selected_agent=selected_agent,
route_reason=route_reason, route_reason=route_reason,
permission_level=ontology.permission.level, permission_level=ontology.permission.level,
status=self._normalize_response_status(final_status), status=response_status,
result=outcome.result, result=outcome.result,
requires_confirmation=requires_confirmation, requires_confirmation=requires_confirmation,
trace_summary=OrchestratorTraceSummary( trace_summary=trace_summary,
scenario=ontology.scenario,
intent=ontology.intent,
tool_count=outcome.tool_count,
failed_tool_count=outcome.failed_tool_count,
selected_capability_codes=selected_capability_codes,
degraded=outcome.degraded,
),
) )
except Exception as exc: except Exception as exc:
logger.exception("Orchestrator run failed run_id=%s", run.run_id) logger.exception("Orchestrator run failed run_id=%s", run.run_id)
@@ -223,8 +295,25 @@ class OrchestratorService:
error_message=str(exc), error_message=str(exc),
finished_at=datetime.now(UTC), finished_at=datetime.now(UTC),
) )
if conversation is not None and conversation_id:
self.conversation_service.update_state(
conversation_id=conversation_id,
run_id=run.run_id,
scenario=None,
intent=None,
context_json=context_json,
draft_payload=None,
)
self.conversation_service.append_message(
conversation_id=conversation_id,
role="assistant",
content=f"Orchestrator 执行失败:{exc}",
run_id=run.run_id,
message_json={"status": AgentRunStatus.FAILED.value},
)
return OrchestratorResponse( return OrchestratorResponse(
run_id=run.run_id, run_id=run.run_id,
conversation_id=conversation_id,
selected_agent=None, selected_agent=None,
route_reason="orchestrator_exception", route_reason="orchestrator_exception",
permission_level=AgentPermissionLevel.READ.value, permission_level=AgentPermissionLevel.READ.value,
@@ -336,6 +425,7 @@ class OrchestratorService:
ontology: OntologyParseResult, ontology: OntologyParseResult,
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
requires_confirmation: bool, requires_confirmation: bool,
context_json: dict[str, Any],
) -> ExecutionOutcome: ) -> ExecutionOutcome:
selected_capability_codes = self._flatten_capability_codes(capabilities) selected_capability_codes = self._flatten_capability_codes(capabilities)
if requires_confirmation: if requires_confirmation:
@@ -347,7 +437,7 @@ class OrchestratorService:
"message": payload.message, "message": payload.message,
"permission_level": ontology.permission.level, "permission_level": ontology.permission.level,
}, },
context_json=payload.context_json, context_json=context_json,
executor=lambda: { executor=lambda: {
"confirmation_title": "操作需要确认", "confirmation_title": "操作需要确认",
"message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。", "message": f"{ontology.permission.reason} 当前仅返回确认摘要,不直接执行动作。",
@@ -372,7 +462,7 @@ class OrchestratorService:
tool_type=AgentToolType.DATABASE.value, tool_type=AgentToolType.DATABASE.value,
tool_name=self._database_tool_name(ontology.scenario), tool_name=self._database_tool_name(ontology.scenario),
request_json=self._build_ontology_json(ontology), request_json=self._build_ontology_json(ontology),
context_json=payload.context_json, context_json=context_json,
executor=lambda: self._build_database_answer(ontology), executor=lambda: self._build_database_answer(ontology),
fallback_factory=lambda exc: { fallback_factory=lambda exc: {
"message": f"数据库查询暂时不可用,已返回降级说明:{exc}", "message": f"数据库查询暂时不可用,已返回降级说明:{exc}",
@@ -386,7 +476,7 @@ class OrchestratorService:
user_id=payload.user_id, user_id=payload.user_id,
message=payload.message or "", message=payload.message or "",
ontology=ontology, ontology=ontology,
context_json=payload.context_json, context_json=context_json,
tool_payload=tool_payload, tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes, selected_capability_codes=selected_capability_codes,
degraded=degraded, degraded=degraded,
@@ -409,7 +499,7 @@ class OrchestratorService:
tool_type=AgentToolType.DATABASE.value, tool_type=AgentToolType.DATABASE.value,
tool_name="knowledge.search", tool_name="knowledge.search",
request_json=self._build_ontology_json(ontology), request_json=self._build_ontology_json(ontology),
context_json=payload.context_json, context_json=context_json,
executor=lambda: self._build_knowledge_answer(ontology, capabilities), executor=lambda: self._build_knowledge_answer(ontology, capabilities),
fallback_factory=lambda exc: { fallback_factory=lambda exc: {
"message": f"知识检索暂时不可用,建议稍后重试:{exc}", "message": f"知识检索暂时不可用,建议稍后重试:{exc}",
@@ -423,7 +513,7 @@ class OrchestratorService:
user_id=payload.user_id, user_id=payload.user_id,
message=payload.message or "", message=payload.message or "",
ontology=ontology, ontology=ontology,
context_json=payload.context_json, context_json=context_json,
tool_payload=tool_payload, tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes, selected_capability_codes=selected_capability_codes,
degraded=degraded, degraded=degraded,
@@ -446,7 +536,7 @@ class OrchestratorService:
tool_type=AgentToolType.RULE_ENGINE.value, tool_type=AgentToolType.RULE_ENGINE.value,
tool_name=self._rule_tool_name(capabilities), tool_name=self._rule_tool_name(capabilities),
request_json=self._build_ontology_json(ontology), request_json=self._build_ontology_json(ontology),
context_json=payload.context_json, context_json=context_json,
executor=lambda: self._build_rule_answer(ontology), executor=lambda: self._build_rule_answer(ontology),
fallback_factory=lambda exc: { fallback_factory=lambda exc: {
"message": f"规则检查暂时不可用,已返回人工复核建议:{exc}", "message": f"规则检查暂时不可用,已返回人工复核建议:{exc}",
@@ -460,7 +550,7 @@ class OrchestratorService:
user_id=payload.user_id, user_id=payload.user_id,
message=payload.message or "", message=payload.message or "",
ontology=ontology, ontology=ontology,
context_json=payload.context_json, context_json=context_json,
tool_payload=tool_payload, tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes, selected_capability_codes=selected_capability_codes,
degraded=degraded, degraded=degraded,
@@ -499,7 +589,7 @@ class OrchestratorService:
user_id=payload.user_id, user_id=payload.user_id,
message=payload.message or "", message=payload.message or "",
ontology=ontology, ontology=ontology,
context_json=payload.context_json, context_json=context_json,
) )
fallback_factory = lambda exc: { fallback_factory = lambda exc: {
"message": f"报销草稿落库失败,请稍后再试:{exc}", "message": f"报销草稿落库失败,请稍后再试:{exc}",
@@ -511,7 +601,7 @@ class OrchestratorService:
tool_type=tool_type, tool_type=tool_type,
tool_name=tool_name, tool_name=tool_name,
request_json=self._build_ontology_json(ontology), request_json=self._build_ontology_json(ontology),
context_json=payload.context_json, context_json=context_json,
executor=executor, executor=executor,
fallback_factory=fallback_factory, fallback_factory=fallback_factory,
) )
@@ -522,7 +612,7 @@ class OrchestratorService:
user_id=payload.user_id, user_id=payload.user_id,
message=payload.message or "", message=payload.message or "",
ontology=ontology, ontology=ontology,
context_json=payload.context_json, context_json=context_json,
tool_payload=tool_payload, tool_payload=tool_payload,
selected_capability_codes=selected_capability_codes, selected_capability_codes=selected_capability_codes,
degraded=degraded, degraded=degraded,
@@ -548,6 +638,7 @@ class OrchestratorService:
capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]], capabilities: dict[str, list[AgentAssetListItem | AgentAssetRead]],
requires_confirmation: bool, requires_confirmation: bool,
task_asset: AgentAssetRead | None, task_asset: AgentAssetRead | None,
context_json: dict[str, Any],
) -> ExecutionOutcome: ) -> ExecutionOutcome:
if requires_confirmation: if requires_confirmation:
return ExecutionOutcome( return ExecutionOutcome(
@@ -566,7 +657,7 @@ class OrchestratorService:
tool_type=AgentToolType.RULE_ENGINE.value, tool_type=AgentToolType.RULE_ENGINE.value,
tool_name=self._rule_tool_name(capabilities), tool_name=self._rule_tool_name(capabilities),
request_json=self._build_ontology_json(ontology), request_json=self._build_ontology_json(ontology),
context_json=payload.context_json, context_json=context_json,
executor=lambda: self._build_rule_answer(ontology), executor=lambda: self._build_rule_answer(ontology),
fallback_factory=lambda exc: { fallback_factory=lambda exc: {
"message": f"规则巡检失败,已降级为待人工复核:{exc}", "message": f"规则巡检失败,已降级为待人工复核:{exc}",
@@ -581,7 +672,7 @@ class OrchestratorService:
"task_code": task_asset.code if task_asset is not None else "", "task_code": task_asset.code if task_asset is not None else "",
"scenario": ontology.scenario, "scenario": ontology.scenario,
}, },
context_json=payload.context_json, context_json=context_json,
executor=lambda: self._build_mcp_answer(task_asset, ontology), executor=lambda: self._build_mcp_answer(task_asset, ontology),
fallback_factory=lambda exc: { fallback_factory=lambda exc: {
"message": f"MCP 调用失败,已使用缓存快照降级:{exc}", "message": f"MCP 调用失败,已使用缓存快照降级:{exc}",
@@ -806,6 +897,8 @@ class OrchestratorService:
} }
if response.draft_payload is not None: if response.draft_payload is not None:
result["draft_payload"] = response.draft_payload.model_dump() result["draft_payload"] = response.draft_payload.model_dump()
if response.review_payload is not None:
result["review_payload"] = response.review_payload.model_dump()
return result return result
@staticmethod @staticmethod

View File

@@ -204,6 +204,7 @@ class SettingsService:
settings_row.admin_account = payload.adminForm.adminAccount settings_row.admin_account = payload.adminForm.adminAccount
settings_row.admin_email = payload.adminForm.adminEmail settings_row.admin_email = payload.adminForm.adminEmail
settings_row.session_timeout = payload.adminForm.sessionTimeout settings_row.session_timeout = payload.adminForm.sessionTimeout
settings_row.conversation_retention_days = payload.sessionForm.conversationRetentionDays
settings_row.notice_email = payload.adminForm.noticeEmail settings_row.notice_email = payload.adminForm.noticeEmail
settings_row.mfa_enabled = payload.adminForm.mfaEnabled settings_row.mfa_enabled = payload.adminForm.mfaEnabled
settings_row.strong_password = payload.adminForm.strongPassword settings_row.strong_password = payload.adminForm.strongPassword
@@ -429,6 +430,7 @@ class SettingsService:
admin_account=admin_account, admin_account=admin_account,
admin_email=admin_email, admin_email=admin_email,
session_timeout=30, session_timeout=30,
conversation_retention_days=3,
notice_email=admin_email, notice_email=admin_email,
mfa_enabled=True, mfa_enabled=True,
strong_password=True, strong_password=True,
@@ -520,6 +522,10 @@ class SettingsService:
if "system_settings" in table_names: if "system_settings" in table_names:
settings_columns = {column["name"] for column in inspector.get_columns("system_settings")} settings_columns = {column["name"] for column in inspector.get_columns("system_settings")}
if "conversation_retention_days" not in settings_columns:
migration_statements.append(
"ALTER TABLE system_settings ADD COLUMN conversation_retention_days INTEGER DEFAULT 3"
)
if "onlyoffice_enabled" not in settings_columns: if "onlyoffice_enabled" not in settings_columns:
migration_statements.append( migration_statements.append(
"ALTER TABLE system_settings ADD COLUMN onlyoffice_enabled BOOLEAN DEFAULT FALSE" "ALTER TABLE system_settings ADD COLUMN onlyoffice_enabled BOOLEAN DEFAULT FALSE"
@@ -600,6 +606,9 @@ class SettingsService:
"loginAlertEnabled": settings_row.login_alert_enabled, "loginAlertEnabled": settings_row.login_alert_enabled,
"adminPasswordConfigured": bool(secrets_row.admin_password_hash), "adminPasswordConfigured": bool(secrets_row.admin_password_hash),
}, },
sessionForm={
"conversationRetentionDays": settings_row.conversation_retention_days,
},
llmForm={ llmForm={
"mainProvider": main_model.provider, "mainProvider": main_model.provider,
"mainModel": main_model.model_name, "mainModel": main_model.model_name,

View File

@@ -2,14 +2,24 @@ from __future__ import annotations
import json import json
import re import re
from datetime import UTC, datetime, timedelta
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.agent_enums import AgentAssetStatus, AgentAssetType from app.core.agent_enums import AgentAssetStatus, AgentAssetType
from app.models.financial_record import ExpenseClaim
from app.schemas.agent_asset import AgentAssetListItem from app.schemas.agent_asset import AgentAssetListItem
from app.schemas.user_agent import ( from app.schemas.user_agent import (
UserAgentCitation, UserAgentCitation,
UserAgentDraftPayload, UserAgentDraftPayload,
UserAgentReviewAction,
UserAgentReviewClaimGroup,
UserAgentReviewDocumentCard,
UserAgentReviewDocumentField,
UserAgentReviewPayload,
UserAgentReviewRiskBrief,
UserAgentReviewSlotCard,
UserAgentRequest, UserAgentRequest,
UserAgentResponse, UserAgentResponse,
UserAgentSuggestedAction, UserAgentSuggestedAction,
@@ -53,8 +63,32 @@ EXPENSE_TYPE_LABELS = {
"meal": "餐费", "meal": "餐费",
"meeting": "会务", "meeting": "会务",
"entertainment": "招待", "entertainment": "招待",
"other": "其他",
} }
GROUP_SCENE_LABELS = {
"travel": "差旅费",
"entertainment": "业务招待费",
"meal": "伙食费",
"transport": "交通费",
"hotel": "住宿费",
"other": "其他费用",
}
SLOT_LABELS = {
"expense_type": "报销类型",
"customer_name": "客户名称",
"time_range": "发生时间",
"location": "地点",
"merchant_name": "酒店/商户",
"amount": "金额",
"participants": "参与人员",
"attachments": "票据附件",
}
DATE_TEXT_PATTERN = re.compile(r"(\d{4}[年/-]\d{1,2}[月/-]\d{1,2}日?)")
AMOUNT_TEXT_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*(?:元|万元|万)")
class UserAgentService: class UserAgentService:
def __init__(self, db: Session) -> None: def __init__(self, db: Session) -> None:
@@ -72,23 +106,32 @@ class UserAgentService:
if payload.ontology.intent == "draft" if payload.ontology.intent == "draft"
else None else None
) )
review_payload = self._build_review_payload(
payload,
citations=citations,
draft_payload=draft_payload,
)
if payload.degraded and payload.tool_payload.get("message"): if payload.degraded and payload.tool_payload.get("message"):
return UserAgentResponse( return UserAgentResponse(
answer=str(payload.tool_payload["message"]), answer=str(payload.tool_payload["message"]),
citations=citations, citations=citations,
suggested_actions=suggested_actions, suggested_actions=suggested_actions,
review_payload=review_payload,
risk_flags=risk_flags, risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation, requires_confirmation=payload.requires_confirmation,
) )
guided_answer = self._build_guided_answer(payload) guided_answer = None
if draft_payload is None or draft_payload.claim_id is None:
guided_answer = self._build_guided_answer(payload)
if guided_answer: if guided_answer:
return UserAgentResponse( return UserAgentResponse(
answer=guided_answer, answer=guided_answer,
citations=citations, citations=citations,
suggested_actions=suggested_actions, suggested_actions=suggested_actions,
draft_payload=draft_payload, draft_payload=draft_payload,
review_payload=review_payload,
risk_flags=risk_flags, risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation, requires_confirmation=payload.requires_confirmation,
) )
@@ -98,20 +141,23 @@ class UserAgentService:
citations=citations, citations=citations,
draft_payload=draft_payload, draft_payload=draft_payload,
) )
answer = self._generate_answer_with_model( answer = None
payload, if not self._should_skip_model_answer(payload, review_payload):
citations=citations, answer = self._generate_answer_with_model(
suggested_actions=suggested_actions, payload,
risk_flags=risk_flags, citations=citations,
draft_payload=draft_payload, suggested_actions=suggested_actions,
fallback_answer=fallback_answer, risk_flags=risk_flags,
) draft_payload=draft_payload,
fallback_answer=fallback_answer,
)
return UserAgentResponse( return UserAgentResponse(
answer=answer or fallback_answer, answer=answer or fallback_answer,
citations=citations, citations=citations,
suggested_actions=suggested_actions, suggested_actions=suggested_actions,
draft_payload=draft_payload, draft_payload=draft_payload,
review_payload=review_payload,
risk_flags=risk_flags, risk_flags=risk_flags,
requires_confirmation=payload.requires_confirmation, requires_confirmation=payload.requires_confirmation,
) )
@@ -129,6 +175,13 @@ class UserAgentService:
if payload.ontology.intent == "risk_check": if payload.ontology.intent == "risk_check":
return self._build_risk_answer(payload, citations) return self._build_risk_answer(payload, citations)
if payload.ontology.intent == "draft":
tool_message = str(payload.tool_payload.get("message") or "").strip()
if tool_message and (
str(payload.tool_payload.get("claim_id") or "").strip()
or str(payload.tool_payload.get("claim_no") or "").strip()
):
return tool_message
if payload.ontology.intent == "draft" and draft_payload is not None: if payload.ontology.intent == "draft" and draft_payload is not None:
return ( return (
f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容," f"已生成 {draft_payload.title},当前仅返回待人工确认的草稿内容,"
@@ -243,6 +296,11 @@ class UserAgentService:
"attachment_names": self._resolve_attachment_names(payload), "attachment_names": self._resolve_attachment_names(payload),
"ocr_summary": payload.context_json.get("ocr_summary", ""), "ocr_summary": payload.context_json.get("ocr_summary", ""),
"ocr_documents": payload.context_json.get("ocr_documents", []), "ocr_documents": payload.context_json.get("ocr_documents", []),
"conversation_id": payload.context_json.get("conversation_id"),
"conversation_scenario": payload.context_json.get("conversation_scenario"),
"conversation_intent": payload.context_json.get("conversation_intent"),
"draft_claim_id": payload.context_json.get("draft_claim_id"),
"conversation_history": self._resolve_conversation_history(payload),
}, },
"tool_payload": payload.tool_payload, "tool_payload": payload.tool_payload,
"citations": [item.model_dump(mode="json") for item in citations], "citations": [item.model_dump(mode="json") for item in citations],
@@ -267,6 +325,7 @@ class UserAgentService:
"并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。" "并明确要求补充费用类型、金额、时间、事由、参与对象或上传票据。"
"如果上下文里只有附件名称,必须明确说明你只拿到了附件名称," "如果上下文里只有附件名称,必须明确说明你只拿到了附件名称,"
"不能假装已看过图片、PDF 或发票内容。" "不能假装已看过图片、PDF 或发票内容。"
"如果提供了 conversation_history必须结合最近轮次理解追问、代词、省略字段和补充信息。"
"不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。" "不要声称已经提交、审批、付款、入账或真正执行了任何动作;如果只是建议、草稿或待确认,要明确说清楚。"
"若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。" "若给出了风险标签、制度引用或建议动作,可以简洁吸收进回答,但不要新增未提供的事实。"
"只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、" "只输出最终给用户看的自然语言,不要输出 JSON、Markdown、标题、"
@@ -447,6 +506,424 @@ class UserAgentService:
), ),
] ]
def _build_review_payload(
self,
payload: UserAgentRequest,
*,
citations: list[UserAgentCitation],
draft_payload: UserAgentDraftPayload | None,
) -> UserAgentReviewPayload | None:
attachment_count = self._resolve_attachment_count(payload)
ocr_documents = self._resolve_ocr_documents(payload)
if payload.ontology.scenario != "expense":
return None
if payload.ontology.intent not in {"draft", "operate"} and attachment_count <= 0 and not ocr_documents:
return None
slot_cards = self._build_review_slot_cards(payload, ocr_documents=ocr_documents)
document_cards = self._build_review_document_cards(payload, ocr_documents=ocr_documents)
claim_groups = self._build_review_claim_groups(
payload,
document_cards=document_cards,
)
risk_briefs = self._build_review_risk_briefs(
payload,
citations=citations,
document_cards=document_cards,
claim_groups=claim_groups,
)
confirmation_actions = self._build_review_confirmation_actions(
payload,
claim_groups=claim_groups,
draft_payload=draft_payload,
)
intent_summary = self._build_review_intent_summary(
payload,
slot_cards=slot_cards,
claim_groups=claim_groups,
)
return UserAgentReviewPayload(
intent_summary=intent_summary,
scenario=payload.ontology.scenario,
intent=payload.ontology.intent,
missing_slots=list(payload.ontology.missing_slots),
risk_briefs=risk_briefs,
slot_cards=slot_cards,
document_cards=document_cards,
claim_groups=claim_groups,
confirmation_actions=confirmation_actions,
)
def _build_review_slot_cards(
self,
payload: UserAgentRequest,
*,
ocr_documents: list[dict[str, object]],
) -> list[UserAgentReviewSlotCard]:
first_doc_fields = self._extract_document_fields(ocr_documents[0]) if ocr_documents else {}
missing_slots = set(payload.ontology.missing_slots)
entity_map = self._collect_entity_values(payload)
time_value = self._format_time_range(payload)
location_value = self._resolve_location_value(payload)
merchant_value = self._extract_document_merchant_name(ocr_documents[0]) if ocr_documents else ""
customer_value = entity_map.get("customer", "")
participants_value = entity_map.get("participants", "")
amount_value = entity_map.get("amount")
if not amount_value:
ocr_total_amount = self._sum_ocr_amounts(ocr_documents)
amount_value = f"{ocr_total_amount:.2f}" if ocr_total_amount > 0 else ""
expense_type_code = entity_map.get("expense_type_code", "")
expense_type_value = EXPENSE_TYPE_LABELS.get(expense_type_code, entity_map.get("expense_type", ""))
if not expense_type_value and ocr_documents:
expense_type_value = self._infer_expense_type_from_documents(payload, ocr_documents)
attachment_value = (
f"{self._resolve_attachment_count(payload)} 份附件"
if self._resolve_attachment_count(payload)
else ""
)
cards = [
self._make_slot_card(
key="expense_type",
value=expense_type_value,
source="user_text" if expense_type_value else "system",
confidence=0.9 if expense_type_value else 0.0,
missing_slots=missing_slots,
),
self._make_slot_card(
key="customer_name",
value=customer_value,
source="user_text" if customer_value else "system",
confidence=0.88 if customer_value else 0.0,
missing_slots=missing_slots,
),
self._make_slot_card(
key="time_range",
value=time_value,
source="user_text" if time_value else "system",
confidence=0.9 if time_value else 0.0,
missing_slots=missing_slots,
),
self._make_slot_card(
key="location",
value=location_value,
source="page_context" if location_value and location_value != "客户现场" else "user_text",
confidence=0.82 if location_value else 0.0,
required=False,
missing_slots=missing_slots,
),
self._make_slot_card(
key="merchant_name",
value=merchant_value,
source="ocr" if merchant_value else "system",
confidence=0.72 if merchant_value else 0.0,
required=False,
missing_slots=missing_slots,
),
self._make_slot_card(
key="amount",
value=amount_value,
source="user_text" if entity_map.get("amount") else "ocr" if amount_value else "system",
confidence=0.92 if amount_value else 0.0,
missing_slots=missing_slots,
),
self._make_slot_card(
key="participants",
value=participants_value,
source="user_text" if participants_value else "system",
confidence=0.8 if participants_value else 0.0,
missing_slots=missing_slots,
),
self._make_slot_card(
key="attachments",
value=attachment_value,
source="upload" if attachment_value else "system",
confidence=1.0 if attachment_value else 0.0,
missing_slots=missing_slots,
),
]
return cards
def _build_review_document_cards(
self,
payload: UserAgentRequest,
*,
ocr_documents: list[dict[str, object]],
) -> list[UserAgentReviewDocumentCard]:
cards: list[UserAgentReviewDocumentCard] = []
for index, item in enumerate(ocr_documents, start=1):
classified = self._classify_document(item, payload)
fields = self._extract_document_fields(item)
cards.append(
UserAgentReviewDocumentCard(
index=index,
filename=str(item.get("filename") or f"document-{index}"),
document_type=classified["document_type"],
suggested_expense_type=classified["expense_type"],
scene_label=GROUP_SCENE_LABELS.get(
classified["group_code"],
classified["scene_label"],
),
summary=str(item.get("summary") or item.get("text") or "").strip(),
avg_score=float(item.get("avg_score") or 0.0),
warnings=[str(warning) for warning in item.get("warnings", []) if str(warning).strip()],
fields=[
UserAgentReviewDocumentField(
label=label,
value=value,
source="ocr",
)
for label, value in fields.items()
if str(value).strip()
],
)
)
return cards
def _build_review_claim_groups(
self,
payload: UserAgentRequest,
*,
document_cards: list[UserAgentReviewDocumentCard],
) -> list[UserAgentReviewClaimGroup]:
groups: dict[str, dict[str, object]] = {}
for card in document_cards:
group_code = self._normalize_group_code(card.suggested_expense_type)
bucket = groups.setdefault(
group_code,
{
"document_indexes": [],
"amount_total": 0.0,
"expense_type": group_code,
"scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"),
"reasons": [],
},
)
bucket["document_indexes"].append(card.index)
bucket["amount_total"] = float(bucket["amount_total"]) + self._extract_amount_from_card(card)
bucket["reasons"].append(f"{card.filename} 识别为 {card.scene_label}")
if not groups:
expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "other")
group_code = self._normalize_group_code(expense_type_code)
groups[group_code] = {
"document_indexes": [],
"amount_total": self._resolve_amount_value(payload),
"expense_type": expense_type_code or "other",
"scene_label": GROUP_SCENE_LABELS.get(group_code, "其他费用"),
"reasons": ["当前主要依据用户文本和页面上下文进行分单建议。"],
}
claim_groups: list[UserAgentReviewClaimGroup] = []
for index, (group_code, bucket) in enumerate(groups.items(), start=1):
title = f"建议报销单 {index}{bucket['scene_label']}"
rationale = (
"".join(dict.fromkeys(str(item) for item in bucket["reasons"]))
if bucket["reasons"]
else "当前仅有单一场景,无需拆单。"
)
claim_groups.append(
UserAgentReviewClaimGroup(
group_code=group_code,
title=title,
expense_type=str(bucket["expense_type"]),
scene_label=str(bucket["scene_label"]),
document_indexes=list(bucket["document_indexes"]),
amount_total=round(float(bucket["amount_total"]), 2),
rationale=rationale,
)
)
return claim_groups
def _build_review_risk_briefs(
self,
payload: UserAgentRequest,
*,
citations: list[UserAgentCitation],
document_cards: list[UserAgentReviewDocumentCard],
claim_groups: list[UserAgentReviewClaimGroup],
) -> list[UserAgentReviewRiskBrief]:
briefs: list[UserAgentReviewRiskBrief] = []
employee_name = self._collect_entity_values(payload).get("employee_name") or str(
payload.context_json.get("name") or ""
).strip()
if employee_name:
since = datetime.now(UTC) - timedelta(days=90)
stmt = select(ExpenseClaim).where(
ExpenseClaim.employee_name == employee_name,
ExpenseClaim.occurred_at >= since,
)
recent_claims = list(self.db.scalars(stmt).all())
if recent_claims:
risky_count = sum(1 for item in recent_claims if item.risk_flags_json)
draft_count = sum(1 for item in recent_claims if item.status == "draft")
briefs.append(
UserAgentReviewRiskBrief(
title="历史报销画像",
level="info",
content=(
f"{employee_name} 最近 90 天共有 {len(recent_claims)} 笔报销,"
f"其中 {risky_count} 笔带风险标记,{draft_count} 笔仍处于草稿态。"
),
)
)
current_amount = self._resolve_amount_value(payload)
if current_amount > 0:
duplicate_count = sum(
1
for item in recent_claims
if abs(float(item.amount) - current_amount) < 0.01
)
if duplicate_count:
briefs.append(
UserAgentReviewRiskBrief(
title="金额重复预警",
level="warning",
content=(
f"近 90 天发现 {duplicate_count} 笔金额相同的报销记录,"
"提交前建议核对是否为重复报销或拆分不当。"
),
)
)
if citations:
briefs.append(
UserAgentReviewRiskBrief(
title="制度注意事项",
level="info",
content=citations[0].excerpt or f"请先核对 {citations[0].title} 的制度要求。",
)
)
warning_count = sum(len(item.warnings) for item in document_cards)
if warning_count:
briefs.append(
UserAgentReviewRiskBrief(
title="票据识别提醒",
level="warning",
content=f"当前共有 {warning_count} 条票据识别提示,建议逐张确认 OCR 识别字段。",
)
)
if len(claim_groups) > 1:
briefs.append(
UserAgentReviewRiskBrief(
title="建议拆单",
level="high",
content=f"系统检测到 {len(claim_groups)} 类费用场景,建议拆成多张报销单后再提交。",
)
)
return briefs[:4]
def _build_review_confirmation_actions(
self,
payload: UserAgentRequest,
*,
claim_groups: list[UserAgentReviewClaimGroup],
draft_payload: UserAgentDraftPayload | None,
) -> list[UserAgentReviewAction]:
actions: list[UserAgentReviewAction] = []
if claim_groups:
if len(claim_groups) > 1:
actions.append(
UserAgentReviewAction(
label=f"{len(claim_groups)} 张报销单生成",
action_type="split_claims",
description="保留当前识别结果,并按费用场景拆分生成多张报销草稿。",
emphasis="primary",
)
)
else:
actions.append(
UserAgentReviewAction(
label="确认并继续生成草稿",
action_type="confirm_review",
description="确认当前识别字段无误后,继续生成或覆盖当前报销草稿。",
emphasis="primary",
)
)
for slot in payload.ontology.missing_slots[:3]:
label = SLOT_LABELS.get(slot, slot)
actions.append(
UserAgentReviewAction(
label=f"补充{label}",
action_type="fill_slot",
description=f"当前还缺少 {label},补充后可提升分单和建单准确度。",
emphasis="secondary",
)
)
if self._resolve_attachment_count(payload) <= 0:
actions.append(
UserAgentReviewAction(
label="继续上传票据",
action_type="upload_more",
description="上传发票、行程单或电子票据后,系统会重新识别并完善报销分组。",
emphasis="secondary",
)
)
if draft_payload is not None and draft_payload.claim_no:
actions.append(
UserAgentReviewAction(
label=f"查看草稿 {draft_payload.claim_no}",
action_type="open_claim",
description="查看当前已创建的报销草稿,并继续补充字段或附件。",
emphasis="secondary",
)
)
return actions[:5]
def _build_review_intent_summary(
self,
payload: UserAgentRequest,
*,
slot_cards: list[UserAgentReviewSlotCard],
claim_groups: list[UserAgentReviewClaimGroup],
) -> str:
slots = {item.key: item for item in slot_cards}
expense_type = slots.get("expense_type")
amount = slots.get("amount")
time_range = slots.get("time_range")
location = slots.get("location")
customer = slots.get("customer_name")
summary = "系统识别出您想要发起一笔报销。"
if expense_type and expense_type.value:
summary = f"系统识别出您想要报销{expense_type.value}"
details: list[str] = []
if customer and customer.value:
details.append(f"客户名称:{customer.value}")
if time_range and time_range.value:
details.append(f"时间:{time_range.value}")
if location and location.value:
details.append(f"地点:{location.value}")
if amount and amount.value:
details.append(f"金额:{amount.value}")
if claim_groups and len(claim_groups) > 1:
details.append(f"建议拆分为 {len(claim_groups)} 张报销单")
if details:
return f"{summary} {''.join(details)}"
return summary
@staticmethod
def _should_skip_model_answer(
payload: UserAgentRequest,
review_payload: UserAgentReviewPayload | None,
) -> bool:
if review_payload is None:
return False
return payload.ontology.scenario == "expense" and (
payload.ontology.intent == "draft"
or int(payload.context_json.get("attachment_count") or 0) > 0
)
def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]: def _build_rule_citations(self, payload: UserAgentRequest) -> list[UserAgentCitation]:
domain = self._resolve_domain(payload.ontology.scenario) domain = self._resolve_domain(payload.ontology.scenario)
items = self.asset_service.list_assets( items = self.asset_service.list_assets(
@@ -516,6 +993,45 @@ class UserAgentService:
return [] return []
return [str(name) for name in names if str(name).strip()] return [str(name) for name in names if str(name).strip()]
@staticmethod
def _resolve_attachment_count(payload: UserAgentRequest) -> int:
names = UserAgentService._resolve_attachment_names(payload)
if names:
return len(names)
try:
return max(0, int(payload.context_json.get("attachment_count") or 0))
except (TypeError, ValueError):
return 0
@staticmethod
def _resolve_ocr_documents(payload: UserAgentRequest) -> list[dict[str, object]]:
documents = payload.context_json.get("ocr_documents")
if not isinstance(documents, list):
return []
normalized: list[dict[str, object]] = []
for item in documents[:8]:
if not isinstance(item, dict):
continue
normalized.append(item)
return normalized
@staticmethod
def _resolve_conversation_history(payload: UserAgentRequest) -> list[dict[str, object]]:
history = payload.context_json.get("conversation_history")
if not isinstance(history, list):
return []
normalized: list[dict[str, object]] = []
for item in history[-8:]:
if not isinstance(item, dict):
continue
role = str(item.get("role") or "").strip()
content = str(item.get("content") or "").strip()
if not role or not content:
continue
normalized.append({"role": role, "content": content})
return normalized
@staticmethod @staticmethod
def _resolve_domain(scenario: str) -> str | None: def _resolve_domain(scenario: str) -> str | None:
if scenario == "expense": if scenario == "expense":
@@ -557,3 +1073,210 @@ class UserAgentService:
if len(cleaned) >= 2: if len(cleaned) >= 2:
break break
return "".join(cleaned[:2]) return "".join(cleaned[:2])
def _collect_entity_values(self, payload: UserAgentRequest) -> dict[str, str]:
values = {
"employee_name": "",
"customer": "",
"participants": "",
"amount": "",
"expense_type": "",
"expense_type_code": "",
}
participants: list[str] = []
for item in payload.ontology.entities:
if item.type == "employee" and not values["employee_name"]:
values["employee_name"] = item.value
elif item.type == "customer" and not values["customer"]:
values["customer"] = item.value
elif item.type == "amount" and item.role != "threshold" and not values["amount"]:
values["amount"] = f"{item.value}" if "" not in item.value else item.value
elif item.type == "expense_type" and not values["expense_type_code"]:
values["expense_type_code"] = item.normalized_value
values["expense_type"] = EXPENSE_TYPE_LABELS.get(
item.normalized_value,
item.value,
)
elif item.type in {"participant", "person"} and item.value.strip():
participants.append(item.value.strip())
if participants:
values["participants"] = "".join(dict.fromkeys(participants))
return values
def _format_time_range(self, payload: UserAgentRequest) -> str:
time_range = payload.ontology.time_range
if time_range.raw:
return time_range.raw
if time_range.start_date and time_range.end_date:
if time_range.start_date == time_range.end_date:
return time_range.start_date
return f"{time_range.start_date}{time_range.end_date}"
return ""
def _resolve_location_value(self, payload: UserAgentRequest) -> str:
request_context = payload.context_json.get("request_context")
if isinstance(request_context, dict):
for key in ("city", "location"):
value = str(request_context.get(key) or "").strip()
if value:
return value
city_match = re.search(r"去(?P<city>[\u4e00-\u9fa5]{2,8})(?:出差|拜访|参会|见客户|客户现场)", payload.message)
if city_match:
return city_match.group("city").strip()
if "客户现场" in payload.message.replace(" ", ""):
return "客户现场"
return ""
def _make_slot_card(
self,
*,
key: str,
value: str,
source: str,
confidence: float,
missing_slots: set[str],
required: bool = True,
) -> UserAgentReviewSlotCard:
is_missing = key in missing_slots or not str(value).strip()
return UserAgentReviewSlotCard(
key=key,
label=SLOT_LABELS.get(key, key),
value=str(value or "").strip(),
source=source,
confidence=confidence,
required=required,
confirmed=not is_missing and source in {"user_text", "page_context", "upload"},
status="missing" if is_missing else "identified" if source == "user_text" else "inferred",
hint=f"建议补充 {SLOT_LABELS.get(key, key)}"
if is_missing and required
else "",
)
def _classify_document(
self,
item: dict[str, object],
payload: UserAgentRequest,
) -> dict[str, str]:
text = " ".join(
[
str(item.get("filename") or ""),
str(item.get("summary") or ""),
str(item.get("text") or ""),
]
).lower()
compact = text.replace(" ", "")
expense_type_code = self._collect_entity_values(payload).get("expense_type_code", "")
has_customer = bool(self._collect_entity_values(payload).get("customer"))
if any(keyword in compact for keyword in ("机票", "航班", "火车", "高铁", "行程单")):
return {
"document_type": "travel_ticket",
"expense_type": "travel",
"group_code": "travel",
"scene_label": "差旅票据",
}
if any(keyword in compact for keyword in ("酒店", "住宿", "宾馆")):
return {
"document_type": "hotel_invoice",
"expense_type": "hotel",
"group_code": "travel",
"scene_label": "住宿票据",
}
if any(keyword in compact for keyword in ("打车", "出租车", "滴滴", "网约车", "过路费", "停车")):
return {
"document_type": "transport_receipt",
"expense_type": "transport",
"group_code": "travel",
"scene_label": "交通票据",
}
if any(keyword in compact for keyword in ("", "饭店", "酒楼", "酒家", "餐饮", "meal")):
group_code = "entertainment" if expense_type_code == "entertainment" or has_customer else "meal"
return {
"document_type": "meal_receipt",
"expense_type": group_code,
"group_code": group_code,
"scene_label": "餐饮票据",
}
return {
"document_type": "other",
"expense_type": expense_type_code or "other",
"group_code": self._normalize_group_code(expense_type_code or "other"),
"scene_label": "其他票据",
}
@staticmethod
def _normalize_group_code(expense_type_code: str) -> str:
if expense_type_code in {"travel", "hotel", "transport"}:
return "travel"
if expense_type_code in {"entertainment", "meal"}:
return expense_type_code
return "other"
def _extract_document_fields(self, item: dict[str, object]) -> dict[str, str]:
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
fields: dict[str, str] = {}
amount_match = AMOUNT_TEXT_PATTERN.search(text)
if amount_match:
fields["金额"] = f"{amount_match.group(1)}"
date_match = DATE_TEXT_PATTERN.search(text)
if date_match:
fields["时间"] = date_match.group(1)
merchant = self._extract_document_merchant_name(item)
if merchant:
fields["商户/酒店"] = merchant
return fields
@staticmethod
def _extract_document_merchant_name(item: dict[str, object]) -> str:
text = " ".join([str(item.get("summary") or ""), str(item.get("text") or "")]).strip()
for keyword in ("酒店", "宾馆", "饭店", "酒楼", "餐厅", "航空", "铁路", "滴滴"):
if keyword in text:
return keyword
return ""
@staticmethod
def _extract_amount_from_card(card: UserAgentReviewDocumentCard) -> float:
for item in card.fields:
if item.label != "金额":
continue
try:
return float(str(item.value).replace("", "").strip())
except ValueError:
return 0.0
return 0.0
def _resolve_amount_value(self, payload: UserAgentRequest) -> float:
for item in payload.ontology.entities:
if item.type == "amount" and item.role != "threshold":
try:
return float(item.normalized_value)
except ValueError:
return 0.0
return 0.0
def _sum_ocr_amounts(self, ocr_documents: list[dict[str, object]]) -> float:
total = 0.0
for item in ocr_documents:
fields = self._extract_document_fields(item)
amount_text = str(fields.get("金额") or "").replace("", "").strip()
if not amount_text:
continue
try:
total += float(amount_text)
except ValueError:
continue
return total
def _infer_expense_type_from_documents(
self,
payload: UserAgentRequest,
ocr_documents: list[dict[str, object]],
) -> str:
labels: list[str] = []
for item in ocr_documents:
classified = self._classify_document(item, payload)
label = GROUP_SCENE_LABELS.get(classified["group_code"], "")
if label and label not in labels:
labels.append(label)
return " + ".join(labels[:3])