Files
X-Agents/server/internal/service/model_service.go

202 lines
5.2 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
"github.com/google/uuid"
)
// ModelService 模型服务
type ModelService struct {
repo *repository.ModelRepository
}
func NewModelService(repo *repository.ModelRepository) *ModelService {
return &ModelService{repo: repo}
}
// List 获取模型列表
func (s *ModelService) List() ([]model.ModelInfo, error) {
return s.repo.FindAll()
}
// GetByID 根据 ID 获取模型
func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) {
return s.repo.FindByID(id)
}
// Create 创建模型
func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) {
// 如果没有提供状态,默认设置为 inactive
status := req.Status
if status == "" {
status = "inactive"
}
info := &model.ModelInfo{
ID: uuid.New().String(),
Name: req.Name,
ModelType: req.ModelType,
Provider: req.Provider,
Model: req.Model,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
APIEndpoint: req.APIEndpoint,
Status: status,
}
if err := s.repo.Create(info); err != nil {
return nil, err
}
return info, nil
}
// Update 更新模型
func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) {
// 检查是否存在
_, err := s.repo.FindByID(id)
if err != nil {
return nil, fmt.Errorf("model not found")
}
// 构建更新字段
fields := make(map[string]interface{})
if req.Name != "" {
fields["name"] = req.Name
}
if req.ModelType != "" {
fields["model_type"] = req.ModelType
}
if req.Provider != "" {
fields["provider"] = req.Provider
}
if req.Model != "" {
fields["model"] = req.Model
}
if req.APIKey != "" {
fields["api_key"] = req.APIKey
}
if req.BaseURL != "" {
fields["base_url"] = req.BaseURL
}
if req.APIEndpoint != "" {
fields["api_endpoint"] = req.APIEndpoint
}
if req.Status != "" {
fields["status"] = req.Status
}
if err := s.repo.UpdateFields(id, fields); err != nil {
return nil, err
}
return s.repo.FindByID(id)
}
// Delete 删除模型
func (s *ModelService) Delete(id string) error {
// 检查是否存在
_, err := s.repo.FindByID(id)
if err != nil {
return fmt.Errorf("model not found")
}
return s.repo.Delete(id)
}
// TestConnection 测试模型连接
func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) {
log.Printf("[TestConnection] 开始测试连接: provider=%s, model=%s, base_url=%s, api_endpoint=%s", req.Provider, req.Model, req.BaseURL, req.APIEndpoint)
// 构建请求 URL
baseURL := req.BaseURL
// 去掉 base_url 末尾的斜杠
baseURL = strings.TrimRight(baseURL, "/")
if req.APIEndpoint != "" {
// 去掉 api_endpoint 开头的斜杠
apiEndpoint := strings.TrimLeft(req.APIEndpoint, "/")
baseURL = baseURL + "/" + apiEndpoint
} else {
// 默认端点 - 根据不同 provider 设置
switch req.Provider {
case "OpenAI":
baseURL = baseURL + "/v1/chat/completions"
case "Ollama":
baseURL = baseURL + "/api/chat"
case "ali", "Ali", "aliyun", "Aliyun":
// 阿里云 DashScope 兼容 OpenAI 格式,需要添加 /chat/completions
// base_url 格式: https://dashscope.aliyuncs.com/compatible-mode/v1
baseURL = baseURL + "/chat/completions"
default:
// 默认使用 OpenAI 兼容格式
baseURL = baseURL + "/v1/chat/completions"
}
}
log.Printf("[TestConnection] 请求 URL: %s", baseURL)
// 构建请求体
requestBody := map[string]interface{}{
"model": req.Model,
"messages": []map[string]string{
{"role": "user", "content": "Hello"},
},
"max_tokens": 10,
}
body, err := json.Marshal(requestBody)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body))
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
httpReq.Header.Set("Content-Type", "application/json")
if req.APIKey != "" {
// 根据不同 provider 使用不同的认证方式
switch req.Provider {
case "ali", "Ali", "aliyun", "Aliyun":
// 阿里云使用 x-api-key 头部
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
httpReq.Header.Set("x-api-key", req.APIKey)
default:
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
}
}
// 发送请求,设置 10 秒超时
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
defer resp.Body.Close()
// 读取响应
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
log.Printf("[TestConnection] 响应状态码: %d, 响应体: %s", resp.StatusCode, string(respBody))
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil
}
return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil
}