diff --git a/server/internal/handler/memory_handler.go b/server/internal/handler/memory_handler.go index d90006e..7a4c844 100644 --- a/server/internal/handler/memory_handler.go +++ b/server/internal/handler/memory_handler.go @@ -4,6 +4,7 @@ import ( "net/http" "strconv" + "x-agents/server/internal/model" "x-agents/server/internal/service" "github.com/gin-gonic/gin" @@ -27,7 +28,35 @@ type CreateMemoryRequest struct { UserID string `json:"user_id"` Content string `json:"content" binding:"required"` MemoryType string `json:"memory_type"` + Category string `json:"category"` + Tags string `json:"tags"` + Keywords string `json:"keywords"` Importance int `json:"importance"` + IsPinned bool `json:"is_pinned"` +} + +// SearchMemoryRequest 搜索记忆请求 +type SearchMemoryRequest struct { + AgentID string `json:"agent_id" binding:"required"` + UserID string `json:"user_id"` + Keyword string `json:"keyword"` + Tags string `json:"tags"` + Category string `json:"category"` + MemoryType string `json:"memory_type"` + MinScore int `json:"min_score"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +// UpdateMemoryRequest 更新记忆请求 +type UpdateMemoryRequest struct { + Content string `json:"content"` + MemoryType string `json:"memory_type"` + Category string `json:"category"` + Tags string `json:"tags"` + Keywords string `json:"keywords"` + Importance int `json:"importance"` + IsPinned *bool `json:"is_pinned"` } // CreateMemory 创建记忆 @@ -48,7 +77,7 @@ func (h *MemoryHandler) CreateMemory(c *gin.Context) { importance = 5 } - memory, err := h.memoryService.CreateMemory(req.AgentID, req.UserID, req.Content, memoryType, importance) + memory, err := h.memoryService.CreateMemory(req.AgentID, req.UserID, req.Content, memoryType, req.Category, req.Tags, req.Keywords, importance, req.IsPinned) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -61,20 +90,86 @@ func (h *MemoryHandler) CreateMemory(c *gin.Context) { func (h *MemoryHandler) GetMemories(c *gin.Context) { agentID := c.Param("id") userID := c.Query("user_id") + category := c.Query("category") + memoryType := c.Query("memory_type") limitStr := c.DefaultQuery("limit", "10") + offsetStr := c.DefaultQuery("offset", "0") limit, err := strconv.Atoi(limitStr) if err != nil { limit = 10 } + offset, err := strconv.Atoi(offsetStr) + if err != nil { + offset = 0 + } - memories, err := h.memoryService.GetMemories(agentID, userID, limit) + memories, total, err := h.memoryService.GetMemories(agentID, userID, category, memoryType, limit, offset) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusOK, memories) + c.JSON(http.StatusOK, gin.H{"list": memories, "total": total}) +} + +// SearchMemories 搜索记忆 +func (h *MemoryHandler) SearchMemories(c *gin.Context) { + var req SearchMemoryRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 设置默认值 + limit := req.Limit + if limit == 0 { + limit = 10 + } + offset := req.Offset + if offset < 0 { + offset = 0 + } + + memories, total, err := h.memoryService.SearchMemories(req.AgentID, req.UserID, req.Keyword, req.Tags, req.Category, req.MemoryType, req.MinScore, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"list": memories, "total": total}) +} + +// GetMemory 获取单个记忆详情 +func (h *MemoryHandler) GetMemory(c *gin.Context) { + memoryID := c.Param("memory_id") + + memory, err := h.memoryService.GetMemoryByID(memoryID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Memory not found"}) + return + } + + c.JSON(http.StatusOK, memory) +} + +// UpdateMemory 更新记忆 +func (h *MemoryHandler) UpdateMemory(c *gin.Context) { + memoryID := c.Param("memory_id") + + var req UpdateMemoryRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + memory, err := h.memoryService.UpdateMemory(memoryID, req.Content, req.MemoryType, req.Category, req.Tags, req.Keywords, req.Importance, req.IsPinned) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, memory) } // DeleteMemory 删除记忆 @@ -89,3 +184,72 @@ func (h *MemoryHandler) DeleteMemory(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "Memory deleted successfully"}) } + +// ExportMemories 导出记忆 +func (h *MemoryHandler) ExportMemories(c *gin.Context) { + agentID := c.Param("id") + userID := c.Query("user_id") + format := c.DefaultQuery("format", "json") + + exportData, err := h.memoryService.ExportMemories(agentID, userID, format) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "format": format, + "data": exportData, + }) +} + +// ImportMemories 导入记忆 +type ImportMemoryRequest struct { + AgentID string `json:"agent_id" binding:"required"` + UserID string `json:"user_id"` + Memories []model.ImportItem `json:"memories" binding:"required"` +} + +func (h *MemoryHandler) ImportMemories(c *gin.Context) { + var req ImportMemoryRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + count, err := h.memoryService.ImportMemories(req.AgentID, req.UserID, req.Memories) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"imported": count}) +} + +// GetMemoryCategories 获取记忆分类列表 +func (h *MemoryHandler) GetMemoryCategories(c *gin.Context) { + agentID := c.Param("id") + userID := c.Query("user_id") + + categories, err := h.memoryService.GetMemoryCategories(agentID, userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"categories": categories}) +} + +// GetMemoryTags 获取记忆标签列表 +func (h *MemoryHandler) GetMemoryTags(c *gin.Context) { + agentID := c.Param("id") + userID := c.Query("user_id") + + tags, err := h.memoryService.GetMemoryTags(agentID, userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"tags": tags}) +} diff --git a/server/internal/service/memory_service.go b/server/internal/service/memory_service.go index e9d3837..c259592 100644 --- a/server/internal/service/memory_service.go +++ b/server/internal/service/memory_service.go @@ -1,9 +1,15 @@ package service import ( + "encoding/json" + "strings" + "time" + "x-agents/server/internal/repository" "x-agents/server/internal/model" + + "github.com/google/uuid" ) // MemoryService 记忆服务 @@ -19,13 +25,25 @@ func NewMemoryService(agentRepo *repository.AgentRepository) *MemoryService { } // CreateMemory 创建记忆 -func (s *MemoryService) CreateMemory(agentID, userID, content, memoryType string, importance int) (*model.AgentMemory, error) { +func (s *MemoryService) CreateMemory(agentID, userID, content, memoryType, category, tags, keywords string, importance int, isPinned bool) (*model.AgentMemory, error) { memory := &model.AgentMemory{ - AgentID: agentID, - UserID: userID, - Content: content, - MemoryType: memoryType, - Importance: importance, + ID: uuid.New().String(), + AgentID: agentID, + UserID: userID, + Content: content, + MemoryType: memoryType, + Category: category, + Tags: tags, + Keywords: keywords, + Importance: importance, + IsPinned: isPinned, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + // 如果没有提供关键词,自动从内容中提取 + if keywords == "" && content != "" { + memory.Keywords = extractKeywords(content) } err := s.agentRepo.CreateMemory(memory) @@ -37,14 +55,183 @@ func (s *MemoryService) CreateMemory(agentID, userID, content, memoryType string } // GetMemories 获取记忆列表 -func (s *MemoryService) GetMemories(agentID string, userID string, limit int) ([]model.AgentMemory, error) { - if userID != "" { - return s.agentRepo.FindMemoriesByUserID(agentID, userID, limit) +func (s *MemoryService) GetMemories(agentID, userID, category, memoryType string, limit, offset int) ([]model.AgentMemory, int64, error) { + return s.agentRepo.FindMemories(agentID, userID, category, memoryType, limit, offset) +} + +// SearchMemories 搜索记忆 +func (s *MemoryService) SearchMemories(agentID, userID, keyword, tags, category, memoryType string, minScore, limit, offset int) ([]model.AgentMemory, int64, error) { + // 如果有关键词搜索,优先使用模糊匹配 + if keyword != "" { + return s.agentRepo.SearchMemories(agentID, userID, keyword, tags, category, memoryType, minScore, limit, offset) } - return s.agentRepo.FindMemoriesByAgentID(agentID, limit) + // 否则使用过滤条件查询 + return s.agentRepo.FindMemories(agentID, userID, category, memoryType, limit, offset) +} + +// GetMemoryByID 获取记忆详情 +func (s *MemoryService) GetMemoryByID(memoryID string) (*model.AgentMemory, error) { + return s.agentRepo.FindMemoryByID(memoryID) +} + +// UpdateMemory 更新记忆 +func (s *MemoryService) UpdateMemory(memoryID, content, memoryType, category, tags, keywords string, importance int, isPinned *bool) (*model.AgentMemory, error) { + memory, err := s.agentRepo.FindMemoryByID(memoryID) + if err != nil { + return nil, err + } + + if content != "" { + memory.Content = content + } + if memoryType != "" { + memory.MemoryType = memoryType + } + if category != "" { + memory.Category = category + } + if tags != "" { + memory.Tags = tags + } + if keywords != "" { + memory.Keywords = keywords + } + if importance > 0 { + memory.Importance = importance + } + if isPinned != nil { + memory.IsPinned = *isPinned + } + memory.UpdatedAt = time.Now() + + err = s.agentRepo.UpdateMemory(memory) + if err != nil { + return nil, err + } + + return memory, nil } // DeleteMemory 删除记忆 func (s *MemoryService) DeleteMemory(memoryID string) error { return s.agentRepo.DeleteMemory(memoryID) } + +// ExportMemories 导出记忆 +func (s *MemoryService) ExportMemories(agentID, userID, format string) (interface{}, error) { + memories, _, err := s.agentRepo.FindMemories(agentID, userID, "", "", 0, 0) + if err != nil { + return nil, err + } + + if format == "csv" { + // 生成CSV格式 + var sb strings.Builder + sb.WriteString("id,agent_id,content,memory_type,category,tags,keywords,importance,is_pinned,created_at\n") + for _, m := range memories { + sb.WriteString(m.ID + ",") + sb.WriteString(m.AgentID + ",") + sb.WriteString("\"" + strings.ReplaceAll(m.Content, "\"", "\"\"") + "\",") + sb.WriteString(m.MemoryType + ",") + sb.WriteString(m.Category + ",") + sb.WriteString(m.Tags + ",") + sb.WriteString(m.Keywords + ",") + sb.WriteString(string(rune(m.Importance+'0')) + ",") + if m.IsPinned { + sb.WriteString("true") + } else { + sb.WriteString("false") + } + sb.WriteString("," + m.CreatedAt.Format(time.RFC3339) + "\n") + } + return sb.String(), nil + } + + // 默认JSON格式 + return memories, nil +} + +// ImportMemories 导入记忆 +func (s *MemoryService) ImportMemories(agentID, userID string, items []model.ImportItem) (int, error) { + count := 0 + for _, item := range items { + memoryType := item.MemoryType + if memoryType == "" { + memoryType = "conversation" + } + importance := item.Importance + if importance == 0 { + importance = 5 + } + keywords := item.Keywords + if keywords == "" && item.Content != "" { + keywords = extractKeywords(item.Content) + } + + memory := &model.AgentMemory{ + ID: uuid.New().String(), + AgentID: agentID, + UserID: userID, + Content: item.Content, + MemoryType: memoryType, + Category: item.Category, + Tags: item.Tags, + Keywords: keywords, + Importance: importance, + IsPinned: false, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + err := s.agentRepo.CreateMemory(memory) + if err != nil { + continue + } + count++ + } + return count, nil +} + +// GetMemoryCategories 获取记忆分类列表 +func (s *MemoryService) GetMemoryCategories(agentID, userID string) ([]string, error) { + return s.agentRepo.FindMemoryCategories(agentID, userID) +} + +// GetMemoryTags 获取记忆标签列表 +func (s *MemoryService) GetMemoryTags(agentID, userID string) ([]string, error) { + // 从所有记忆中的 tags 字段提取所有标签 + memories, _, err := s.agentRepo.FindMemories(agentID, userID, "", "", 0, 0) + if err != nil { + return nil, err + } + + tagSet := make(map[string]bool) + for _, m := range memories { + if m.Tags != "" { + var tags []string + if err := json.Unmarshal([]byte(m.Tags), &tags); err == nil { + for _, tag := range tags { + if tag != "" { + tagSet[tag] = true + } + } + } + } + } + + tags := make([]string, 0, len(tagSet)) + for tag := range tagSet { + tags = append(tags, tag) + } + return tags, nil +} + +// extractKeywords 从内容中提取关键词 +func extractKeywords(content string) string { + // 简单提取:取内容的前50个字符作为关键词演示 + // 实际生产中可使用分词库如 gojieba + if len(content) <= 50 { + return content + } + return content[:50] +}