Files
YG-Datasets/backend/app/api/v1/models/__init__.py
Developer 7514e7e763 feat: 完善模型管理功能
- 新增模型 API 路由,支持 CRUD 和测试连接
- 支持 MiniMax、GLM、OpenAI Compatible 三种供应商
- 添加连接状态持久化 (untested/connected/disconnected)
- 修复 CORS 和数据库模型兼容性问题
- 前端 UI 优化:供应商默认 API 地址自动填充

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 23:02:43 +08:00

258 lines
8.5 KiB
Python

"""
Model API Router
"""
import uuid
import httpx
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.core.database import get_db
from app.api.response import ApiResponse
from app.models.models import ModelConfig
from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
router = APIRouter()
async def test_model_connection(model: ModelConfig) -> dict:
"""Test model connection by calling the API"""
if not model.api_key:
return {"success": False, "message": "API Key is missing"}
api_base = model.api_base or ""
provider = model.provider
model_name = model.model_name
api_key = model.api_key
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
if provider == "openai":
# OpenAI compatible API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
headers=headers,
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5
}
)
elif provider == "minimax":
# MiniMax API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions_v2",
headers={
**headers,
"Authorization": f"Bearer {api_key}"
},
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}]
}
)
elif provider == "glm":
# GLM API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
headers=headers,
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}]
}
)
else:
return {"success": False, "message": f"Unsupported provider: {provider}"}
if response.status_code == 200:
return {"success": True, "message": "Connection successful"}
else:
return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"}
except httpx.TimeoutException:
return {"success": False, "message": "Connection timeout"}
except Exception as e:
return {"success": False, "message": f"Connection failed: {str(e)}"}
# Helper to convert string to UUID
def parse_uuid(id_str: str) -> uuid.UUID:
"""Parse string to UUID"""
try:
return uuid.UUID(id_str)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid UUID format")
@router.get("", response_model=ApiResponse)
async def list_models(db: AsyncSession = Depends(get_db)):
"""Get all models"""
result = await db.execute(
select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711
)
models = result.scalars().all()
# Convert to Pydantic schema
model_responses = [ModelResponse.model_validate(m) for m in models]
return ApiResponse(data=model_responses)
@router.post("", response_model=ApiResponse)
async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
"""Create a new model"""
# If setting as default, unset other defaults first
if model.is_default == "true":
await db.execute(
update(ModelConfig)
.where(ModelConfig.project_id == None) # noqa: E711
.values(is_default="false")
)
db_model = ModelConfig(
provider=model.provider,
model_name=model.model_name,
api_key=model.api_key,
api_base=model.api_base,
is_default=model.is_default,
project_id=None # Global model config
)
db.add(db_model)
await db.commit()
await db.refresh(db_model)
# Convert to Pydantic schema
response = ModelResponse.model_validate(db_model)
return ApiResponse(data=response)
@router.get("/{model_id}", response_model=ApiResponse)
async def get_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Get a model by ID"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.put("/{model_id}", response_model=ApiResponse)
async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)):
"""Update a model"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# If setting as default, unset other defaults first
if model_update.is_default == "true":
await db.execute(
update(ModelConfig)
.where(
ModelConfig.project_id == None, # noqa: E711
ModelConfig.id != model_uuid
)
.values(is_default="false")
)
update_data = model_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(model, key, value)
await db.commit()
await db.refresh(model)
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.delete("/{model_id}", response_model=ApiResponse)
async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a model"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await db.delete(model)
await db.commit()
return ApiResponse(message="Model deleted successfully")
@router.post("/{model_id}/set-default", response_model=ApiResponse)
async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Set a model as default"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Unset all other defaults
await db.execute(
update(ModelConfig)
.where(
ModelConfig.project_id == None, # noqa: E711
ModelConfig.id != model_uuid
)
.values(is_default="false")
)
model.is_default = "true"
await db.commit()
await db.refresh(model)
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.post("/{model_id}/test", response_model=ApiResponse)
async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Test model connection"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Test the connection
test_result = await test_model_connection(model)
# Save connection status to database
model.connection_status = "connected" if test_result["success"] else "disconnected"
await db.commit()
await db.refresh(model)
# Return updated model
response = ModelResponse.model_validate(model)
return ApiResponse(data={"test_result": test_result, "model": response})