feat: 更新 Server 后端服务
- 更新 agent handler 和 service 层 - 新增 chat_group handler 和 service - 删除废弃的 chat_handler - 更新 tool 相关处理 - 更新 API 文档和依赖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,7 +3,10 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -11,15 +14,31 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// 获取项目根目录
|
||||
func getProjectRoot() string {
|
||||
// 从当前工作目录向上查找 .env 文件
|
||||
dir, _ := os.Getwd()
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := os.Stat(filepath.Join(dir, ".env")); err == nil {
|
||||
return dir
|
||||
}
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
// 默认返回当前目录
|
||||
return "."
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
JWTSecret string
|
||||
DatabaseType string // 数据库类型: mysql 或 sqlite
|
||||
DatabaseHost string
|
||||
DatabasePort string
|
||||
DatabaseUser string
|
||||
DatabasePassword string
|
||||
DatabaseName string
|
||||
DatabaseURL string // 拼接后的完整连接字符串
|
||||
SQLitePath string // SQLite 数据库文件路径
|
||||
PythonServiceURL string
|
||||
AICoreServiceAddr string // AI-Core gRPC 服务地址,如 "localhost:50051"
|
||||
// 文件上传配置
|
||||
@@ -36,23 +55,22 @@ type Config struct {
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("../config")
|
||||
viper.AddConfigPath("../../config")
|
||||
// 重新初始化 viper,避免之前的状态影响
|
||||
viper.Reset()
|
||||
|
||||
// 默认值
|
||||
// 第一步:设置默认值
|
||||
viper.SetDefault("port", "8080")
|
||||
viper.SetDefault("jwt_secret", "your-secret-key-change-in-production")
|
||||
viper.SetDefault("python_service_url", "http://localhost:8081")
|
||||
viper.SetDefault("ai_core_service_addr", "localhost:50051")
|
||||
// 数据库默认配置
|
||||
viper.SetDefault("database_type", "mysql")
|
||||
viper.SetDefault("database_host", "localhost")
|
||||
viper.SetDefault("database_port", "3306")
|
||||
viper.SetDefault("database_user", "root")
|
||||
viper.SetDefault("database_password", "root")
|
||||
viper.SetDefault("database_name", "x_agents")
|
||||
viper.SetDefault("sqlite_path", "./data/x_agents.db")
|
||||
// 文件上传默认配置
|
||||
viper.SetDefault("upload_mode", "local")
|
||||
viper.SetDefault("upload_local_path", "resource/files")
|
||||
@@ -64,30 +82,84 @@ func Load() *Config {
|
||||
viper.SetDefault("minio_bucket", "x-agents")
|
||||
viper.SetDefault("minio_use_ssl", false)
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Printf("Using default config: %v", err)
|
||||
// 第二步:读取 config.yaml(优先级低)
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./config")
|
||||
viper.AddConfigPath("../config")
|
||||
viper.AddConfigPath("../../config")
|
||||
_ = viper.MergeInConfig() // 忽略错误,可能没有 config.yaml
|
||||
|
||||
// 第三步:读取 .env 文件(优先级最高)
|
||||
projectRoot := getProjectRoot()
|
||||
log.Printf("Project root: %s", projectRoot)
|
||||
|
||||
viper.SetConfigName(".env")
|
||||
viper.SetConfigType("env")
|
||||
viper.AddConfigPath(projectRoot) // 项目根目录 (X-Agents)
|
||||
viper.AddConfigPath(".") // 当前目录
|
||||
viper.AddConfigPath("..") // 父目录
|
||||
viper.AddConfigPath("../..") // 上两级目录
|
||||
viper.SetEnvPrefix("GO") // 环境变量前缀 GO_xxx (仅对环境变量生效)
|
||||
viper.AutomaticEnv()
|
||||
_ = viper.MergeInConfig() // 忽略错误,可能没有 .env
|
||||
|
||||
// 处理 .env 文件中的键名(去掉 GO_ 前缀映射)
|
||||
envToConfig := map[string]string{
|
||||
"GO_PORT": "port",
|
||||
"GO_DATABASE_TYPE": "database_type",
|
||||
"GO_DATABASE_HOST": "database_host",
|
||||
"GO_DATABASE_PORT": "database_port",
|
||||
"GO_DATABASE_NAME": "database_name",
|
||||
"GO_DATABASE_USER": "database_user",
|
||||
"GO_DATABASE_PASSWORD": "database_password",
|
||||
"GO_SQLITE_PATH": "sqlite_path",
|
||||
}
|
||||
for envKey, configKey := range envToConfig {
|
||||
if val := viper.GetString(envKey); val != "" {
|
||||
viper.Set(configKey, val)
|
||||
}
|
||||
}
|
||||
|
||||
// 拼接数据库连接字符串
|
||||
dbHost := viper.GetString("database_host")
|
||||
dbPort := viper.GetString("database_port")
|
||||
dbUser := viper.GetString("database_user")
|
||||
dbPassword := viper.GetString("database_password")
|
||||
dbName := viper.GetString("database_name")
|
||||
databaseURL := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
dbUser, dbPassword, dbHost, dbPort, dbName)
|
||||
log.Printf("Loaded config: database_type=%s, port=%s",
|
||||
viper.GetString("database_type"), viper.GetString("port"))
|
||||
|
||||
// 获取数据库类型
|
||||
dbType := viper.GetString("database_type")
|
||||
var databaseURL string
|
||||
|
||||
if dbType == "sqlite" {
|
||||
sqlitePath := viper.GetString("sqlite_path")
|
||||
// 确保 SQLite 数据目录存在 (跨平台处理)
|
||||
dir := filepath.Dir(sqlitePath)
|
||||
if dir != "." && dir != "" {
|
||||
os.MkdirAll(dir, 0755)
|
||||
}
|
||||
databaseURL = sqlitePath
|
||||
} else {
|
||||
// MySQL 连接字符串
|
||||
dbHost := viper.GetString("database_host")
|
||||
dbPort := viper.GetString("database_port")
|
||||
dbUser := viper.GetString("database_user")
|
||||
dbPassword := viper.GetString("database_password")
|
||||
dbName := viper.GetString("database_name")
|
||||
databaseURL = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
dbUser, dbPassword, dbHost, dbPort, dbName)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Port: viper.GetString("port"),
|
||||
JWTSecret: viper.GetString("jwt_secret"),
|
||||
DatabaseType: dbType,
|
||||
DatabaseURL: databaseURL,
|
||||
DatabaseHost: dbHost,
|
||||
DatabasePort: dbPort,
|
||||
DatabaseUser: dbUser,
|
||||
DatabasePassword: dbPassword,
|
||||
DatabaseName: dbName,
|
||||
DatabaseHost: viper.GetString("database_host"),
|
||||
DatabasePort: viper.GetString("database_port"),
|
||||
DatabaseUser: viper.GetString("database_user"),
|
||||
DatabasePassword: viper.GetString("database_password"),
|
||||
DatabaseName: viper.GetString("database_name"),
|
||||
SQLitePath: viper.GetString("sqlite_path"),
|
||||
PythonServiceURL: viper.GetString("python_service_url"),
|
||||
AICoreServiceAddr: viper.GetString("ai_core_service_addr"),
|
||||
AICoreServiceAddr: viper.GetString("ai_core_service_addr"),
|
||||
// 文件上传配置
|
||||
UploadMode: viper.GetString("upload_mode"),
|
||||
UploadLocalPath: viper.GetString("upload_local_path"),
|
||||
@@ -108,9 +180,23 @@ func InitDB(cfg *Config) (*gorm.DB, error) {
|
||||
return nil, fmt.Errorf("database URL is empty")
|
||||
}
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
|
||||
if cfg.DatabaseType == "sqlite" {
|
||||
// SQLite 不需要创建目录逻辑,因为 Load 函数已经处理了
|
||||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
log.Printf("Using SQLite database: %s", dsn)
|
||||
} else {
|
||||
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
log.Printf("Using MySQL database: %s:%s/%s",
|
||||
cfg.DatabaseHost, cfg.DatabasePort, cfg.DatabaseName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect database: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
@@ -22,7 +23,7 @@ func NewAgentHandler(agentService *service.AgentService) *AgentHandler {
|
||||
|
||||
// ChatRequest 对话请求
|
||||
type ChatRequest struct {
|
||||
AgentID int `json:"agent_id" binding:"required"`
|
||||
AgentID string `json:"agent_id" binding:"required"` // 字符串类型
|
||||
Message string `json:"message" binding:"required"`
|
||||
SessionID string `json:"session_id"`
|
||||
ModelID string `json:"model_id"`
|
||||
@@ -69,10 +70,14 @@ func (h *AgentHandler) Chat(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取用户 ID(从认证中间件获取)
|
||||
userID := 1 // TODO: 从 c.Get("user_id") 获取
|
||||
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
|
||||
userID, _ := strconv.Atoi(userIDStr)
|
||||
|
||||
// 将前端传来的字符串 agent_id 转换为 int
|
||||
agentID, _ := strconv.Atoi(req.AgentID)
|
||||
|
||||
pythonReq := service.AgentChatRequest{
|
||||
AgentID: req.AgentID,
|
||||
AgentID: agentID,
|
||||
Message: req.Message,
|
||||
UserID: userID,
|
||||
SessionID: req.SessionID,
|
||||
@@ -122,7 +127,11 @@ func (h *AgentHandler) ChatStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取用户 ID
|
||||
userID := 1 // TODO: 从 c.Get("user_id") 获取
|
||||
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
|
||||
userID, _ := strconv.Atoi(userIDStr)
|
||||
|
||||
// 将前端传来的字符串 agent_id 转换为 int
|
||||
agentID, _ := strconv.Atoi(req.AgentID)
|
||||
|
||||
// 构建 SSE 流
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -131,7 +140,7 @@ func (h *AgentHandler) ChatStream(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// 调用 Python 服务的流式端点
|
||||
err := h.agentService.ChatStream(c, req.AgentID, req.Message, req.SessionID, req.ModelID, userID)
|
||||
err := h.agentService.ChatStream(c, agentID, req.Message, req.SessionID, req.ModelID, userID)
|
||||
if err != nil && !c.IsAborted() {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
@@ -173,7 +182,8 @@ func (h *AgentHandler) TeamChat(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取用户 ID
|
||||
userID := 1 // TODO: 从 c.Get("user_id") 获取
|
||||
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
|
||||
userID, _ := strconv.Atoi(userIDStr)
|
||||
|
||||
pythonReq := service.TeamChatRequest{
|
||||
SupervisorAgentID: req.SupervisorAgentID,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -140,10 +141,13 @@ func (h *ChatGroupHandler) GroupChat(c *gin.Context) {
|
||||
}
|
||||
|
||||
response, err := h.groupService.GroupChat(userID, req.Message, agentIDs, req.SessionID)
|
||||
log.Printf("[ChatGroupHandler] Got response, err: %v", err)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[ChatGroupHandler] Response subtask_results: %+v", response.SubtaskResults)
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatHandler struct {
|
||||
chatService *service.ChatService
|
||||
}
|
||||
|
||||
func NewChatHandler(chatService *service.ChatService) *ChatHandler {
|
||||
return &ChatHandler{chatService: chatService}
|
||||
}
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
var req model.AgentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取用户ID(由中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.chatService.Chat(c.Request.Context(), userID.(string), req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ListAgents 获取 Agent 列表
|
||||
func (h *ChatHandler) ListAgents(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
agents, err := h.chatService.ListAgents(userID.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if agents == nil {
|
||||
agents = []model.Agent{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"agents": agents})
|
||||
}
|
||||
|
||||
// CreateAgent 创建 Agent
|
||||
func (h *ChatHandler) CreateAgent(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
agent, err := h.chatService.CreateAgent(userID.(string), req.Name, req.Description)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, agent)
|
||||
}
|
||||
@@ -164,3 +164,46 @@ func (h *ToolHandler) Delete(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "tool deleted"})
|
||||
}
|
||||
|
||||
// SyncFromPythonRequest 从Python同步工具请求
|
||||
type SyncFromPythonRequest struct {
|
||||
Tools []PythonTool `json:"tools" binding:"required"`
|
||||
}
|
||||
|
||||
// PythonTool Python端工具结构
|
||||
type PythonTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters string `json:"parameters"`
|
||||
Category string `json:"category"`
|
||||
}
|
||||
|
||||
// SyncFromPython 从Python端同步工具
|
||||
// @Summary 从Python端同步工具
|
||||
// @Description 接收Python Agent同步过来的工具列表并存储到数据库
|
||||
// @Tags 工具管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param tools body SyncFromPythonRequest true "工具列表"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /tool/sync-from-python [post]
|
||||
func (h *ToolHandler) SyncFromPython(c *gin.Context) {
|
||||
var req struct {
|
||||
Tools []map[string]interface{} `json:"tools" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.toolService.SyncToolsFromPython(req.Tools)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "tools synced from python",
|
||||
"synced_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -63,12 +63,30 @@ type AgentMemory struct {
|
||||
AgentID string `json:"agent_id" gorm:"size:191;index"`
|
||||
UserID string `json:"user_id" gorm:"size:191;index"`
|
||||
Content string `json:"content" gorm:"type:text"`
|
||||
MemoryType string `json:"memory_type" gorm:"size:20"` // experience/preference/conversation
|
||||
Importance int `json:"importance" gorm:"default:5"`
|
||||
MemoryType string `json:"memory_type" gorm:"size:20"` // experience/preference/conversation/fact
|
||||
Category string `json:"category" gorm:"size:50;index"` // 记忆分类
|
||||
Tags string `json:"tags" gorm:"size:500"` // 标签,JSON数组格式
|
||||
Keywords string `json:"keywords" gorm:"size:500"` // 关键词,用于搜索
|
||||
Importance int `json:"importance" gorm:"default:5;index"` // 重要性等级 1-10
|
||||
IsPinned bool `json:"is_pinned" gorm:"default:false"` // 是否置顶
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (AgentMemory) TableName() string {
|
||||
return "agent_memories"
|
||||
}
|
||||
|
||||
// ImportItem 导入记忆项
|
||||
type ImportItem struct {
|
||||
Content string `json:"content"`
|
||||
MemoryType string `json:"memory_type"`
|
||||
Category string `json:"category"`
|
||||
Tags string `json:"tags"`
|
||||
Keywords string `json:"keywords"`
|
||||
Importance int `json:"importance"`
|
||||
}
|
||||
|
||||
// AgentTeam 多智能体协作配置
|
||||
type AgentTeam struct {
|
||||
ID string `json:"id" gorm:"primaryKey"`
|
||||
|
||||
@@ -39,6 +39,13 @@ func (r *AgentRepository) FindAll() ([]model.Agent, error) {
|
||||
return agents, err
|
||||
}
|
||||
|
||||
// FindByIDs 根据多个ID查询
|
||||
func (r *AgentRepository) FindByIDs(ids []string) ([]model.Agent, error) {
|
||||
var agents []model.Agent
|
||||
err := r.db.Where("id IN ?", ids).Find(&agents).Error
|
||||
return agents, err
|
||||
}
|
||||
|
||||
func (r *AgentRepository) Update(agent *model.Agent) error {
|
||||
return r.db.Save(agent).Error
|
||||
}
|
||||
@@ -101,6 +108,103 @@ func (r *AgentRepository) DeleteMemory(id string) error {
|
||||
return r.db.Delete(&model.AgentMemory{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// FindMemories 通用查询记忆,支持分类和类型过滤
|
||||
func (r *AgentRepository) FindMemories(agentID, userID, category, memoryType string, limit, offset int) ([]model.AgentMemory, int64, error) {
|
||||
var memories []model.AgentMemory
|
||||
query := r.db.Model(&model.AgentMemory{})
|
||||
|
||||
if agentID != "" {
|
||||
query = query.Where("agent_id = ?", agentID)
|
||||
}
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if category != "" {
|
||||
query = query.Where("category = ?", category)
|
||||
}
|
||||
if memoryType != "" {
|
||||
query = query.Where("memory_type = ?", memoryType)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
countQuery := query
|
||||
if err := countQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 置顶的记忆优先,然后按重要性降序,最后按创建时间降序
|
||||
err := query.Order("is_pinned DESC, importance DESC, created_at DESC").Limit(limit).Offset(offset).Find(&memories).Error
|
||||
return memories, total, err
|
||||
}
|
||||
|
||||
// SearchMemories 搜索记忆
|
||||
func (r *AgentRepository) SearchMemories(agentID, userID, keyword, tags, category, memoryType string, minScore, limit, offset int) ([]model.AgentMemory, int64, error) {
|
||||
var memories []model.AgentMemory
|
||||
query := r.db.Model(&model.AgentMemory{})
|
||||
|
||||
if agentID != "" {
|
||||
query = query.Where("agent_id = ?", agentID)
|
||||
}
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if keyword != "" {
|
||||
keyword = "%" + keyword + "%"
|
||||
query = query.Where("content LIKE ? OR keywords LIKE ? OR tags LIKE ?", keyword, keyword, keyword)
|
||||
}
|
||||
if category != "" {
|
||||
query = query.Where("category = ?", category)
|
||||
}
|
||||
if memoryType != "" {
|
||||
query = query.Where("memory_type = ?", memoryType)
|
||||
}
|
||||
if minScore > 0 {
|
||||
query = query.Where("importance >= ?", minScore)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
var total int64
|
||||
countQuery := query
|
||||
if err := countQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Order("is_pinned DESC, importance DESC, created_at DESC").Limit(limit).Offset(offset).Find(&memories).Error
|
||||
return memories, total, err
|
||||
}
|
||||
|
||||
// FindMemoryByID 根据ID查询记忆
|
||||
func (r *AgentRepository) FindMemoryByID(id string) (*model.AgentMemory, error) {
|
||||
var memory model.AgentMemory
|
||||
err := r.db.Where("id = ?", id).First(&memory).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &memory, nil
|
||||
}
|
||||
|
||||
// UpdateMemory 更新记忆
|
||||
func (r *AgentRepository) UpdateMemory(memory *model.AgentMemory) error {
|
||||
return r.db.Save(memory).Error
|
||||
}
|
||||
|
||||
// FindMemoryCategories 获取记忆分类列表
|
||||
func (r *AgentRepository) FindMemoryCategories(agentID, userID string) ([]string, error) {
|
||||
var categories []string
|
||||
query := r.db.Model(&model.AgentMemory{}).Distinct("category")
|
||||
|
||||
if agentID != "" {
|
||||
query = query.Where("agent_id = ?", agentID)
|
||||
}
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
err := query.Pluck("category", &categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// AgentTeam 相关方法
|
||||
|
||||
func (r *AgentRepository) CreateAgentTeam(team *model.AgentTeam) error {
|
||||
|
||||
@@ -44,6 +44,15 @@ func (r *ToolRepository) FindByID(id string) (*model.Tool, error) {
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
func (r *ToolRepository) FindByName(name string) (*model.Tool, error) {
|
||||
var tool model.Tool
|
||||
err := r.db.First(&tool, "name = ?", name).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
func (r *ToolRepository) Update(tool *model.Tool) error {
|
||||
return r.db.Save(tool).Error
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ func (s *AgentService) Chat(req AgentChatRequest) (*AgentChatResponse, error) {
|
||||
log.Printf("[AgentService] Sending to Python: model_id=%s, api_key=%s, base_url=%s, provider=%s, model=%s",
|
||||
req.ModelID, apiKeyPreview, req.BaseURL, req.ModelProvider, req.ModelName)
|
||||
|
||||
url := fmt.Sprintf("%s/agent/chat", s.pythonURL)
|
||||
url := fmt.Sprintf("%s/api/v1/agent/chat", s.pythonURL)
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
@@ -153,7 +153,7 @@ func (s *AgentService) Chat(req AgentChatRequest) (*AgentChatResponse, error) {
|
||||
|
||||
// TeamChat 多智能体群聊
|
||||
func (s *AgentService) TeamChat(req TeamChatRequest) (*TeamChatResponse, error) {
|
||||
url := fmt.Sprintf("%s/agent/team/chat", s.pythonURL)
|
||||
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
|
||||
|
||||
// 设置默认策略
|
||||
if req.Strategy == "" {
|
||||
@@ -228,7 +228,7 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID
|
||||
log.Printf("[ChatStream] modelID is empty or modelRepo is nil: modelID=%s, modelRepo=%v", modelID, s.modelRepo != nil)
|
||||
}
|
||||
|
||||
streamURL := fmt.Sprintf("%s/agent/chat/stream", s.pythonURL)
|
||||
streamURL := fmt.Sprintf("%s/api/v1/agent/chat/stream", s.pythonURL)
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"x-agents/server/internal/model"
|
||||
"x-agents/server/internal/repository"
|
||||
|
||||
@@ -10,21 +20,27 @@ import (
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrNoAgents = errors.New("no agents provided")
|
||||
ErrNoAgents = errors.New("no agents provided")
|
||||
ErrAgentNotFound = errors.New("agent not found")
|
||||
)
|
||||
|
||||
// ChatGroupService 群聊服务
|
||||
type ChatGroupService struct {
|
||||
groupRepo *repository.ChatGroupRepository
|
||||
agentRepo *repository.AgentRepository
|
||||
groupRepo *repository.ChatGroupRepository
|
||||
agentRepo *repository.AgentRepository
|
||||
pythonURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewChatGroupService 创建群聊服务
|
||||
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository) *ChatGroupService {
|
||||
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository, pythonURL string) *ChatGroupService {
|
||||
return &ChatGroupService{
|
||||
groupRepo: groupRepo,
|
||||
agentRepo: agentRepo,
|
||||
pythonURL: pythonURL,
|
||||
client: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,37 +124,122 @@ func (s *ChatGroupService) GroupChat(userID, message, agentIDs, sessionID string
|
||||
return nil, ErrAgentNotFound
|
||||
}
|
||||
|
||||
// 并行调用所有 Agent
|
||||
results := make(chan model.SubtaskResult, len(agents))
|
||||
for _, agent := range agents {
|
||||
go func(agentID, agentName string) {
|
||||
// TODO: 调用实际的 Agent 对话逻辑
|
||||
// 这里暂时返回模拟结果
|
||||
result := model.SubtaskResult{
|
||||
AgentID: agentID,
|
||||
AgentName: agentName,
|
||||
Reply: "Agent response placeholder",
|
||||
TokensUsed: 100,
|
||||
DurationMs: 500,
|
||||
}
|
||||
results <- result
|
||||
}(agent.ID, agent.Name)
|
||||
// 解析 userID 为整数
|
||||
userIDInt, err := strconv.Atoi(userID)
|
||||
if err != nil {
|
||||
userIDInt = 1 // 默认值
|
||||
}
|
||||
|
||||
// 收集结果
|
||||
subtaskResults := make([]model.SubtaskResult, 0, len(agents))
|
||||
for i := 0; i < len(agents); i++ {
|
||||
subtaskResults = append(subtaskResults, <-results)
|
||||
// 将 agent UUIDs 转换为整数 IDs
|
||||
memberAgentIDs := make([]int, len(agents))
|
||||
for i := range agents {
|
||||
// 使用索引+1 作为 agent ID(因为 Python 端使用整数 ID)
|
||||
memberAgentIDs[i] = i + 1
|
||||
log.Printf("[ChatGroupService] Agent: %s (ID: %s)", agents[i].Name, agents[i].ID)
|
||||
}
|
||||
|
||||
// 调用 Python TeamAgent 进行群聊
|
||||
teamReq := TeamChatRequest{
|
||||
SupervisorAgentID: 0, // 没有 supervisor,使用并行策略
|
||||
MemberAgentIDs: memberAgentIDs,
|
||||
Message: message,
|
||||
UserID: userIDInt,
|
||||
SessionID: sessionID,
|
||||
Strategy: "parallel",
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/agent/team/chat", s.pythonURL)
|
||||
jsonData, err := json.Marshal(teamReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call python team agent: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("python team agent error: %s", string(body))
|
||||
}
|
||||
|
||||
var teamResp TeamChatResponse
|
||||
if err := json.Unmarshal(body, &teamResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[ChatGroupService] Raw teamResp: %+v", teamResp)
|
||||
log.Printf("[ChatGroupService] SubtaskResults count: %d", len(teamResp.SubtaskResults))
|
||||
|
||||
// 转换 SubtaskResults - Python 返回格式: agent_id, status, result
|
||||
// 使用 agents 列表获取 agent 名称
|
||||
subtaskResults := make([]model.SubtaskResult, 0)
|
||||
for i, sr := range teamResp.SubtaskResults {
|
||||
srMap, ok := sr.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 agent 名称(从已查询的 agents 列表)
|
||||
agentName := ""
|
||||
if i < len(agents) {
|
||||
agentName = agents[i].Name
|
||||
}
|
||||
|
||||
// 处理 agent_id (Python 返回 int,Go 需要 string)
|
||||
agentIDVal := srMap["agent_id"]
|
||||
var agentIDStr string
|
||||
switch v := agentIDVal.(type) {
|
||||
case float64:
|
||||
agentIDStr = fmt.Sprintf("%d", int(v))
|
||||
case string:
|
||||
agentIDStr = v
|
||||
default:
|
||||
agentIDStr = "0"
|
||||
}
|
||||
|
||||
// 如果没有从 agents 列表获取到名称,使用 agent_id
|
||||
if agentName == "" {
|
||||
agentName = agentIDStr
|
||||
}
|
||||
|
||||
// 获取 result 作为 reply
|
||||
resultVal, _ := srMap["result"]
|
||||
var reply string
|
||||
switch v := resultVal.(type) {
|
||||
case string:
|
||||
reply = v
|
||||
default:
|
||||
reply = fmt.Sprintf("%v", resultVal)
|
||||
}
|
||||
|
||||
subtaskResults = append(subtaskResults, model.SubtaskResult{
|
||||
AgentID: agentIDStr,
|
||||
AgentName: agentName,
|
||||
Reply: reply,
|
||||
TokensUsed: 0, // Python 暂时未返回
|
||||
DurationMs: 0, // Python 暂时未返回
|
||||
})
|
||||
}
|
||||
close(results)
|
||||
|
||||
// 汇总结果
|
||||
response := &model.GroupChatResponse{
|
||||
SessionID: sessionID,
|
||||
Reply: "Group chat completed",
|
||||
DurationMs: 1000,
|
||||
TokensUsed: 500,
|
||||
Strategy: "parallel",
|
||||
Reply: teamResp.Response,
|
||||
DurationMs: teamResp.DurationMs,
|
||||
Strategy: teamResp.Strategy,
|
||||
SubtaskResults: subtaskResults,
|
||||
}
|
||||
|
||||
@@ -150,8 +251,15 @@ func parseAgentIDs(agentIDs string) []string {
|
||||
if agentIDs == "" {
|
||||
return []string{}
|
||||
}
|
||||
// 简单解析,假设是 JSON 数组格式
|
||||
// 实际应该使用 json.Unmarshal
|
||||
// 这里简化处理,直接返回
|
||||
// 尝试解析 JSON 数组格式
|
||||
var ids []string
|
||||
if err := json.Unmarshal([]byte(agentIDs), &ids); err == nil {
|
||||
return ids
|
||||
}
|
||||
// 如果解析失败,可能是逗号分隔的字符串
|
||||
if strings.Contains(agentIDs, ",") {
|
||||
return strings.Split(agentIDs, ",")
|
||||
}
|
||||
// 简单处理,直接返回
|
||||
return []string{agentIDs}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,56 @@ func (s *ToolService) InitDefaultTools() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncToolsFromPython 从Python端同步工具
|
||||
func (s *ToolService) SyncToolsFromPython(pythonTools []map[string]interface{}) (int, error) {
|
||||
count := 0
|
||||
for _, pt := range pythonTools {
|
||||
// 提取工具信息
|
||||
name, _ := pt["name"].(string)
|
||||
description, _ := pt["description"].(string)
|
||||
parameters, _ := pt["parameters"].(string)
|
||||
category, _ := pt["category"].(string)
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查工具是否已存在
|
||||
existing, err := s.toolRepo.FindByName(name)
|
||||
if err == nil && existing != nil {
|
||||
// 更新现有工具
|
||||
existing.Description = description
|
||||
existing.Parameters = parameters
|
||||
if category != "" {
|
||||
existing.Category = category
|
||||
}
|
||||
if err := s.toolRepo.Update(existing); err != nil {
|
||||
log.Printf("[ToolService] Failed to update tool %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// 创建新工具
|
||||
tool := model.Tool{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: parameters,
|
||||
Category: category,
|
||||
Provider: "python",
|
||||
Status: "active",
|
||||
SecurityLevel: "safe",
|
||||
RequireApproval: false,
|
||||
}
|
||||
if err := s.toolRepo.Create(&tool); err != nil {
|
||||
log.Printf("[ToolService] Failed to create tool %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
count++
|
||||
}
|
||||
log.Printf("[ToolService] Synced %d tools from Python", count)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getDefaultTools 获取默认工具列表
|
||||
func (s *ToolService) getDefaultTools() []model.Tool {
|
||||
return []model.Tool{
|
||||
|
||||
Reference in New Issue
Block a user