package service import ( "bytes" "encoding/json" "fmt" "io" "net/http" "x-agents/server/internal/model" "x-agents/server/internal/repository" "github.com/google/uuid" ) // ModelService 模型服务 type ModelService struct { repo *repository.ModelRepository } func NewModelService(repo *repository.ModelRepository) *ModelService { return &ModelService{repo: repo} } // List 获取模型列表 func (s *ModelService) List() ([]model.ModelInfo, error) { return s.repo.FindAll() } // GetByID 根据 ID 获取模型 func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) { return s.repo.FindByID(id) } // Create 创建模型 func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) { info := &model.ModelInfo{ ID: uuid.New().String(), Name: req.Name, ModelType: req.ModelType, Provider: req.Provider, Model: req.Model, APIKey: req.APIKey, BaseURL: req.BaseURL, APIEndpoint: req.APIEndpoint, Status: "active", } if err := s.repo.Create(info); err != nil { return nil, err } return info, nil } // Update 更新模型 func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) { // 检查是否存在 _, err := s.repo.FindByID(id) if err != nil { return nil, fmt.Errorf("model not found") } // 构建更新字段 fields := make(map[string]interface{}) if req.Name != "" { fields["name"] = req.Name } if req.ModelType != "" { fields["model_type"] = req.ModelType } if req.Provider != "" { fields["provider"] = req.Provider } if req.Model != "" { fields["model"] = req.Model } if req.APIKey != "" { fields["api_key"] = req.APIKey } if req.BaseURL != "" { fields["base_url"] = req.BaseURL } if req.APIEndpoint != "" { fields["api_endpoint"] = req.APIEndpoint } if req.Status != "" { fields["status"] = req.Status } if err := s.repo.UpdateFields(id, fields); err != nil { return nil, err } return s.repo.FindByID(id) } // Delete 删除模型 func (s *ModelService) Delete(id string) error { // 检查是否存在 _, err := s.repo.FindByID(id) if err != nil { return fmt.Errorf("model not found") } return s.repo.Delete(id) } // TestConnection 测试模型连接 func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) { // 构建请求 URL baseURL := req.BaseURL if req.APIEndpoint != "" { baseURL = baseURL + req.APIEndpoint } else { // 默认端点 switch req.Provider { case "OpenAI": baseURL = baseURL + "/v1/chat/completions" case "Ollama": baseURL = baseURL + "/api/chat" } } // 构建请求体 requestBody := map[string]interface{}{ "model": req.Model, "messages": []map[string]string{ {"role": "user", "content": "Hello"}, }, "max_tokens": 10, } body, err := json.Marshal(requestBody) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } // 创建 HTTP 请求 httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body)) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } httpReq.Header.Set("Content-Type", "application/json") if req.APIKey != "" { httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) } // 发送请求 client := &http.Client{} resp, err := client.Do(httpReq) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } defer resp.Body.Close() // 读取响应 respBody, err := io.ReadAll(resp.Body) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } if resp.StatusCode >= 200 && resp.StatusCode < 300 { return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil } return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil }