Files
X-Financial/server/src/app/api/v1/endpoints/steward.py
caoxiaozhu a6674a1e76 feat(steward): off_topic 场景细分与引导回复
- 将业务无关输入细分为 greeting / meaningless / off_business 三类场景
- 新增 StewardOffTopicAgent,用 function calling 生成管家语气引导回复
- steward endpoint 与 user_agent_application 串联 off_topic 引导话术
- 补充 planner 与 user agent 的 off_topic 覆盖测试
2026-06-18 22:12:10 +08:00

438 lines
16 KiB
Python

from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.api.deps import get_db
from app.models.financial_record import ExpenseClaim
from app.schemas.common import ErrorResponse
from app.schemas.steward import (
StewardPlanRequest,
StewardPlanResponse,
StewardRuntimeDecisionRequest,
StewardRuntimeDecisionResponse,
StewardSlotDecisionRequest,
StewardSlotDecisionResponse,
StewardThinkingEvent,
)
from app.services.agent_conversations import AgentConversationService
from app.services.expense_claim_draft_flow import APPROVED_APPLICATION_LINK_STATUSES
from app.services.expense_claims import ExpenseClaimService
from app.services.runtime_chat import RuntimeChatService
from app.services.steward_flow_state import StewardFlowStateService
from app.services.steward_intent_agent import StewardIntentAgent
from app.services.steward_off_topic_agent import StewardOffTopicAgent
from app.services.steward_planner import StewardPlannerService
from app.services.steward_runtime_decision_agent import StewardRuntimeDecisionAgent
from app.services.steward_slot_decision_agent import StewardSlotDecisionAgent
router = APIRouter(prefix="/steward")
DbSession = Annotated[Session, Depends(get_db)]
@router.post(
"/plans",
response_model=StewardPlanResponse,
summary="生成小财管家任务计划",
description="把首页自然语言和附件元信息拆解为可确认、可追踪、可分派的财务任务计划。",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorResponse,
"description": "请求缺少任务描述,无法生成小财管家计划。",
}
},
)
def create_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StewardPlanResponse:
try:
planner = _build_steward_planner(db)
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
plan = planner.build_plan(hydrated_payload)
return _attach_conversation_state(db, hydrated_payload, plan)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
@router.post(
"/slot-decisions",
response_model=StewardSlotDecisionResponse,
summary="判断小财管家当前任务字段缺口",
description="结合当前任务、本体字段和用户上下文,使用 function calling 判断下一步应先追问用户还是展示核对结果。",
)
def create_steward_slot_decision(
payload: StewardSlotDecisionRequest,
db: DbSession,
) -> StewardSlotDecisionResponse:
return StewardSlotDecisionAgent(RuntimeChatService(db)).decide(payload)
@router.post(
"/runtime-decisions",
response_model=StewardRuntimeDecisionResponse,
summary="判断小财管家运行时下一步动作",
description="结合任务队列、当前结构化结果和用户输入,使用 function calling 判断应提交当前单据、继续下一任务、补字段或重新规划。",
)
def create_steward_runtime_decision(
payload: StewardRuntimeDecisionRequest,
db: DbSession,
) -> StewardRuntimeDecisionResponse:
hydrated_payload = _hydrate_runtime_decision_payload(db, payload)
decision = StewardRuntimeDecisionAgent(RuntimeChatService(db)).decide(hydrated_payload)
return _attach_runtime_conversation_state(db, hydrated_payload, decision)
@router.post(
"/plans/stream",
summary="流式生成小财管家任务计划",
description="以 NDJSON 逐条返回小财管家的过程摘要事件,最后返回完整任务计划。",
)
async def stream_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StreamingResponse:
return StreamingResponse(
_iter_steward_plan_events(payload, _build_steward_planner(db), db),
media_type="application/x-ndjson",
)
async def _iter_steward_plan_events(
payload: StewardPlanRequest,
planner: StewardPlannerService,
db: Session,
) -> AsyncIterator[str]:
yield _encode_stream_event(
"thinking",
StewardThinkingEvent(
event_id="intent_agent_stream_start",
stage="stream_start",
title="读取用户输入",
content="我先识别申请/报销边界;如果是历史差旅描述,会先查询可关联申请单再决定流程。",
status="running",
).model_dump(mode="json"),
)
await asyncio.sleep(0)
try:
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
plan = planner.build_plan(hydrated_payload)
plan = _attach_conversation_state(db, hydrated_payload, plan)
except ValueError as exc:
yield _encode_stream_event("error", {"message": str(exc)})
return
for event in plan.thinking_events:
yield _encode_stream_event("thinking", event.model_dump(mode="json"))
await asyncio.sleep(0.6)
yield _encode_stream_event("plan", plan.model_dump(mode="json"))
def _encode_stream_event(event: str, data: dict[str, Any]) -> str:
return json.dumps({"event": event, "data": data}, ensure_ascii=False) + "\n"
def _build_steward_planner(db: Session) -> StewardPlannerService:
runtime_chat = RuntimeChatService(db)
return StewardPlannerService(
intent_agent=StewardIntentAgent(runtime_chat),
off_topic_agent=StewardOffTopicAgent(runtime_chat),
)
def _hydrate_required_application_gate(
db: Session,
payload: StewardPlanRequest,
planner: StewardPlannerService,
) -> StewardPlanRequest:
context_json = dict(payload.context_json or {})
required_gate = context_json.get("required_application_gate")
if isinstance(required_gate, dict):
travel_gate = required_gate.get("travel")
if isinstance(travel_gate, dict) and travel_gate.get("checked") is True:
return payload
message = planner._clean_text(payload.message)
base_date = planner._resolve_base_date(payload.client_now_iso, context_json)
if not planner._looks_like_ambiguous_travel_flow(message, base_date, payload):
return payload
candidates = _query_required_application_gate_candidates(db, payload, context_json)
next_required_gate = dict(required_gate) if isinstance(required_gate, dict) else {}
next_required_gate["travel"] = {
"checked": True,
"candidate_count": len(candidates),
"candidates": candidates[:5],
}
return payload.model_copy(
update={
"context_json": {
**context_json,
"required_application_gate": next_required_gate,
}
}
)
def _query_required_application_gate_candidates(
db: Session,
payload: StewardPlanRequest,
context_json: dict[str, Any],
) -> list[dict[str, Any]]:
identities = _resolve_required_application_gate_identities(payload, context_json)
stmt = (
select(ExpenseClaim)
.order_by(ExpenseClaim.submitted_at.desc(), ExpenseClaim.updated_at.desc())
.limit(200)
)
candidates: list[dict[str, Any]] = []
for claim in db.scalars(stmt).all():
if not ExpenseClaimService._is_expense_application_claim(claim):
continue
if str(claim.status or "").strip().lower() not in APPROVED_APPLICATION_LINK_STATUSES:
continue
if identities and not _claim_matches_required_application_identity(claim, identities):
continue
if not _claim_matches_required_travel_application(claim, payload.message):
continue
candidates.append(_serialize_required_application_gate_candidate(claim))
return candidates
def _resolve_required_application_gate_identities(
payload: StewardPlanRequest,
context_json: dict[str, Any],
) -> set[str]:
raw_values = [
payload.user_id,
context_json.get("user_id"),
context_json.get("username"),
context_json.get("name"),
context_json.get("employee_id"),
context_json.get("employee_no"),
context_json.get("employee_name"),
]
identities: set[str] = set()
for value in raw_values:
normalized = _normalize_required_application_identity(value)
if normalized:
identities.add(normalized)
return identities
def _normalize_required_application_identity(value: Any) -> str:
return str(value or "").strip().casefold()
def _claim_matches_required_application_identity(claim: ExpenseClaim, identities: set[str]) -> bool:
claim_identities = {
_normalize_required_application_identity(claim.employee_id),
_normalize_required_application_identity(claim.employee_name),
}
claim_identities.discard("")
return bool(claim_identities.intersection(identities))
def _claim_matches_required_travel_application(claim: ExpenseClaim, message: str) -> bool:
expense_type = str(claim.expense_type or "").strip().casefold()
if any(token in expense_type for token in ("travel", "trip", "差旅", "出差")):
return True
claim_text = "".join(
[
str(claim.reason or ""),
str(claim.location or ""),
str(claim.claim_no or ""),
]
)
if "差旅" in claim_text or "出差" in claim_text:
return True
compact_message = str(message or "").replace(" ", "")
location = str(claim.location or "").strip()
return bool(location and location in compact_message and "出差" in compact_message)
def _serialize_required_application_gate_candidate(claim: ExpenseClaim) -> dict[str, Any]:
business_time = _resolve_required_application_business_time(claim)
status_label = _resolve_required_application_status_label(claim.status)
return {
"id": str(claim.id or "").strip(),
"claim_no": str(claim.claim_no or "").strip(),
"reason": str(claim.reason or "").strip(),
"location": str(claim.location or "").strip(),
"business_time": business_time,
"status_label": status_label,
"application_claim_id": str(claim.id or "").strip(),
"application_claim_no": str(claim.claim_no or "").strip(),
"application_reason": str(claim.reason or "").strip(),
"application_location": str(claim.location or "").strip(),
"application_business_time": business_time,
"application_status_label": status_label,
}
def _resolve_required_application_business_time(claim: ExpenseClaim) -> str:
for flag in list(claim.risk_flags_json or []):
if not isinstance(flag, dict):
continue
for source in (
flag,
flag.get("application_detail"),
flag.get("applicationDetail"),
flag.get("review_form_values"),
flag.get("reviewFormValues"),
):
if not isinstance(source, dict):
continue
value = (
source.get("application_business_time")
or source.get("applicationBusinessTime")
or source.get("business_time")
or source.get("businessTime")
)
if str(value or "").strip():
return str(value).strip()
if claim.occurred_at is not None:
return claim.occurred_at.date().isoformat()
return ""
def _resolve_required_application_status_label(status: Any) -> str:
normalized = str(status or "").strip().lower()
return {
"approved": "已审批",
"completed": "已完成",
}.get(normalized, normalized)
def _attach_conversation_state(
db: Session,
payload: StewardPlanRequest,
plan: StewardPlanResponse,
) -> StewardPlanResponse:
context_json = dict(payload.context_json or {})
context_json["session_type"] = str(context_json.get("session_type") or "steward").strip() or "steward"
conversation_service = AgentConversationService(db)
conversation = conversation_service.get_or_create_conversation(
conversation_id=_resolve_conversation_id(context_json),
user_id=payload.user_id,
source="user_message",
context_json=context_json,
)
current_state = _resolve_current_steward_state(conversation.state_json, context_json)
steward_state = StewardFlowStateService().merge_plan(current_state, plan)
conversation = conversation_service.update_state(
conversation_id=conversation.conversation_id,
run_id=None,
scenario="steward",
intent="plan",
context_json={
**context_json,
"steward_state": steward_state,
},
) or conversation
conversation_service.append_message(
conversation_id=conversation.conversation_id,
role="user",
content=payload.message,
message_json={"source": "steward_plan_request"},
)
conversation_service.append_message(
conversation_id=conversation.conversation_id,
role="assistant",
content=plan.summary,
message_json={
"source": "steward_plan_response",
"plan_id": plan.plan_id,
"steward_state": steward_state,
},
)
return plan.model_copy(
update={
"conversation_id": conversation.conversation_id,
"steward_state": steward_state,
}
)
def _attach_runtime_conversation_state(
db: Session,
payload: StewardRuntimeDecisionRequest,
decision: StewardRuntimeDecisionResponse,
) -> StewardRuntimeDecisionResponse:
steward_state = decision.steward_state
if not isinstance(steward_state, dict) or not steward_state:
return decision
context_json = dict(payload.context_json or {})
conversation_id = _resolve_conversation_id(context_json)
if not conversation_id:
return decision
conversation_service = AgentConversationService(db)
conversation_service.update_state(
conversation_id=conversation_id,
run_id=None,
scenario="steward",
intent="runtime_decision",
context_json={
**context_json,
"steward_state": steward_state,
},
)
return decision
def _hydrate_runtime_decision_payload(
db: Session,
payload: StewardRuntimeDecisionRequest,
) -> StewardRuntimeDecisionRequest:
context_json = dict(payload.context_json or {})
runtime_state = dict(payload.runtime_state or {})
if isinstance(runtime_state.get("steward_state"), dict) and runtime_state["steward_state"]:
return payload
if isinstance(context_json.get("steward_state"), dict) and context_json["steward_state"]:
return payload
conversation_id = _resolve_conversation_id(context_json)
if not conversation_id:
return payload
conversation = AgentConversationService(db).get_conversation(conversation_id)
stored_state = conversation.state_json.get("steward_state") if conversation and isinstance(conversation.state_json, dict) else None
if not isinstance(stored_state, dict) or not stored_state:
return payload
runtime_state["steward_state"] = stored_state
conversation_state = dict(context_json.get("conversation_state") or {})
conversation_state["steward_state"] = stored_state
context_json["conversation_state"] = conversation_state
return payload.model_copy(
update={
"runtime_state": runtime_state,
"context_json": context_json,
}
)
def _resolve_conversation_id(context_json: dict[str, Any]) -> str | None:
return str(
context_json.get("conversation_id")
or context_json.get("conversationId")
or ""
).strip() or None
def _resolve_current_steward_state(
conversation_state: dict[str, Any] | None,
context_json: dict[str, Any],
) -> dict[str, Any]:
state_json = conversation_state if isinstance(conversation_state, dict) else {}
stored_state = state_json.get("steward_state")
if isinstance(stored_state, dict) and stored_state:
return stored_state
incoming_state = context_json.get("steward_state") or context_json.get("stewardState")
return incoming_state if isinstance(incoming_state, dict) else {}