- 新增chat_sessions和chat_groups数据库表 - 更新skill_handler和model相关接口 - 修改main.go注册新路由 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
223 lines
5.6 KiB
Go
223 lines
5.6 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
"x-agents/server/internal/model"
|
||
"x-agents/server/internal/repository"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
// ModelService 模型服务
|
||
type ModelService struct {
|
||
repo *repository.ModelRepository
|
||
}
|
||
|
||
func NewModelService(repo *repository.ModelRepository) *ModelService {
|
||
return &ModelService{repo: repo}
|
||
}
|
||
|
||
// List 获取模型列表
|
||
func (s *ModelService) List() ([]model.ModelInfo, error) {
|
||
return s.repo.FindAll()
|
||
}
|
||
|
||
// GetByID 根据 ID 获取模型
|
||
func (s *ModelService) GetByID(id string) (*model.ModelInfo, error) {
|
||
return s.repo.FindByID(id)
|
||
}
|
||
|
||
// Create 创建模型
|
||
func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, error) {
|
||
// 检查模型名称是否已存在
|
||
existing, err := s.repo.FindByName(req.Name)
|
||
if err == nil && existing != nil {
|
||
return nil, fmt.Errorf("model with name '%s' already exists", req.Name)
|
||
}
|
||
|
||
// 如果没有提供状态,默认设置为 inactive (0)
|
||
status := req.Status
|
||
if status == 0 {
|
||
status = 0 // inactive
|
||
}
|
||
|
||
info := &model.ModelInfo{
|
||
ID: uuid.New().String(),
|
||
Name: req.Name,
|
||
ModelType: req.ModelType,
|
||
Provider: req.Provider,
|
||
Model: req.Model,
|
||
APIKey: req.APIKey,
|
||
BaseURL: req.BaseURL,
|
||
APIEndpoint: req.APIEndpoint,
|
||
Status: status,
|
||
}
|
||
|
||
if err := s.repo.Create(info); err != nil {
|
||
return nil, err
|
||
}
|
||
return info, nil
|
||
}
|
||
|
||
// Update 更新模型
|
||
func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.ModelInfo, error) {
|
||
// 检查是否存在
|
||
_, err := s.repo.FindByID(id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("model not found")
|
||
}
|
||
|
||
// 构建更新字段
|
||
fields := make(map[string]interface{})
|
||
if req.Name != "" {
|
||
fields["name"] = req.Name
|
||
}
|
||
if req.ModelType != "" {
|
||
fields["model_type"] = req.ModelType
|
||
}
|
||
if req.Provider != "" {
|
||
fields["provider"] = req.Provider
|
||
}
|
||
if req.Model != "" {
|
||
fields["model"] = req.Model
|
||
}
|
||
if req.APIKey != "" {
|
||
fields["api_key"] = req.APIKey
|
||
}
|
||
if req.BaseURL != "" {
|
||
fields["base_url"] = req.BaseURL
|
||
}
|
||
if req.APIEndpoint != "" {
|
||
fields["api_endpoint"] = req.APIEndpoint
|
||
}
|
||
// Status为int类型,0表示inactive,1表示active
|
||
// 只在明确传入Status值时才更新(Status > 0 表示传入了值)
|
||
if req.Status > 0 {
|
||
fields["status"] = req.Status
|
||
}
|
||
|
||
if err := s.repo.UpdateFields(id, fields); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return s.repo.FindByID(id)
|
||
}
|
||
|
||
// Delete 删除模型
|
||
func (s *ModelService) Delete(id string) error {
|
||
// 检查是否存在
|
||
_, err := s.repo.FindByID(id)
|
||
if err != nil {
|
||
return fmt.Errorf("model not found")
|
||
}
|
||
return s.repo.Delete(id)
|
||
}
|
||
|
||
// TestConnection 测试模型连接
|
||
func (s *ModelService) TestConnection(req model.TestModelRequest) (*model.TestModelResponse, error) {
|
||
// 构建请求 URL
|
||
baseURL := req.BaseURL
|
||
// 去掉 base_url 末尾的斜杠
|
||
baseURL = strings.TrimRight(baseURL, "/")
|
||
|
||
if req.APIEndpoint != "" {
|
||
// 去掉 api_endpoint 开头的斜杠
|
||
apiEndpoint := strings.TrimLeft(req.APIEndpoint, "/")
|
||
baseURL = baseURL + "/" + apiEndpoint
|
||
} else {
|
||
// 根据 model_type 确定端点
|
||
switch req.ModelType {
|
||
case "embedding":
|
||
// embedding 模型使用 /v1/embeddings
|
||
switch req.Provider {
|
||
case "Ollama":
|
||
baseURL = baseURL + "/api/embeddings"
|
||
default:
|
||
baseURL = baseURL + "/v1/embeddings"
|
||
}
|
||
default:
|
||
// chat 模型使用 /chat/completions
|
||
switch req.Provider {
|
||
case "OpenAI":
|
||
baseURL = baseURL + "/chat/completions"
|
||
case "Ollama":
|
||
baseURL = baseURL + "/api/chat"
|
||
case "ali", "Ali", "aliyun", "Aliyun":
|
||
// 阿里云 DashScope 兼容 OpenAI 格式
|
||
baseURL = baseURL + "/chat/completions"
|
||
default:
|
||
// 默认使用 /chat/completions
|
||
baseURL = baseURL + "/chat/completions"
|
||
}
|
||
}
|
||
}
|
||
|
||
// 构建请求体 - 根据 model_type 使用不同的格式
|
||
var requestBody map[string]interface{}
|
||
if req.ModelType == "embedding" {
|
||
requestBody = map[string]interface{}{
|
||
"model": req.Model,
|
||
"input": "Hello",
|
||
"format": "float",
|
||
}
|
||
} else {
|
||
requestBody = map[string]interface{}{
|
||
"model": req.Model,
|
||
"messages": []map[string]string{
|
||
{"role": "user", "content": "Hello"},
|
||
},
|
||
"max_tokens": 10,
|
||
}
|
||
}
|
||
|
||
body, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||
}
|
||
|
||
// 创建 HTTP 请求
|
||
httpReq, err := http.NewRequest("POST", baseURL, bytes.NewBuffer(body))
|
||
if err != nil {
|
||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
if req.APIKey != "" {
|
||
// 根据不同 provider 使用不同的认证方式
|
||
switch req.Provider {
|
||
case "ali", "Ali", "aliyun", "Aliyun":
|
||
// 阿里云使用 x-api-key 头部
|
||
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
|
||
httpReq.Header.Set("x-api-key", req.APIKey)
|
||
default:
|
||
httpReq.Header.Set("Authorization", "Bearer "+req.APIKey)
|
||
}
|
||
}
|
||
|
||
// 发送请求,设置 10 秒超时
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return &model.TestModelResponse{Success: false, Message: err.Error()}, nil
|
||
}
|
||
|
||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||
return &model.TestModelResponse{Success: true, Message: "Connection successful"}, nil
|
||
}
|
||
|
||
return &model.TestModelResponse{Success: false, Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, string(respBody))}, nil
|
||
}
|