Files
X-Financial/server/tests/test_steward_graph_runtime.py
caoxiaozhu 5311c99d69 refactor(server): steward 决策链路改用 LangGraph 编排
- 新增 StewardGraphPlannerService,用 LangGraph 状态图编排意图识别→流程判断→模型/规则分支→兜底,替代原 planner 内线性调用
- 新增 StewardGraphRuntimeService 编排运行时决策与槽位决策;StewardActionContracts/Executor 统一动作合约与执行
- steward_intent_agent/application_fact_resolver/runtime_chat 适配图执行器,config 暴露图相关开关
- pyproject/uv.lock 新增 langgraph 依赖
- 新增 graph_planner/graph_runtime/action_executor 测试,更新 intent_agent/planner/fact_resolver/runtime_chat/reimbursement 测试
2026-06-24 21:58:35 +08:00

269 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from typing import Any
import pytest
from app.api.v1.endpoints import steward as steward_endpoint
from app.core.config import get_settings
from app.schemas.steward import (
StewardRuntimeDecisionRequest,
StewardSlotDecisionRequest,
)
from app.services.runtime_chat import (
RuntimeChatCallTrace,
RuntimeChatToolCall,
RuntimeToolCallResult,
)
from app.services.steward_graph_runtime import StewardGraphRuntime
from app.services.steward_runtime_decision_agent import STEWARD_RUNTIME_DECISION_FUNCTION_NAME
from app.services.steward_slot_decision_agent import STEWARD_SLOT_DECISION_FUNCTION_NAME
class _FakeRuntime:
def __init__(
self,
payloads: dict[str, dict[str, Any] | None] | None = None,
*,
fail_functions: set[str] | None = None,
) -> None:
self.payloads = payloads or {}
self.fail_functions = fail_functions or set()
self.called_functions: list[str] = []
self.last_messages: list[dict[str, Any]] = []
def complete_with_tool_call(self, messages, tools, tool_choice, **kwargs):
function_name = str(tool_choice["function"]["name"])
self.called_functions.append(function_name)
self.last_messages = messages
if function_name in self.fail_functions:
raise RuntimeError(f"{function_name} failed")
payload = self.payloads.get(function_name)
if payload is None:
return RuntimeToolCallResult(tool_call=None, calls=[])
return RuntimeToolCallResult(
tool_call=RuntimeChatToolCall(name=function_name, arguments=payload),
calls=[
RuntimeChatCallTrace(
slot=function_name,
provider="fake",
model="fake",
attempt=1,
status="succeeded",
)
],
)
class _FailingGraphRuntime:
def __init__(self, runtime_chat_service) -> None:
self.runtime_chat_service = runtime_chat_service
def decide_slot(self, request):
raise RuntimeError("langgraph runtime unavailable")
def decide_runtime(self, request):
raise RuntimeError("langgraph runtime unavailable")
@pytest.fixture(autouse=True)
def _clear_settings_cache():
get_settings.cache_clear()
yield
get_settings.cache_clear()
def test_graph_runtime_routes_slot_decision_through_langgraph_tool_node() -> None:
runtime = _FakeRuntime(
{
STEWARD_SLOT_DECISION_FUNCTION_NAME: {
"next_action": "ask_user",
"required_fields": ["expense_type", "time_range", "location", "reason", "transport_mode"],
"missing_fields": ["transport_mode"],
"question": "请问您这次打算怎么出行?",
"options": [
{"field_key": "transport_mode", "label": "火车", "value": "火车"},
{"field_key": "transport_mode", "label": "飞机", "value": "飞机"},
],
"rationale": "出行方式会影响交通费用测算。",
}
}
)
result = StewardGraphRuntime(runtime).decide_slot(
StewardSlotDecisionRequest(
task_type="expense_application",
user_message="2026-02-20 至 2026-02-23上海出差国网仿生产服务器部署",
ontology_fields={
"expense_type": "travel",
"time_range": "2026-02-20 至 2026-02-23",
"location": "上海",
"reason": "国网仿生产服务器部署",
},
missing_fields=["transport_mode"],
)
)
assert result.decision_source == "llm_function_call"
assert result.next_action == "ask_user"
assert result.missing_fields == ["transport_mode"]
assert runtime.called_functions == [STEWARD_SLOT_DECISION_FUNCTION_NAME]
def test_graph_runtime_slot_graph_falls_back_when_tool_node_fails() -> None:
runtime = _FakeRuntime(fail_functions={STEWARD_SLOT_DECISION_FUNCTION_NAME})
result = StewardGraphRuntime(runtime).decide_slot(
StewardSlotDecisionRequest(
task_type="expense_application",
user_message="上海出差,辅助国网仿生产部署",
ontology_fields={
"expense_type": "travel",
"location": "上海",
"reason": "辅助国网仿生产部署",
},
missing_fields=["transport_mode"],
)
)
assert result.decision_source == "rule_fallback"
assert result.next_action == "ask_user"
assert result.missing_fields == ["transport_mode"]
assert any(
trace.get("slot") == "langgraph_slot_decision"
and trace.get("status") == "failed"
for trace in result.model_call_traces
)
def test_graph_runtime_merges_memory_before_runtime_action_node() -> None:
runtime = _FakeRuntime({STEWARD_RUNTIME_DECISION_FUNCTION_NAME: None})
result = StewardGraphRuntime(runtime).decide_runtime(
StewardRuntimeDecisionRequest(
user_message="我坐高铁",
runtime_state={},
context_json={
"conversation_state": {
"steward_state": {
"active_flow": "travel_application",
"flows": {
"travel_application": {
"flow_id": "travel_application",
"intent": "travel_application_create",
"fields": {
"expense_type": "travel",
"time_range": "2026-07-02",
"location": "北京",
"reason": "客户现场支撑",
},
"missing_fields": ["transport_mode"],
}
},
}
}
},
)
)
assert result.decision_source == "rule_fallback"
assert result.next_action == "fill_current_slot"
assert result.field_key == "transport_mode"
assert result.field_value == "我坐高铁"
assert result.steward_state["flows"]["travel_application"]["fields"]["transport_mode"] == "我坐高铁"
assert result.steward_state["flows"]["travel_application"]["missing_fields"] == []
assert runtime.called_functions == [STEWARD_RUNTIME_DECISION_FUNCTION_NAME]
assert "steward_state" in runtime.last_messages[-1]["content"]
def test_graph_runtime_selected_flow_action_node_skips_model_call() -> None:
runtime = _FakeRuntime()
result = StewardGraphRuntime(runtime).decide_runtime(
StewardRuntimeDecisionRequest(
user_message="补办出差申请",
runtime_state={
"steward_state": {
"active_flow": "",
"pending_flow_confirmation": {
"status": "pending",
"candidate_flows": [
{"flow_id": "travel_application", "label": "补办出差申请"},
{"flow_id": "travel_reimbursement", "label": "发起费用报销"},
],
},
"flows": {
"travel_application": {
"flow_id": "travel_application",
"intent": "travel_application_create",
"status": "pending_flow_confirmation",
"fields": {
"time_range": "2026-02-20",
"location": "上海",
"expense_type": "travel",
"reason": "辅助国网仿生产环境部署",
},
"missing_fields": ["transport_mode"],
}
},
}
},
)
)
assert result.decision_source == "rule_fallback"
assert result.next_action == "continue_selected_flow"
assert result.target_task_id == "travel_application"
assert result.steward_state["active_flow"] == "travel_application"
assert runtime.called_functions == []
def test_slot_endpoint_helper_falls_back_to_legacy_agent_when_langgraph_runtime_fails(monkeypatch) -> None:
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "langgraph")
get_settings.cache_clear()
monkeypatch.setattr(steward_endpoint, "StewardGraphRuntime", _FailingGraphRuntime)
runtime = _FakeRuntime({STEWARD_SLOT_DECISION_FUNCTION_NAME: None})
result = steward_endpoint._decide_steward_slot(
StewardSlotDecisionRequest(
task_type="expense_application",
user_message="上海出差,辅助国网仿生产部署",
ontology_fields={
"expense_type": "travel",
"location": "上海",
"reason": "辅助国网仿生产部署",
},
missing_fields=["transport_mode"],
),
runtime,
)
assert result.decision_source == "rule_fallback"
assert result.next_action == "ask_user"
assert runtime.called_functions == [STEWARD_SLOT_DECISION_FUNCTION_NAME]
def test_runtime_endpoint_helper_falls_back_to_legacy_agent_when_langgraph_runtime_fails(monkeypatch) -> None:
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "langgraph")
get_settings.cache_clear()
monkeypatch.setattr(steward_endpoint, "StewardGraphRuntime", _FailingGraphRuntime)
runtime = _FakeRuntime({STEWARD_RUNTIME_DECISION_FUNCTION_NAME: None})
result = steward_endpoint._decide_steward_runtime(
StewardRuntimeDecisionRequest(
user_message="确认",
runtime_state={
"pending_steward_action": {
"message_id": "msg-next-task",
"target_task_id": "task-reimbursement-meal",
}
},
),
runtime,
)
assert result.decision_source == "rule_fallback"
assert result.next_action == "continue_next_task"
assert result.target_message_id == "msg-next-task"
assert runtime.called_functions == [STEWARD_RUNTIME_DECISION_FUNCTION_NAME]