269 lines
10 KiB
Python
269 lines
10 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from app.api.v1.endpoints import steward as steward_endpoint
|
|||
|
|
from app.core.config import get_settings
|
|||
|
|
from app.schemas.steward import (
|
|||
|
|
StewardRuntimeDecisionRequest,
|
|||
|
|
StewardSlotDecisionRequest,
|
|||
|
|
)
|
|||
|
|
from app.services.runtime_chat import (
|
|||
|
|
RuntimeChatCallTrace,
|
|||
|
|
RuntimeChatToolCall,
|
|||
|
|
RuntimeToolCallResult,
|
|||
|
|
)
|
|||
|
|
from app.services.steward_graph_runtime import StewardGraphRuntime
|
|||
|
|
from app.services.steward_runtime_decision_agent import STEWARD_RUNTIME_DECISION_FUNCTION_NAME
|
|||
|
|
from app.services.steward_slot_decision_agent import STEWARD_SLOT_DECISION_FUNCTION_NAME
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _FakeRuntime:
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
payloads: dict[str, dict[str, Any] | None] | None = None,
|
|||
|
|
*,
|
|||
|
|
fail_functions: set[str] | None = None,
|
|||
|
|
) -> None:
|
|||
|
|
self.payloads = payloads or {}
|
|||
|
|
self.fail_functions = fail_functions or set()
|
|||
|
|
self.called_functions: list[str] = []
|
|||
|
|
self.last_messages: list[dict[str, Any]] = []
|
|||
|
|
|
|||
|
|
def complete_with_tool_call(self, messages, tools, tool_choice, **kwargs):
|
|||
|
|
function_name = str(tool_choice["function"]["name"])
|
|||
|
|
self.called_functions.append(function_name)
|
|||
|
|
self.last_messages = messages
|
|||
|
|
if function_name in self.fail_functions:
|
|||
|
|
raise RuntimeError(f"{function_name} failed")
|
|||
|
|
payload = self.payloads.get(function_name)
|
|||
|
|
if payload is None:
|
|||
|
|
return RuntimeToolCallResult(tool_call=None, calls=[])
|
|||
|
|
return RuntimeToolCallResult(
|
|||
|
|
tool_call=RuntimeChatToolCall(name=function_name, arguments=payload),
|
|||
|
|
calls=[
|
|||
|
|
RuntimeChatCallTrace(
|
|||
|
|
slot=function_name,
|
|||
|
|
provider="fake",
|
|||
|
|
model="fake",
|
|||
|
|
attempt=1,
|
|||
|
|
status="succeeded",
|
|||
|
|
)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _FailingGraphRuntime:
|
|||
|
|
def __init__(self, runtime_chat_service) -> None:
|
|||
|
|
self.runtime_chat_service = runtime_chat_service
|
|||
|
|
|
|||
|
|
def decide_slot(self, request):
|
|||
|
|
raise RuntimeError("langgraph runtime unavailable")
|
|||
|
|
|
|||
|
|
def decide_runtime(self, request):
|
|||
|
|
raise RuntimeError("langgraph runtime unavailable")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture(autouse=True)
|
|||
|
|
def _clear_settings_cache():
|
|||
|
|
get_settings.cache_clear()
|
|||
|
|
yield
|
|||
|
|
get_settings.cache_clear()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_graph_runtime_routes_slot_decision_through_langgraph_tool_node() -> None:
|
|||
|
|
runtime = _FakeRuntime(
|
|||
|
|
{
|
|||
|
|
STEWARD_SLOT_DECISION_FUNCTION_NAME: {
|
|||
|
|
"next_action": "ask_user",
|
|||
|
|
"required_fields": ["expense_type", "time_range", "location", "reason", "transport_mode"],
|
|||
|
|
"missing_fields": ["transport_mode"],
|
|||
|
|
"question": "请问您这次打算怎么出行?",
|
|||
|
|
"options": [
|
|||
|
|
{"field_key": "transport_mode", "label": "火车", "value": "火车"},
|
|||
|
|
{"field_key": "transport_mode", "label": "飞机", "value": "飞机"},
|
|||
|
|
],
|
|||
|
|
"rationale": "出行方式会影响交通费用测算。",
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = StewardGraphRuntime(runtime).decide_slot(
|
|||
|
|
StewardSlotDecisionRequest(
|
|||
|
|
task_type="expense_application",
|
|||
|
|
user_message="2026-02-20 至 2026-02-23,上海出差,国网仿生产服务器部署",
|
|||
|
|
ontology_fields={
|
|||
|
|
"expense_type": "travel",
|
|||
|
|
"time_range": "2026-02-20 至 2026-02-23",
|
|||
|
|
"location": "上海",
|
|||
|
|
"reason": "国网仿生产服务器部署",
|
|||
|
|
},
|
|||
|
|
missing_fields=["transport_mode"],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "llm_function_call"
|
|||
|
|
assert result.next_action == "ask_user"
|
|||
|
|
assert result.missing_fields == ["transport_mode"]
|
|||
|
|
assert runtime.called_functions == [STEWARD_SLOT_DECISION_FUNCTION_NAME]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_graph_runtime_slot_graph_falls_back_when_tool_node_fails() -> None:
|
|||
|
|
runtime = _FakeRuntime(fail_functions={STEWARD_SLOT_DECISION_FUNCTION_NAME})
|
|||
|
|
|
|||
|
|
result = StewardGraphRuntime(runtime).decide_slot(
|
|||
|
|
StewardSlotDecisionRequest(
|
|||
|
|
task_type="expense_application",
|
|||
|
|
user_message="上海出差,辅助国网仿生产部署",
|
|||
|
|
ontology_fields={
|
|||
|
|
"expense_type": "travel",
|
|||
|
|
"location": "上海",
|
|||
|
|
"reason": "辅助国网仿生产部署",
|
|||
|
|
},
|
|||
|
|
missing_fields=["transport_mode"],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "rule_fallback"
|
|||
|
|
assert result.next_action == "ask_user"
|
|||
|
|
assert result.missing_fields == ["transport_mode"]
|
|||
|
|
assert any(
|
|||
|
|
trace.get("slot") == "langgraph_slot_decision"
|
|||
|
|
and trace.get("status") == "failed"
|
|||
|
|
for trace in result.model_call_traces
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_graph_runtime_merges_memory_before_runtime_action_node() -> None:
|
|||
|
|
runtime = _FakeRuntime({STEWARD_RUNTIME_DECISION_FUNCTION_NAME: None})
|
|||
|
|
|
|||
|
|
result = StewardGraphRuntime(runtime).decide_runtime(
|
|||
|
|
StewardRuntimeDecisionRequest(
|
|||
|
|
user_message="我坐高铁",
|
|||
|
|
runtime_state={},
|
|||
|
|
context_json={
|
|||
|
|
"conversation_state": {
|
|||
|
|
"steward_state": {
|
|||
|
|
"active_flow": "travel_application",
|
|||
|
|
"flows": {
|
|||
|
|
"travel_application": {
|
|||
|
|
"flow_id": "travel_application",
|
|||
|
|
"intent": "travel_application_create",
|
|||
|
|
"fields": {
|
|||
|
|
"expense_type": "travel",
|
|||
|
|
"time_range": "2026-07-02",
|
|||
|
|
"location": "北京",
|
|||
|
|
"reason": "客户现场支撑",
|
|||
|
|
},
|
|||
|
|
"missing_fields": ["transport_mode"],
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "rule_fallback"
|
|||
|
|
assert result.next_action == "fill_current_slot"
|
|||
|
|
assert result.field_key == "transport_mode"
|
|||
|
|
assert result.field_value == "我坐高铁"
|
|||
|
|
assert result.steward_state["flows"]["travel_application"]["fields"]["transport_mode"] == "我坐高铁"
|
|||
|
|
assert result.steward_state["flows"]["travel_application"]["missing_fields"] == []
|
|||
|
|
assert runtime.called_functions == [STEWARD_RUNTIME_DECISION_FUNCTION_NAME]
|
|||
|
|
assert "steward_state" in runtime.last_messages[-1]["content"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_graph_runtime_selected_flow_action_node_skips_model_call() -> None:
|
|||
|
|
runtime = _FakeRuntime()
|
|||
|
|
|
|||
|
|
result = StewardGraphRuntime(runtime).decide_runtime(
|
|||
|
|
StewardRuntimeDecisionRequest(
|
|||
|
|
user_message="补办出差申请",
|
|||
|
|
runtime_state={
|
|||
|
|
"steward_state": {
|
|||
|
|
"active_flow": "",
|
|||
|
|
"pending_flow_confirmation": {
|
|||
|
|
"status": "pending",
|
|||
|
|
"candidate_flows": [
|
|||
|
|
{"flow_id": "travel_application", "label": "补办出差申请"},
|
|||
|
|
{"flow_id": "travel_reimbursement", "label": "发起费用报销"},
|
|||
|
|
],
|
|||
|
|
},
|
|||
|
|
"flows": {
|
|||
|
|
"travel_application": {
|
|||
|
|
"flow_id": "travel_application",
|
|||
|
|
"intent": "travel_application_create",
|
|||
|
|
"status": "pending_flow_confirmation",
|
|||
|
|
"fields": {
|
|||
|
|
"time_range": "2026-02-20",
|
|||
|
|
"location": "上海",
|
|||
|
|
"expense_type": "travel",
|
|||
|
|
"reason": "辅助国网仿生产环境部署",
|
|||
|
|
},
|
|||
|
|
"missing_fields": ["transport_mode"],
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "rule_fallback"
|
|||
|
|
assert result.next_action == "continue_selected_flow"
|
|||
|
|
assert result.target_task_id == "travel_application"
|
|||
|
|
assert result.steward_state["active_flow"] == "travel_application"
|
|||
|
|
assert runtime.called_functions == []
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_slot_endpoint_helper_falls_back_to_legacy_agent_when_langgraph_runtime_fails(monkeypatch) -> None:
|
|||
|
|
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "langgraph")
|
|||
|
|
get_settings.cache_clear()
|
|||
|
|
monkeypatch.setattr(steward_endpoint, "StewardGraphRuntime", _FailingGraphRuntime)
|
|||
|
|
runtime = _FakeRuntime({STEWARD_SLOT_DECISION_FUNCTION_NAME: None})
|
|||
|
|
|
|||
|
|
result = steward_endpoint._decide_steward_slot(
|
|||
|
|
StewardSlotDecisionRequest(
|
|||
|
|
task_type="expense_application",
|
|||
|
|
user_message="上海出差,辅助国网仿生产部署",
|
|||
|
|
ontology_fields={
|
|||
|
|
"expense_type": "travel",
|
|||
|
|
"location": "上海",
|
|||
|
|
"reason": "辅助国网仿生产部署",
|
|||
|
|
},
|
|||
|
|
missing_fields=["transport_mode"],
|
|||
|
|
),
|
|||
|
|
runtime,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "rule_fallback"
|
|||
|
|
assert result.next_action == "ask_user"
|
|||
|
|
assert runtime.called_functions == [STEWARD_SLOT_DECISION_FUNCTION_NAME]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_runtime_endpoint_helper_falls_back_to_legacy_agent_when_langgraph_runtime_fails(monkeypatch) -> None:
|
|||
|
|
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "langgraph")
|
|||
|
|
get_settings.cache_clear()
|
|||
|
|
monkeypatch.setattr(steward_endpoint, "StewardGraphRuntime", _FailingGraphRuntime)
|
|||
|
|
runtime = _FakeRuntime({STEWARD_RUNTIME_DECISION_FUNCTION_NAME: None})
|
|||
|
|
|
|||
|
|
result = steward_endpoint._decide_steward_runtime(
|
|||
|
|
StewardRuntimeDecisionRequest(
|
|||
|
|
user_message="确认",
|
|||
|
|
runtime_state={
|
|||
|
|
"pending_steward_action": {
|
|||
|
|
"message_id": "msg-next-task",
|
|||
|
|
"target_task_id": "task-reimbursement-meal",
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
runtime,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert result.decision_source == "rule_fallback"
|
|||
|
|
assert result.next_action == "continue_next_task"
|
|||
|
|
assert result.target_message_id == "msg-next-task"
|
|||
|
|
assert runtime.called_functions == [STEWARD_RUNTIME_DECISION_FUNCTION_NAME]
|