package service import ( "bytes" "encoding/json" "fmt" "io" "log" "net/http" "strings" "time" "x-agents/server/internal/model" "x-agents/server/internal/repository" "github.com/google/uuid" ) // ModelService 模型服务 type ModelService struct { repo *repository.ModelRepository } func NewModelService(repo *repository.ModelRepository) *ModelService { return &ModelService{repo: repo} } // List 获取模型列表 func (s *ModelService) List() ([]model.ModelInfo, error) { return s.repo.FindAll() } // GetByID 根据 ID 获取模型 func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) { return s.repo.FindByID(id) } // Create 创建模型 func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) { // 如果没有提供状态,默认设置为 inactive status := req.Status if status == "" { status = "inactive" } info := &model.ModelInfo{ ID: uuid.New().String(), Name: req.Name, ModelType: req.ModelType, Provider: req.Provider, Model: req.Model, APIKey: req.APIKey, BaseURL: req.BaseURL, APIEndpoint: req.APIEndpoint, Status: status, } if err := s.repo.Create(info); err != nil { return nil, err } return info, nil } // Update 更新模型 func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) { // 检查是否存在 _, err := s.repo.FindByID(id) if err != nil { return nil, fmt.Errorf("model not found") } // 构建更新字段 fields := make(map[string]interface{}) if req.Name != "" { fields["name"] = req.Name } if req.ModelType != "" { fields["model_type"] = req.ModelType } if req.Provider != "" { fields["provider"] = req.Provider } if req.Model != "" { fields["model"] = req.Model } if req.APIKey != "" { fields["api_key"] = req.APIKey } if req.BaseURL != "" { fields["base_url"] = req.BaseURL } if req.APIEndpoint != "" { fields["api_endpoint"] = req.APIEndpoint } if req.Status != "" { fields["status"] = req.Status } if err := s.repo.UpdateFields(id, fields); err != nil { return nil, err } return s.repo.FindByID(id) } // Delete 删除模型 func (s *ModelService) Delete(id string) error { // 检查是否存在 _, err := s.repo.FindByID(id) if err != nil { return fmt.Errorf("model not found") } return s.repo.Delete(id) } // TestConnection 测试模型连接 func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) { // 构建请求 URL baseURL := req.BaseURL // 去掉 base_url 末尾的斜杠 baseURL = strings.TrimRight(baseURL, "/") if req.APIEndpoint != "" { // 去掉 api_endpoint 开头的斜杠 apiEndpoint := strings.TrimLeft(req.APIEndpoint, "/") baseURL = baseURL + "/" + apiEndpoint } else { // 根据 model_type 确定端点 switch req.ModelType { case "embedding": // embedding 模型使用 /v1/embeddings switch req.Provider { case "Ollama": baseURL = baseURL + "/api/embeddings" default: baseURL = baseURL + "/v1/embeddings" } default: // chat 模型使用 /chat/completions switch req.Provider { case "OpenAI": baseURL = baseURL + "/v1/chat/completions" case "Ollama": baseURL = baseURL + "/api/chat" case "ali", "Ali", "aliyun", "Aliyun": // 阿里云 DashScope 兼容 OpenAI 格式 baseURL = baseURL + "/chat/completions" default: // 默认使用 OpenAI 兼容格式 baseURL = baseURL + "/v1/chat/completions" } } } // 构建请求体 - 根据 model_type 使用不同的格式 var requestBody map[string]interface{} if req.ModelType == "embedding" { requestBody = map[string]interface{}{ "model": req.Model, "input": "Hello", "format": "float", } } else { requestBody = map[string]interface{}{ "model": req.Model, "messages": []map[string]string{ {"role": "user", "content": "Hello"}, }, "max_tokens": 10, } } body, err := json.Marshal(requestBody) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } // 创建 HTTP 请求 httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body)) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } httpReq.Header.Set("Content-Type", "application/json") if req.APIKey != "" { // 根据不同 provider 使用不同的认证方式 switch req.Provider { case "ali", "Ali", "aliyun", "Aliyun": // 阿里云使用 x-api-key 头部 httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) httpReq.Header.Set("x-api-key", req.APIKey) default: httpReq.Header.Set("Authorization", "Bearer "+req.APIKey) } } // 发送请求,设置 10 秒超时 client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(httpReq) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } defer resp.Body.Close() // 读取响应 respBody, err := io.ReadAll(resp.Body) if err != nil { return &model.TestModelResponse{Success: false, Message: err.Error()}, nil } if resp.StatusCode >= 200 && resp.StatusCode < 300 { return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil } return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil }