- 将业务无关输入细分为 greeting / meaningless / off_business 三类场景 - 新增 StewardOffTopicAgent,用 function calling 生成管家语气引导回复 - steward endpoint 与 user_agent_application 串联 off_topic 引导话术 - 补充 planner 与 user agent 的 off_topic 覆盖测试
438 lines
16 KiB
Python
438 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from collections.abc import AsyncIterator
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.deps import get_db
|
|
from app.models.financial_record import ExpenseClaim
|
|
from app.schemas.common import ErrorResponse
|
|
from app.schemas.steward import (
|
|
StewardPlanRequest,
|
|
StewardPlanResponse,
|
|
StewardRuntimeDecisionRequest,
|
|
StewardRuntimeDecisionResponse,
|
|
StewardSlotDecisionRequest,
|
|
StewardSlotDecisionResponse,
|
|
StewardThinkingEvent,
|
|
)
|
|
from app.services.agent_conversations import AgentConversationService
|
|
from app.services.expense_claim_draft_flow import APPROVED_APPLICATION_LINK_STATUSES
|
|
from app.services.expense_claims import ExpenseClaimService
|
|
from app.services.runtime_chat import RuntimeChatService
|
|
from app.services.steward_flow_state import StewardFlowStateService
|
|
from app.services.steward_intent_agent import StewardIntentAgent
|
|
from app.services.steward_off_topic_agent import StewardOffTopicAgent
|
|
from app.services.steward_planner import StewardPlannerService
|
|
from app.services.steward_runtime_decision_agent import StewardRuntimeDecisionAgent
|
|
from app.services.steward_slot_decision_agent import StewardSlotDecisionAgent
|
|
|
|
router = APIRouter(prefix="/steward")
|
|
DbSession = Annotated[Session, Depends(get_db)]
|
|
|
|
|
|
@router.post(
|
|
"/plans",
|
|
response_model=StewardPlanResponse,
|
|
summary="生成小财管家任务计划",
|
|
description="把首页自然语言和附件元信息拆解为可确认、可追踪、可分派的财务任务计划。",
|
|
responses={
|
|
status.HTTP_400_BAD_REQUEST: {
|
|
"model": ErrorResponse,
|
|
"description": "请求缺少任务描述,无法生成小财管家计划。",
|
|
}
|
|
},
|
|
)
|
|
def create_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StewardPlanResponse:
|
|
try:
|
|
planner = _build_steward_planner(db)
|
|
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
|
|
plan = planner.build_plan(hydrated_payload)
|
|
return _attach_conversation_state(db, hydrated_payload, plan)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
|
|
|
|
|
@router.post(
|
|
"/slot-decisions",
|
|
response_model=StewardSlotDecisionResponse,
|
|
summary="判断小财管家当前任务字段缺口",
|
|
description="结合当前任务、本体字段和用户上下文,使用 function calling 判断下一步应先追问用户还是展示核对结果。",
|
|
)
|
|
def create_steward_slot_decision(
|
|
payload: StewardSlotDecisionRequest,
|
|
db: DbSession,
|
|
) -> StewardSlotDecisionResponse:
|
|
return StewardSlotDecisionAgent(RuntimeChatService(db)).decide(payload)
|
|
|
|
|
|
@router.post(
|
|
"/runtime-decisions",
|
|
response_model=StewardRuntimeDecisionResponse,
|
|
summary="判断小财管家运行时下一步动作",
|
|
description="结合任务队列、当前结构化结果和用户输入,使用 function calling 判断应提交当前单据、继续下一任务、补字段或重新规划。",
|
|
)
|
|
def create_steward_runtime_decision(
|
|
payload: StewardRuntimeDecisionRequest,
|
|
db: DbSession,
|
|
) -> StewardRuntimeDecisionResponse:
|
|
hydrated_payload = _hydrate_runtime_decision_payload(db, payload)
|
|
decision = StewardRuntimeDecisionAgent(RuntimeChatService(db)).decide(hydrated_payload)
|
|
return _attach_runtime_conversation_state(db, hydrated_payload, decision)
|
|
|
|
|
|
@router.post(
|
|
"/plans/stream",
|
|
summary="流式生成小财管家任务计划",
|
|
description="以 NDJSON 逐条返回小财管家的过程摘要事件,最后返回完整任务计划。",
|
|
)
|
|
async def stream_steward_plan(payload: StewardPlanRequest, db: DbSession) -> StreamingResponse:
|
|
return StreamingResponse(
|
|
_iter_steward_plan_events(payload, _build_steward_planner(db), db),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
async def _iter_steward_plan_events(
|
|
payload: StewardPlanRequest,
|
|
planner: StewardPlannerService,
|
|
db: Session,
|
|
) -> AsyncIterator[str]:
|
|
yield _encode_stream_event(
|
|
"thinking",
|
|
StewardThinkingEvent(
|
|
event_id="intent_agent_stream_start",
|
|
stage="stream_start",
|
|
title="读取用户输入",
|
|
content="我先识别申请/报销边界;如果是历史差旅描述,会先查询可关联申请单再决定流程。",
|
|
status="running",
|
|
).model_dump(mode="json"),
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
try:
|
|
hydrated_payload = _hydrate_required_application_gate(db, payload, planner)
|
|
plan = planner.build_plan(hydrated_payload)
|
|
plan = _attach_conversation_state(db, hydrated_payload, plan)
|
|
except ValueError as exc:
|
|
yield _encode_stream_event("error", {"message": str(exc)})
|
|
return
|
|
|
|
for event in plan.thinking_events:
|
|
yield _encode_stream_event("thinking", event.model_dump(mode="json"))
|
|
await asyncio.sleep(0.6)
|
|
|
|
yield _encode_stream_event("plan", plan.model_dump(mode="json"))
|
|
|
|
|
|
def _encode_stream_event(event: str, data: dict[str, Any]) -> str:
|
|
return json.dumps({"event": event, "data": data}, ensure_ascii=False) + "\n"
|
|
|
|
|
|
def _build_steward_planner(db: Session) -> StewardPlannerService:
|
|
runtime_chat = RuntimeChatService(db)
|
|
return StewardPlannerService(
|
|
intent_agent=StewardIntentAgent(runtime_chat),
|
|
off_topic_agent=StewardOffTopicAgent(runtime_chat),
|
|
)
|
|
|
|
|
|
def _hydrate_required_application_gate(
|
|
db: Session,
|
|
payload: StewardPlanRequest,
|
|
planner: StewardPlannerService,
|
|
) -> StewardPlanRequest:
|
|
context_json = dict(payload.context_json or {})
|
|
required_gate = context_json.get("required_application_gate")
|
|
if isinstance(required_gate, dict):
|
|
travel_gate = required_gate.get("travel")
|
|
if isinstance(travel_gate, dict) and travel_gate.get("checked") is True:
|
|
return payload
|
|
|
|
message = planner._clean_text(payload.message)
|
|
base_date = planner._resolve_base_date(payload.client_now_iso, context_json)
|
|
if not planner._looks_like_ambiguous_travel_flow(message, base_date, payload):
|
|
return payload
|
|
|
|
candidates = _query_required_application_gate_candidates(db, payload, context_json)
|
|
next_required_gate = dict(required_gate) if isinstance(required_gate, dict) else {}
|
|
next_required_gate["travel"] = {
|
|
"checked": True,
|
|
"candidate_count": len(candidates),
|
|
"candidates": candidates[:5],
|
|
}
|
|
return payload.model_copy(
|
|
update={
|
|
"context_json": {
|
|
**context_json,
|
|
"required_application_gate": next_required_gate,
|
|
}
|
|
}
|
|
)
|
|
|
|
|
|
def _query_required_application_gate_candidates(
|
|
db: Session,
|
|
payload: StewardPlanRequest,
|
|
context_json: dict[str, Any],
|
|
) -> list[dict[str, Any]]:
|
|
identities = _resolve_required_application_gate_identities(payload, context_json)
|
|
stmt = (
|
|
select(ExpenseClaim)
|
|
.order_by(ExpenseClaim.submitted_at.desc(), ExpenseClaim.updated_at.desc())
|
|
.limit(200)
|
|
)
|
|
candidates: list[dict[str, Any]] = []
|
|
for claim in db.scalars(stmt).all():
|
|
if not ExpenseClaimService._is_expense_application_claim(claim):
|
|
continue
|
|
if str(claim.status or "").strip().lower() not in APPROVED_APPLICATION_LINK_STATUSES:
|
|
continue
|
|
if identities and not _claim_matches_required_application_identity(claim, identities):
|
|
continue
|
|
if not _claim_matches_required_travel_application(claim, payload.message):
|
|
continue
|
|
candidates.append(_serialize_required_application_gate_candidate(claim))
|
|
return candidates
|
|
|
|
|
|
def _resolve_required_application_gate_identities(
|
|
payload: StewardPlanRequest,
|
|
context_json: dict[str, Any],
|
|
) -> set[str]:
|
|
raw_values = [
|
|
payload.user_id,
|
|
context_json.get("user_id"),
|
|
context_json.get("username"),
|
|
context_json.get("name"),
|
|
context_json.get("employee_id"),
|
|
context_json.get("employee_no"),
|
|
context_json.get("employee_name"),
|
|
]
|
|
identities: set[str] = set()
|
|
for value in raw_values:
|
|
normalized = _normalize_required_application_identity(value)
|
|
if normalized:
|
|
identities.add(normalized)
|
|
return identities
|
|
|
|
|
|
def _normalize_required_application_identity(value: Any) -> str:
|
|
return str(value or "").strip().casefold()
|
|
|
|
|
|
def _claim_matches_required_application_identity(claim: ExpenseClaim, identities: set[str]) -> bool:
|
|
claim_identities = {
|
|
_normalize_required_application_identity(claim.employee_id),
|
|
_normalize_required_application_identity(claim.employee_name),
|
|
}
|
|
claim_identities.discard("")
|
|
return bool(claim_identities.intersection(identities))
|
|
|
|
|
|
def _claim_matches_required_travel_application(claim: ExpenseClaim, message: str) -> bool:
|
|
expense_type = str(claim.expense_type or "").strip().casefold()
|
|
if any(token in expense_type for token in ("travel", "trip", "差旅", "出差")):
|
|
return True
|
|
|
|
claim_text = "".join(
|
|
[
|
|
str(claim.reason or ""),
|
|
str(claim.location or ""),
|
|
str(claim.claim_no or ""),
|
|
]
|
|
)
|
|
if "差旅" in claim_text or "出差" in claim_text:
|
|
return True
|
|
|
|
compact_message = str(message or "").replace(" ", "")
|
|
location = str(claim.location or "").strip()
|
|
return bool(location and location in compact_message and "出差" in compact_message)
|
|
|
|
|
|
def _serialize_required_application_gate_candidate(claim: ExpenseClaim) -> dict[str, Any]:
|
|
business_time = _resolve_required_application_business_time(claim)
|
|
status_label = _resolve_required_application_status_label(claim.status)
|
|
return {
|
|
"id": str(claim.id or "").strip(),
|
|
"claim_no": str(claim.claim_no or "").strip(),
|
|
"reason": str(claim.reason or "").strip(),
|
|
"location": str(claim.location or "").strip(),
|
|
"business_time": business_time,
|
|
"status_label": status_label,
|
|
"application_claim_id": str(claim.id or "").strip(),
|
|
"application_claim_no": str(claim.claim_no or "").strip(),
|
|
"application_reason": str(claim.reason or "").strip(),
|
|
"application_location": str(claim.location or "").strip(),
|
|
"application_business_time": business_time,
|
|
"application_status_label": status_label,
|
|
}
|
|
|
|
|
|
def _resolve_required_application_business_time(claim: ExpenseClaim) -> str:
|
|
for flag in list(claim.risk_flags_json or []):
|
|
if not isinstance(flag, dict):
|
|
continue
|
|
for source in (
|
|
flag,
|
|
flag.get("application_detail"),
|
|
flag.get("applicationDetail"),
|
|
flag.get("review_form_values"),
|
|
flag.get("reviewFormValues"),
|
|
):
|
|
if not isinstance(source, dict):
|
|
continue
|
|
value = (
|
|
source.get("application_business_time")
|
|
or source.get("applicationBusinessTime")
|
|
or source.get("business_time")
|
|
or source.get("businessTime")
|
|
)
|
|
if str(value or "").strip():
|
|
return str(value).strip()
|
|
if claim.occurred_at is not None:
|
|
return claim.occurred_at.date().isoformat()
|
|
return ""
|
|
|
|
|
|
def _resolve_required_application_status_label(status: Any) -> str:
|
|
normalized = str(status or "").strip().lower()
|
|
return {
|
|
"approved": "已审批",
|
|
"completed": "已完成",
|
|
}.get(normalized, normalized)
|
|
|
|
|
|
def _attach_conversation_state(
|
|
db: Session,
|
|
payload: StewardPlanRequest,
|
|
plan: StewardPlanResponse,
|
|
) -> StewardPlanResponse:
|
|
context_json = dict(payload.context_json or {})
|
|
context_json["session_type"] = str(context_json.get("session_type") or "steward").strip() or "steward"
|
|
conversation_service = AgentConversationService(db)
|
|
conversation = conversation_service.get_or_create_conversation(
|
|
conversation_id=_resolve_conversation_id(context_json),
|
|
user_id=payload.user_id,
|
|
source="user_message",
|
|
context_json=context_json,
|
|
)
|
|
current_state = _resolve_current_steward_state(conversation.state_json, context_json)
|
|
steward_state = StewardFlowStateService().merge_plan(current_state, plan)
|
|
conversation = conversation_service.update_state(
|
|
conversation_id=conversation.conversation_id,
|
|
run_id=None,
|
|
scenario="steward",
|
|
intent="plan",
|
|
context_json={
|
|
**context_json,
|
|
"steward_state": steward_state,
|
|
},
|
|
) or conversation
|
|
conversation_service.append_message(
|
|
conversation_id=conversation.conversation_id,
|
|
role="user",
|
|
content=payload.message,
|
|
message_json={"source": "steward_plan_request"},
|
|
)
|
|
conversation_service.append_message(
|
|
conversation_id=conversation.conversation_id,
|
|
role="assistant",
|
|
content=plan.summary,
|
|
message_json={
|
|
"source": "steward_plan_response",
|
|
"plan_id": plan.plan_id,
|
|
"steward_state": steward_state,
|
|
},
|
|
)
|
|
return plan.model_copy(
|
|
update={
|
|
"conversation_id": conversation.conversation_id,
|
|
"steward_state": steward_state,
|
|
}
|
|
)
|
|
|
|
|
|
def _attach_runtime_conversation_state(
|
|
db: Session,
|
|
payload: StewardRuntimeDecisionRequest,
|
|
decision: StewardRuntimeDecisionResponse,
|
|
) -> StewardRuntimeDecisionResponse:
|
|
steward_state = decision.steward_state
|
|
if not isinstance(steward_state, dict) or not steward_state:
|
|
return decision
|
|
context_json = dict(payload.context_json or {})
|
|
conversation_id = _resolve_conversation_id(context_json)
|
|
if not conversation_id:
|
|
return decision
|
|
|
|
conversation_service = AgentConversationService(db)
|
|
conversation_service.update_state(
|
|
conversation_id=conversation_id,
|
|
run_id=None,
|
|
scenario="steward",
|
|
intent="runtime_decision",
|
|
context_json={
|
|
**context_json,
|
|
"steward_state": steward_state,
|
|
},
|
|
)
|
|
return decision
|
|
|
|
|
|
def _hydrate_runtime_decision_payload(
|
|
db: Session,
|
|
payload: StewardRuntimeDecisionRequest,
|
|
) -> StewardRuntimeDecisionRequest:
|
|
context_json = dict(payload.context_json or {})
|
|
runtime_state = dict(payload.runtime_state or {})
|
|
if isinstance(runtime_state.get("steward_state"), dict) and runtime_state["steward_state"]:
|
|
return payload
|
|
if isinstance(context_json.get("steward_state"), dict) and context_json["steward_state"]:
|
|
return payload
|
|
|
|
conversation_id = _resolve_conversation_id(context_json)
|
|
if not conversation_id:
|
|
return payload
|
|
conversation = AgentConversationService(db).get_conversation(conversation_id)
|
|
stored_state = conversation.state_json.get("steward_state") if conversation and isinstance(conversation.state_json, dict) else None
|
|
if not isinstance(stored_state, dict) or not stored_state:
|
|
return payload
|
|
|
|
runtime_state["steward_state"] = stored_state
|
|
conversation_state = dict(context_json.get("conversation_state") or {})
|
|
conversation_state["steward_state"] = stored_state
|
|
context_json["conversation_state"] = conversation_state
|
|
return payload.model_copy(
|
|
update={
|
|
"runtime_state": runtime_state,
|
|
"context_json": context_json,
|
|
}
|
|
)
|
|
|
|
|
|
def _resolve_conversation_id(context_json: dict[str, Any]) -> str | None:
|
|
return str(
|
|
context_json.get("conversation_id")
|
|
or context_json.get("conversationId")
|
|
or ""
|
|
).strip() or None
|
|
|
|
|
|
def _resolve_current_steward_state(
|
|
conversation_state: dict[str, Any] | None,
|
|
context_json: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
state_json = conversation_state if isinstance(conversation_state, dict) else {}
|
|
stored_state = state_json.get("steward_state")
|
|
if isinstance(stored_state, dict) and stored_state:
|
|
return stored_state
|
|
incoming_state = context_json.get("steward_state") or context_json.get("stewardState")
|
|
return incoming_state if isinstance(incoming_state, dict) else {}
|