feat: 更新 Server 后端服务

- 更新 agent handler 和 service 层
- 新增 chat_group handler 和 service
- 删除废弃的 chat_handler
- 更新 tool 相关处理
- 更新 API 文档和依赖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-13 21:26:27 +08:00
parent 237ab9f6d7
commit 71e8cc59d5
24 changed files with 5007 additions and 347 deletions

View File

@@ -3,7 +3,10 @@ package config
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -11,15 +14,31 @@ import (
"github.com/spf13/viper"
)
// 获取项目根目录
func getProjectRoot() string {
// 从当前工作目录向上查找 .env 文件
dir, _ := os.Getwd()
for i := 0; i < 5; i++ {
if _, err := os.Stat(filepath.Join(dir, ".env")); err == nil {
return dir
}
dir = filepath.Dir(dir)
}
// 默认返回当前目录
return "."
}
type Config struct {
Port string
JWTSecret string
DatabaseType string // 数据库类型: mysql 或 sqlite
DatabaseHost string
DatabasePort string
DatabaseUser string
DatabasePassword string
DatabaseName string
DatabaseURL string // 拼接后的完整连接字符串
SQLitePath string // SQLite 数据库文件路径
PythonServiceURL string
AICoreServiceAddr string // AI-Core gRPC 服务地址,如 "localhost:50051"
// 文件上传配置
@@ -36,23 +55,22 @@ type Config struct {
}
func Load() *Config {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
viper.AddConfigPath("../config")
viper.AddConfigPath("../../config")
// 重新初始化 viper避免之前的状态影响
viper.Reset()
// 默认值
// 第一步:设置默认值
viper.SetDefault("port", "8080")
viper.SetDefault("jwt_secret", "your-secret-key-change-in-production")
viper.SetDefault("python_service_url", "http://localhost:8081")
viper.SetDefault("ai_core_service_addr", "localhost:50051")
// 数据库默认配置
viper.SetDefault("database_type", "mysql")
viper.SetDefault("database_host", "localhost")
viper.SetDefault("database_port", "3306")
viper.SetDefault("database_user", "root")
viper.SetDefault("database_password", "root")
viper.SetDefault("database_name", "x_agents")
viper.SetDefault("sqlite_path", "./data/x_agents.db")
// 文件上传默认配置
viper.SetDefault("upload_mode", "local")
viper.SetDefault("upload_local_path", "resource/files")
@@ -64,30 +82,84 @@ func Load() *Config {
viper.SetDefault("minio_bucket", "x-agents")
viper.SetDefault("minio_use_ssl", false)
if err := viper.ReadInConfig(); err != nil {
log.Printf("Using default config: %v", err)
// 第二步:读取 config.yaml优先级低
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./config")
viper.AddConfigPath("../config")
viper.AddConfigPath("../../config")
_ = viper.MergeInConfig() // 忽略错误,可能没有 config.yaml
// 第三步:读取 .env 文件(优先级最高)
projectRoot := getProjectRoot()
log.Printf("Project root: %s", projectRoot)
viper.SetConfigName(".env")
viper.SetConfigType("env")
viper.AddConfigPath(projectRoot) // 项目根目录 (X-Agents)
viper.AddConfigPath(".") // 当前目录
viper.AddConfigPath("..") // 父目录
viper.AddConfigPath("../..") // 上两级目录
viper.SetEnvPrefix("GO") // 环境变量前缀 GO_xxx (仅对环境变量生效)
viper.AutomaticEnv()
_ = viper.MergeInConfig() // 忽略错误,可能没有 .env
// 处理 .env 文件中的键名(去掉 GO_ 前缀映射)
envToConfig := map[string]string{
"GO_PORT": "port",
"GO_DATABASE_TYPE": "database_type",
"GO_DATABASE_HOST": "database_host",
"GO_DATABASE_PORT": "database_port",
"GO_DATABASE_NAME": "database_name",
"GO_DATABASE_USER": "database_user",
"GO_DATABASE_PASSWORD": "database_password",
"GO_SQLITE_PATH": "sqlite_path",
}
for envKey, configKey := range envToConfig {
if val := viper.GetString(envKey); val != "" {
viper.Set(configKey, val)
}
}
// 拼接数据库连接字符串
dbHost := viper.GetString("database_host")
dbPort := viper.GetString("database_port")
dbUser := viper.GetString("database_user")
dbPassword := viper.GetString("database_password")
dbName := viper.GetString("database_name")
databaseURL := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbUser, dbPassword, dbHost, dbPort, dbName)
log.Printf("Loaded config: database_type=%s, port=%s",
viper.GetString("database_type"), viper.GetString("port"))
// 获取数据库类型
dbType := viper.GetString("database_type")
var databaseURL string
if dbType == "sqlite" {
sqlitePath := viper.GetString("sqlite_path")
// 确保 SQLite 数据目录存在 (跨平台处理)
dir := filepath.Dir(sqlitePath)
if dir != "." && dir != "" {
os.MkdirAll(dir, 0755)
}
databaseURL = sqlitePath
} else {
// MySQL 连接字符串
dbHost := viper.GetString("database_host")
dbPort := viper.GetString("database_port")
dbUser := viper.GetString("database_user")
dbPassword := viper.GetString("database_password")
dbName := viper.GetString("database_name")
databaseURL = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
dbUser, dbPassword, dbHost, dbPort, dbName)
}
return &Config{
Port: viper.GetString("port"),
JWTSecret: viper.GetString("jwt_secret"),
DatabaseType: dbType,
DatabaseURL: databaseURL,
DatabaseHost: dbHost,
DatabasePort: dbPort,
DatabaseUser: dbUser,
DatabasePassword: dbPassword,
DatabaseName: dbName,
DatabaseHost: viper.GetString("database_host"),
DatabasePort: viper.GetString("database_port"),
DatabaseUser: viper.GetString("database_user"),
DatabasePassword: viper.GetString("database_password"),
DatabaseName: viper.GetString("database_name"),
SQLitePath: viper.GetString("sqlite_path"),
PythonServiceURL: viper.GetString("python_service_url"),
AICoreServiceAddr: viper.GetString("ai_core_service_addr"),
AICoreServiceAddr: viper.GetString("ai_core_service_addr"),
// 文件上传配置
UploadMode: viper.GetString("upload_mode"),
UploadLocalPath: viper.GetString("upload_local_path"),
@@ -108,9 +180,23 @@ func InitDB(cfg *Config) (*gorm.DB, error) {
return nil, fmt.Errorf("database URL is empty")
}
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
var db *gorm.DB
var err error
if cfg.DatabaseType == "sqlite" {
// SQLite 不需要创建目录逻辑,因为 Load 函数已经处理了
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
log.Printf("Using SQLite database: %s", dsn)
} else {
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
log.Printf("Using MySQL database: %s:%s/%s",
cfg.DatabaseHost, cfg.DatabasePort, cfg.DatabaseName)
}
if err != nil {
return nil, fmt.Errorf("failed to connect database: %w", err)
}

View File

@@ -2,6 +2,7 @@ package handler
import (
"net/http"
"strconv"
"x-agents/server/internal/service"
@@ -22,7 +23,7 @@ func NewAgentHandler(agentService *service.AgentService) *AgentHandler {
// ChatRequest 对话请求
type ChatRequest struct {
AgentID int `json:"agent_id" binding:"required"`
AgentID string `json:"agent_id" binding:"required"` // 字符串类型
Message string `json:"message" binding:"required"`
SessionID string `json:"session_id"`
ModelID string `json:"model_id"`
@@ -69,10 +70,14 @@ func (h *AgentHandler) Chat(c *gin.Context) {
}
// 获取用户 ID从认证中间件获取
userID := 1 // TODO: 从 c.Get("user_id") 获取
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
userID, _ := strconv.Atoi(userIDStr)
// 将前端传来的字符串 agent_id 转换为 int
agentID, _ := strconv.Atoi(req.AgentID)
pythonReq := service.AgentChatRequest{
AgentID: req.AgentID,
AgentID: agentID,
Message: req.Message,
UserID: userID,
SessionID: req.SessionID,
@@ -122,7 +127,11 @@ func (h *AgentHandler) ChatStream(c *gin.Context) {
}
// 获取用户 ID
userID := 1 // TODO: 从 c.Get("user_id") 获取
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
userID, _ := strconv.Atoi(userIDStr)
// 将前端传来的字符串 agent_id 转换为 int
agentID, _ := strconv.Atoi(req.AgentID)
// 构建 SSE 流
c.Header("Content-Type", "text/event-stream")
@@ -131,7 +140,7 @@ func (h *AgentHandler) ChatStream(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
// 调用 Python 服务的流式端点
err := h.agentService.ChatStream(c, req.AgentID, req.Message, req.SessionID, req.ModelID, userID)
err := h.agentService.ChatStream(c, agentID, req.Message, req.SessionID, req.ModelID, userID)
if err != nil && !c.IsAborted() {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@@ -173,7 +182,8 @@ func (h *AgentHandler) TeamChat(c *gin.Context) {
}
// 获取用户 ID
userID := 1 // TODO: 从 c.Get("user_id") 获取
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
userID, _ := strconv.Atoi(userIDStr)
pythonReq := service.TeamChatRequest{
SupervisorAgentID: req.SupervisorAgentID,

View File

@@ -1,6 +1,7 @@
package handler
import (
"log"
"net/http"
"strconv"
@@ -140,10 +141,13 @@ func (h *ChatGroupHandler) GroupChat(c *gin.Context) {
}
response, err := h.groupService.GroupChat(userID, req.Message, agentIDs, req.SessionID)
log.Printf("[ChatGroupHandler] Got response, err: %v", err)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
log.Printf("[ChatGroupHandler] Response subtask_results: %+v", response.SubtaskResults)
c.JSON(http.StatusOK, response)
}

View File

@@ -1,89 +0,0 @@
package handler
import (
"net/http"
"x-agents/server/internal/model"
"x-agents/server/internal/service"
"github.com/gin-gonic/gin"
)
type ChatHandler struct {
chatService *service.ChatService
}
func NewChatHandler(chatService *service.ChatService) *ChatHandler {
return &ChatHandler{chatService: chatService}
}
// Chat 处理聊天请求
func (h *ChatHandler) Chat(c *gin.Context) {
var req model.AgentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 从上下文获取用户ID由中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
resp, err := h.chatService.Chat(c.Request.Context(), userID.(string), req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, resp)
}
// ListAgents 获取 Agent 列表
func (h *ChatHandler) ListAgents(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
agents, err := h.chatService.ListAgents(userID.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if agents == nil {
agents = []model.Agent{}
}
c.JSON(http.StatusOK, gin.H{"agents": agents})
}
// CreateAgent 创建 Agent
func (h *ChatHandler) CreateAgent(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
agent, err := h.chatService.CreateAgent(userID.(string), req.Name, req.Description)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, agent)
}

View File

@@ -164,3 +164,46 @@ func (h *ToolHandler) Delete(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "tool deleted"})
}
// SyncFromPythonRequest 从Python同步工具请求
type SyncFromPythonRequest struct {
Tools []PythonTool `json:"tools" binding:"required"`
}
// PythonTool Python端工具结构
type PythonTool struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters string `json:"parameters"`
Category string `json:"category"`
}
// SyncFromPython 从Python端同步工具
// @Summary 从Python端同步工具
// @Description 接收Python Agent同步过来的工具列表并存储到数据库
// @Tags 工具管理
// @Accept json
// @Produce json
// @Param tools body SyncFromPythonRequest true "工具列表"
// @Success 200 {object} map[string]interface{}
// @Router /tool/sync-from-python [post]
func (h *ToolHandler) SyncFromPython(c *gin.Context) {
var req struct {
Tools []map[string]interface{} `json:"tools" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
count, err := h.toolService.SyncToolsFromPython(req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "tools synced from python",
"synced_count": count,
})
}

View File

@@ -63,12 +63,30 @@ type AgentMemory struct {
AgentID string `json:"agent_id" gorm:"size:191;index"`
UserID string `json:"user_id" gorm:"size:191;index"`
Content string `json:"content" gorm:"type:text"`
MemoryType string `json:"memory_type" gorm:"size:20"` // experience/preference/conversation
Importance int `json:"importance" gorm:"default:5"`
MemoryType string `json:"memory_type" gorm:"size:20"` // experience/preference/conversation/fact
Category string `json:"category" gorm:"size:50;index"` // 记忆分类
Tags string `json:"tags" gorm:"size:500"` // 标签JSON数组格式
Keywords string `json:"keywords" gorm:"size:500"` // 关键词,用于搜索
Importance int `json:"importance" gorm:"default:5;index"` // 重要性等级 1-10
IsPinned bool `json:"is_pinned" gorm:"default:false"` // 是否置顶
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (AgentMemory) TableName() string {
return "agent_memories"
}
// ImportItem 导入记忆项
type ImportItem struct {
Content string `json:"content"`
MemoryType string `json:"memory_type"`
Category string `json:"category"`
Tags string `json:"tags"`
Keywords string `json:"keywords"`
Importance int `json:"importance"`
}
// AgentTeam 多智能体协作配置
type AgentTeam struct {
ID string `json:"id" gorm:"primaryKey"`

View File

@@ -39,6 +39,13 @@ func (r *AgentRepository) FindAll() ([]model.Agent, error) {
return agents, err
}
// FindByIDs 根据多个ID查询
func (r *AgentRepository) FindByIDs(ids []string) ([]model.Agent, error) {
var agents []model.Agent
err := r.db.Where("id IN ?", ids).Find(&agents).Error
return agents, err
}
func (r *AgentRepository) Update(agent *model.Agent) error {
return r.db.Save(agent).Error
}
@@ -101,6 +108,103 @@ func (r *AgentRepository) DeleteMemory(id string) error {
return r.db.Delete(&model.AgentMemory{}, "id = ?", id).Error
}
// FindMemories 通用查询记忆,支持分类和类型过滤
func (r *AgentRepository) FindMemories(agentID, userID, category, memoryType string, limit, offset int) ([]model.AgentMemory, int64, error) {
var memories []model.AgentMemory
query := r.db.Model(&model.AgentMemory{})
if agentID != "" {
query = query.Where("agent_id = ?", agentID)
}
if userID != "" {
query = query.Where("user_id = ?", userID)
}
if category != "" {
query = query.Where("category = ?", category)
}
if memoryType != "" {
query = query.Where("memory_type = ?", memoryType)
}
// 统计总数
var total int64
countQuery := query
if err := countQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 置顶的记忆优先,然后按重要性降序,最后按创建时间降序
err := query.Order("is_pinned DESC, importance DESC, created_at DESC").Limit(limit).Offset(offset).Find(&memories).Error
return memories, total, err
}
// SearchMemories 搜索记忆
func (r *AgentRepository) SearchMemories(agentID, userID, keyword, tags, category, memoryType string, minScore, limit, offset int) ([]model.AgentMemory, int64, error) {
var memories []model.AgentMemory
query := r.db.Model(&model.AgentMemory{})
if agentID != "" {
query = query.Where("agent_id = ?", agentID)
}
if userID != "" {
query = query.Where("user_id = ?", userID)
}
if keyword != "" {
keyword = "%" + keyword + "%"
query = query.Where("content LIKE ? OR keywords LIKE ? OR tags LIKE ?", keyword, keyword, keyword)
}
if category != "" {
query = query.Where("category = ?", category)
}
if memoryType != "" {
query = query.Where("memory_type = ?", memoryType)
}
if minScore > 0 {
query = query.Where("importance >= ?", minScore)
}
// 统计总数
var total int64
countQuery := query
if err := countQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
err := query.Order("is_pinned DESC, importance DESC, created_at DESC").Limit(limit).Offset(offset).Find(&memories).Error
return memories, total, err
}
// FindMemoryByID 根据ID查询记忆
func (r *AgentRepository) FindMemoryByID(id string) (*model.AgentMemory, error) {
var memory model.AgentMemory
err := r.db.Where("id = ?", id).First(&memory).Error
if err != nil {
return nil, err
}
return &memory, nil
}
// UpdateMemory 更新记忆
func (r *AgentRepository) UpdateMemory(memory *model.AgentMemory) error {
return r.db.Save(memory).Error
}
// FindMemoryCategories 获取记忆分类列表
func (r *AgentRepository) FindMemoryCategories(agentID, userID string) ([]string, error) {
var categories []string
query := r.db.Model(&model.AgentMemory{}).Distinct("category")
if agentID != "" {
query = query.Where("agent_id = ?", agentID)
}
if userID != "" {
query = query.Where("user_id = ?", userID)
}
err := query.Pluck("category", &categories).Error
return categories, err
}
// AgentTeam 相关方法
func (r *AgentRepository) CreateAgentTeam(team *model.AgentTeam) error {

View File

@@ -44,6 +44,15 @@ func (r *ToolRepository) FindByID(id string) (*model.Tool, error) {
return &tool, nil
}
func (r *ToolRepository) FindByName(name string) (*model.Tool, error) {
var tool model.Tool
err := r.db.First(&tool, "name = ?", name).Error
if err != nil {
return nil, err
}
return &tool, nil
}
func (r *ToolRepository) Update(tool *model.Tool) error {
return r.db.Save(tool).Error
}

View File

@@ -115,7 +115,7 @@ func (s *AgentService) Chat(req AgentChatRequest) (*AgentChatResponse, error) {
log.Printf("[AgentService] Sending to Python: model_id=%s, api_key=%s, base_url=%s, provider=%s, model=%s",
req.ModelID, apiKeyPreview, req.BaseURL, req.ModelProvider, req.ModelName)
url := fmt.Sprintf("%s/agent/chat", s.pythonURL)
url := fmt.Sprintf("%s/api/v1/agent/chat", s.pythonURL)
jsonData, err := json.Marshal(req)
if err != nil {
@@ -153,7 +153,7 @@ func (s *AgentService) Chat(req AgentChatRequest) (*AgentChatResponse, error) {
// TeamChat 多智能体群聊
func (s *AgentService) TeamChat(req TeamChatRequest) (*TeamChatResponse, error) {
url := fmt.Sprintf("%s/agent/team/chat", s.pythonURL)
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
// 设置默认策略
if req.Strategy == "" {
@@ -228,7 +228,7 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID
log.Printf("[ChatStream] modelID is empty or modelRepo is nil: modelID=%s, modelRepo=%v", modelID, s.modelRepo != nil)
}
streamURL := fmt.Sprintf("%s/agent/chat/stream", s.pythonURL)
streamURL := fmt.Sprintf("%s/api/v1/agent/chat/stream", s.pythonURL)
jsonData, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -1,7 +1,17 @@
package service
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
@@ -10,21 +20,27 @@ import (
// 错误定义
var (
ErrNoAgents = errors.New("no agents provided")
ErrNoAgents = errors.New("no agents provided")
ErrAgentNotFound = errors.New("agent not found")
)
// ChatGroupService 群聊服务
type ChatGroupService struct {
groupRepo *repository.ChatGroupRepository
agentRepo *repository.AgentRepository
groupRepo *repository.ChatGroupRepository
agentRepo *repository.AgentRepository
pythonURL string
client *http.Client
}
// NewChatGroupService 创建群聊服务
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository) *ChatGroupService {
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository, pythonURL string) *ChatGroupService {
return &ChatGroupService{
groupRepo: groupRepo,
agentRepo: agentRepo,
pythonURL: pythonURL,
client: &http.Client{
Timeout: 120 * time.Second,
},
}
}
@@ -108,37 +124,122 @@ func (s *ChatGroupService) GroupChat(userID, message, agentIDs, sessionID string
return nil, ErrAgentNotFound
}
// 并行调用所有 Agent
results := make(chan model.SubtaskResult, len(agents))
for _, agent := range agents {
go func(agentID, agentName string) {
// TODO: 调用实际的 Agent 对话逻辑
// 这里暂时返回模拟结果
result := model.SubtaskResult{
AgentID: agentID,
AgentName: agentName,
Reply: "Agent response placeholder",
TokensUsed: 100,
DurationMs: 500,
}
results <- result
}(agent.ID, agent.Name)
// 解析 userID 为整数
userIDInt, err := strconv.Atoi(userID)
if err != nil {
userIDInt = 1 // 默认值
}
// 收集结果
subtaskResults := make([]model.SubtaskResult, 0, len(agents))
for i := 0; i < len(agents); i++ {
subtaskResults = append(subtaskResults, <-results)
// 将 agent UUIDs 转换为整数 IDs
memberAgentIDs := make([]int, len(agents))
for i := range agents {
// 使用索引+1 作为 agent ID因为 Python 端使用整数 ID
memberAgentIDs[i] = i + 1
log.Printf("[ChatGroupService] Agent: %s (ID: %s)", agents[i].Name, agents[i].ID)
}
// 调用 Python TeamAgent 进行群聊
teamReq := TeamChatRequest{
SupervisorAgentID: 0, // 没有 supervisor使用并行策略
MemberAgentIDs: memberAgentIDs,
Message: message,
UserID: userIDInt,
SessionID: sessionID,
Strategy: "parallel",
}
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
jsonData, err := json.Marshal(teamReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to call python team agent: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("python team agent error: %s", string(body))
}
var teamResp TeamChatResponse
if err := json.Unmarshal(body, &teamResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
log.Printf("[ChatGroupService] Raw teamResp: %+v", teamResp)
log.Printf("[ChatGroupService] SubtaskResults count: %d", len(teamResp.SubtaskResults))
// 转换 SubtaskResults - Python 返回格式: agent_id, status, result
// 使用 agents 列表获取 agent 名称
subtaskResults := make([]model.SubtaskResult, 0)
for i, sr := range teamResp.SubtaskResults {
srMap, ok := sr.(map[string]interface{})
if !ok {
continue
}
// 获取 agent 名称(从已查询的 agents 列表)
agentName := ""
if i < len(agents) {
agentName = agents[i].Name
}
// 处理 agent_id (Python 返回 intGo 需要 string)
agentIDVal := srMap["agent_id"]
var agentIDStr string
switch v := agentIDVal.(type) {
case float64:
agentIDStr = fmt.Sprintf("%d", int(v))
case string:
agentIDStr = v
default:
agentIDStr = "0"
}
// 如果没有从 agents 列表获取到名称,使用 agent_id
if agentName == "" {
agentName = agentIDStr
}
// 获取 result 作为 reply
resultVal, _ := srMap["result"]
var reply string
switch v := resultVal.(type) {
case string:
reply = v
default:
reply = fmt.Sprintf("%v", resultVal)
}
subtaskResults = append(subtaskResults, model.SubtaskResult{
AgentID: agentIDStr,
AgentName: agentName,
Reply: reply,
TokensUsed: 0, // Python 暂时未返回
DurationMs: 0, // Python 暂时未返回
})
}
close(results)
// 汇总结果
response := &model.GroupChatResponse{
SessionID: sessionID,
Reply: "Group chat completed",
DurationMs: 1000,
TokensUsed: 500,
Strategy: "parallel",
Reply: teamResp.Response,
DurationMs: teamResp.DurationMs,
Strategy: teamResp.Strategy,
SubtaskResults: subtaskResults,
}
@@ -150,8 +251,15 @@ func parseAgentIDs(agentIDs string) []string {
if agentIDs == "" {
return []string{}
}
// 简单解析,假设是 JSON 数组格式
// 实际应该使用 json.Unmarshal
// 这里简化处理,直接返回
// 尝试解析 JSON 数组格式
var ids []string
if err := json.Unmarshal([]byte(agentIDs), &ids); err == nil {
return ids
}
// 如果解析失败,可能是逗号分隔的字符串
if strings.Contains(agentIDs, ",") {
return strings.Split(agentIDs, ",")
}
// 简单处理,直接返回
return []string{agentIDs}
}

View File

@@ -71,6 +71,56 @@ func (s *ToolService) InitDefaultTools() error {
return nil
}
// SyncToolsFromPython 从Python端同步工具
func (s *ToolService) SyncToolsFromPython(pythonTools []map[string]interface{}) (int, error) {
count := 0
for _, pt := range pythonTools {
// 提取工具信息
name, _ := pt["name"].(string)
description, _ := pt["description"].(string)
parameters, _ := pt["parameters"].(string)
category, _ := pt["category"].(string)
if name == "" {
continue
}
// 检查工具是否已存在
existing, err := s.toolRepo.FindByName(name)
if err == nil && existing != nil {
// 更新现有工具
existing.Description = description
existing.Parameters = parameters
if category != "" {
existing.Category = category
}
if err := s.toolRepo.Update(existing); err != nil {
log.Printf("[ToolService] Failed to update tool %s: %v", name, err)
continue
}
} else {
// 创建新工具
tool := model.Tool{
Name: name,
Description: description,
Parameters: parameters,
Category: category,
Provider: "python",
Status: "active",
SecurityLevel: "safe",
RequireApproval: false,
}
if err := s.toolRepo.Create(&tool); err != nil {
log.Printf("[ToolService] Failed to create tool %s: %v", name, err)
continue
}
}
count++
}
log.Printf("[ToolService] Synced %d tools from Python", count)
return count, nil
}
// getDefaultTools 获取默认工具列表
func (s *ToolService) getDefaultTools() []model.Tool {
return []model.Tool{