Files
YG-Datasets/backend/app/api/v1/models/__init__.py
Developer 6aa271c4f7 refactor: 前端架构重构 - 提取 CSS 和逻辑到独立模块
前端重构:
- 删除旧的大体积 Vue 组件(HomeView, FileManage, TextSplit 等)
- 删除旧的 composables(useFormatters, useModels, useProjects)
- 新增 core/, page-logic/, pages/, shared/ 模块化目录结构
- 提取 CSS 到 styles/pages/ 目录
- 添加全局样式 variables.css 和 common.css

后端 API 更新:
- chunks: 语义分割 API 增强
- files: 文件处理 API 更新
- models: 模型管理 API 更新
- questions: 问答管理 API 更新
- database: 数据库连接优化
- semantic_embedding: 语义嵌入服务优化

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 14:23:34 +08:00

306 lines
11 KiB
Python

"""
Model API Router
"""
import uuid
import httpx
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.core.database import get_db
from app.api.response import ApiResponse
from app.models.models import ModelConfig
from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse
router = APIRouter()
VALID_MODEL_TYPES = {"chat", "vlm", "embedding", "rerank"}
def normalize_model_type(model_type: str | None, model_name: str | None) -> str:
"""Normalize model type, with keyword fallback for legacy records."""
if model_type in VALID_MODEL_TYPES and model_type != "chat":
return model_type
normalized_name = (model_name or "").strip().lower()
rerank_keywords = ("rerank", "bce-reranker", "gte-rerank")
embedding_keywords = (
"embedding",
"embed",
"text-embedding",
"bge-",
"bge_m3",
"gte-",
"m3e",
"e5-",
"jina-embeddings",
)
vlm_keywords = ("vl", "vision", "visual", "multimodal", "qwen-vl", "gpt-4o")
if any(keyword in normalized_name for keyword in rerank_keywords):
return "rerank"
if any(keyword in normalized_name for keyword in embedding_keywords):
return "embedding"
if any(keyword in normalized_name for keyword in vlm_keywords):
return "vlm"
return model_type if model_type in VALID_MODEL_TYPES else "chat"
async def test_model_connection(model: ModelConfig) -> dict:
"""Test model connection by calling the API"""
if not model.api_key:
return {"success": False, "message": "API Key is missing"}
api_base = model.api_base or ""
provider = model.provider
model_name = model.model_name
model_type = normalize_model_type(model.model_type, model_name)
api_key = model.api_key
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
if model_type in {"chat", "vlm"} and provider in {"openai", "ali"}:
# OpenAI compatible API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
headers=headers,
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5
}
)
elif model_type in {"chat", "vlm"} and provider == "minimax":
# MiniMax API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions_v2",
headers={
**headers,
"Authorization": f"Bearer {api_key}"
},
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}]
}
)
elif model_type in {"chat", "vlm"} and provider == "glm":
# GLM API test
response = await client.post(
f"{api_base.rstrip('/')}/chat/completions",
headers=headers,
json={
"model": model_name,
"messages": [{"role": "user", "content": "Hi"}]
}
)
elif model_type == "embedding" and provider in {"openai", "ali", "glm"}:
response = await client.post(
f"{api_base.rstrip('/')}/embeddings",
headers=headers,
json={
"model": model_name,
"input": "test"
}
)
elif model_type == "embedding" and provider == "minimax":
return {"success": False, "message": "MiniMax embedding 自动测试暂未接入,请手动确认端点与模型"}
elif model_type == "rerank":
return {"success": False, "message": "Rerank 自动测试暂未接入,请先保存配置并在实际流程中验证"}
else:
return {"success": False, "message": f"Unsupported provider/type: {provider}/{model_type}"}
if response.status_code == 200:
return {"success": True, "message": "Connection successful"}
else:
return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"}
except httpx.TimeoutException:
return {"success": False, "message": "Connection timeout"}
except Exception as e:
return {"success": False, "message": f"Connection failed: {str(e)}"}
# Helper to convert string to UUID
def parse_uuid(id_str: str) -> uuid.UUID:
"""Parse string to UUID"""
try:
return uuid.UUID(id_str)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid UUID format")
@router.get("", response_model=ApiResponse)
async def list_models(db: AsyncSession = Depends(get_db)):
"""Get all models"""
result = await db.execute(
select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711
)
models = result.scalars().all()
# Convert to Pydantic schema
model_responses = [ModelResponse.model_validate(m) for m in models]
return ApiResponse(data=model_responses)
@router.post("", response_model=ApiResponse)
async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)):
"""Create a new model"""
# If setting as default, unset other defaults first
if model.is_default == "true":
await db.execute(
update(ModelConfig)
.where(ModelConfig.project_id == None) # noqa: E711
.values(is_default="false")
)
db_model = ModelConfig(
provider=model.provider,
model_type=model.model_type,
model_name=model.model_name,
api_key=model.api_key,
api_base=model.api_base,
is_default=model.is_default,
project_id=None # Global model config
)
db.add(db_model)
await db.commit()
await db.refresh(db_model)
# Convert to Pydantic schema
response = ModelResponse.model_validate(db_model)
return ApiResponse(data=response)
@router.get("/{model_id}", response_model=ApiResponse)
async def get_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Get a model by ID"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.put("/{model_id}", response_model=ApiResponse)
async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)):
"""Update a model"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# If setting as default, unset other defaults first
if model_update.is_default == "true":
await db.execute(
update(ModelConfig)
.where(
ModelConfig.project_id == None, # noqa: E711
ModelConfig.id != model_uuid
)
.values(is_default="false")
)
update_data = model_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(model, key, value)
await db.commit()
await db.refresh(model)
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.delete("/{model_id}", response_model=ApiResponse)
async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a model"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await db.delete(model)
await db.commit()
return ApiResponse(message="Model deleted successfully")
@router.post("/{model_id}/set-default", response_model=ApiResponse)
async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Set a model as default"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Unset all other defaults
await db.execute(
update(ModelConfig)
.where(
ModelConfig.project_id == None, # noqa: E711
ModelConfig.id != model_uuid
)
.values(is_default="false")
)
model.is_default = "true"
await db.commit()
await db.refresh(model)
response = ModelResponse.model_validate(model)
return ApiResponse(data=response)
@router.post("/{model_id}/test", response_model=ApiResponse)
async def test_model(model_id: str, db: AsyncSession = Depends(get_db)):
"""Test model connection"""
model_uuid = parse_uuid(model_id)
result = await db.execute(
select(ModelConfig).where(
ModelConfig.id == model_uuid,
ModelConfig.project_id == None # noqa: E711
)
)
model = result.scalar_one_or_none()
if not model:
raise HTTPException(status_code=404, detail="Model not found")
# Test the connection
test_result = await test_model_connection(model)
# Save connection status to database
model.model_type = normalize_model_type(model.model_type, model.model_name)
model.connection_status = "connected" if test_result["success"] else "disconnected"
await db.commit()
await db.refresh(model)
# Return updated model
response = ModelResponse.model_validate(model)
return ApiResponse(data={"test_result": test_result, "model": response})