Files
X-Agents/server/internal/service/chat_group_service.go

266 lines
6.4 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
"github.com/google/uuid"
)
// 错误定义
var (
ErrNoAgents = errors.New("no agents provided")
ErrAgentNotFound = errors.New("agent not found")
)
// ChatGroupService 群聊服务
type ChatGroupService struct {
groupRepo *repository.ChatGroupRepository
agentRepo *repository.AgentRepository
pythonURL string
client *http.Client
}
// NewChatGroupService 创建群聊服务
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository, pythonURL string) *ChatGroupService {
return &ChatGroupService{
groupRepo: groupRepo,
agentRepo: agentRepo,
pythonURL: pythonURL,
client: &http.Client{
Timeout: 120 * time.Second,
},
}
}
// CreateGroup 创建群聊
func (s *ChatGroupService) CreateGroup(req model.CreateGroupRequest) (*model.ChatGroup, error) {
group := &model.ChatGroup{
ID: uuid.New().String(),
UserID: req.UserID,
Name: req.Name,
Description: req.Description,
AgentIDs: req.AgentIDs,
Status: "active",
}
err := s.groupRepo.Create(group)
if err != nil {
return nil, err
}
return group, nil
}
// GetGroup 获取群聊详情
func (s *ChatGroupService) GetGroup(id string) (*model.ChatGroup, error) {
return s.groupRepo.FindByID(id)
}
// ListGroups 获取群聊列表
func (s *ChatGroupService) ListGroups(userID string, limit, offset int) ([]model.ChatGroup, int64, error) {
if limit <= 0 {
limit = 20
}
return s.groupRepo.FindByUserID(userID, limit, offset)
}
// UpdateGroup 更新群聊
func (s *ChatGroupService) UpdateGroup(id string, req model.UpdateGroupRequest) (*model.ChatGroup, error) {
group, err := s.groupRepo.FindByID(id)
if err != nil {
return nil, err
}
if req.Name != "" {
group.Name = req.Name
}
if req.Description != "" {
group.Description = req.Description
}
if req.AgentIDs != "" {
group.AgentIDs = req.AgentIDs
}
if req.Status != "" {
group.Status = req.Status
}
err = s.groupRepo.Update(group)
if err != nil {
return nil, err
}
return group, nil
}
// DeleteGroup 删除群聊
func (s *ChatGroupService) DeleteGroup(id string) error {
return s.groupRepo.Delete(id)
}
// GroupChat 群聊对话
func (s *ChatGroupService) GroupChat(userID, message, agentIDs, sessionID string) (*model.GroupChatResponse, error) {
// 解析 Agent IDs
agentIDList := parseAgentIDs(agentIDs)
if len(agentIDList) == 0 {
return nil, ErrNoAgents
}
// 获取所有 Agent 信息
agents, err := s.agentRepo.FindByIDs(agentIDList)
if err != nil || len(agents) == 0 {
return nil, ErrAgentNotFound
}
// 解析 userID 为整数
userIDInt, err := strconv.Atoi(userID)
if err != nil {
userIDInt = 1 // 默认值
}
// 将 agent UUIDs 转换为整数 IDs
memberAgentIDs := make([]int, len(agents))
for i := range agents {
// 使用索引+1 作为 agent ID因为 Python 端使用整数 ID
memberAgentIDs[i] = i + 1
log.Printf("[ChatGroupService] Agent: %s (ID: %s)", agents[i].Name, agents[i].ID)
}
// 调用 Python TeamAgent 进行群聊
teamReq := TeamChatRequest{
SupervisorAgentID: 0, // 没有 supervisor使用并行策略
MemberAgentIDs: memberAgentIDs,
Message: message,
UserID: userIDInt,
SessionID: sessionID,
Strategy: "parallel",
}
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
jsonData, err := json.Marshal(teamReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to call python team agent: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("python team agent error: %s", string(body))
}
var teamResp TeamChatResponse
if err := json.Unmarshal(body, &teamResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
log.Printf("[ChatGroupService] Raw teamResp: %+v", teamResp)
log.Printf("[ChatGroupService] SubtaskResults count: %d", len(teamResp.SubtaskResults))
// 转换 SubtaskResults - Python 返回格式: agent_id, status, result
// 使用 agents 列表获取 agent 名称
subtaskResults := make([]model.SubtaskResult, 0)
for i, sr := range teamResp.SubtaskResults {
srMap, ok := sr.(map[string]interface{})
if !ok {
continue
}
// 获取 agent 名称(从已查询的 agents 列表)
agentName := ""
if i < len(agents) {
agentName = agents[i].Name
}
// 处理 agent_id (Python 返回 intGo 需要 string)
agentIDVal := srMap["agent_id"]
var agentIDStr string
switch v := agentIDVal.(type) {
case float64:
agentIDStr = fmt.Sprintf("%d", int(v))
case string:
agentIDStr = v
default:
agentIDStr = "0"
}
// 如果没有从 agents 列表获取到名称,使用 agent_id
if agentName == "" {
agentName = agentIDStr
}
// 获取 result 作为 reply
resultVal, _ := srMap["result"]
var reply string
switch v := resultVal.(type) {
case string:
reply = v
default:
reply = fmt.Sprintf("%v", resultVal)
}
subtaskResults = append(subtaskResults, model.SubtaskResult{
AgentID: agentIDStr,
AgentName: agentName,
Reply: reply,
TokensUsed: 0, // Python 暂时未返回
DurationMs: 0, // Python 暂时未返回
})
}
// 汇总结果
response := &model.GroupChatResponse{
SessionID: sessionID,
Reply: teamResp.Response,
DurationMs: teamResp.DurationMs,
Strategy: teamResp.Strategy,
SubtaskResults: subtaskResults,
}
return response, nil
}
// 辅助函数:解析 Agent IDs
func parseAgentIDs(agentIDs string) []string {
if agentIDs == "" {
return []string{}
}
// 尝试解析 JSON 数组格式
var ids []string
if err := json.Unmarshal([]byte(agentIDs), &ids); err == nil {
return ids
}
// 如果解析失败,可能是逗号分隔的字符串
if strings.Contains(agentIDs, ",") {
return strings.Split(agentIDs, ",")
}
// 简单处理,直接返回
return []string{agentIDs}
}