- 更新 agent handler 和 service 层 - 新增 chat_group handler 和 service - 删除废弃的 chat_handler - 更新 tool 相关处理 - 更新 API 文档和依赖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
266 lines
6.4 KiB
Go
266 lines
6.4 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"x-agents/server/internal/model"
|
||
"x-agents/server/internal/repository"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
// 错误定义
|
||
var (
|
||
ErrNoAgents = errors.New("no agents provided")
|
||
ErrAgentNotFound = errors.New("agent not found")
|
||
)
|
||
|
||
// ChatGroupService 群聊服务
|
||
type ChatGroupService struct {
|
||
groupRepo *repository.ChatGroupRepository
|
||
agentRepo *repository.AgentRepository
|
||
pythonURL string
|
||
client *http.Client
|
||
}
|
||
|
||
// NewChatGroupService 创建群聊服务
|
||
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository, pythonURL string) *ChatGroupService {
|
||
return &ChatGroupService{
|
||
groupRepo: groupRepo,
|
||
agentRepo: agentRepo,
|
||
pythonURL: pythonURL,
|
||
client: &http.Client{
|
||
Timeout: 120 * time.Second,
|
||
},
|
||
}
|
||
}
|
||
|
||
// CreateGroup 创建群聊
|
||
func (s *ChatGroupService) CreateGroup(req model.CreateGroupRequest) (*model.ChatGroup, error) {
|
||
group := &model.ChatGroup{
|
||
ID: uuid.New().String(),
|
||
UserID: req.UserID,
|
||
Name: req.Name,
|
||
Description: req.Description,
|
||
AgentIDs: req.AgentIDs,
|
||
Status: "active",
|
||
}
|
||
|
||
err := s.groupRepo.Create(group)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return group, nil
|
||
}
|
||
|
||
// GetGroup 获取群聊详情
|
||
func (s *ChatGroupService) GetGroup(id string) (*model.ChatGroup, error) {
|
||
return s.groupRepo.FindByID(id)
|
||
}
|
||
|
||
// ListGroups 获取群聊列表
|
||
func (s *ChatGroupService) ListGroups(userID string, limit, offset int) ([]model.ChatGroup, int64, error) {
|
||
if limit <= 0 {
|
||
limit = 20
|
||
}
|
||
return s.groupRepo.FindByUserID(userID, limit, offset)
|
||
}
|
||
|
||
// UpdateGroup 更新群聊
|
||
func (s *ChatGroupService) UpdateGroup(id string, req model.UpdateGroupRequest) (*model.ChatGroup, error) {
|
||
group, err := s.groupRepo.FindByID(id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if req.Name != "" {
|
||
group.Name = req.Name
|
||
}
|
||
if req.Description != "" {
|
||
group.Description = req.Description
|
||
}
|
||
if req.AgentIDs != "" {
|
||
group.AgentIDs = req.AgentIDs
|
||
}
|
||
if req.Status != "" {
|
||
group.Status = req.Status
|
||
}
|
||
|
||
err = s.groupRepo.Update(group)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return group, nil
|
||
}
|
||
|
||
// DeleteGroup 删除群聊
|
||
func (s *ChatGroupService) DeleteGroup(id string) error {
|
||
return s.groupRepo.Delete(id)
|
||
}
|
||
|
||
// GroupChat 群聊对话
|
||
func (s *ChatGroupService) GroupChat(userID, message, agentIDs, sessionID string) (*model.GroupChatResponse, error) {
|
||
// 解析 Agent IDs
|
||
agentIDList := parseAgentIDs(agentIDs)
|
||
|
||
if len(agentIDList) == 0 {
|
||
return nil, ErrNoAgents
|
||
}
|
||
|
||
// 获取所有 Agent 信息
|
||
agents, err := s.agentRepo.FindByIDs(agentIDList)
|
||
if err != nil || len(agents) == 0 {
|
||
return nil, ErrAgentNotFound
|
||
}
|
||
|
||
// 解析 userID 为整数
|
||
userIDInt, err := strconv.Atoi(userID)
|
||
if err != nil {
|
||
userIDInt = 1 // 默认值
|
||
}
|
||
|
||
// 将 agent UUIDs 转换为整数 IDs
|
||
memberAgentIDs := make([]int, len(agents))
|
||
for i := range agents {
|
||
// 使用索引+1 作为 agent ID(因为 Python 端使用整数 ID)
|
||
memberAgentIDs[i] = i + 1
|
||
log.Printf("[ChatGroupService] Agent: %s (ID: %s)", agents[i].Name, agents[i].ID)
|
||
}
|
||
|
||
// 调用 Python TeamAgent 进行群聊
|
||
teamReq := TeamChatRequest{
|
||
SupervisorAgentID: 0, // 没有 supervisor,使用并行策略
|
||
MemberAgentIDs: memberAgentIDs,
|
||
Message: message,
|
||
UserID: userIDInt,
|
||
SessionID: sessionID,
|
||
Strategy: "parallel",
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
|
||
jsonData, err := json.Marshal(teamReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||
}
|
||
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to call python team agent: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("python team agent error: %s", string(body))
|
||
}
|
||
|
||
var teamResp TeamChatResponse
|
||
if err := json.Unmarshal(body, &teamResp); err != nil {
|
||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||
}
|
||
|
||
log.Printf("[ChatGroupService] Raw teamResp: %+v", teamResp)
|
||
log.Printf("[ChatGroupService] SubtaskResults count: %d", len(teamResp.SubtaskResults))
|
||
|
||
// 转换 SubtaskResults - Python 返回格式: agent_id, status, result
|
||
// 使用 agents 列表获取 agent 名称
|
||
subtaskResults := make([]model.SubtaskResult, 0)
|
||
for i, sr := range teamResp.SubtaskResults {
|
||
srMap, ok := sr.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 获取 agent 名称(从已查询的 agents 列表)
|
||
agentName := ""
|
||
if i < len(agents) {
|
||
agentName = agents[i].Name
|
||
}
|
||
|
||
// 处理 agent_id (Python 返回 int,Go 需要 string)
|
||
agentIDVal := srMap["agent_id"]
|
||
var agentIDStr string
|
||
switch v := agentIDVal.(type) {
|
||
case float64:
|
||
agentIDStr = fmt.Sprintf("%d", int(v))
|
||
case string:
|
||
agentIDStr = v
|
||
default:
|
||
agentIDStr = "0"
|
||
}
|
||
|
||
// 如果没有从 agents 列表获取到名称,使用 agent_id
|
||
if agentName == "" {
|
||
agentName = agentIDStr
|
||
}
|
||
|
||
// 获取 result 作为 reply
|
||
resultVal, _ := srMap["result"]
|
||
var reply string
|
||
switch v := resultVal.(type) {
|
||
case string:
|
||
reply = v
|
||
default:
|
||
reply = fmt.Sprintf("%v", resultVal)
|
||
}
|
||
|
||
subtaskResults = append(subtaskResults, model.SubtaskResult{
|
||
AgentID: agentIDStr,
|
||
AgentName: agentName,
|
||
Reply: reply,
|
||
TokensUsed: 0, // Python 暂时未返回
|
||
DurationMs: 0, // Python 暂时未返回
|
||
})
|
||
}
|
||
|
||
// 汇总结果
|
||
response := &model.GroupChatResponse{
|
||
SessionID: sessionID,
|
||
Reply: teamResp.Response,
|
||
DurationMs: teamResp.DurationMs,
|
||
Strategy: teamResp.Strategy,
|
||
SubtaskResults: subtaskResults,
|
||
}
|
||
|
||
return response, nil
|
||
}
|
||
|
||
// 辅助函数:解析 Agent IDs
|
||
func parseAgentIDs(agentIDs string) []string {
|
||
if agentIDs == "" {
|
||
return []string{}
|
||
}
|
||
// 尝试解析 JSON 数组格式
|
||
var ids []string
|
||
if err := json.Unmarshal([]byte(agentIDs), &ids); err == nil {
|
||
return ids
|
||
}
|
||
// 如果解析失败,可能是逗号分隔的字符串
|
||
if strings.Contains(agentIDs, ",") {
|
||
return strings.Split(agentIDs, ",")
|
||
}
|
||
// 简单处理,直接返回
|
||
return []string{agentIDs}
|
||
}
|