feat: 新增Model管理模块
- 添加 Model 实体定义 - 实现 Model CRUD 接口 - 添加 Model 仓储层和服务层 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
108
server/internal/handler/model_handler.go
Normal file
108
server/internal/handler/model_handler.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ModelHandler 模型处理器
|
||||
type ModelHandler struct {
|
||||
service *service.ModelService
|
||||
}
|
||||
|
||||
func NewModelHandler(svc *service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{service: svc}
|
||||
}
|
||||
|
||||
// List 获取列表
|
||||
func (h *ModelHandler) List(c *gin.Context) {
|
||||
list, err := h.service.List()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if list == nil {
|
||||
list = []model.ModelInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"list": list})
|
||||
}
|
||||
|
||||
// GetByID 获取详情
|
||||
func (h *ModelHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
model, err := h.service.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Model not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model)
|
||||
}
|
||||
|
||||
// Create 创建
|
||||
func (h *ModelHandler) Create(c *gin.Context) {
|
||||
var req model.CreateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.service.Create(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (h *ModelHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req model.UpdateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.service.Update(id, req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (h *ModelHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
err := h.service.Delete(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// Test 测试连接
|
||||
func (h *ModelHandler) Test(c *gin.Context) {
|
||||
var req model.TestModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.service.TestConnection(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
69
server/internal/model/model_info.go
Normal file
69
server/internal/model/model_info.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id" gorm:"primaryKey;type:varchar(36)"`
|
||||
Name string `json:"name" gorm:"type:varchar(255);not null"`
|
||||
ModelType string `json:"model_type" gorm:"type:varchar(50);not null"` // chat/embedding/rerank/vlm
|
||||
Provider string `json:"provider" gorm:"type:varchar(50);not null"` // OpenAI/Ollama
|
||||
Model string `json:"model" gorm:"type:varchar(255);not null"` // 模型标识
|
||||
APIKey string `json:"api_key" gorm:"type:text"` // API 密钥
|
||||
BaseURL string `json:"base_url" gorm:"type:varchar(500)"` // 基础 URL
|
||||
APIEndpoint string `json:"api_endpoint" gorm:"type:varchar(500)"` // API 端点路径
|
||||
Status string `json:"status" gorm:"type:varchar(20);default:active"` // active/inactive
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (ModelInfo) TableName() string {
|
||||
return "model_info"
|
||||
}
|
||||
|
||||
// ModelListRequest 获取模型列表请求
|
||||
type ModelListRequest struct {
|
||||
}
|
||||
|
||||
// ModelListResponse 获取模型列表响应
|
||||
type ModelListResponse struct {
|
||||
List []ModelInfo `json:"list"`
|
||||
}
|
||||
|
||||
// CreateModelRequest 创建模型请求
|
||||
type CreateModelRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ModelType string `json:"model_type" binding:"required"`
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
Model string `json:"model" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
APIEndpoint string `json:"api_endpoint"`
|
||||
}
|
||||
|
||||
// UpdateModelRequest 更新模型请求
|
||||
type UpdateModelRequest struct {
|
||||
Name string `json:"name"`
|
||||
ModelType string `json:"model_type"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIEndpoint string `json:"api_endpoint"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// TestModelRequest 测试模型连接请求
|
||||
type TestModelRequest struct {
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
Model string `json:"model" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
APIEndpoint string `json:"api_endpoint"`
|
||||
}
|
||||
|
||||
// TestModelResponse 测试模型连接响应
|
||||
type TestModelResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
53
server/internal/repository/model_repo.go
Normal file
53
server/internal/repository/model_repo.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"x-agents/server/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ModelRepository 模型仓储
|
||||
type ModelRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewModelRepository(db *gorm.DB) *ModelRepository {
|
||||
return &ModelRepository{db: db}
|
||||
}
|
||||
|
||||
// FindAll 获取所有模型
|
||||
func (r *ModelRepository) FindAll() ([]model.ModelInfo, error) {
|
||||
var models []model.ModelInfo
|
||||
err := r.db.Order("created_at desc").Find(&models).Error
|
||||
return models, err
|
||||
}
|
||||
|
||||
// FindByID 根据 ID 获取模型
|
||||
func (r *ModelRepository) FindByID(id string) (*model.ModelInfo, error) {
|
||||
var model model.ModelInfo
|
||||
err := r.db.Where("id = ?", id).First(&model).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
// Create 创建模型
|
||||
func (r *ModelRepository) Create(info *model.ModelInfo) error {
|
||||
return r.db.Create(info).Error
|
||||
}
|
||||
|
||||
// Update 更新模型
|
||||
func (r *ModelRepository) Update(id string, info *model.ModelInfo) error {
|
||||
return r.db.Where("id = ?", id).Updates(info).Error
|
||||
}
|
||||
|
||||
// Delete 删除模型
|
||||
func (r *ModelRepository) Delete(id string) error {
|
||||
return r.db.Where("id = ?", id).Delete(&model.ModelInfo{}).Error
|
||||
}
|
||||
|
||||
// UpdateFields 更新指定字段
|
||||
func (r *ModelRepository) UpdateFields(id string, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.ModelInfo{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
166
server/internal/service/model_service.go
Normal file
166
server/internal/service/model_service.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ModelService 模型服务
|
||||
type ModelService struct {
|
||||
repo *repository.ModelRepository
|
||||
}
|
||||
|
||||
func NewModelService(repo *repository.ModelRepository) *ModelService {
|
||||
return &ModelService{repo: repo}
|
||||
}
|
||||
|
||||
// List 获取模型列表
|
||||
func (s *ModelService) List() ([]model.ModelInfo, error) {
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取模型
|
||||
func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) {
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// Create 创建模型
|
||||
func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) {
|
||||
info := &model.ModelInfo{
|
||||
ID: uuid.New().String(),
|
||||
Name: req.Name,
|
||||
ModelType: req.ModelType,
|
||||
Provider: req.Provider,
|
||||
Model: req.Model,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
APIEndpoint: req.APIEndpoint,
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
if err := s.repo.Create(info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Update 更新模型
|
||||
func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) {
|
||||
// 检查是否存在
|
||||
_, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found")
|
||||
}
|
||||
|
||||
// 构建更新字段
|
||||
fields := make(map[string]interface{})
|
||||
if req.Name != "" {
|
||||
fields["name"] = req.Name
|
||||
}
|
||||
if req.ModelType != "" {
|
||||
fields["model_type"] = req.ModelType
|
||||
}
|
||||
if req.Provider != "" {
|
||||
fields["provider"] = req.Provider
|
||||
}
|
||||
if req.Model != "" {
|
||||
fields["model"] = req.Model
|
||||
}
|
||||
if req.APIKey != "" {
|
||||
fields["api_key"] = req.APIKey
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
fields["base_url"] = req.BaseURL
|
||||
}
|
||||
if req.APIEndpoint != "" {
|
||||
fields["api_endpoint"] = req.APIEndpoint
|
||||
}
|
||||
if req.Status != "" {
|
||||
fields["status"] = req.Status
|
||||
}
|
||||
|
||||
if err := s.repo.UpdateFields(id, fields); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// Delete 删除模型
|
||||
func (s *ModelService) Delete(id string) error {
|
||||
// 检查是否存在
|
||||
_, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model not found")
|
||||
}
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
|
||||
// TestConnection 测试模型连接
|
||||
func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) {
|
||||
// 构建请求 URL
|
||||
baseURL := req.BaseURL
|
||||
if req.APIEndpoint != "" {
|
||||
baseURL = baseURL + req.APIEndpoint
|
||||
} else {
|
||||
// 默认端点
|
||||
switch req.Provider {
|
||||
case "OpenAI":
|
||||
baseURL = baseURL + "/v1/chat/completions"
|
||||
case "Ollama":
|
||||
baseURL = baseURL + "/api/chat"
|
||||
}
|
||||
}
|
||||
|
||||
// 构建请求体
|
||||
requestBody := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "Hello"},
|
||||
},
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if req.APIKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil
|
||||
}
|
||||
|
||||
return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user