feat: 完善模型管理功能
- 新增模型 API 路由,支持 CRUD 和测试连接 - 支持 MiniMax、GLM、OpenAI Compatible 三种供应商 - 添加连接状态持久化 (untested/connected/disconnected) - 修复 CORS 和数据库模型兼容性问题 - 前端 UI 优化:供应商默认 API 地址自动填充 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ API v1 Router
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import files, projects, chunks, questions, datasets, eval
|
||||
from app.api.v1 import files, projects, chunks, questions, datasets, eval, models
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -15,3 +15,4 @@ api_router.include_router(chunks.router, prefix="/chunks", tags=["chunks"])
|
||||
api_router.include_router(questions.router, prefix="/questions", tags=["questions"])
|
||||
api_router.include_router(datasets.router, prefix="/datasets", tags=["datasets"])
|
||||
api_router.include_router(eval.router, prefix="/eval", tags=["eval"])
|
||||
api_router.include_router(models.router, prefix="/models", tags=["models"])
|
||||
|
||||
257
backend/app/api/v1/models/__init__.py
Normal file
257
backend/app/api/v1/models/__init__.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Model API Router
|
||||
"""
|
||||
import uuid
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.response import ApiResponse
|
||||
from app.models.models import ModelConfig
|
||||
from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"""Test model connection by calling the API"""
|
||||
if not model.api_key:
|
||||
return {"success": False, "message": "API Key is missing"}
|
||||
|
||||
api_base = model.api_base or ""
|
||||
provider = model.provider
|
||||
model_name = model.model_name
|
||||
api_key = model.api_key
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai":
|
||||
# OpenAI compatible API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5
|
||||
}
|
||||
)
|
||||
elif provider == "minimax":
|
||||
# MiniMax API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions_v2",
|
||||
headers={
|
||||
**headers,
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif provider == "glm":
|
||||
# GLM API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
else:
|
||||
return {"success": False, "message": f"Unsupported provider: {provider}"}
|
||||
|
||||
if response.status_code == 200:
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
else:
|
||||
return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {"success": False, "message": "Connection timeout"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"Connection failed: {str(e)}"}
|
||||
|
||||
|
||||
# Helper to convert string to UUID
|
||||
def parse_uuid(id_str: str) -> uuid.UUID:
|
||||
"""Parse string to UUID"""
|
||||
try:
|
||||
return uuid.UUID(id_str)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid UUID format")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
async def list_models(db: AsyncSession = Depends(get_db)):
|
||||
"""Get all models"""
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
models = result.scalars().all()
|
||||
# Convert to Pydantic schema
|
||||
model_responses = [ModelResponse.model_validate(m) for m in models]
|
||||
return ApiResponse(data=model_responses)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new model"""
|
||||
# If setting as default, unset other defaults first
|
||||
if model.is_default == "true":
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(ModelConfig.project_id == None) # noqa: E711
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
db_model = ModelConfig(
|
||||
provider=model.provider,
|
||||
model_name=model.model_name,
|
||||
api_key=model.api_key,
|
||||
api_base=model.api_base,
|
||||
is_default=model.is_default,
|
||||
project_id=None # Global model config
|
||||
)
|
||||
db.add(db_model)
|
||||
await db.commit()
|
||||
await db.refresh(db_model)
|
||||
# Convert to Pydantic schema
|
||||
response = ModelResponse.model_validate(db_model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
async def get_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get a model by ID"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update a model"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# If setting as default, unset other defaults first
|
||||
if model_update.is_default == "true":
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(
|
||||
ModelConfig.project_id == None, # noqa: E711
|
||||
ModelConfig.id != model_uuid
|
||||
)
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
update_data = model_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(model, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.delete("/{model_id}", response_model=ApiResponse)
|
||||
async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a model"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
await db.delete(model)
|
||||
await db.commit()
|
||||
return ApiResponse(message="Model deleted successfully")
|
||||
|
||||
|
||||
@router.post("/{model_id}/set-default", response_model=ApiResponse)
|
||||
async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Set a model as default"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# Unset all other defaults
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(
|
||||
ModelConfig.project_id == None, # noqa: E711
|
||||
ModelConfig.id != model_uuid
|
||||
)
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
model.is_default = "true"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.post("/{model_id}/test", response_model=ApiResponse)
|
||||
async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Test model connection"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# Test the connection
|
||||
test_result = await test_model_connection(model)
|
||||
|
||||
# Save connection status to database
|
||||
model.connection_status = "connected" if test_result["success"] else "disconnected"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
|
||||
# Return updated model
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data={"test_result": test_result, "model": response})
|
||||
Reference in New Issue
Block a user