Files
X-Financial/server/tests/test_runtime_chat_service.py

274 lines
11 KiB
Python
Raw Normal View History

from __future__ import annotations
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.db.base import Base
from app.services import runtime_chat as runtime_chat_module
from app.services.runtime_chat import RuntimeChatService
def build_session_factory() -> sessionmaker[Session]:
engine = create_engine(
"sqlite+pysqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine)
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
def _clear_runtime_chat_cooldown() -> None:
runtime_chat_module._slot_failure_until.clear()
def test_runtime_chat_fails_over_to_backup_before_retrying_main(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
calls: list[str] = []
def fake_load_chat_slot(slot: str):
return {
"slot": slot,
"provider": "MiniMax" if slot == "main" else "GLM",
"endpoint": "https://example.com/v1",
"model": "main-model" if slot == "main" else "backup-model",
"apiKey": "secret",
}
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
del messages, max_tokens, temperature, timeout_seconds
calls.append(config["slot"])
if config["slot"] == "main":
raise RuntimeError("main unavailable")
return "backup answer"
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
answer = service.complete([{"role": "user", "content": "hello"}])
assert answer == "backup answer"
assert calls == ["main", "backup"]
def test_runtime_chat_complete_with_trace_records_slot_failover(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
def fake_load_chat_slot(slot: str):
return {
"slot": slot,
"provider": "MiniMax" if slot == "main" else "GLM",
"endpoint": "https://example.com/v1",
"model": "main-model" if slot == "main" else "backup-model",
"apiKey": "secret",
}
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
del messages, max_tokens, temperature, timeout_seconds
if config["slot"] == "main":
raise RuntimeError("incorrect api key")
return "backup answer"
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
result = service.complete_with_trace([{"role": "user", "content": "hello"}])
assert result.text == "backup answer"
assert [item.status for item in result.calls] == ["failed", "succeeded"]
assert result.calls[0].provider == "MiniMax"
assert result.calls[0].error_message == "incorrect api key"
assert result.calls_as_dicts()[1]["model"] == "backup-model"
def test_runtime_chat_does_not_rehit_failed_slots_during_cooldown(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
calls: list[str] = []
def fake_load_chat_slot(slot: str):
return {
"slot": slot,
"provider": slot,
"endpoint": "https://example.com/v1",
"model": f"{slot}-model",
"apiKey": "secret",
}
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
del messages, max_tokens, temperature, timeout_seconds
calls.append(config["slot"])
raise RuntimeError("unavailable")
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
monkeypatch.setattr("app.services.runtime_chat.sleep", lambda *_args, **_kwargs: None)
assert service.complete([{"role": "user", "content": "hello"}]) is None
assert calls == ["main", "backup"]
def test_runtime_chat_disables_glm_thinking_for_direct_user_answers(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
captured: dict[str, object] = {}
def fake_send_json_request(method, url, *, headers, payload, timeout_seconds):
captured["method"] = method
captured["url"] = url
captured["headers"] = headers
captured["payload"] = payload
captured["timeout_seconds"] = timeout_seconds
return 200, {"choices": [{"message": {"content": "ok"}}]}
monkeypatch.setattr("app.services.runtime_chat._send_json_request", fake_send_json_request)
answer = service._request_openai_compatible(
provider="GLM",
endpoint="https://open.bigmodel.cn/api/paas/v4/",
model="glm-5.1",
api_key="secret",
messages=[{"role": "user", "content": "hello"}],
max_tokens=32,
temperature=0.2,
timeout_seconds=17,
)
assert answer == "ok"
assert captured["payload"]["thinking"] == {"type": "disabled"}
assert captured["timeout_seconds"] == 17
def test_runtime_chat_openai_compatible_tool_call_payload(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
captured: dict[str, object] = {}
def fake_send_json_request(method, url, *, headers, payload, timeout_seconds):
captured["method"] = method
captured["url"] = url
captured["headers"] = headers
captured["payload"] = payload
captured["timeout_seconds"] = timeout_seconds
return 200, {
"choices": [
{
"message": {
"tool_calls": [
{
"id": "call_001",
"type": "function",
"function": {
"name": "submit_steward_intent_plan",
"arguments": "{\"tasks\": []}",
},
}
]
}
}
]
}
monkeypatch.setattr("app.services.runtime_chat._send_json_request", fake_send_json_request)
tool_call = service._request_openai_compatible_tool_call(
provider="OpenAI Compatible",
endpoint="https://api.example.com/v1",
model="gpt-test",
api_key="secret",
messages=[{"role": "user", "content": "hello"}],
tools=[{"type": "function", "function": {"name": "submit_steward_intent_plan"}}],
tool_choice={"type": "function", "function": {"name": "submit_steward_intent_plan"}},
max_tokens=128,
temperature=0.1,
timeout_seconds=19,
)
assert tool_call is not None
assert tool_call.name == "submit_steward_intent_plan"
assert tool_call.arguments == {"tasks": []}
assert captured["url"] == "https://api.example.com/v1/chat/completions"
assert captured["payload"]["tools"][0]["function"]["name"] == "submit_steward_intent_plan"
assert captured["payload"]["tool_choice"]["function"]["name"] == "submit_steward_intent_plan"
assert captured["headers"]["Authorization"] == "Bearer secret"
def test_runtime_chat_supports_single_pass_fast_failover(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
calls: list[tuple[str, int]] = []
def fake_load_chat_slot(slot: str):
return {
"slot": slot,
"provider": slot,
"endpoint": "https://example.com/v1",
"model": f"{slot}-model",
"apiKey": "secret",
}
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
del messages, max_tokens, temperature
calls.append((config["slot"], timeout_seconds))
raise RuntimeError("unavailable")
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
assert (
service.complete(
[{"role": "user", "content": "hello"}],
timeout_seconds=15,
slot_timeouts={"main": 8, "backup": 20},
max_attempts=1,
)
is None
)
assert calls == [("main", 8), ("backup", 20)]
def test_runtime_chat_skips_slot_during_cooldown(monkeypatch) -> None:
_clear_runtime_chat_cooldown()
session_factory = build_session_factory()
with session_factory() as db:
service = RuntimeChatService(db)
calls: list[str] = []
def fake_load_chat_slot(slot: str):
return {
"slot": slot,
"provider": slot,
"endpoint": "https://example.com/v1",
"model": f"{slot}-model",
"apiKey": "secret",
}
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
del messages, max_tokens, temperature, timeout_seconds
calls.append(config["slot"])
if config["slot"] == "main":
raise RuntimeError("main unavailable")
return "backup answer"
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
assert service.complete([{"role": "user", "content": "hello"}], max_attempts=1) == "backup answer"
assert service.complete([{"role": "user", "content": "hello again"}], max_attempts=1) == "backup answer"
assert calls == ["main", "backup", "backup"]