From aabdf81073c4e35335185079483413c709995a05 Mon Sep 17 00:00:00 2001 From: "DESKTOP-72TV0V4\\caoxiaozhu" Date: Sat, 7 Mar 2026 13:53:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9EModel=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 Model 实体定义 - 实现 Model CRUD 接口 - 添加 Model 仓储层和服务层 Co-Authored-By: Claude Opus 4.6 --- server/internal/handler/model_handler.go | 108 +++++++++++++++ server/internal/model/model_info.go | 69 ++++++++++ server/internal/repository/model_repo.go | 53 ++++++++ server/internal/service/model_service.go | 166 +++++++++++++++++++++++ 4 files changed, 396 insertions(+) create mode 100644 server/internal/handler/model_handler.go create mode 100644 server/internal/model/model_info.go create mode 100644 server/internal/repository/model_repo.go create mode 100644 server/internal/service/model_service.go diff --git a/server/internal/handler/model_handler.go b/server/internal/handler/model_handler.go new file mode 100644 index 0000000..2f91f74 --- /dev/null +++ b/server/internal/handler/model_handler.go @@ -0,0 +1,108 @@ +package handler + +import ( + "net/http" + "x-agents/server/internal/model" + "x-agents/server/internal/service" + + "github.com/gin-gonic/gin" +) + +// ModelHandler 模型处理器 +type ModelHandler struct { + service *service.ModelService +} + +func NewModelHandler(svc *service.ModelService) *ModelHandler { + return &ModelHandler{service: svc} +} + +// List 获取列表 +func (h *ModelHandler) List(c *gin.Context) { + list, err := h.service.List() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if list == nil { + list = []model.ModelInfo{} + } + + c.JSON(http.StatusOK, gin.H{"list": list}) +} + +// GetByID 获取详情 +func (h *ModelHandler) GetByID(c *gin.Context) { + id := c.Param("id") + model, err := h.service.GetByID(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Model not found"}) + return + } + c.JSON(http.StatusOK, model) +} + +// Create 创建 +func (h *ModelHandler) Create(c *gin.Context) { + var req model.CreateModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result, err := h.service.Create(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// Update 更新 +func (h *ModelHandler) Update(c *gin.Context) { + id := c.Param("id") + var req model.UpdateModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result, err := h.service.Update(id, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// Delete 删除 +func (h *ModelHandler) Delete(c *gin.Context) { + id := c.Param("id") + err := h.service.Delete(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// Test 测试连接 +func (h *ModelHandler) Test(c *gin.Context) { + var req model.TestModelRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result, err := h.service.TestConnection(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} diff --git a/server/internal/model/model_info.go b/server/internal/model/model_info.go new file mode 100644 index 0000000..2845b65 --- /dev/null +++ b/server/internal/model/model_info.go @@ -0,0 +1,69 @@ +package model + +import "time" + +// ModelInfo 模型信息 +type ModelInfo struct { + ID string `json:"id" gorm:"primaryKey;type:varchar(36)"` + Name string `json:"name" gorm:"type:varchar(255);not null"` + ModelType string `json:"model_type" gorm:"type:varchar(50);not null"` // chat/embedding/rerank/vlm + Provider string `json:"provider" gorm:"type:varchar(50);not null"` // OpenAI/Ollama + Model string `json:"model" gorm:"type:varchar(255);not null"` // 模型标识 + APIKey string `json:"api_key" gorm:"type:text"` // API 密钥 + BaseURL string `json:"base_url" gorm:"type:varchar(500)"` // 基础 URL + APIEndpoint string `json:"api_endpoint" gorm:"type:varchar(500)"` // API 端点路径 + Status string `json:"status" gorm:"type:varchar(20);default:active"` // active/inactive + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (ModelInfo) TableName() string { + return "model_info" +} + +// ModelListRequest 获取模型列表请求 +type ModelListRequest struct { +} + +// ModelListResponse 获取模型列表响应 +type ModelListResponse struct { + List []ModelInfo `json:"list"` +} + +// CreateModelRequest 创建模型请求 +type CreateModelRequest struct { + Name string `json:"name" binding:"required"` + ModelType string `json:"model_type" binding:"required"` + Provider string `json:"provider" binding:"required"` + Model string `json:"model" binding:"required"` + APIKey string `json:"api_key" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + APIEndpoint string `json:"api_endpoint"` +} + +// UpdateModelRequest 更新模型请求 +type UpdateModelRequest struct { + Name string `json:"name"` + ModelType string `json:"model_type"` + Provider string `json:"provider"` + Model string `json:"model"` + APIKey string `json:"api_key"` + BaseURL string `json:"base_url"` + APIEndpoint string `json:"api_endpoint"` + Status string `json:"status"` +} + +// TestModelRequest 测试模型连接请求 +type TestModelRequest struct { + Provider string `json:"provider" binding:"required"` + Model string `json:"model" binding:"required"` + APIKey string `json:"api_key" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + APIEndpoint string `json:"api_endpoint"` +} + +// TestModelResponse 测试模型连接响应 +type TestModelResponse struct { + Success bool `json:"success"` + Message string `json:"message"` +} diff --git a/server/internal/repository/model_repo.go b/server/internal/repository/model_repo.go new file mode 100644 index 0000000..0840bc6 --- /dev/null +++ b/server/internal/repository/model_repo.go @@ -0,0 +1,53 @@ +package repository + +import ( + "x-agents/server/internal/model" + + "gorm.io/gorm" +) + +// ModelRepository 模型仓储 +type ModelRepository struct { + db *gorm.DB +} + +func NewModelRepository(db *gorm.DB) *ModelRepository { + return &ModelRepository{db: db} +} + +// FindAll 获取所有模型 +func (r *ModelRepository) FindAll() ([]model.ModelInfo, error) { + var models []model.ModelInfo + err := r.db.Order("created_at desc").Find(&models).Error + return models, err +} + +// FindByID 根据 ID 获取模型 +func (r *ModelRepository) FindByID(id string) (*model.ModelInfo, error) { + var model model.ModelInfo + err := r.db.Where("id = ?", id).First(&model).Error + if err != nil { + return nil, err + } + return &model, nil +} + +// Create 创建模型 +func (r *ModelRepository) Create(info *model.ModelInfo) error { + return r.db.Create(info).Error +} + +// Update 更新模型 +func (r *ModelRepository) Update(id string, info *model.ModelInfo) error { + return r.db.Where("id = ?", id).Updates(info).Error +} + +// Delete 删除模型 +func (r *ModelRepository) Delete(id string) error { + return r.db.Where("id = ?", id).Delete(&model.ModelInfo{}).Error +} + +// UpdateFields 更新指定字段 +func (r *ModelRepository) UpdateFields(id string, fields map[string]interface{}) error { + return r.db.Model(&model.ModelInfo{}).Where("id = ?", id).Updates(fields).Error +} diff --git a/server/internal/service/model_service.go b/server/internal/service/model_service.go new file mode 100644 index 0000000..e1634fd --- /dev/null +++ b/server/internal/service/model_service.go @@ -0,0 +1,166 @@ +package service + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "x-agents/server/internal/model" + "x-agents/server/internal/repository" + + "github.com/google/uuid" +) + +// ModelService 模型服务 +type ModelService struct { + repo *repository.ModelRepository +} + +func NewModelService(repo *repository.ModelRepository) *ModelService { + return &ModelService{repo: repo} +} + +// List 获取模型列表 +func (s *ModelService) List() ([]model.ModelInfo, error) { + return s.repo.FindAll() +} + +// GetByID 根据 ID 获取模型 +func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) { + return s.repo.FindByID(id) +} + +// Create 创建模型 +func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) { + info := &model.ModelInfo{ + ID: uuid.New().String(), + Name: req.Name, + ModelType: req.ModelType, + Provider: req.Provider, + Model: req.Model, + APIKey: req.APIKey, + BaseURL: req.BaseURL, + APIEndpoint: req.APIEndpoint, + Status: "active", + } + + if err := s.repo.Create(info); err != nil { + return nil, err + } + return info, nil +} + +// Update 更新模型 +func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) { + // 检查是否存在 + _, err := s.repo.FindByID(id) + if err != nil { + return nil, fmt.Errorf("model not found") + } + + // 构建更新字段 + fields := make(map[string]interface{}) + if req.Name != "" { + fields["name"] = req.Name + } + if req.ModelType != "" { + fields["model_type"] = req.ModelType + } + if req.Provider != "" { + fields["provider"] = req.Provider + } + if req.Model != "" { + fields["model"] = req.Model + } + if req.APIKey != "" { + fields["api_key"] = req.APIKey + } + if req.BaseURL != "" { + fields["base_url"] = req.BaseURL + } + if req.APIEndpoint != "" { + fields["api_endpoint"] = req.APIEndpoint + } + if req.Status != "" { + fields["status"] = req.Status + } + + if err := s.repo.UpdateFields(id, fields); err != nil { + return nil, err + } + + return s.repo.FindByID(id) +} + +// Delete 删除模型 +func (s *ModelService) Delete(id string) error { + // 检查是否存在 + _, err := s.repo.FindByID(id) + if err != nil { + return fmt.Errorf("model not found") + } + return s.repo.Delete(id) +} + +// TestConnection 测试模型连接 +func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) { + // 构建请求 URL + baseURL := req.BaseURL + if req.APIEndpoint != "" { + baseURL = baseURL + req.APIEndpoint + } else { + // 默认端点 + switch req.Provider { + case "OpenAI": + baseURL = baseURL + "/v1/chat/completions" + case "Ollama": + baseURL = baseURL + "/api/chat" + } + } + + // 构建请求体 + requestBody := map[string]interface{}{ + "model": req.Model, + "messages": []map[string]string{ + {"role": "user", "content": "Hello"}, + }, + "max_tokens": 10, + } + + body, err := json.Marshal(requestBody) + if err != nil { + return &model.TestModelResponse{Success: false, Message: err.Error()}, nil + } + + // 创建 HTTP 请求 + httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body)) + if err != nil { + return &model.TestModelResponse{Success: false, Message: err.Error()}, nil + } + + httpReq.Header.Set("Content-Type", "application/json") + if req.APIKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) + } + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + return &model.TestModelResponse{Success: false, Message: err.Error()}, nil + } + defer resp.Body.Close() + + // 读取响应 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return &model.TestModelResponse{Success: false, Message: err.Error()}, nil + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil + } + + return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil +}