feat: 完善模型管理功能
- 新增模型 API 路由,支持 CRUD 和测试连接 - 支持 MiniMax、GLM、OpenAI Compatible 三种供应商 - 添加连接状态持久化 (untested/connected/disconnected) - 修复 CORS 和数据库模型兼容性问题 - 前端 UI 优化:供应商默认 API 地址自动填充 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
API Dependencies
|
||||
API 依赖项
|
||||
"""
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import Depends
|
||||
from app.core.auth import verify_api_key
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ API v1 Router
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import files, projects, chunks, questions, datasets, eval
|
||||
from app.api.v1 import files, projects, chunks, questions, datasets, eval, models
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -15,3 +15,4 @@ api_router.include_router(chunks.router, prefix="/chunks", tags=["chunks"])
|
||||
api_router.include_router(questions.router, prefix="/questions", tags=["questions"])
|
||||
api_router.include_router(datasets.router, prefix="/datasets", tags=["datasets"])
|
||||
api_router.include_router(eval.router, prefix="/eval", tags=["eval"])
|
||||
api_router.include_router(models.router, prefix="/models", tags=["models"])
|
||||
|
||||
257
backend/app/api/v1/models/__init__.py
Normal file
257
backend/app/api/v1/models/__init__.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Model API Router
|
||||
"""
|
||||
import uuid
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.response import ApiResponse
|
||||
from app.models.models import ModelConfig
|
||||
from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def test_model_connection(model: ModelConfig) -> dict:
|
||||
"""Test model connection by calling the API"""
|
||||
if not model.api_key:
|
||||
return {"success": False, "message": "API Key is missing"}
|
||||
|
||||
api_base = model.api_base or ""
|
||||
provider = model.provider
|
||||
model_name = model.model_name
|
||||
api_key = model.api_key
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai":
|
||||
# OpenAI compatible API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5
|
||||
}
|
||||
)
|
||||
elif provider == "minimax":
|
||||
# MiniMax API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions_v2",
|
||||
headers={
|
||||
**headers,
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
elif provider == "glm":
|
||||
# GLM API test
|
||||
response = await client.post(
|
||||
f"{api_base.rstrip('/')}/chat/completions",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}
|
||||
)
|
||||
else:
|
||||
return {"success": False, "message": f"Unsupported provider: {provider}"}
|
||||
|
||||
if response.status_code == 200:
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
else:
|
||||
return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {"success": False, "message": "Connection timeout"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"Connection failed: {str(e)}"}
|
||||
|
||||
|
||||
# Helper to convert string to UUID
|
||||
def parse_uuid(id_str: str) -> uuid.UUID:
|
||||
"""Parse string to UUID"""
|
||||
try:
|
||||
return uuid.UUID(id_str)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid UUID format")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
async def list_models(db: AsyncSession = Depends(get_db)):
|
||||
"""Get all models"""
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711
|
||||
)
|
||||
models = result.scalars().all()
|
||||
# Convert to Pydantic schema
|
||||
model_responses = [ModelResponse.model_validate(m) for m in models]
|
||||
return ApiResponse(data=model_responses)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new model"""
|
||||
# If setting as default, unset other defaults first
|
||||
if model.is_default == "true":
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(ModelConfig.project_id == None) # noqa: E711
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
db_model = ModelConfig(
|
||||
provider=model.provider,
|
||||
model_name=model.model_name,
|
||||
api_key=model.api_key,
|
||||
api_base=model.api_base,
|
||||
is_default=model.is_default,
|
||||
project_id=None # Global model config
|
||||
)
|
||||
db.add(db_model)
|
||||
await db.commit()
|
||||
await db.refresh(db_model)
|
||||
# Convert to Pydantic schema
|
||||
response = ModelResponse.model_validate(db_model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
async def get_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Get a model by ID"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update a model"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# If setting as default, unset other defaults first
|
||||
if model_update.is_default == "true":
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(
|
||||
ModelConfig.project_id == None, # noqa: E711
|
||||
ModelConfig.id != model_uuid
|
||||
)
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
update_data = model_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(model, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.delete("/{model_id}", response_model=ApiResponse)
|
||||
async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a model"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
await db.delete(model)
|
||||
await db.commit()
|
||||
return ApiResponse(message="Model deleted successfully")
|
||||
|
||||
|
||||
@router.post("/{model_id}/set-default", response_model=ApiResponse)
|
||||
async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Set a model as default"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# Unset all other defaults
|
||||
await db.execute(
|
||||
update(ModelConfig)
|
||||
.where(
|
||||
ModelConfig.project_id == None, # noqa: E711
|
||||
ModelConfig.id != model_uuid
|
||||
)
|
||||
.values(is_default="false")
|
||||
)
|
||||
|
||||
model.is_default = "true"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data=response)
|
||||
|
||||
|
||||
@router.post("/{model_id}/test", response_model=ApiResponse)
|
||||
async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Test model connection"""
|
||||
model_uuid = parse_uuid(model_id)
|
||||
result = await db.execute(
|
||||
select(ModelConfig).where(
|
||||
ModelConfig.id == model_uuid,
|
||||
ModelConfig.project_id == None # noqa: E711
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# Test the connection
|
||||
test_result = await test_model_connection(model)
|
||||
|
||||
# Save connection status to database
|
||||
model.connection_status = "connected" if test_result["success"] else "disconnected"
|
||||
await db.commit()
|
||||
await db.refresh(model)
|
||||
|
||||
# Return updated model
|
||||
response = ModelResponse.model_validate(model)
|
||||
return ApiResponse(data={"test_result": test_result, "model": response})
|
||||
@@ -100,3 +100,8 @@ async def get_db() -> AsyncSession:
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
# Import all models to register them with Base.metadata
|
||||
# This ensures all models are loaded before create_all is called
|
||||
from app.models.models import * # noqa: F401, F403, E402
|
||||
|
||||
@@ -21,6 +21,9 @@ from app.core.database import init_db, close_db
|
||||
from app.core.exceptions import AppException
|
||||
from app.core.logging import logger
|
||||
|
||||
# Import all models to register them with Base.metadata
|
||||
from app.models.models import * # noqa: F401, F403
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add request ID to each request"""
|
||||
@@ -83,7 +86,7 @@ app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# CORS - Configure properly for production
|
||||
# For development, you can use ["*"] but for production, specify exact origins
|
||||
ALLOWED_ORIGINS = settings.ALLOWED_ORIGINS.split(",") if settings.ALLOWED_ORIGINS else ["*"]
|
||||
ALLOWED_ORIGINS = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
||||
@@ -135,12 +135,13 @@ class ModelConfig(Base, UUIDMixin, TimestampMixin):
|
||||
"""Model configuration for LLM providers"""
|
||||
__tablename__ = "model_configs"
|
||||
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
provider = Column(String(50), nullable=False) # openai, anthropic, ollama, custom
|
||||
project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True)
|
||||
provider = Column(String(50), nullable=False) # minimax, glm, openai
|
||||
model_name = Column(String(100))
|
||||
api_key = Column(String(500))
|
||||
api_base = Column(String(500))
|
||||
is_default = Column(String(10), default="false")
|
||||
connection_status = Column(String(20), default="untested") # untested, connected, disconnected
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="model_configs")
|
||||
|
||||
@@ -50,6 +50,13 @@ from app.schemas.eval import (
|
||||
TaskResponse,
|
||||
)
|
||||
|
||||
from app.schemas.model import (
|
||||
ModelBase,
|
||||
ModelCreate,
|
||||
ModelUpdate,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"TimestampMixin",
|
||||
@@ -86,4 +93,9 @@ __all__ = [
|
||||
"EvalDatasetResponse",
|
||||
"TaskBase",
|
||||
"TaskResponse",
|
||||
# Model
|
||||
"ModelBase",
|
||||
"ModelCreate",
|
||||
"ModelUpdate",
|
||||
"ModelResponse",
|
||||
]
|
||||
|
||||
41
backend/app/schemas/model.py
Normal file
41
backend/app/schemas/model.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Model Schema
|
||||
"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class ModelBase(BaseModel):
|
||||
"""Base model schema"""
|
||||
provider: str = Field(..., description="Model provider: minimax, glm, openai")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
api_key: Optional[str] = Field(None, description="API key")
|
||||
api_base: Optional[str] = Field(None, description="API base URL")
|
||||
is_default: str = Field(default="false", description="Is default model: true/false")
|
||||
|
||||
|
||||
class ModelCreate(ModelBase):
|
||||
"""Model creation schema"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelUpdate(BaseModel):
|
||||
"""Model update schema"""
|
||||
provider: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
is_default: Optional[str] = None
|
||||
|
||||
|
||||
class ModelResponse(ModelBase):
|
||||
"""Model response schema"""
|
||||
id: UUID
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
project_id: Optional[UUID] = None
|
||||
connection_status: Optional[str] = Field(default="untested")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
Reference in New Issue
Block a user