185 lines
7.0 KiB
Python
185 lines
7.0 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from sqlalchemy import create_engine
|
||
|
|
from sqlalchemy.orm import Session, sessionmaker
|
||
|
|
from sqlalchemy.pool import StaticPool
|
||
|
|
|
||
|
|
from app.db.base import Base
|
||
|
|
from app.services import runtime_chat as runtime_chat_module
|
||
|
|
from app.services.runtime_chat import RuntimeChatService
|
||
|
|
|
||
|
|
|
||
|
|
def build_session_factory() -> sessionmaker[Session]:
|
||
|
|
engine = create_engine(
|
||
|
|
"sqlite+pysqlite:///:memory:",
|
||
|
|
connect_args={"check_same_thread": False},
|
||
|
|
poolclass=StaticPool,
|
||
|
|
)
|
||
|
|
Base.metadata.create_all(bind=engine)
|
||
|
|
return sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||
|
|
|
||
|
|
|
||
|
|
def _clear_runtime_chat_cooldown() -> None:
|
||
|
|
runtime_chat_module._slot_failure_until.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_chat_fails_over_to_backup_before_retrying_main(monkeypatch) -> None:
|
||
|
|
_clear_runtime_chat_cooldown()
|
||
|
|
session_factory = build_session_factory()
|
||
|
|
with session_factory() as db:
|
||
|
|
service = RuntimeChatService(db)
|
||
|
|
calls: list[str] = []
|
||
|
|
|
||
|
|
def fake_load_chat_slot(slot: str):
|
||
|
|
return {
|
||
|
|
"slot": slot,
|
||
|
|
"provider": "MiniMax" if slot == "main" else "GLM",
|
||
|
|
"endpoint": "https://example.com/v1",
|
||
|
|
"model": "main-model" if slot == "main" else "backup-model",
|
||
|
|
"apiKey": "secret",
|
||
|
|
}
|
||
|
|
|
||
|
|
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
|
||
|
|
del messages, max_tokens, temperature, timeout_seconds
|
||
|
|
calls.append(config["slot"])
|
||
|
|
if config["slot"] == "main":
|
||
|
|
raise RuntimeError("main unavailable")
|
||
|
|
return "backup answer"
|
||
|
|
|
||
|
|
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
|
||
|
|
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
|
||
|
|
|
||
|
|
answer = service.complete([{"role": "user", "content": "hello"}])
|
||
|
|
|
||
|
|
assert answer == "backup answer"
|
||
|
|
assert calls == ["main", "backup"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_chat_does_not_rehit_failed_slots_during_cooldown(monkeypatch) -> None:
|
||
|
|
_clear_runtime_chat_cooldown()
|
||
|
|
session_factory = build_session_factory()
|
||
|
|
with session_factory() as db:
|
||
|
|
service = RuntimeChatService(db)
|
||
|
|
calls: list[str] = []
|
||
|
|
|
||
|
|
def fake_load_chat_slot(slot: str):
|
||
|
|
return {
|
||
|
|
"slot": slot,
|
||
|
|
"provider": slot,
|
||
|
|
"endpoint": "https://example.com/v1",
|
||
|
|
"model": f"{slot}-model",
|
||
|
|
"apiKey": "secret",
|
||
|
|
}
|
||
|
|
|
||
|
|
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
|
||
|
|
del messages, max_tokens, temperature, timeout_seconds
|
||
|
|
calls.append(config["slot"])
|
||
|
|
raise RuntimeError("unavailable")
|
||
|
|
|
||
|
|
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
|
||
|
|
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
|
||
|
|
monkeypatch.setattr("app.services.runtime_chat.sleep", lambda *_args, **_kwargs: None)
|
||
|
|
|
||
|
|
assert service.complete([{"role": "user", "content": "hello"}]) is None
|
||
|
|
assert calls == ["main", "backup"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_chat_disables_glm_thinking_for_direct_user_answers(monkeypatch) -> None:
|
||
|
|
_clear_runtime_chat_cooldown()
|
||
|
|
session_factory = build_session_factory()
|
||
|
|
with session_factory() as db:
|
||
|
|
service = RuntimeChatService(db)
|
||
|
|
captured: dict[str, object] = {}
|
||
|
|
|
||
|
|
def fake_send_json_request(method, url, *, headers, payload, timeout_seconds):
|
||
|
|
captured["method"] = method
|
||
|
|
captured["url"] = url
|
||
|
|
captured["headers"] = headers
|
||
|
|
captured["payload"] = payload
|
||
|
|
captured["timeout_seconds"] = timeout_seconds
|
||
|
|
return 200, {"choices": [{"message": {"content": "ok"}}]}
|
||
|
|
|
||
|
|
monkeypatch.setattr("app.services.runtime_chat._send_json_request", fake_send_json_request)
|
||
|
|
|
||
|
|
answer = service._request_openai_compatible(
|
||
|
|
provider="GLM",
|
||
|
|
endpoint="https://open.bigmodel.cn/api/paas/v4/",
|
||
|
|
model="glm-5.1",
|
||
|
|
api_key="secret",
|
||
|
|
messages=[{"role": "user", "content": "hello"}],
|
||
|
|
max_tokens=32,
|
||
|
|
temperature=0.2,
|
||
|
|
timeout_seconds=17,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert answer == "ok"
|
||
|
|
assert captured["payload"]["thinking"] == {"type": "disabled"}
|
||
|
|
assert captured["timeout_seconds"] == 17
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_chat_supports_single_pass_fast_failover(monkeypatch) -> None:
|
||
|
|
_clear_runtime_chat_cooldown()
|
||
|
|
session_factory = build_session_factory()
|
||
|
|
with session_factory() as db:
|
||
|
|
service = RuntimeChatService(db)
|
||
|
|
calls: list[tuple[str, int]] = []
|
||
|
|
|
||
|
|
def fake_load_chat_slot(slot: str):
|
||
|
|
return {
|
||
|
|
"slot": slot,
|
||
|
|
"provider": slot,
|
||
|
|
"endpoint": "https://example.com/v1",
|
||
|
|
"model": f"{slot}-model",
|
||
|
|
"apiKey": "secret",
|
||
|
|
}
|
||
|
|
|
||
|
|
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
|
||
|
|
del messages, max_tokens, temperature
|
||
|
|
calls.append((config["slot"], timeout_seconds))
|
||
|
|
raise RuntimeError("unavailable")
|
||
|
|
|
||
|
|
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
|
||
|
|
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
|
||
|
|
|
||
|
|
assert (
|
||
|
|
service.complete(
|
||
|
|
[{"role": "user", "content": "hello"}],
|
||
|
|
timeout_seconds=15,
|
||
|
|
slot_timeouts={"main": 8, "backup": 20},
|
||
|
|
max_attempts=1,
|
||
|
|
)
|
||
|
|
is None
|
||
|
|
)
|
||
|
|
assert calls == [("main", 8), ("backup", 20)]
|
||
|
|
|
||
|
|
|
||
|
|
def test_runtime_chat_skips_slot_during_cooldown(monkeypatch) -> None:
|
||
|
|
_clear_runtime_chat_cooldown()
|
||
|
|
session_factory = build_session_factory()
|
||
|
|
with session_factory() as db:
|
||
|
|
service = RuntimeChatService(db)
|
||
|
|
calls: list[str] = []
|
||
|
|
|
||
|
|
def fake_load_chat_slot(slot: str):
|
||
|
|
return {
|
||
|
|
"slot": slot,
|
||
|
|
"provider": slot,
|
||
|
|
"endpoint": "https://example.com/v1",
|
||
|
|
"model": f"{slot}-model",
|
||
|
|
"apiKey": "secret",
|
||
|
|
}
|
||
|
|
|
||
|
|
def fake_request_chat_completion(config, messages, *, max_tokens, temperature, timeout_seconds):
|
||
|
|
del messages, max_tokens, temperature, timeout_seconds
|
||
|
|
calls.append(config["slot"])
|
||
|
|
if config["slot"] == "main":
|
||
|
|
raise RuntimeError("main unavailable")
|
||
|
|
return "backup answer"
|
||
|
|
|
||
|
|
monkeypatch.setattr(service, "_load_chat_slot", fake_load_chat_slot)
|
||
|
|
monkeypatch.setattr(service, "_request_chat_completion", fake_request_chat_completion)
|
||
|
|
|
||
|
|
assert service.complete([{"role": "user", "content": "hello"}], max_attempts=1) == "backup answer"
|
||
|
|
assert service.complete([{"role": "user", "content": "hello again"}], max_attempts=1) == "backup answer"
|
||
|
|
assert calls == ["main", "backup", "backup"]
|