diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index cd06f6e..afea6f3 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -147,7 +147,7 @@ func main() { } // 3. 自动迁移表 - if err := db.AutoMigrate(&model.DatabaseInfo{}, &model.SubTableInfo{}, &model.ModelInfo{}, &model.KnowledgeBase{}, &model.KnowledgeDocument{}, &model.User{}, &model.Role{}, &model.Tool{}, &model.MCP{}, &model.Skill{}, &model.Agent{}, &model.AgentSkill{}, &model.AgentKnowledgeBase{}, &model.AgentMemory{}, &model.AgentTeam{}, &model.AgentTask{}).Error; err != nil { + if err := db.AutoMigrate(&model.DatabaseInfo{}, &model.SubTableInfo{}, &model.ModelInfo{}, &model.KnowledgeBase{}, &model.KnowledgeDocument{}, &model.User{}, &model.Role{}, &model.Tool{}, &model.MCP{}, &model.Skill{}, &model.Agent{}, &model.AgentSkill{}, &model.AgentKnowledgeBase{}, &model.AgentMemory{}, &model.AgentTeam{}, &model.AgentTask{}, &model.ChatSession{}, &model.ChatMessage{}, &model.ChatGroup{}).Error; err != nil { log.Printf("Warning: AutoMigrate error: %v", err) } @@ -255,6 +255,57 @@ func main() { `) log.Println("Skills table verified/created") + // 3.5 确保 chat_sessions 表存在 + db.Exec(` + CREATE TABLE IF NOT EXISTS chat_sessions ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + agent_id VARCHAR(36), + title VARCHAR(255), + model_id VARCHAR(36), + status VARCHAR(20) DEFAULT 'active', + created_at DATETIME(3), + updated_at DATETIME(3), + INDEX idx_chat_sessions_user (user_id), + INDEX idx_chat_sessions_agent (agent_id), + INDEX idx_chat_sessions_updated (updated_at DESC) + ) + `) + log.Println("Chat sessions table verified/created") + + // 3.6 确保 chat_messages 表存在 + db.Exec(` + CREATE TABLE IF NOT EXISTS chat_messages ( + id VARCHAR(36) PRIMARY KEY, + session_id VARCHAR(36) NOT NULL, + role VARCHAR(20), + content TEXT, + tokens_used INT DEFAULT 0, + duration_ms INT DEFAULT 0, + metadata TEXT, + created_at DATETIME(3), + INDEX idx_chat_messages_session (session_id), + INDEX idx_chat_messages_created (created_at ASC) + ) + `) + log.Println("Chat messages table verified/created") + + // 3.7 确保 chat_groups 表存在 + db.Exec(` + CREATE TABLE IF NOT EXISTS chat_groups ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + name VARCHAR(100) NOT NULL, + description TEXT, + agent_ids TEXT, + status VARCHAR(20) DEFAULT 'active', + created_at DATETIME(3), + updated_at DATETIME(3), + INDEX idx_chat_groups_user (user_id) + ) + `) + log.Println("Chat groups table verified/created") + // 使用GORM Migrator添加缺失的列 migrator := db.Migrator() @@ -300,6 +351,7 @@ func main() { mcpRepo := repository.NewMCPRepository(db) skillRepo := repository.NewSkillRepository(db) agentRepo := repository.NewAgentRepository(db) + chatRepo := repository.NewChatRepository(db) // 4.1 初始化默认管理员用户 initDefaultAdmin(userRepo) @@ -348,6 +400,13 @@ func main() { skillHandler := handler.NewSkillHandler(skillService) agentHandler := handler.NewAgentHandler(agentService) memoryHandler := handler.NewMemoryHandler(memoryService) + sessionHandler := handler.NewSessionHandler(chatRepo) + + // 初始化群聊服务 + chatGroupRepo := repository.NewChatGroupRepository(db) + chatGroupService := service.NewChatGroupService(chatGroupRepo, agentRepo) + chatGroupHandler := handler.NewChatGroupHandler(chatGroupService) + var uploadHandler *handler.UploadHandler if uploadService != nil { uploadHandler = handler.NewUploadHandler(uploadService, knowledgeRepo) @@ -511,8 +570,8 @@ func main() { { skillGroup.GET("/list", skillHandler.List) skillGroup.GET("/sync", skillHandler.Sync) - skillGroup.GET("/:id", skillHandler.GetByID) skillGroup.GET("/content", skillHandler.GetSkillContent) + skillGroup.GET("/:id", skillHandler.GetByID) skillGroup.POST("/add", skillHandler.Create) skillGroup.PUT("/:id", skillHandler.Update) skillGroup.DELETE("/:id", skillHandler.Delete) @@ -531,12 +590,47 @@ func main() { agentGroup.POST("/team/chat", agentHandler.TeamChat) } + // 会话管理模块 + chatGroup := r.Group("/api/chat") + { + chatGroup.POST("/sessions", sessionHandler.CreateSession) + chatGroup.GET("/sessions", sessionHandler.ListSessions) + chatGroup.GET("/sessions/:id", sessionHandler.GetSession) + chatGroup.PUT("/sessions/:id", sessionHandler.UpdateSession) + chatGroup.DELETE("/sessions/:id", sessionHandler.DeleteSession) + chatGroup.GET("/sessions/:id/messages", sessionHandler.GetMessages) + chatGroup.POST("/messages", sessionHandler.CreateMessage) + } + + // 群聊管理模块 + groupChat := r.Group("/api/chat/groups") + { + groupChat.POST("", chatGroupHandler.CreateGroup) + groupChat.GET("", chatGroupHandler.ListGroups) + groupChat.GET("/:id", chatGroupHandler.GetGroup) + groupChat.PUT("/:id", chatGroupHandler.UpdateGroup) + groupChat.DELETE("/:id", chatGroupHandler.DeleteGroup) + groupChat.POST("/:id/chat", chatGroupHandler.GroupChat) + } + // 记忆管理模块 memoryGroup := r.Group("/api/agent/:id/memories") { memoryGroup.GET("", memoryHandler.GetMemories) memoryGroup.POST("", memoryHandler.CreateMemory) - memoryGroup.DELETE("/:memory_id", memoryHandler.DeleteMemory) + memoryGroup.GET("/search", memoryHandler.SearchMemories) + memoryGroup.GET("/categories", memoryHandler.GetMemoryCategories) + memoryGroup.GET("/tags", memoryHandler.GetMemoryTags) + memoryGroup.GET("/export", memoryHandler.ExportMemories) + memoryGroup.POST("/import", memoryHandler.ImportMemories) + } + + // 单个记忆操作 + memoryItemGroup := r.Group("/api/agent/memories/:memory_id") + { + memoryItemGroup.GET("", memoryHandler.GetMemory) + memoryItemGroup.PUT("", memoryHandler.UpdateMemory) + memoryItemGroup.DELETE("", memoryHandler.DeleteMemory) } // Swagger 文档 diff --git a/server/internal/handler/skill_handler.go b/server/internal/handler/skill_handler.go index 46d58d3..57f6a75 100644 --- a/server/internal/handler/skill_handler.go +++ b/server/internal/handler/skill_handler.go @@ -1,7 +1,9 @@ package handler import ( + "archive/zip" "fmt" + "io" "log" "net/http" "os" @@ -155,7 +157,7 @@ func (h *SkillHandler) Create(c *gin.Context) { // 处理文件上传 file, err := c.FormFile("file") if err == nil { - // 读取文件内容,解析 YAML front matter 获取 name + // 重新打开文件以便多次读取 fileContent, err := file.Open() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read file: " + err.Error()}) @@ -163,51 +165,172 @@ func (h *SkillHandler) Create(c *gin.Context) { } defer fileContent.Close() - content := make([]byte, 1024*1024) // 最多读取 1MB - n, _ := fileContent.Read(content) - contentStr := string(content[:n]) + // 检测是否是 zip 文件 + isZip := strings.HasSuffix(strings.ToLower(file.Filename), ".zip") - // 解析 name - parsedName, parsedDesc := parseSkillMeta(contentStr) - log.Printf("[SkillHandler] Original skill_name from form: %s, parsed name from file: %s", originalSkillName, parsedName) + var contentStr string - // 优先使用文件解析出的 name - if parsedName != "" { - skillName = parsedName - } else if skillName == "" { - // 如果解析不到且表单也没传,用文件名 - skillName = filepath.Base(file.Filename) - } + if isZip { + // 解压 zip 文件 + log.Printf("[SkillHandler] Processing ZIP file: %s", file.Filename) - if parsedDesc != "" { - skillDesc = parsedDesc - } + // 读取整个 zip 内容 + zipData, err := io.ReadAll(fileContent) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read zip file: " + err.Error()}) + return + } - // 清理 skillName,只保留纯文件名,去除所有路径 - skillName = filepath.Base(skillName) - // 去除 .md 后缀 - skillName = strings.TrimSuffix(skillName, ".md") - // 去除空格 - skillName = strings.TrimSpace(skillName) - // 如果包含路径分隔符,取最后一部分 - if idx := strings.LastIndexAny(skillName, "/\\"); idx >= 0 { - skillName = skillName[idx+1:] - } + // 创建临时 zip reader + zipReader, err := zip.NewReader(strings.NewReader(string(zipData)), int64(len(zipData))) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse zip file: " + err.Error()}) + return + } - log.Printf("[SkillHandler] Final skill name: %s", skillName) + // 先创建技能目录(使用传入的 skill_name 或从 zip 文件名推断) + if skillName == "" { + skillName = strings.TrimSuffix(filepath.Base(file.Filename), ".zip") + } + skillName = strings.TrimSpace(skillName) - // 创建技能目录 - skillPath = filepath.Join(projectRoot, "core", "agents", "skills", skillDir, skillName) - if err := os.MkdirAll(skillPath, 0755); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create skill directory: " + err.Error()}) - return - } + skillPath = filepath.Join(projectRoot, "core", "agents", "skills", skillDir, skillName) + if err := os.MkdirAll(skillPath, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create skill directory: " + err.Error()}) + return + } - // 保存文件(使用之前读取的内容) - skillFilePath := filepath.Join(skillPath, "SKILL.md") - if err := os.WriteFile(skillFilePath, []byte(contentStr), 0644); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save skill file: " + err.Error()}) - return + // 解压所有文件 + for _, zipFile := range zipReader.File { + fileName := zipFile.Name + // 跳过目录 + if strings.HasSuffix(fileName, "/") { + continue + } + + // 确保文件路径安全,去除任何绝对路径或路径遍历 + fileName = strings.TrimPrefix(fileName, "/") + if strings.Contains(fileName, "..") { + continue + } + + targetPath := filepath.Join(skillPath, fileName) + + // 创建父目录 + parentDir := filepath.Dir(targetPath) + if err := os.MkdirAll(parentDir, 0755); err != nil { + log.Printf("[SkillHandler] Warning: failed to create directory %s: %v", parentDir, err) + continue + } + + // 读取并解压文件 + zipFileContent, err := zipFile.Open() + if err != nil { + log.Printf("[SkillHandler] Warning: failed to open %s in zip: %v", fileName, err) + continue + } + + content, err := io.ReadAll(zipFileContent) + zipFileContent.Close() + + if err != nil { + log.Printf("[SkillHandler] Warning: failed to read %s from zip: %v", fileName, err) + continue + } + + if err := os.WriteFile(targetPath, content, 0644); err != nil { + log.Printf("[SkillHandler] Warning: failed to write %s: %v", targetPath, err) + continue + } + + log.Printf("[SkillHandler] Extracted: %s", targetPath) + + // 如果是 SKILL.md,解析元数据 + if strings.HasSuffix(strings.ToLower(fileName), "skill.md") { + contentStr = string(content) + } + } + + // 如果没有找到 SKILL.md,尝试找其他 .md 文件 + if contentStr == "" { + files, _ := filepath.Glob(filepath.Join(skillPath, "*.md")) + if len(files) > 0 { + content, err := os.ReadFile(files[0]) + if err == nil { + contentStr = string(content) + } + } + } + + // 解析元数据 + if contentStr != "" { + parsedName, parsedDesc := parseSkillMeta(contentStr) + if parsedName != "" { + // 如果从 zip 中解析到 skill name,可能需要重命名目录 + if parsedName != skillName { + newSkillPath := filepath.Join(projectRoot, "core", "agents", "skills", skillDir, parsedName) + if err := os.Rename(skillPath, newSkillPath); err != nil { + log.Printf("[SkillHandler] Warning: failed to rename skill directory: %v", err) + } else { + skillPath = newSkillPath + } + skillName = parsedName + } + } + if parsedDesc != "" { + skillDesc = parsedDesc + } + } + + log.Printf("[SkillHandler] ZIP imported successfully: %s", skillName) + } else { + // 普通 md 文件处理(原有逻辑) + content := make([]byte, 1024*1024) // 最多读取 1MB + n, _ := fileContent.Read(content) + contentStr = string(content[:n]) + + // 解析 name + parsedName, parsedDesc := parseSkillMeta(contentStr) + log.Printf("[SkillHandler] Original skill_name from form: %s, parsed name from file: %s", originalSkillName, parsedName) + + // 优先使用文件解析出的 name + if parsedName != "" { + skillName = parsedName + } else if skillName == "" { + // 如果解析不到且表单也没传,用文件名 + skillName = filepath.Base(file.Filename) + } + + if parsedDesc != "" { + skillDesc = parsedDesc + } + + // 清理 skillName,只保留纯文件名,去除所有路径 + skillName = filepath.Base(skillName) + // 去除 .md 后缀 + skillName = strings.TrimSuffix(skillName, ".md") + // 去除空格 + skillName = strings.TrimSpace(skillName) + // 如果包含路径分隔符,取最后一部分 + if idx := strings.LastIndexAny(skillName, "/\\"); idx >= 0 { + skillName = skillName[idx+1:] + } + + log.Printf("[SkillHandler] Final skill name: %s", skillName) + + // 创建技能目录 + skillPath = filepath.Join(projectRoot, "core", "agents", "skills", skillDir, skillName) + if err := os.MkdirAll(skillPath, 0755); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create skill directory: " + err.Error()}) + return + } + + // 保存文件(使用之前读取的内容) + skillFilePath := filepath.Join(skillPath, "SKILL.md") + if err := os.WriteFile(skillFilePath, []byte(contentStr), 0644); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save skill file: " + err.Error()}) + return + } } } else { // 如果没有上传文件但提供了 name 和 desc,创建默认文件 diff --git a/server/internal/model/model_info.go b/server/internal/model/model_info.go index b1666bf..ca3e650 100644 --- a/server/internal/model/model_info.go +++ b/server/internal/model/model_info.go @@ -12,7 +12,7 @@ type ModelInfo struct { APIKey string `json:"api_key" gorm:"type:text"` // API 密钥 BaseURL string `json:"base_url" gorm:"type:varchar(500)"` // 基础 URL APIEndpoint string `json:"api_endpoint" gorm:"type:varchar(500)"` // API 端点路径 - Status string `json:"status" gorm:"type:varchar(20);default:active"` // active/inactive + Status int `json:"status" gorm:"type:tinyint;default:0"` // 1:active, 0:inactive CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -39,7 +39,7 @@ type CreateModelRequest struct { APIKey string `json:"api_key" binding:"required"` BaseURL string `json:"base_url" binding:"required"` APIEndpoint string `json:"api_endpoint"` - Status string `json:"status"` + Status int `json:"status"` } // UpdateModelRequest 更新模型请求 @@ -51,7 +51,7 @@ type UpdateModelRequest struct { APIKey string `json:"api_key"` BaseURL string `json:"base_url"` APIEndpoint string `json:"api_endpoint"` - Status string `json:"status"` + Status int `json:"status"` } // TestModelRequest 测试模型连接请求 diff --git a/server/internal/service/model_service.go b/server/internal/service/model_service.go index 37a146d..a7db9d6 100644 --- a/server/internal/service/model_service.go +++ b/server/internal/service/model_service.go @@ -41,10 +41,10 @@ func (s *ModelService) Create(req model.CreateModelRequest) (*model.ModelInfo, e return nil, fmt.Errorf("model with name '%s' already exists", req.Name) } - // 如果没有提供状态,默认设置为 inactive + // 如果没有提供状态,默认设置为 inactive (0) status := req.Status - if status == "" { - status = "inactive" + if status == 0 { + status = 0 // inactive } info := &model.ModelInfo{ @@ -96,7 +96,9 @@ func (s *ModelService) Update(id string, req model.UpdateModelRequest) (*model.M if req.APIEndpoint != "" { fields["api_endpoint"] = req.APIEndpoint } - if req.Status != "" { + // Status为int类型,0表示inactive,1表示active + // 只在明确传入Status值时才更新(Status > 0 表示传入了值) + if req.Status > 0 { fields["status"] = req.Status } diff --git a/server/migrations/agent_system.sql b/server/migrations/agent_system.sql index 39d474b..9c7589f 100644 --- a/server/migrations/agent_system.sql +++ b/server/migrations/agent_system.sql @@ -49,6 +49,48 @@ CREATE INDEX IF NOT EXISTS idx_agent_tasks_created ON agent_tasks(created_at DES CREATE INDEX IF NOT EXISTS idx_agent_team_supervisor ON agent_teams(supervisor_agent_id); CREATE INDEX IF NOT EXISTS idx_agent_team_member ON agent_teams(member_agent_id); +-- Chat Sessions Table +CREATE TABLE IF NOT EXISTS chat_sessions ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + agent_id VARCHAR(36), + title VARCHAR(255), + model_id VARCHAR(36), + status VARCHAR(20) DEFAULT 'active', + created_at DATETIME(3), + updated_at DATETIME(3), + INDEX idx_chat_sessions_user (user_id), + INDEX idx_chat_sessions_agent (agent_id), + INDEX idx_chat_sessions_updated (updated_at DESC) +); + +-- Chat Messages Table +CREATE TABLE IF NOT EXISTS chat_messages ( + id VARCHAR(36) PRIMARY KEY, + session_id VARCHAR(36) NOT NULL, + role VARCHAR(20), + content TEXT, + tokens_used INT DEFAULT 0, + duration_ms INT DEFAULT 0, + metadata TEXT, + created_at DATETIME(3), + INDEX idx_chat_messages_session (session_id), + INDEX idx_chat_messages_created (created_at ASC) +); + +-- Chat Groups Table +CREATE TABLE IF NOT EXISTS chat_groups ( + id VARCHAR(36) PRIMARY KEY, + user_id VARCHAR(36) NOT NULL, + name VARCHAR(100) NOT NULL, + description TEXT, + agent_ids TEXT, + status VARCHAR(20) DEFAULT 'active', + created_at DATETIME(3), + updated_at DATETIME(3), + INDEX idx_chat_groups_user (user_id) +); + -- Agent Memory Indexes CREATE INDEX IF NOT EXISTS idx_agent_memory_agent ON agent_memories(agent_id); CREATE INDEX IF NOT EXISTS idx_agent_memory_user ON agent_memories(agent_id, user_id);