Files
X-Financial/server/tests/test_steward_graph_planner.py

235 lines
8.8 KiB
Python
Raw Normal View History

from __future__ import annotations
from app.api.v1.endpoints import steward as steward_endpoint
from app.core.config import get_settings
from app.schemas.steward import StewardPlanRequest
from app.services.steward_graph_planner import StewardGraphPlannerService
from app.services.steward_intent_agent import StewardIntentAgentResult
from app.services.steward_planner import StewardPlannerService
class GraphTravelApplicationIntentAgent:
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别出差申请草稿",
"content": "模型识别到用户要创建上海出差申请,并保存草稿。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "上海出差申请",
"summary": (
"2026-02-20 至 2026-02-23 前往上海,"
"国网仿生产服务器部署,火车出行。"
),
"requested_action": "save_draft",
"confidence": 0.95,
"ontology_fields": {
"time_range": "2026-02-20 至 2026-02-23",
"location": "上海",
"expense_type": "差旅",
"reason": "国网仿生产服务器部署",
"transport_type": "火车",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[
{
"slot": "main",
"provider": "MiniMax",
"model": "abab-test",
"attempt": 1,
"status": "succeeded",
}
],
)
class GraphSubmitTravelApplicationIntentAgent:
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return StewardIntentAgentResult(
payload={
"thinking_events": [
{
"stage": "task_split",
"title": "识别出差申请提交",
"content": "模型识别到用户要创建上海出差申请,并直接提交。",
}
],
"tasks": [
{
"task_type": "expense_application",
"title": "上海出差申请",
"summary": (
"2026-02-20 至 2026-02-23 前往上海,"
"辅助国网仿生产服务器部署,火车出行。"
),
"requested_action": "submit",
"confidence": 0.96,
"ontology_fields": {
"time_range": "2026-02-20 至 2026-02-23",
"location": "上海",
"expense_type": "差旅",
"reason": "辅助国网仿生产服务器部署",
"transport_mode": "火车",
},
"missing_fields": [],
}
],
"attachment_groups": [],
},
model_call_traces=[
{
"slot": "main",
"provider": "MiniMax",
"model": "abab-test",
"attempt": 1,
"status": "succeeded",
}
],
)
class GraphEmptyIntentAgent:
def __init__(self) -> None:
self.calls = 0
def detect(self, request, *, base_date, canonical_fields):
self.calls += 1
return None
def test_langgraph_planner_preserves_llm_save_draft_plan() -> None:
intent_agent = GraphTravelApplicationIntentAgent()
service = StewardGraphPlannerService(intent_agent=intent_agent)
result = service.build_plan(
StewardPlanRequest(
message="2026-02-20 至 2026-02-23上海出差国网仿生产服务器部署火车保存草稿",
client_now_iso="2026-02-10T09:00:00+08:00",
)
)
assert intent_agent.calls == 1
assert result.planning_source == "llm_function_call"
assert result.tasks[0].requested_action == "save_draft"
assert result.tasks[0].ontology_fields["time_range"] == "2026-02-20 至 2026-02-23"
assert result.tasks[0].ontology_fields["transport_mode"] == "train"
assert result.model_call_traces[0]["provider"] == "MiniMax"
def test_langgraph_planner_builds_submit_action_steps_for_application() -> None:
intent_agent = GraphSubmitTravelApplicationIntentAgent()
service = StewardGraphPlannerService(intent_agent=intent_agent)
result = service.build_plan(
StewardPlanRequest(
message="2026-02-20 至 2026-02-23去上海出差辅助国网仿生产服务器部署交通火车直接提交",
client_now_iso="2026-02-10T09:00:00+08:00",
)
)
assert intent_agent.calls == 1
assert result.planning_source == "llm_function_call"
assert result.action_steps[0].action_type == "detect_intent"
assert [step.action_type for step in result.tasks[0].action_steps] == [
"fill_application_fields",
"build_application_preview",
"validate_required_fields",
"run_duplicate_precheck",
"submit_application",
]
assert result.tasks[0].action_steps[0].payload["ontology_fields"]["location"] == "上海"
assert result.tasks[0].action_steps[-1].requires_confirmation is True
assert result.tasks[0].action_steps[-1].status == "pending_confirmation"
def test_langgraph_planner_falls_back_when_model_returns_no_tool_call() -> None:
intent_agent = GraphEmptyIntentAgent()
service = StewardGraphPlannerService(intent_agent=intent_agent)
result = service.build_plan(
StewardPlanRequest(
message="2026-02-20 至 2026-02-23上海出差国网仿生产服务器部署火车保存草稿",
client_now_iso="2026-02-10T09:00:00+08:00",
)
)
assert intent_agent.calls == 1
assert result.planning_source == "rule_fallback"
assert result.tasks[0].requested_action == "save_draft"
assert result.tasks[0].ontology_fields["time_range"] == "2026-02-20 至 2026-02-23"
assert result.tasks[0].ontology_fields["transport_mode"] == "train"
assert result.model_call_traces == []
def test_langgraph_planner_rule_fallback_builds_save_draft_action_steps() -> None:
intent_agent = GraphEmptyIntentAgent()
service = StewardGraphPlannerService(intent_agent=intent_agent)
result = service.build_plan(
StewardPlanRequest(
message="2026-02-20 至 2026-02-23上海出差国网仿生产服务器部署火车保存草稿",
client_now_iso="2026-02-10T09:00:00+08:00",
)
)
assert result.planning_source == "rule_fallback"
assert result.tasks[0].requested_action == "save_draft"
assert [step.action_type for step in result.tasks[0].action_steps] == [
"fill_application_fields",
"build_application_preview",
"validate_required_fields",
"save_application_draft",
]
assert result.tasks[0].action_steps[-1].status == "planned"
def test_build_steward_planner_uses_langgraph_runtime_when_enabled(monkeypatch) -> None:
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "langgraph")
get_settings.cache_clear()
try:
planner = steward_endpoint._build_steward_planner(db=object())
finally:
get_settings.cache_clear()
assert isinstance(planner, StewardGraphPlannerService)
def test_build_steward_planner_defaults_to_langgraph_runtime(monkeypatch) -> None:
monkeypatch.delenv("STEWARD_AGENT_RUNTIME", raising=False)
get_settings.cache_clear()
try:
planner = steward_endpoint._build_steward_planner(db=object())
finally:
get_settings.cache_clear()
assert isinstance(planner, StewardGraphPlannerService)
def test_build_steward_planner_can_fall_back_to_legacy_runtime(monkeypatch) -> None:
monkeypatch.setenv("STEWARD_AGENT_RUNTIME", "legacy")
get_settings.cache_clear()
try:
planner = steward_endpoint._build_steward_planner(db=object())
finally:
get_settings.cache_clear()
assert isinstance(planner, StewardPlannerService)