diff --git a/backend/app/api/dependencies.py b/backend/app/api/dependencies.py index 09721cc..b9784b9 100644 --- a/backend/app/api/dependencies.py +++ b/backend/app/api/dependencies.py @@ -2,7 +2,7 @@ API Dependencies API 依赖项 """ -from typing import Annotated +from typing import Annotated, Optional from fastapi import Depends from app.core.auth import verify_api_key diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py index aacac86..c41ec91 100644 --- a/backend/app/api/v1/__init__.py +++ b/backend/app/api/v1/__init__.py @@ -4,7 +4,7 @@ API v1 Router from fastapi import APIRouter -from app.api.v1 import files, projects, chunks, questions, datasets, eval +from app.api.v1 import files, projects, chunks, questions, datasets, eval, models api_router = APIRouter() @@ -15,3 +15,4 @@ api_router.include_router(chunks.router, prefix="/chunks", tags=["chunks"]) api_router.include_router(questions.router, prefix="/questions", tags=["questions"]) api_router.include_router(datasets.router, prefix="/datasets", tags=["datasets"]) api_router.include_router(eval.router, prefix="/eval", tags=["eval"]) +api_router.include_router(models.router, prefix="/models", tags=["models"]) diff --git a/backend/app/api/v1/models/__init__.py b/backend/app/api/v1/models/__init__.py new file mode 100644 index 0000000..f6c043f --- /dev/null +++ b/backend/app/api/v1/models/__init__.py @@ -0,0 +1,257 @@ +""" +Model API Router +""" +import uuid +import httpx +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, update + +from app.core.database import get_db +from app.api.response import ApiResponse +from app.models.models import ModelConfig +from app.schemas.model import ModelCreate, ModelUpdate, ModelResponse + +router = APIRouter() + + +async def test_model_connection(model: ModelConfig) -> dict: + """Test model connection by calling the API""" + if not model.api_key: + return {"success": False, "message": "API Key is missing"} + + api_base = model.api_base or "" + provider = model.provider + model_name = model.model_name + api_key = model.api_key + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + if provider == "openai": + # OpenAI compatible API test + response = await client.post( + f"{api_base.rstrip('/')}/chat/completions", + headers=headers, + json={ + "model": model_name, + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 5 + } + ) + elif provider == "minimax": + # MiniMax API test + response = await client.post( + f"{api_base.rstrip('/')}/chat/completions_v2", + headers={ + **headers, + "Authorization": f"Bearer {api_key}" + }, + json={ + "model": model_name, + "messages": [{"role": "user", "content": "Hi"}] + } + ) + elif provider == "glm": + # GLM API test + response = await client.post( + f"{api_base.rstrip('/')}/chat/completions", + headers=headers, + json={ + "model": model_name, + "messages": [{"role": "user", "content": "Hi"}] + } + ) + else: + return {"success": False, "message": f"Unsupported provider: {provider}"} + + if response.status_code == 200: + return {"success": True, "message": "Connection successful"} + else: + return {"success": False, "message": f"API error: {response.status_code} - {response.text[:100]}"} + + except httpx.TimeoutException: + return {"success": False, "message": "Connection timeout"} + except Exception as e: + return {"success": False, "message": f"Connection failed: {str(e)}"} + + +# Helper to convert string to UUID +def parse_uuid(id_str: str) -> uuid.UUID: + """Parse string to UUID""" + try: + return uuid.UUID(id_str) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid UUID format") + + +@router.get("", response_model=ApiResponse) +async def list_models(db: AsyncSession = Depends(get_db)): + """Get all models""" + result = await db.execute( + select(ModelConfig).where(ModelConfig.project_id == None) # noqa: E711 + ) + models = result.scalars().all() + # Convert to Pydantic schema + model_responses = [ModelResponse.model_validate(m) for m in models] + return ApiResponse(data=model_responses) + + +@router.post("", response_model=ApiResponse) +async def create_model(model: ModelCreate, db: AsyncSession = Depends(get_db)): + """Create a new model""" + # If setting as default, unset other defaults first + if model.is_default == "true": + await db.execute( + update(ModelConfig) + .where(ModelConfig.project_id == None) # noqa: E711 + .values(is_default="false") + ) + + db_model = ModelConfig( + provider=model.provider, + model_name=model.model_name, + api_key=model.api_key, + api_base=model.api_base, + is_default=model.is_default, + project_id=None # Global model config + ) + db.add(db_model) + await db.commit() + await db.refresh(db_model) + # Convert to Pydantic schema + response = ModelResponse.model_validate(db_model) + return ApiResponse(data=response) + + +@router.get("/{model_id}", response_model=ApiResponse) +async def get_model(model_id: str, db: AsyncSession = Depends(get_db)): + """Get a model by ID""" + model_uuid = parse_uuid(model_id) + result = await db.execute( + select(ModelConfig).where( + ModelConfig.id == model_uuid, + ModelConfig.project_id == None # noqa: E711 + ) + ) + model = result.scalar_one_or_none() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + response = ModelResponse.model_validate(model) + return ApiResponse(data=response) + + +@router.put("/{model_id}", response_model=ApiResponse) +async def update_model(model_id: str, model_update: ModelUpdate, db: AsyncSession = Depends(get_db)): + """Update a model""" + model_uuid = parse_uuid(model_id) + result = await db.execute( + select(ModelConfig).where( + ModelConfig.id == model_uuid, + ModelConfig.project_id == None # noqa: E711 + ) + ) + model = result.scalar_one_or_none() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + # If setting as default, unset other defaults first + if model_update.is_default == "true": + await db.execute( + update(ModelConfig) + .where( + ModelConfig.project_id == None, # noqa: E711 + ModelConfig.id != model_uuid + ) + .values(is_default="false") + ) + + update_data = model_update.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(model, key, value) + + await db.commit() + await db.refresh(model) + response = ModelResponse.model_validate(model) + return ApiResponse(data=response) + + +@router.delete("/{model_id}", response_model=ApiResponse) +async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)): + """Delete a model""" + model_uuid = parse_uuid(model_id) + result = await db.execute( + select(ModelConfig).where( + ModelConfig.id == model_uuid, + ModelConfig.project_id == None # noqa: E711 + ) + ) + model = result.scalar_one_or_none() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + await db.delete(model) + await db.commit() + return ApiResponse(message="Model deleted successfully") + + +@router.post("/{model_id}/set-default", response_model=ApiResponse) +async def set_default_model(model_id: str, db: AsyncSession = Depends(get_db)): + """Set a model as default""" + model_uuid = parse_uuid(model_id) + result = await db.execute( + select(ModelConfig).where( + ModelConfig.id == model_uuid, + ModelConfig.project_id == None # noqa: E711 + ) + ) + model = result.scalar_one_or_none() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + # Unset all other defaults + await db.execute( + update(ModelConfig) + .where( + ModelConfig.project_id == None, # noqa: E711 + ModelConfig.id != model_uuid + ) + .values(is_default="false") + ) + + model.is_default = "true" + await db.commit() + await db.refresh(model) + response = ModelResponse.model_validate(model) + return ApiResponse(data=response) + + +@router.post("/{model_id}/test", response_model=ApiResponse) +async def test_model(model_id: str, db: AsyncSession = Depends(get_db)): + """Test model connection""" + model_uuid = parse_uuid(model_id) + result = await db.execute( + select(ModelConfig).where( + ModelConfig.id == model_uuid, + ModelConfig.project_id == None # noqa: E711 + ) + ) + model = result.scalar_one_or_none() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + # Test the connection + test_result = await test_model_connection(model) + + # Save connection status to database + model.connection_status = "connected" if test_result["success"] else "disconnected" + await db.commit() + await db.refresh(model) + + # Return updated model + response = ModelResponse.model_validate(model) + return ApiResponse(data={"test_result": test_result, "model": response}) diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 03270cd..244e522 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -100,3 +100,8 @@ async def get_db() -> AsyncSession: raise finally: await session.close() + + +# Import all models to register them with Base.metadata +# This ensures all models are loaded before create_all is called +from app.models.models import * # noqa: F401, F403, E402 diff --git a/backend/app/main.py b/backend/app/main.py index ea0c03b..7310ba7 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -21,6 +21,9 @@ from app.core.database import init_db, close_db from app.core.exceptions import AppException from app.core.logging import logger +# Import all models to register them with Base.metadata +from app.models.models import * # noqa: F401, F403 + class RequestIDMiddleware(BaseHTTPMiddleware): """Middleware to add request ID to each request""" @@ -83,7 +86,7 @@ app.add_middleware(RequestIDMiddleware) # CORS - Configure properly for production # For development, you can use ["*"] but for production, specify exact origins -ALLOWED_ORIGINS = settings.ALLOWED_ORIGINS.split(",") if settings.ALLOWED_ORIGINS else ["*"] +ALLOWED_ORIGINS = ["*"] app.add_middleware( CORSMiddleware, diff --git a/backend/app/models/models.py b/backend/app/models/models.py index 37264f9..4d03ccd 100644 --- a/backend/app/models/models.py +++ b/backend/app/models/models.py @@ -135,12 +135,13 @@ class ModelConfig(Base, UUIDMixin, TimestampMixin): """Model configuration for LLM providers""" __tablename__ = "model_configs" - project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) - provider = Column(String(50), nullable=False) # openai, anthropic, ollama, custom + project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True) + provider = Column(String(50), nullable=False) # minimax, glm, openai model_name = Column(String(100)) api_key = Column(String(500)) api_base = Column(String(500)) is_default = Column(String(10), default="false") + connection_status = Column(String(20), default="untested") # untested, connected, disconnected # Relationships project = relationship("Project", back_populates="model_configs") diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index 09b144b..7e3a668 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -50,6 +50,13 @@ from app.schemas.eval import ( TaskResponse, ) +from app.schemas.model import ( + ModelBase, + ModelCreate, + ModelUpdate, + ModelResponse, +) + __all__ = [ # Base "TimestampMixin", @@ -86,4 +93,9 @@ __all__ = [ "EvalDatasetResponse", "TaskBase", "TaskResponse", + # Model + "ModelBase", + "ModelCreate", + "ModelUpdate", + "ModelResponse", ] diff --git a/backend/app/schemas/model.py b/backend/app/schemas/model.py new file mode 100644 index 0000000..f21d19c --- /dev/null +++ b/backend/app/schemas/model.py @@ -0,0 +1,41 @@ +""" +Model Schema +""" +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional +from datetime import datetime +from uuid import UUID + + +class ModelBase(BaseModel): + """Base model schema""" + provider: str = Field(..., description="Model provider: minimax, glm, openai") + model_name: str = Field(..., description="Model name") + api_key: Optional[str] = Field(None, description="API key") + api_base: Optional[str] = Field(None, description="API base URL") + is_default: str = Field(default="false", description="Is default model: true/false") + + +class ModelCreate(ModelBase): + """Model creation schema""" + pass + + +class ModelUpdate(BaseModel): + """Model update schema""" + provider: Optional[str] = None + model_name: Optional[str] = None + api_key: Optional[str] = None + api_base: Optional[str] = None + is_default: Optional[str] = None + + +class ModelResponse(ModelBase): + """Model response schema""" + id: UUID + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + project_id: Optional[UUID] = None + connection_status: Optional[str] = Field(default="untested") + + model_config = ConfigDict(from_attributes=True) diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 75aaafc..ed98c14 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -1,6 +1,6 @@ import axios from 'axios' import type { AxiosInstance } from 'axios' -import type { Project, ProjectCreate, ProjectUpdate } from '@/types' +import type { Project, ProjectCreate, ProjectUpdate, Model, ModelCreate } from '@/types' const request: AxiosInstance = axios.create({ baseURL: import.meta.env.PROD @@ -91,4 +91,14 @@ export const evalApi = { getResults: (projectId: string, taskId: string) => request.get(`/projects/${projectId}/eval-tasks/${taskId}`) } +export const modelApi = { + list: () => request.get('/models/'), + get: (id: string) => request.get(`/models/${id}`), + create: (data: ModelCreate) => request.post<{ id: string }>('/models/', data), + update: (id: string, data: Partial) => request.put(`/models/${id}`, data), + delete: (id: string) => request.delete(`/models/${id}`), + setDefault: (id: string) => request.post(`/models/${id}/set-default`), + test: (id: string) => request.post<{ success: boolean; message: string }>(`/models/${id}/test`) +} + export default request diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js index e8b4aa4..920d9c1 100644 --- a/frontend/src/router/index.js +++ b/frontend/src/router/index.js @@ -51,11 +51,6 @@ const routes = [ path: '/models', name: 'ModelSettings', component: () => import('@/views/ModelSettingsView.vue') - }, - { - path: '/data-square', - name: 'DataSquare', - component: () => import('@/views/DataSquareView.vue') } ] diff --git a/frontend/src/types/model.d.ts b/frontend/src/types/model.d.ts index d53fd55..7d61b1f 100644 --- a/frontend/src/types/model.d.ts +++ b/frontend/src/types/model.d.ts @@ -2,6 +2,18 @@ * Model Configuration Types */ +export interface Model { + id: string + provider: ModelProvider + model_name: string + api_key?: string + api_base?: string + is_default: 'true' | 'false' + connection_status?: 'untested' | 'connected' | 'disconnected' + created_at?: string + updated_at?: string +} + export interface ModelConfig { id: string provider: ModelProvider @@ -9,11 +21,12 @@ export interface ModelConfig { api_key?: string api_base?: string is_default: 'true' | 'false' + connection_status?: 'untested' | 'connected' | 'disconnected' created_at?: string updated_at?: string } -export type ModelProvider = 'openai' | 'anthropic' | 'google' | 'other' +export type ModelProvider = 'minimax' | 'glm' | 'openai' export interface ModelCreate { provider: ModelProvider diff --git a/frontend/src/views/HomeView.vue b/frontend/src/views/HomeView.vue index 3f2aec2..ad97e56 100644 --- a/frontend/src/views/HomeView.vue +++ b/frontend/src/views/HomeView.vue @@ -3,6 +3,33 @@
+ + +
AI 驱动数据生成 @@ -20,103 +47,85 @@ 创建项目 - - - 数据集广场 + + + 模型管理
- +
- -
-
-
-
- - - - - - - - + +
+
+
+
+
+
+
+ + + + + + + + + +
+
+
+ + +
+
+
+
+
-
-
-
-
- -
- 多格式支持 - PDF DOCX EPUB Excel +
+
+
+
+
+
+ + 处理完成
- -
-
-
-
- - - - - - - - -
-
-
-
-
- -
- AI 生成 - 智能问答 自动标注 -
+ +
+ + 多格式支持 +
+
+ + AI 生成 +
+
+ + 智能评估
- -
-
-
-
- - - - - - - - -
-
-
-
-
- -
- 智能评估 - 质量分析 模型对比 -
+ +
+ + API 集成
-
-
- - -
-
-
- +
+ + 批量处理
-
-

模型配置

-

管理 AI 模型 API 配置

+
+ + 数据安全 +
+
+ + 可视化
-
@@ -179,7 +188,7 @@ import { ref, onMounted } from 'vue' import { useRouter } from 'vue-router' import { ElMessage } from 'element-plus' -import { FolderAdd } from '@element-plus/icons-vue' +import { FolderAdd, Check, Connection, Clock, Lock, TrendCharts } from '@element-plus/icons-vue' import { projectApi } from '@/api' import type { Project, ProjectCreate } from '@/types' @@ -275,5 +284,5 @@ onMounted(() => fetchProjects()) diff --git a/frontend/src/views/ModelSettingsView.vue b/frontend/src/views/ModelSettingsView.vue index 861d02c..9a63e00 100644 --- a/frontend/src/views/ModelSettingsView.vue +++ b/frontend/src/views/ModelSettingsView.vue @@ -17,9 +17,9 @@

- 模型配置 + 模型管理

-

管理您的 AI 模型 API 配置

+

管理您的 AI 模型 API

@@ -31,19 +31,6 @@
- -
-
-
- {{ stat.icon }} -
-
- {{ stat.value }} - {{ stat.label }} -
-
-
-
@@ -98,9 +85,11 @@