feat(steward): off_topic 场景细分与引导回复

- 将业务无关输入细分为 greeting / meaningless / off_business 三类场景
- 新增 StewardOffTopicAgent,用 function calling 生成管家语气引导回复
- steward endpoint 与 user_agent_application 串联 off_topic 引导话术
- 补充 planner 与 user agent 的 off_topic 覆盖测试
This commit is contained in:
caoxiaozhu
2026-06-18 22:12:10 +08:00
parent 127d603e7d
commit a6674a1e76
6 changed files with 952 additions and 50 deletions

View File

@@ -7,9 +7,11 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.models.financial_record import ExpenseClaim
from app.schemas.common import ErrorResponse
from app.schemas.steward import (
StewardPlanRequest,
@@ -21,9 +23,12 @@ from app.schemas.steward import (
StewardThinkingEvent,
)
from app.services.agent_conversations import AgentConversationService
from app.services.expense_claim_draft_flow import APPROVED_APPLICATION_LINK_STATUSES
from app.services.expense_claims import ExpenseClaimService
from app.services.runtime_chat import RuntimeChatService
from app.services.steward_flow_state import StewardFlowStateService
from app.services.steward_intent_agent import StewardIntentAgent
from app.services.steward_off_topic_agent import StewardOffTopicAgent
from app.services.steward_planner import StewardPlannerService
from app.services.steward_runtime_decision_agent import StewardRuntimeDecisionAgent
from app.services.steward_slot_decision_agent import StewardSlotDecisionAgent
@@ -46,8 +51,10 @@ DbSession = Annotated[Session, Depends(get_db)]
)
def create_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StewardPlanResponse:
try:
plan = _build_steward_planner(db).build_plan(payload)
return _attach_conversation_state(db, payload, plan)
planner = _build_steward_planner(db)
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
plan = planner.build_plan(hydrated_payload)
return _attach_conversation_state(db, hydrated_payload, plan)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
@@ -103,15 +110,16 @@ async def _iter_steward_plan_events(
event_id="intent_agent_stream_start",
stage="stream_start",
title="读取用户输入",
content="我先判断这句话里是否同时包含申请报销或附件归集事项,再决定处理顺序",
content="我先识别申请/报销边界;如果是历史差旅描述,会先查询可关联申请单再决定流程",
status="running",
).model_dump(mode="json"),
)
await asyncio.sleep(0)
try:
plan = planner.build_plan(payload)
plan = _attach_conversation_state(db, payload, plan)
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
plan = planner.build_plan(hydrated_payload)
plan = _attach_conversation_state(db, hydrated_payload, plan)
except ValueError as exc:
yield _encode_stream_event("error", {"message": str(exc)})
return
@@ -128,11 +136,179 @@ def _encode_stream_event(event: str, data: dict[str, Any]) -> str:
def _build_steward_planner(db: Session) -> StewardPlannerService:
runtime_chat = RuntimeChatService(db)
return StewardPlannerService(
intent_agent=StewardIntentAgent(RuntimeChatService(db)),
intent_agent=StewardIntentAgent(runtime_chat),
off_topic_agent=StewardOffTopicAgent(runtime_chat),
)
def _hydrate_required_application_gate(
db: Session,
payload: StewardPlanRequest,
planner: StewardPlannerService,
) -> StewardPlanRequest:
context_json = dict(payload.context_json or {})
required_gate = context_json.get("required_application_gate")
if isinstance(required_gate, dict):
travel_gate = required_gate.get("travel")
if isinstance(travel_gate, dict) and travel_gate.get("checked") is True:
return payload
message = planner._clean_text(payload.message)
base_date = planner._resolve_base_date(payload.client_now_iso, context_json)
if not planner._looks_like_ambiguous_travel_flow(message, base_date, payload):
return payload
candidates = _query_required_application_gate_candidates(db, payload, context_json)
next_required_gate = dict(required_gate) if isinstance(required_gate, dict) else {}
next_required_gate["travel"] = {
"checked": True,
"candidate_count": len(candidates),
"candidates": candidates[:5],
}
return payload.model_copy(
update={
"context_json": {
**context_json,
"required_application_gate": next_required_gate,
}
}
)
def _query_required_application_gate_candidates(
db: Session,
payload: StewardPlanRequest,
context_json: dict[str, Any],
) -> list[dict[str, Any]]:
identities = _resolve_required_application_gate_identities(payload, context_json)
stmt = (
select(ExpenseClaim)
.order_by(ExpenseClaim.submitted_at.desc(), ExpenseClaim.updated_at.desc())
.limit(200)
)
candidates: list[dict[str, Any]] = []
for claim in db.scalars(stmt).all():
if not ExpenseClaimService._is_expense_application_claim(claim):
continue
if str(claim.status or "").strip().lower() not in APPROVED_APPLICATION_LINK_STATUSES:
continue
if identities and not _claim_matches_required_application_identity(claim, identities):
continue
if not _claim_matches_required_travel_application(claim, payload.message):
continue
candidates.append(_serialize_required_application_gate_candidate(claim))
return candidates
def _resolve_required_application_gate_identities(
payload: StewardPlanRequest,
context_json: dict[str, Any],
) -> set[str]:
raw_values = [
payload.user_id,
context_json.get("user_id"),
context_json.get("username"),
context_json.get("name"),
context_json.get("employee_id"),
context_json.get("employee_no"),
context_json.get("employee_name"),
]
identities: set[str] = set()
for value in raw_values:
normalized = _normalize_required_application_identity(value)
if normalized:
identities.add(normalized)
return identities
def _normalize_required_application_identity(value: Any) -> str:
return str(value or "").strip().casefold()
def _claim_matches_required_application_identity(claim: ExpenseClaim, identities: set[str]) -> bool:
claim_identities = {
_normalize_required_application_identity(claim.employee_id),
_normalize_required_application_identity(claim.employee_name),
}
claim_identities.discard("")
return bool(claim_identities.intersection(identities))
def _claim_matches_required_travel_application(claim: ExpenseClaim, message: str) -> bool:
expense_type = str(claim.expense_type or "").strip().casefold()
if any(token in expense_type for token in ("travel", "trip", "差旅", "出差")):
return True
claim_text = "".join(
[
str(claim.reason or ""),
str(claim.location or ""),
str(claim.claim_no or ""),
]
)
if "差旅" in claim_text or "出差" in claim_text:
return True
compact_message = str(message or "").replace(" ", "")
location = str(claim.location or "").strip()
return bool(location and location in compact_message and "出差" in compact_message)
def _serialize_required_application_gate_candidate(claim: ExpenseClaim) -> dict[str, Any]:
business_time = _resolve_required_application_business_time(claim)
status_label = _resolve_required_application_status_label(claim.status)
return {
"id": str(claim.id or "").strip(),
"claim_no": str(claim.claim_no or "").strip(),
"reason": str(claim.reason or "").strip(),
"location": str(claim.location or "").strip(),
"business_time": business_time,
"status_label": status_label,
"application_claim_id": str(claim.id or "").strip(),
"application_claim_no": str(claim.claim_no or "").strip(),
"application_reason": str(claim.reason or "").strip(),
"application_location": str(claim.location or "").strip(),
"application_business_time": business_time,
"application_status_label": status_label,
}
def _resolve_required_application_business_time(claim: ExpenseClaim) -> str:
for flag in list(claim.risk_flags_json or []):
if not isinstance(flag, dict):
continue
for source in (
flag,
flag.get("application_detail"),
flag.get("applicationDetail"),
flag.get("review_form_values"),
flag.get("reviewFormValues"),
):
if not isinstance(source, dict):
continue
value = (
source.get("application_business_time")
or source.get("applicationBusinessTime")
or source.get("business_time")
or source.get("businessTime")
)
if str(value or "").strip():
return str(value).strip()
if claim.occurred_at is not None:
return claim.occurred_at.date().isoformat()
return ""
def _resolve_required_application_status_label(status: Any) -> str:
normalized = str(status or "").strip().lower()
return {
"approved": "已审批",
"completed": "已完成",
}.get(normalized, normalized)
def _attach_conversation_state(
db: Session,
payload: StewardPlanRequest,