From 31f0feafb5503200341c5fccc26a499b9a918595 Mon Sep 17 00:00:00 2001 From: "DESKTOP-72TV0V4\\caoxiaozhu" Date: Sun, 15 Mar 2026 19:49:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=92=8C=20Agent=20=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化 session_handler 会话处理逻辑 - 增强 agent_service Agent 服务功能 - 新增 chat_repository 仓储方法 - 更新 agent_handler 和 chat_group_handler - 更新数据模型 agent 和 chat_session Co-Authored-By: Claude Opus 4.6 --- server/cmd/api/main.go | 5 +- server/internal/handler/agent_handler.go | 15 ++-- server/internal/handler/chat_group_handler.go | 8 +- server/internal/handler/session_handler.go | 81 ++++++++++++++++++- server/internal/model/agent.go | 1 + server/internal/model/chat_session.go | 2 +- server/internal/repository/chat_repository.go | 15 ++++ server/internal/service/agent_service.go | 49 +++++++---- 8 files changed, 146 insertions(+), 30 deletions(-) diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index e645f92..d26a80e 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -377,7 +377,7 @@ func main() { toolService := service.NewToolService(toolRepo) mcpService := service.NewMCPService(mcpRepo) skillService := service.NewSkillService(skillRepo) - agentService := service.NewAgentService(cfg.PythonServiceURL, modelRepo, agentRepo) + agentService := service.NewAgentService(cfg.PythonServiceURL, modelRepo, agentRepo, chatRepo) memoryService := service.NewMemoryService(agentRepo) // 4.2 初始化默认工具 @@ -407,7 +407,7 @@ func main() { skillHandler := handler.NewSkillHandler(skillService) agentHandler := handler.NewAgentHandler(agentService) memoryHandler := handler.NewMemoryHandler(memoryService) - sessionHandler := handler.NewSessionHandler(chatRepo) + sessionHandler := handler.NewSessionHandler(chatRepo, agentService) // 初始化群聊服务 chatGroupRepo := repository.NewChatGroupRepository(db) @@ -608,6 +608,7 @@ func main() { chatGroup.DELETE("/sessions/:id", sessionHandler.DeleteSession) chatGroup.GET("/sessions/:id/messages", sessionHandler.GetMessages) chatGroup.POST("/messages", sessionHandler.CreateMessage) + chatGroup.POST("/sessions/generate-title", sessionHandler.GenerateSessionTitle) } // 群聊管理模块 diff --git a/server/internal/handler/agent_handler.go b/server/internal/handler/agent_handler.go index 35050db..16cd869 100644 --- a/server/internal/handler/agent_handler.go +++ b/server/internal/handler/agent_handler.go @@ -32,7 +32,7 @@ type ChatRequest struct { // ChatResponse 对话响应 type ChatResponse struct { - AgentID int `json:"agent_id"` + AgentID string `json:"agent_id"` // 支持 UUID 字符串 Reply string `json:"reply"` ToolsUsed []string `json:"tools_used"` SessionID string `json:"session_id"` @@ -73,11 +73,9 @@ func (h *AgentHandler) Chat(c *gin.Context) { userIDStr := "1" // TODO: 从 c.Get("user_id") 获取 userID, _ := strconv.Atoi(userIDStr) - // 将前端传来的字符串 agent_id 转换为 int - agentID, _ := strconv.Atoi(req.AgentID) - + // 直接使用字符串类型的 agent_id,支持 UUID pythonReq := service.AgentChatRequest{ - AgentID: agentID, + AgentID: req.AgentID, Message: req.Message, UserID: userID, SessionID: req.SessionID, @@ -130,8 +128,8 @@ func (h *AgentHandler) ChatStream(c *gin.Context) { userIDStr := "1" // TODO: 从 c.Get("user_id") 获取 userID, _ := strconv.Atoi(userIDStr) - // 将前端传来的字符串 agent_id 转换为 int - agentID, _ := strconv.Atoi(req.AgentID) + // 直接使用字符串类型的 agent_id,支持 UUID + agentID := req.AgentID // 构建 SSE 流 c.Header("Content-Type", "text/event-stream") @@ -317,6 +315,7 @@ func (h *AgentHandler) DeleteAgent(c *gin.Context) { type UpdateAgentRequest struct { Name string `json:"name"` Description string `json:"description"` + Avatar string `json:"avatar"` Skills []string `json:"skills"` RoleDescription string `json:"role_description"` ModelProvider string `json:"model_provider"` @@ -345,7 +344,7 @@ func (h *AgentHandler) UpdateAgent(c *gin.Context) { return } - err := h.agentService.UpdateAgent(agentID, req.Name, req.Description, req.Skills, req.RoleDescription, req.ModelProvider, req.ModelName) + err := h.agentService.UpdateAgent(agentID, req.Name, req.Description, req.Avatar, req.Skills, req.RoleDescription, req.ModelProvider, req.ModelName) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/server/internal/handler/chat_group_handler.go b/server/internal/handler/chat_group_handler.go index cf8d745..75aabc7 100644 --- a/server/internal/handler/chat_group_handler.go +++ b/server/internal/handler/chat_group_handler.go @@ -31,12 +31,18 @@ func (h *ChatGroupHandler) CreateGroup(c *gin.Context) { return } - // 从上下文获取用户ID + // 从上下文获取用户ID,如果存在则覆盖请求中的user_id userID, exists := c.Get("user_id") if exists { req.UserID = userID.(string) } + // 验证user_id + if req.UserID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "user_id is required"}) + return + } + group, err := h.groupService.CreateGroup(req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/server/internal/handler/session_handler.go b/server/internal/handler/session_handler.go index a6fa9fe..3cf00f3 100644 --- a/server/internal/handler/session_handler.go +++ b/server/internal/handler/session_handler.go @@ -1,8 +1,10 @@ package handler import ( + "fmt" "net/http" "strconv" + "strings" "x-agents/server/internal/model" "x-agents/server/internal/repository" @@ -94,11 +96,12 @@ func (h *ChatHandler) CreateAgent(c *gin.Context) { // SessionHandler 处理会话管理 type SessionHandler struct { - chatRepo *repository.ChatRepository + chatRepo *repository.ChatRepository + agentService *service.AgentService } -func NewSessionHandler(chatRepo *repository.ChatRepository) *SessionHandler { - return &SessionHandler{chatRepo: chatRepo} +func NewSessionHandler(chatRepo *repository.ChatRepository, agentService *service.AgentService) *SessionHandler { + return &SessionHandler{chatRepo: chatRepo, agentService: agentService} } // CreateSession 创建会话 @@ -226,6 +229,16 @@ func (h *SessionHandler) CreateMessage(c *gin.Context) { return } + // Debug: 打印请求内容 + fmt.Printf("[CreateMessage] Request: session_id=%s, role=%s, content_len=%d\n", + req.SessionID, req.Role, len(req.Content)) + + // 验证 content 不为空 + if req.Content == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Content cannot be empty"}) + return + } + // 检查会话是否存在 _, err := h.chatRepo.GetSessionByID(req.SessionID) if err != nil { @@ -250,3 +263,65 @@ func (h *SessionHandler) CreateMessage(c *gin.Context) { c.JSON(http.StatusOK, message) } + +// GenerateSessionTitleRequest 生成会话标题请求 +type GenerateSessionTitleRequest struct { + SessionID string `json:"session_id" binding:"required"` +} + +// GenerateSessionTitle 生成会话标题 +func (h *SessionHandler) GenerateSessionTitle(c *gin.Context) { + var req GenerateSessionTitleRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 获取会话的所有消息 + messages, _, err := h.chatRepo.GetMessagesBySessionID(req.SessionID, 100, 0) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get messages"}) + return + } + + if len(messages) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Not enough messages to generate title"}) + return + } + + // 获取会话信息 + session, err := h.chatRepo.GetSessionByID(req.SessionID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"}) + return + } + + // 简单方式:从用户的第一条消息中提取前15个字符作为标题 + var title string + for _, msg := range messages { + if msg.Role == "user" { + // 清理内容,去除换行和多余空格 + content := strings.ReplaceAll(msg.Content, "\n", " ") + content = strings.TrimSpace(content) + // 限制长度 + if len(content) > 15 { + title = content[:15] + "..." + } else { + title = content + } + break + } + } + + if title == "" { + title = "新会话" + } + + fmt.Printf("[GenerateSessionTitle] Generated title: %s\n", title) + + // 更新会话标题 + session.Title = title + h.chatRepo.UpdateSession(session) + + c.JSON(http.StatusOK, gin.H{"title": title}) +} diff --git a/server/internal/model/agent.go b/server/internal/model/agent.go index 6be3d10..45ad6d2 100644 --- a/server/internal/model/agent.go +++ b/server/internal/model/agent.go @@ -19,6 +19,7 @@ type Agent struct { Name string `json:"name" gorm:"size:100;not null"` Description string `json:"description" gorm:"type:text"` OwnerID string `json:"owner_id" gorm:"size:50;not null;index"` + Avatar string `json:"avatar" gorm:"size:50"` // 头像 (emoji) // 技能列表(JSON数组) Skills []string `json:"skills" gorm:"type:text;serializer:json"` diff --git a/server/internal/model/chat_session.go b/server/internal/model/chat_session.go index 35e2e3a..ca5f11f 100644 --- a/server/internal/model/chat_session.go +++ b/server/internal/model/chat_session.go @@ -52,7 +52,7 @@ type UpdateSessionRequest struct { type CreateMessageRequest struct { SessionID string `json:"session_id" binding:"required"` Role string `json:"role" binding:"required"` // user/assistant - Content string `json:"content" binding:"required"` + Content string `json:"content"` TokensUsed int `json:"tokens_used"` DurationMs int `json:"duration_ms"` Metadata string `json:"metadata"` diff --git a/server/internal/repository/chat_repository.go b/server/internal/repository/chat_repository.go index 74b7366..07abb36 100644 --- a/server/internal/repository/chat_repository.go +++ b/server/internal/repository/chat_repository.go @@ -56,6 +56,21 @@ func (r *ChatRepository) DeleteSession(id string) error { return r.db.Delete(&model.ChatSession{}, "id = ?", id).Error } +// DeleteSessionsByAgentID 删除智能体的所有会话 +func (r *ChatRepository) DeleteSessionsByAgentID(agentID string) error { + // 先查询该智能体的所有会话 + var sessions []model.ChatSession + if err := r.db.Where("agent_id = ?", agentID).Find(&sessions).Error; err != nil { + return err + } + // 删除每个会话下的所有消息 + for _, session := range sessions { + r.db.Where("session_id = ?", session.ID).Delete(&model.ChatMessage{}) + } + // 再删除所有会话 + return r.db.Where("agent_id = ?", agentID).Delete(&model.ChatSession{}).Error +} + // Message CRUD // CreateMessage 创建消息 diff --git a/server/internal/service/agent_service.go b/server/internal/service/agent_service.go index 7d307ae..72d9e83 100644 --- a/server/internal/service/agent_service.go +++ b/server/internal/service/agent_service.go @@ -18,7 +18,7 @@ import ( // AgentChatRequest Python Agent 对话请求 type AgentChatRequest struct { - AgentID int `json:"agent_id"` + AgentID string `json:"agent_id"` // 支持 UUID 字符串 Message string `json:"message"` UserID int `json:"user_id"` SessionID string `json:"session_id,omitempty"` @@ -32,7 +32,7 @@ type AgentChatRequest struct { // AgentChatResponse Python Agent 对话响应 type AgentChatResponse struct { - AgentID int `json:"agent_id"` + AgentID string `json:"agent_id"` // 支持 UUID 字符串 Response string `json:"response"` ToolCalls []interface{} `json:"tool_calls"` TokensUsed int `json:"tokens_used"` @@ -66,10 +66,11 @@ type AgentService struct { client *http.Client modelRepo *repository.ModelRepository agentRepo *repository.AgentRepository + chatRepo *repository.ChatRepository } // NewAgentService 创建 Agent 服务 -func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, agentRepo *repository.AgentRepository) *AgentService { +func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, agentRepo *repository.AgentRepository, chatRepo *repository.ChatRepository) *AgentService { return &AgentService{ pythonURL: pythonURL, client: &http.Client{ @@ -77,6 +78,7 @@ func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, ag }, modelRepo: modelRepo, agentRepo: agentRepo, + chatRepo: chatRepo, } } @@ -195,16 +197,19 @@ func (s *AgentService) TeamChat(req TeamChatRequest) (*TeamChatResponse, error) } // ChatStream 流式对话 -func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID, modelID string, userID int) error { +func (s *AgentService) ChatStream(c interface{}, agentID string, message, sessionID, modelID string, userID int) error { // 获取 gin.Context ginCtx, ok := c.(*gin.Context) if !ok { return fmt.Errorf("invalid context type") } + log.Printf("[ChatStream] Request: agentID=%s, message=%s, sessionID=%s, modelID=%s, userID=%d", + agentID, message, sessionID, modelID, userID) + // 初始化请求体 reqBody := map[string]interface{}{ - "agent_id": agentID, + "agent_id": agentID, // 传递字符串类型的 agent_id,支持 UUID "message": message, "user_id": userID, "session_id": sessionID, @@ -267,8 +272,10 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID for { n, err := resp.Body.Read(buf) if n > 0 { + log.Printf("[ChatStream] Received %d bytes from Python", n) _, writeErr := ginCtx.Writer.Write(buf[:n]) if writeErr != nil { + log.Printf("[ChatStream] Write error: %v", writeErr) break } // 强制刷新到客户端 @@ -277,6 +284,7 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID } } if err != nil { + log.Printf("[ChatStream] Done reading from Python, err: %v", err) break } } @@ -300,9 +308,10 @@ type CreateAgentRequest struct { // CreateAgentResponse 创建智能体响应 type CreateAgentResponse struct { - AgentID int `json:"agent_id"` - Name string `json:"name"` - Message string `json:"message"` + AgentID int `json:"agent_id"` // 保留兼容性 + AgentIDStr string `json:"agent_id_str"` // 返回实际的 UUID + Name string `json:"name"` + Message string `json:"message"` } // CreateAgent 创建智能体 @@ -329,6 +338,7 @@ func (s *AgentService) CreateAgent(req CreateAgentRequest, userID int) (*CreateA Name: req.Name, Description: req.Description, OwnerID: fmt.Sprintf("%d", userID), + Avatar: req.Avatar, Skills: skills, RoleDescription: req.Prompt, ModelProvider: req.ModelProvider, @@ -347,13 +357,11 @@ func (s *AgentService) CreateAgent(req CreateAgentRequest, userID int) (*CreateA log.Printf("[AgentService] Agent created in database: %s (ID: %s)", agent.Name, agent.ID) - // 解析 agent ID 为整数返回 - agentIDInt := int(time.Now().Unix()) % 100000 - + // 返回数据库中实际的 Agent ID (UUID字符串) return &CreateAgentResponse{ - AgentID: agentIDInt, - Name: agent.Name, - Message: "Agent created successfully", + AgentIDStr: agent.ID, + Name: agent.Name, + Message: "Agent created successfully", }, nil } @@ -429,6 +437,14 @@ func (s *AgentService) DeleteAgent(agentID string) error { return fmt.Errorf("agent not found: %w", err) } + // 先删除该智能体的所有会话和消息 + if s.chatRepo != nil { + if err := s.chatRepo.DeleteSessionsByAgentID(agentID); err != nil { + log.Printf("[AgentService] DeleteAgent: failed to delete sessions: %v", err) + // 继续尝试删除 agent,不因为 session 删除失败而中止 + } + } + if err := s.agentRepo.Delete(agentID); err != nil { return fmt.Errorf("failed to delete agent: %w", err) } @@ -438,7 +454,7 @@ func (s *AgentService) DeleteAgent(agentID string) error { } // UpdateAgent 更新智能体 -func (s *AgentService) UpdateAgent(agentID, name, description string, skills []string, roleDescription, modelProvider, modelName string) error { +func (s *AgentService) UpdateAgent(agentID, name, description, avatar string, skills []string, roleDescription, modelProvider, modelName string) error { if s.agentRepo == nil { return fmt.Errorf("agent repository not initialized") } @@ -458,6 +474,9 @@ func (s *AgentService) UpdateAgent(agentID, name, description string, skills []s if description != "" { agent.Description = description } + if avatar != "" { + agent.Avatar = avatar + } if skills != nil { agent.Skills = skills }