feat: 增强会话管理和 Agent 服务

- 优化 session_handler 会话处理逻辑
- 增强 agent_service Agent 服务功能
- 新增 chat_repository 仓储方法
- 更新 agent_handler 和 chat_group_handler
- 更新数据模型 agent 和 chat_session

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-15 19:49:27 +08:00
parent bce8b9240b
commit 31f0feafb5
8 changed files with 146 additions and 30 deletions

View File

@@ -377,7 +377,7 @@ func main() {
toolService := service.NewToolService(toolRepo)
mcpService := service.NewMCPService(mcpRepo)
skillService := service.NewSkillService(skillRepo)
agentService := service.NewAgentService(cfg.PythonServiceURL, modelRepo, agentRepo)
agentService := service.NewAgentService(cfg.PythonServiceURL, modelRepo, agentRepo, chatRepo)
memoryService := service.NewMemoryService(agentRepo)
// 4.2 初始化默认工具
@@ -407,7 +407,7 @@ func main() {
skillHandler := handler.NewSkillHandler(skillService)
agentHandler := handler.NewAgentHandler(agentService)
memoryHandler := handler.NewMemoryHandler(memoryService)
sessionHandler := handler.NewSessionHandler(chatRepo)
sessionHandler := handler.NewSessionHandler(chatRepo, agentService)
// 初始化群聊服务
chatGroupRepo := repository.NewChatGroupRepository(db)
@@ -608,6 +608,7 @@ func main() {
chatGroup.DELETE("/sessions/:id", sessionHandler.DeleteSession)
chatGroup.GET("/sessions/:id/messages", sessionHandler.GetMessages)
chatGroup.POST("/messages", sessionHandler.CreateMessage)
chatGroup.POST("/sessions/generate-title", sessionHandler.GenerateSessionTitle)
}
// 群聊管理模块

View File

@@ -32,7 +32,7 @@ type ChatRequest struct {
// ChatResponse 对话响应
type ChatResponse struct {
AgentID int `json:"agent_id"`
AgentID string `json:"agent_id"` // 支持 UUID 字符串
Reply string `json:"reply"`
ToolsUsed []string `json:"tools_used"`
SessionID string `json:"session_id"`
@@ -73,11 +73,9 @@ func (h *AgentHandler) Chat(c *gin.Context) {
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
userID, _ := strconv.Atoi(userIDStr)
// 将前端传来的字符串 agent_id 转换为 int
agentID, _ := strconv.Atoi(req.AgentID)
// 直接使用字符串类型的 agent_id,支持 UUID
pythonReq := service.AgentChatRequest{
AgentID: agentID,
AgentID: req.AgentID,
Message: req.Message,
UserID: userID,
SessionID: req.SessionID,
@@ -130,8 +128,8 @@ func (h *AgentHandler) ChatStream(c *gin.Context) {
userIDStr := "1" // TODO: 从 c.Get("user_id") 获取
userID, _ := strconv.Atoi(userIDStr)
// 将前端传来的字符串 agent_id 转换为 int
agentID, _ := strconv.Atoi(req.AgentID)
// 直接使用字符串类型的 agent_id,支持 UUID
agentID := req.AgentID
// 构建 SSE 流
c.Header("Content-Type", "text/event-stream")
@@ -317,6 +315,7 @@ func (h *AgentHandler) DeleteAgent(c *gin.Context) {
type UpdateAgentRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Avatar string `json:"avatar"`
Skills []string `json:"skills"`
RoleDescription string `json:"role_description"`
ModelProvider string `json:"model_provider"`
@@ -345,7 +344,7 @@ func (h *AgentHandler) UpdateAgent(c *gin.Context) {
return
}
err := h.agentService.UpdateAgent(agentID, req.Name, req.Description, req.Skills, req.RoleDescription, req.ModelProvider, req.ModelName)
err := h.agentService.UpdateAgent(agentID, req.Name, req.Description, req.Avatar, req.Skills, req.RoleDescription, req.ModelProvider, req.ModelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

View File

@@ -31,12 +31,18 @@ func (h *ChatGroupHandler) CreateGroup(c *gin.Context) {
return
}
// 从上下文获取用户ID
// 从上下文获取用户ID如果存在则覆盖请求中的user_id
userID, exists := c.Get("user_id")
if exists {
req.UserID = userID.(string)
}
// 验证user_id
if req.UserID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "user_id is required"})
return
}
group, err := h.groupService.CreateGroup(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -1,8 +1,10 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"x-agents/server/internal/model"
"x-agents/server/internal/repository"
@@ -95,10 +97,11 @@ func (h *ChatHandler) CreateAgent(c *gin.Context) {
// SessionHandler 处理会话管理
type SessionHandler struct {
chatRepo *repository.ChatRepository
agentService *service.AgentService
}
func NewSessionHandler(chatRepo *repository.ChatRepository) *SessionHandler {
return &SessionHandler{chatRepo: chatRepo}
func NewSessionHandler(chatRepo *repository.ChatRepository, agentService *service.AgentService) *SessionHandler {
return &SessionHandler{chatRepo: chatRepo, agentService: agentService}
}
// CreateSession 创建会话
@@ -226,6 +229,16 @@ func (h *SessionHandler) CreateMessage(c *gin.Context) {
return
}
// Debug: 打印请求内容
fmt.Printf("[CreateMessage] Request: session_id=%s, role=%s, content_len=%d\n",
req.SessionID, req.Role, len(req.Content))
// 验证 content 不为空
if req.Content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Content cannot be empty"})
return
}
// 检查会话是否存在
_, err := h.chatRepo.GetSessionByID(req.SessionID)
if err != nil {
@@ -250,3 +263,65 @@ func (h *SessionHandler) CreateMessage(c *gin.Context) {
c.JSON(http.StatusOK, message)
}
// GenerateSessionTitleRequest 生成会话标题请求
type GenerateSessionTitleRequest struct {
SessionID string `json:"session_id" binding:"required"`
}
// GenerateSessionTitle 生成会话标题
func (h *SessionHandler) GenerateSessionTitle(c *gin.Context) {
var req GenerateSessionTitleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取会话的所有消息
messages, _, err := h.chatRepo.GetMessagesBySessionID(req.SessionID, 100, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get messages"})
return
}
if len(messages) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Not enough messages to generate title"})
return
}
// 获取会话信息
session, err := h.chatRepo.GetSessionByID(req.SessionID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
return
}
// 简单方式从用户的第一条消息中提取前15个字符作为标题
var title string
for _, msg := range messages {
if msg.Role == "user" {
// 清理内容,去除换行和多余空格
content := strings.ReplaceAll(msg.Content, "\n", " ")
content = strings.TrimSpace(content)
// 限制长度
if len(content) > 15 {
title = content[:15] + "..."
} else {
title = content
}
break
}
}
if title == "" {
title = "新会话"
}
fmt.Printf("[GenerateSessionTitle] Generated title: %s\n", title)
// 更新会话标题
session.Title = title
h.chatRepo.UpdateSession(session)
c.JSON(http.StatusOK, gin.H{"title": title})
}

View File

@@ -19,6 +19,7 @@ type Agent struct {
Name string `json:"name" gorm:"size:100;not null"`
Description string `json:"description" gorm:"type:text"`
OwnerID string `json:"owner_id" gorm:"size:50;not null;index"`
Avatar string `json:"avatar" gorm:"size:50"` // 头像 (emoji)
// 技能列表(JSON数组)
Skills []string `json:"skills" gorm:"type:text;serializer:json"`

View File

@@ -52,7 +52,7 @@ type UpdateSessionRequest struct {
type CreateMessageRequest struct {
SessionID string `json:"session_id" binding:"required"`
Role string `json:"role" binding:"required"` // user/assistant
Content string `json:"content" binding:"required"`
Content string `json:"content"`
TokensUsed int `json:"tokens_used"`
DurationMs int `json:"duration_ms"`
Metadata string `json:"metadata"`

View File

@@ -56,6 +56,21 @@ func (r *ChatRepository) DeleteSession(id string) error {
return r.db.Delete(&model.ChatSession{}, "id = ?", id).Error
}
// DeleteSessionsByAgentID 删除智能体的所有会话
func (r *ChatRepository) DeleteSessionsByAgentID(agentID string) error {
// 先查询该智能体的所有会话
var sessions []model.ChatSession
if err := r.db.Where("agent_id = ?", agentID).Find(&sessions).Error; err != nil {
return err
}
// 删除每个会话下的所有消息
for _, session := range sessions {
r.db.Where("session_id = ?", session.ID).Delete(&model.ChatMessage{})
}
// 再删除所有会话
return r.db.Where("agent_id = ?", agentID).Delete(&model.ChatSession{}).Error
}
// Message CRUD
// CreateMessage 创建消息

View File

@@ -18,7 +18,7 @@ import (
// AgentChatRequest Python Agent 对话请求
type AgentChatRequest struct {
AgentID int `json:"agent_id"`
AgentID string `json:"agent_id"` // 支持 UUID 字符串
Message string `json:"message"`
UserID int `json:"user_id"`
SessionID string `json:"session_id,omitempty"`
@@ -32,7 +32,7 @@ type AgentChatRequest struct {
// AgentChatResponse Python Agent 对话响应
type AgentChatResponse struct {
AgentID int `json:"agent_id"`
AgentID string `json:"agent_id"` // 支持 UUID 字符串
Response string `json:"response"`
ToolCalls []interface{} `json:"tool_calls"`
TokensUsed int `json:"tokens_used"`
@@ -66,10 +66,11 @@ type AgentService struct {
client *http.Client
modelRepo *repository.ModelRepository
agentRepo *repository.AgentRepository
chatRepo *repository.ChatRepository
}
// NewAgentService 创建 Agent 服务
func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, agentRepo *repository.AgentRepository) *AgentService {
func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, agentRepo *repository.AgentRepository, chatRepo *repository.ChatRepository) *AgentService {
return &AgentService{
pythonURL: pythonURL,
client: &http.Client{
@@ -77,6 +78,7 @@ func NewAgentService(pythonURL string, modelRepo *repository.ModelRepository, ag
},
modelRepo: modelRepo,
agentRepo: agentRepo,
chatRepo: chatRepo,
}
}
@@ -195,16 +197,19 @@ func (s *AgentService) TeamChat(req TeamChatRequest) (*TeamChatResponse, error)
}
// ChatStream 流式对话
func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID, modelID string, userID int) error {
func (s *AgentService) ChatStream(c interface{}, agentID string, message, sessionID, modelID string, userID int) error {
// 获取 gin.Context
ginCtx, ok := c.(*gin.Context)
if !ok {
return fmt.Errorf("invalid context type")
}
log.Printf("[ChatStream] Request: agentID=%s, message=%s, sessionID=%s, modelID=%s, userID=%d",
agentID, message, sessionID, modelID, userID)
// 初始化请求体
reqBody := map[string]interface{}{
"agent_id": agentID,
"agent_id": agentID, // 传递字符串类型的 agent_id支持 UUID
"message": message,
"user_id": userID,
"session_id": sessionID,
@@ -267,8 +272,10 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID
for {
n, err := resp.Body.Read(buf)
if n > 0 {
log.Printf("[ChatStream] Received %d bytes from Python", n)
_, writeErr := ginCtx.Writer.Write(buf[:n])
if writeErr != nil {
log.Printf("[ChatStream] Write error: %v", writeErr)
break
}
// 强制刷新到客户端
@@ -277,6 +284,7 @@ func (s *AgentService) ChatStream(c interface{}, agentID int, message, sessionID
}
}
if err != nil {
log.Printf("[ChatStream] Done reading from Python, err: %v", err)
break
}
}
@@ -300,7 +308,8 @@ type CreateAgentRequest struct {
// CreateAgentResponse 创建智能体响应
type CreateAgentResponse struct {
AgentID int `json:"agent_id"`
AgentID int `json:"agent_id"` // 保留兼容性
AgentIDStr string `json:"agent_id_str"` // 返回实际的 UUID
Name string `json:"name"`
Message string `json:"message"`
}
@@ -329,6 +338,7 @@ func (s *AgentService) CreateAgent(req CreateAgentRequest, userID int) (*CreateA
Name: req.Name,
Description: req.Description,
OwnerID: fmt.Sprintf("%d", userID),
Avatar: req.Avatar,
Skills: skills,
RoleDescription: req.Prompt,
ModelProvider: req.ModelProvider,
@@ -347,11 +357,9 @@ func (s *AgentService) CreateAgent(req CreateAgentRequest, userID int) (*CreateA
log.Printf("[AgentService] Agent created in database: %s (ID: %s)", agent.Name, agent.ID)
// 解析 agent ID 为整数返回
agentIDInt := int(time.Now().Unix()) % 100000
// 返回数据库中实际的 Agent ID (UUID字符串)
return &CreateAgentResponse{
AgentID: agentIDInt,
AgentIDStr: agent.ID,
Name: agent.Name,
Message: "Agent created successfully",
}, nil
@@ -429,6 +437,14 @@ func (s *AgentService) DeleteAgent(agentID string) error {
return fmt.Errorf("agent not found: %w", err)
}
// 先删除该智能体的所有会话和消息
if s.chatRepo != nil {
if err := s.chatRepo.DeleteSessionsByAgentID(agentID); err != nil {
log.Printf("[AgentService] DeleteAgent: failed to delete sessions: %v", err)
// 继续尝试删除 agent不因为 session 删除失败而中止
}
}
if err := s.agentRepo.Delete(agentID); err != nil {
return fmt.Errorf("failed to delete agent: %w", err)
}
@@ -438,7 +454,7 @@ func (s *AgentService) DeleteAgent(agentID string) error {
}
// UpdateAgent 更新智能体
func (s *AgentService) UpdateAgent(agentID, name, description string, skills []string, roleDescription, modelProvider, modelName string) error {
func (s *AgentService) UpdateAgent(agentID, name, description, avatar string, skills []string, roleDescription, modelProvider, modelName string) error {
if s.agentRepo == nil {
return fmt.Errorf("agent repository not initialized")
}
@@ -458,6 +474,9 @@ func (s *AgentService) UpdateAgent(agentID, name, description string, skills []s
if description != "" {
agent.Description = description
}
if avatar != "" {
agent.Avatar = avatar
}
if skills != nil {
agent.Skills = skills
}