- 新增模型 API 路由,支持 CRUD 和测试连接 - 支持 MiniMax、GLM、OpenAI Compatible 三种供应商 - 添加连接状态持久化 (untested/connected/disconnected) - 修复 CORS 和数据库模型兼容性问题 - 前端 UI 优化:供应商默认 API 地址自动填充 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
258 lines
8.5 KiB
Python
258 lines
8.5 KiB
Python
"""
|
|
Model API Router
|
|
"""
|
|
import uuid
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update
|
|
|
|
from app.core.database import get_db
|
|
from app.api.response import ApiResponse
|
|
from app.models.models import ModelConfig
|
|
from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
async def test_model_connection(model: ModelConfig) -> dict:
|
|
"""Test model connection by calling the API"""
|
|
if not model.api_key:
|
|
return {"success": False, "message": "API Key is missing"}
|
|
|
|
api_base = model.api_base or ""
|
|
provider = model.provider
|
|
model_name = model.model_name
|
|
api_key = model.api_key
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
if provider == "openai":
|
|
# OpenAI compatible API test
|
|
response = await client.post(
|
|
f"{api_base.rstrip('/')}/chat/completions",
|
|
headers=headers,
|
|
json={
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 5
|
|
}
|
|
)
|
|
elif provider == "minimax":
|
|
# MiniMax API test
|
|
response = await client.post(
|
|
f"{api_base.rstrip('/')}/chat/completions_v2",
|
|
headers={
|
|
**headers,
|
|
"Authorization": f"Bearer {api_key}"
|
|
},
|
|
json={
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": "Hi"}]
|
|
}
|
|
)
|
|
elif provider == "glm":
|
|
# GLM API test
|
|
response = await client.post(
|
|
f"{api_base.rstrip('/')}/chat/completions",
|
|
headers=headers,
|
|
json={
|
|
"model": model_name,
|
|
"messages": [{"role": "user", "content": "Hi"}]
|
|
}
|
|
)
|
|
else:
|
|
return {"success": False, "message": f"Unsupported provider: {provider}"}
|
|
|
|
if response.status_code == 200:
|
|
return {"success": True, "message": "Connection successful"}
|
|
else:
|
|
return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"}
|
|
|
|
except httpx.TimeoutException:
|
|
return {"success": False, "message": "Connection timeout"}
|
|
except Exception as e:
|
|
return {"success": False, "message": f"Connection failed: {str(e)}"}
|
|
|
|
|
|
# Helper to convert string to UUID
|
|
def parse_uuid(id_str: str) -> uuid.UUID:
|
|
"""Parse string to UUID"""
|
|
try:
|
|
return uuid.UUID(id_str)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid UUID format")
|
|
|
|
|
|
@router.get("", response_model=ApiResponse)
|
|
async def list_models(db: AsyncSession = Depends(get_db)):
|
|
"""Get all models"""
|
|
result = await db.execute(
|
|
select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711
|
|
)
|
|
models = result.scalars().all()
|
|
# Convert to Pydantic schema
|
|
model_responses = [ModelResponse.model_validate(m) for m in models]
|
|
return ApiResponse(data=model_responses)
|
|
|
|
|
|
@router.post("", response_model=ApiResponse)
|
|
async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
|
|
"""Create a new model"""
|
|
# If setting as default, unset other defaults first
|
|
if model.is_default == "true":
|
|
await db.execute(
|
|
update(ModelConfig)
|
|
.where(ModelConfig.project_id == None) # noqa: E711
|
|
.values(is_default="false")
|
|
)
|
|
|
|
db_model = ModelConfig(
|
|
provider=model.provider,
|
|
model_name=model.model_name,
|
|
api_key=model.api_key,
|
|
api_base=model.api_base,
|
|
is_default=model.is_default,
|
|
project_id=None # Global model config
|
|
)
|
|
db.add(db_model)
|
|
await db.commit()
|
|
await db.refresh(db_model)
|
|
# Convert to Pydantic schema
|
|
response = ModelResponse.model_validate(db_model)
|
|
return ApiResponse(data=response)
|
|
|
|
|
|
@router.get("/{model_id}", response_model=ApiResponse)
|
|
async def get_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""Get a model by ID"""
|
|
model_uuid = parse_uuid(model_id)
|
|
result = await db.execute(
|
|
select(ModelConfig).where(
|
|
ModelConfig.id == model_uuid,
|
|
ModelConfig.project_id == None # noqa: E711
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
response = ModelResponse.model_validate(model)
|
|
return ApiResponse(data=response)
|
|
|
|
|
|
@router.put("/{model_id}", response_model=ApiResponse)
|
|
async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)):
|
|
"""Update a model"""
|
|
model_uuid = parse_uuid(model_id)
|
|
result = await db.execute(
|
|
select(ModelConfig).where(
|
|
ModelConfig.id == model_uuid,
|
|
ModelConfig.project_id == None # noqa: E711
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# If setting as default, unset other defaults first
|
|
if model_update.is_default == "true":
|
|
await db.execute(
|
|
update(ModelConfig)
|
|
.where(
|
|
ModelConfig.project_id == None, # noqa: E711
|
|
ModelConfig.id != model_uuid
|
|
)
|
|
.values(is_default="false")
|
|
)
|
|
|
|
update_data = model_update.model_dump(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(model, key, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
response = ModelResponse.model_validate(model)
|
|
return ApiResponse(data=response)
|
|
|
|
|
|
@router.delete("/{model_id}", response_model=ApiResponse)
|
|
async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""Delete a model"""
|
|
model_uuid = parse_uuid(model_id)
|
|
result = await db.execute(
|
|
select(ModelConfig).where(
|
|
ModelConfig.id == model_uuid,
|
|
ModelConfig.project_id == None # noqa: E711
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
await db.delete(model)
|
|
await db.commit()
|
|
return ApiResponse(message="Model deleted successfully")
|
|
|
|
|
|
@router.post("/{model_id}/set-default", response_model=ApiResponse)
|
|
async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""Set a model as default"""
|
|
model_uuid = parse_uuid(model_id)
|
|
result = await db.execute(
|
|
select(ModelConfig).where(
|
|
ModelConfig.id == model_uuid,
|
|
ModelConfig.project_id == None # noqa: E711
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# Unset all other defaults
|
|
await db.execute(
|
|
update(ModelConfig)
|
|
.where(
|
|
ModelConfig.project_id == None, # noqa: E711
|
|
ModelConfig.id != model_uuid
|
|
)
|
|
.values(is_default="false")
|
|
)
|
|
|
|
model.is_default = "true"
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
response = ModelResponse.model_validate(model)
|
|
return ApiResponse(data=response)
|
|
|
|
|
|
@router.post("/{model_id}/test", response_model=ApiResponse)
|
|
async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
|
"""Test model connection"""
|
|
model_uuid = parse_uuid(model_id)
|
|
result = await db.execute(
|
|
select(ModelConfig).where(
|
|
ModelConfig.id == model_uuid,
|
|
ModelConfig.project_id == None # noqa: E711
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# Test the connection
|
|
test_result = await test_model_connection(model)
|
|
|
|
# Save connection status to database
|
|
model.connection_status = "connected" if test_result["success"] else "disconnected"
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
|
|
# Return updated model
|
|
response = ModelResponse.model_validate(model)
|
|
return ApiResponse(data={"test_result": test_result, "model": response})
|