Files
X-Agents/server/internal/service/model_service.go
DESKTOP-72TV0V4\caoxiaozhu e5ea4ff359 feat: 更新数据库和后端服务
- 新增chat_sessions和chat_groups数据库表
- 更新skill_handler和model相关接口
- 修改main.go注册新路由

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 14:33:54 +08:00

223 lines
5.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
"github.com/google/uuid"
)
// ModelService 模型服务
type ModelService struct {
repo *repository.ModelRepository
}
func NewModelService(repo *repository.ModelRepository) *ModelService {
return &ModelService{repo: repo}
}
// List 获取模型列表
func (s *ModelService) List() ([]model.ModelInfo, error) {
return s.repo.FindAll()
}
// GetByID 根据 ID 获取模型
func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) {
return s.repo.FindByID(id)
}
// Create 创建模型
func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) {
// 检查模型名称是否已存在
existing, err := s.repo.FindByName(req.Name)
if err == nil && existing != nil {
return nil, fmt.Errorf("model with name '%s' already exists", req.Name)
}
// 如果没有提供状态,默认设置为 inactive (0)
status := req.Status
if status == 0 {
status = 0 // inactive
}
info := &model.ModelInfo{
ID: uuid.New().String(),
Name: req.Name,
ModelType: req.ModelType,
Provider: req.Provider,
Model: req.Model,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
APIEndpoint: req.APIEndpoint,
Status: status,
}
if err := s.repo.Create(info); err != nil {
return nil, err
}
return info, nil
}
// Update 更新模型
func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) {
// 检查是否存在
_, err := s.repo.FindByID(id)
if err != nil {
return nil, fmt.Errorf("model not found")
}
// 构建更新字段
fields := make(map[string]interface{})
if req.Name != "" {
fields["name"] = req.Name
}
if req.ModelType != "" {
fields["model_type"] = req.ModelType
}
if req.Provider != "" {
fields["provider"] = req.Provider
}
if req.Model != "" {
fields["model"] = req.Model
}
if req.APIKey != "" {
fields["api_key"] = req.APIKey
}
if req.BaseURL != "" {
fields["base_url"] = req.BaseURL
}
if req.APIEndpoint != "" {
fields["api_endpoint"] = req.APIEndpoint
}
// Status为int类型0表示inactive1表示active
// 只在明确传入Status值时才更新Status > 0 表示传入了值)
if req.Status > 0 {
fields["status"] = req.Status
}
if err := s.repo.UpdateFields(id, fields); err != nil {
return nil, err
}
return s.repo.FindByID(id)
}
// Delete 删除模型
func (s *ModelService) Delete(id string) error {
// 检查是否存在
_, err := s.repo.FindByID(id)
if err != nil {
return fmt.Errorf("model not found")
}
return s.repo.Delete(id)
}
// TestConnection 测试模型连接
func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) {
// 构建请求 URL
baseURL := req.BaseURL
// 去掉 base_url 末尾的斜杠
baseURL = strings.TrimRight(baseURL, "/")
if req.APIEndpoint != "" {
// 去掉 api_endpoint 开头的斜杠
apiEndpoint := strings.TrimLeft(req.APIEndpoint, "/")
baseURL = baseURL + "/" + apiEndpoint
} else {
// 根据 model_type 确定端点
switch req.ModelType {
case "embedding":
// embedding 模型使用 /v1/embeddings
switch req.Provider {
case "Ollama":
baseURL = baseURL + "/api/embeddings"
default:
baseURL = baseURL + "/v1/embeddings"
}
default:
// chat 模型使用 /chat/completions
switch req.Provider {
case "OpenAI":
baseURL = baseURL + "/chat/completions"
case "Ollama":
baseURL = baseURL + "/api/chat"
case "ali", "Ali", "aliyun", "Aliyun":
// 阿里云 DashScope 兼容 OpenAI 格式
baseURL = baseURL + "/chat/completions"
default:
// 默认使用 /chat/completions
baseURL = baseURL + "/chat/completions"
}
}
}
// 构建请求体 - 根据 model_type 使用不同的格式
var requestBody map[string]interface{}
if req.ModelType == "embedding" {
requestBody = map[string]interface{}{
"model": req.Model,
"input": "Hello",
"format": "float",
}
} else {
requestBody = map[string]interface{}{
"model": req.Model,
"messages": []map[string]string{
{"role": "user", "content": "Hello"},
},
"max_tokens": 10,
}
}
body, err := json.Marshal(requestBody)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body))
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
httpReq.Header.Set("Content-Type", "application/json")
if req.APIKey != "" {
// 根据不同 provider 使用不同的认证方式
switch req.Provider {
case "ali", "Ali", "aliyun", "Aliyun":
// 阿里云使用 x-api-key 头部
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
httpReq.Header.Set("x-api-key", req.APIKey)
default:
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
}
}
// 发送请求,设置 10 秒超时
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
defer resp.Body.Close()
// 读取响应
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil
}
return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil
}