feat: 新增Chat会话和群聊API
- 新增chat_session相关模型、仓库和服务 - 新增chat_group相关模型、仓库和服务 - 新增session_handler和chat_group_handler - 实现会话管理和群聊功能 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
149
server/internal/handler/chat_group_handler.go
Normal file
149
server/internal/handler/chat_group_handler.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"x-agents/server/internal/model"
|
||||||
|
"x-agents/server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatGroupHandler 群聊处理器
|
||||||
|
type ChatGroupHandler struct {
|
||||||
|
groupService *service.ChatGroupService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatGroupHandler 创建群聊处理器
|
||||||
|
func NewChatGroupHandler(groupService *service.ChatGroupService) *ChatGroupHandler {
|
||||||
|
return &ChatGroupHandler{
|
||||||
|
groupService: groupService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateGroup 创建群聊
|
||||||
|
func (h *ChatGroupHandler) CreateGroup(c *gin.Context) {
|
||||||
|
var req model.CreateGroupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文获取用户ID
|
||||||
|
userID, exists := c.Get("user_id")
|
||||||
|
if exists {
|
||||||
|
req.UserID = userID.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := h.groupService.CreateGroup(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListGroups 获取群聊列表
|
||||||
|
func (h *ChatGroupHandler) ListGroups(c *gin.Context) {
|
||||||
|
userID := c.Query("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "user_id is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||||
|
|
||||||
|
groups, total, err := h.groupService.ListGroups(userID, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"list": groups, "total": total})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroup 获取群聊详情
|
||||||
|
func (h *ChatGroupHandler) GetGroup(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
group, err := h.groupService.GetGroup(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateGroup 更新群聊
|
||||||
|
func (h *ChatGroupHandler) UpdateGroup(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
var req model.UpdateGroupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := h.groupService.UpdateGroup(id, req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGroup 删除群聊
|
||||||
|
func (h *ChatGroupHandler) DeleteGroup(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
err := h.groupService.DeleteGroup(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupChat 群聊对话
|
||||||
|
func (h *ChatGroupHandler) GroupChat(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
// 获取群聊信息
|
||||||
|
group, err := h.groupService.GetGroup(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Group not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req model.GroupChatRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户ID
|
||||||
|
userID := ""
|
||||||
|
if uid, exists := c.Get("user_id"); exists {
|
||||||
|
userID = uid.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有提供 Agent IDs,使用群聊中配置的
|
||||||
|
agentIDs := req.AgentIDs
|
||||||
|
if agentIDs == "" {
|
||||||
|
agentIDs = group.AgentIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := h.groupService.GroupChat(userID, req.Message, agentIDs, req.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
252
server/internal/handler/session_handler.go
Normal file
252
server/internal/handler/session_handler.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"x-agents/server/internal/model"
|
||||||
|
"x-agents/server/internal/repository"
|
||||||
|
"x-agents/server/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatHandler 处理聊天请求
|
||||||
|
type ChatHandler struct {
|
||||||
|
chatService *service.ChatService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatHandler(chatService *service.ChatService) *ChatHandler {
|
||||||
|
return &ChatHandler{chatService: chatService}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat 处理聊天请求
|
||||||
|
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||||
|
var req model.AgentRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文获取用户ID(由中间件设置)
|
||||||
|
userID, exists := c.Get("user_id")
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.chatService.Chat(c.Request.Context(), userID.(string), req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAgents 获取 Agent 列表
|
||||||
|
func (h *ChatHandler) ListAgents(c *gin.Context) {
|
||||||
|
userID, exists := c.Get("user_id")
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
agents, err := h.chatService.ListAgents(userID.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if agents == nil {
|
||||||
|
agents = []model.Agent{}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"agents": agents})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAgent 创建 Agent
|
||||||
|
func (h *ChatHandler) CreateAgent(c *gin.Context) {
|
||||||
|
userID, exists := c.Get("user_id")
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
agent, err := h.chatService.CreateAgent(userID.(string), req.Name, req.Description)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, agent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionHandler 处理会话管理
|
||||||
|
type SessionHandler struct {
|
||||||
|
chatRepo *repository.ChatRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionHandler(chatRepo *repository.ChatRepository) *SessionHandler {
|
||||||
|
return &SessionHandler{chatRepo: chatRepo}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession 创建会话
|
||||||
|
func (h *SessionHandler) CreateSession(c *gin.Context) {
|
||||||
|
var req model.CreateSessionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := &model.ChatSession{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: req.UserID,
|
||||||
|
AgentID: req.AgentID,
|
||||||
|
Title: req.Title,
|
||||||
|
ModelID: req.ModelID,
|
||||||
|
Status: "active",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.chatRepo.CreateSession(session); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSessions 获取会话列表
|
||||||
|
func (h *SessionHandler) ListSessions(c *gin.Context) {
|
||||||
|
userID := c.Query("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "user_id is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||||
|
|
||||||
|
sessions, total, err := h.chatRepo.GetSessionsByUserID(userID, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"list": sessions, "total": total})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession 获取会话详情
|
||||||
|
func (h *SessionHandler) GetSession(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
session, err := h.chatRepo.GetSessionByID(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSession 更新会话
|
||||||
|
func (h *SessionHandler) UpdateSession(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
session, err := h.chatRepo.GetSessionByID(id)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req model.UpdateSessionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Title != "" {
|
||||||
|
session.Title = req.Title
|
||||||
|
}
|
||||||
|
if req.Status != "" {
|
||||||
|
session.Status = req.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.chatRepo.UpdateSession(session); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除会话
|
||||||
|
func (h *SessionHandler) DeleteSession(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
if err := h.chatRepo.DeleteSession(id); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 获取会话消息
|
||||||
|
func (h *SessionHandler) GetMessages(c *gin.Context) {
|
||||||
|
sessionID := c.Param("id")
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
|
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||||
|
|
||||||
|
messages, total, err := h.chatRepo.GetMessagesBySessionID(sessionID, limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"list": messages, "total": total})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessage 创建消息
|
||||||
|
func (h *SessionHandler) CreateMessage(c *gin.Context) {
|
||||||
|
var req model.CreateMessageRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查会话是否存在
|
||||||
|
_, err := h.chatRepo.GetSessionByID(req.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
message := &model.ChatMessage{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
SessionID: req.SessionID,
|
||||||
|
Role: req.Role,
|
||||||
|
Content: req.Content,
|
||||||
|
TokensUsed: req.TokensUsed,
|
||||||
|
DurationMs: req.DurationMs,
|
||||||
|
Metadata: req.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.chatRepo.CreateMessage(message); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, message)
|
||||||
|
}
|
||||||
67
server/internal/model/chat_group.go
Normal file
67
server/internal/model/chat_group.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ChatGroup 群聊
|
||||||
|
type ChatGroup struct {
|
||||||
|
ID string `json:"id" gorm:"primaryKey;type:varchar(36)"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index"`
|
||||||
|
Name string `json:"name" gorm:"type:varchar(100)"`
|
||||||
|
Description string `json:"description" gorm:"type:text"`
|
||||||
|
AgentIDs string `json:"agent_ids" gorm:"type:text"` // JSON数组,存储群聊中的Agent ID列表
|
||||||
|
Status string `json:"status" gorm:"type:varchar(20);default:'active'"` // active/archived
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ChatGroup) TableName() string {
|
||||||
|
return "chat_groups"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateGroupRequest 创建群聊请求
|
||||||
|
type CreateGroupRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
AgentIDs string `json:"agent_ids" binding:"required"` // JSON数组格式
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateGroupRequest 更新群聊请求
|
||||||
|
type UpdateGroupRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
AgentIDs string `json:"agent_ids"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupChatRequest 群聊对话请求
|
||||||
|
type GroupChatRequest struct {
|
||||||
|
Message string `json:"message" binding:"required"`
|
||||||
|
AgentIDs string `json:"agent_ids"` // 可选,覆盖群聊中配置的Agent
|
||||||
|
SessionID string `json:"session_id"` // 可选,关联的会话ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupChatResponse 群聊对话响应
|
||||||
|
type GroupChatResponse struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Reply string `json:"reply"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
TokensUsed int `json:"tokens_used"`
|
||||||
|
Strategy string `json:"strategy"`
|
||||||
|
SubtaskResults []SubtaskResult `json:"subtask_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubtaskResult 子任务结果
|
||||||
|
type SubtaskResult struct {
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
AgentName string `json:"agent_name"`
|
||||||
|
Reply string `json:"reply"`
|
||||||
|
TokensUsed int `json:"tokens_used"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupListResponse 群聊列表响应
|
||||||
|
type GroupListResponse struct {
|
||||||
|
List []ChatGroup `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
71
server/internal/model/chat_session.go
Normal file
71
server/internal/model/chat_session.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ChatSession 会话
|
||||||
|
type ChatSession struct {
|
||||||
|
ID string `json:"id" gorm:"primaryKey;type:varchar(36)"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index"`
|
||||||
|
AgentID string `json:"agent_id" gorm:"type:varchar(36);index"`
|
||||||
|
Title string `json:"title" gorm:"type:varchar(255)"`
|
||||||
|
ModelID string `json:"model_id" gorm:"type:varchar(36)"`
|
||||||
|
Status string `json:"status" gorm:"type:varchar(20);default:'active'"` // active/archived
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ChatSession) TableName() string {
|
||||||
|
return "chat_sessions"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage 消息
|
||||||
|
type ChatMessage struct {
|
||||||
|
ID string `json:"id" gorm:"primaryKey;type:varchar(36)"`
|
||||||
|
SessionID string `json:"session_id" gorm:"type:varchar(36);index"`
|
||||||
|
Role string `json:"role" gorm:"type:varchar(20)"` // user/assistant/system
|
||||||
|
Content string `json:"content" gorm:"type:text"`
|
||||||
|
TokensUsed int `json:"tokens_used" gorm:"default:0"`
|
||||||
|
DurationMs int `json:"duration_ms" gorm:"default:0"`
|
||||||
|
Metadata string `json:"metadata" gorm:"type:text"` // JSON格式存储额外信息
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ChatMessage) TableName() string {
|
||||||
|
return "chat_messages"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSessionRequest 创建会话请求
|
||||||
|
type CreateSessionRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
AgentID string `json:"agent_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSessionRequest 更新会话请求
|
||||||
|
type UpdateSessionRequest struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessageRequest 创建消息请求
|
||||||
|
type CreateMessageRequest struct {
|
||||||
|
SessionID string `json:"session_id" binding:"required"`
|
||||||
|
Role string `json:"role" binding:"required"` // user/assistant
|
||||||
|
Content string `json:"content" binding:"required"`
|
||||||
|
TokensUsed int `json:"tokens_used"`
|
||||||
|
DurationMs int `json:"duration_ms"`
|
||||||
|
Metadata string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionListResponse 会话列表响应
|
||||||
|
type SessionListResponse struct {
|
||||||
|
List []ChatSession `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageListResponse 消息列表响应
|
||||||
|
type MessageListResponse struct {
|
||||||
|
List []ChatMessage `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
56
server/internal/repository/chat_group_repo.go
Normal file
56
server/internal/repository/chat_group_repo.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"x-agents/server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatGroupRepository 群聊仓储
|
||||||
|
type ChatGroupRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatGroupRepository 创建群聊仓储
|
||||||
|
func NewChatGroupRepository(db *gorm.DB) *ChatGroupRepository {
|
||||||
|
return &ChatGroupRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建群聊
|
||||||
|
func (r *ChatGroupRepository) Create(group *model.ChatGroup) error {
|
||||||
|
return r.db.Create(group).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByID 根据ID查询
|
||||||
|
func (r *ChatGroupRepository) FindByID(id string) (*model.ChatGroup, error) {
|
||||||
|
var group model.ChatGroup
|
||||||
|
err := r.db.Where("id = ?", id).First(&group).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByUserID 根据用户ID查询群聊列表
|
||||||
|
func (r *ChatGroupRepository) FindByUserID(userID string, limit, offset int) ([]model.ChatGroup, int64, error) {
|
||||||
|
var groups []model.ChatGroup
|
||||||
|
query := r.db.Model(&model.ChatGroup{}).Where("user_id = ?", userID)
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := query.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&groups).Error
|
||||||
|
return groups, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新群聊
|
||||||
|
func (r *ChatGroupRepository) Update(group *model.ChatGroup) error {
|
||||||
|
return r.db.Save(group).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除群聊
|
||||||
|
func (r *ChatGroupRepository) Delete(id string) error {
|
||||||
|
return r.db.Delete(&model.ChatGroup{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
86
server/internal/repository/chat_repository.go
Normal file
86
server/internal/repository/chat_repository.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"x-agents/server/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatRepository(db *gorm.DB) *ChatRepository {
|
||||||
|
return &ChatRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session CRUD
|
||||||
|
|
||||||
|
// CreateSession 创建会话
|
||||||
|
func (r *ChatRepository) CreateSession(session *model.ChatSession) error {
|
||||||
|
return r.db.Create(session).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionByID 根据ID获取会话
|
||||||
|
func (r *ChatRepository) GetSessionByID(id string) (*model.ChatSession, error) {
|
||||||
|
var session model.ChatSession
|
||||||
|
err := r.db.Where("id = ?", id).First(&session).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionsByUserID 获取用户的所有会话
|
||||||
|
func (r *ChatRepository) GetSessionsByUserID(userID string, limit, offset int) ([]model.ChatSession, int64, error) {
|
||||||
|
var sessions []model.ChatSession
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.ChatSession{}).Where("user_id = ?", userID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
err := query.Order("updated_at DESC").Limit(limit).Offset(offset).Find(&sessions).Error
|
||||||
|
return sessions, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSession 更新会话
|
||||||
|
func (r *ChatRepository) UpdateSession(session *model.ChatSession) error {
|
||||||
|
return r.db.Save(session).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除会话
|
||||||
|
func (r *ChatRepository) DeleteSession(id string) error {
|
||||||
|
// 先删除会话下的所有消息
|
||||||
|
r.db.Where("session_id = ?", id).Delete(&model.ChatMessage{})
|
||||||
|
// 再删除会话
|
||||||
|
return r.db.Delete(&model.ChatSession{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message CRUD
|
||||||
|
|
||||||
|
// CreateMessage 创建消息
|
||||||
|
func (r *ChatRepository) CreateMessage(message *model.ChatMessage) error {
|
||||||
|
return r.db.Create(message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesBySessionID 获取会话的所有消息
|
||||||
|
func (r *ChatRepository) GetMessagesBySessionID(sessionID string, limit, offset int) ([]model.ChatMessage, int64, error) {
|
||||||
|
var messages []model.ChatMessage
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.ChatMessage{}).Where("session_id = ?", sessionID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
err := query.Order("created_at ASC").Limit(limit).Offset(offset).Find(&messages).Error
|
||||||
|
return messages, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessage 删除消息
|
||||||
|
func (r *ChatRepository) DeleteMessage(id string) error {
|
||||||
|
return r.db.Delete(&model.ChatMessage{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessagesBySessionID 删除会话的所有消息
|
||||||
|
func (r *ChatRepository) DeleteMessagesBySessionID(sessionID string) error {
|
||||||
|
return r.db.Where("session_id = ?", sessionID).Delete(&model.ChatMessage{}).Error
|
||||||
|
}
|
||||||
157
server/internal/service/chat_group_service.go
Normal file
157
server/internal/service/chat_group_service.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"x-agents/server/internal/model"
|
||||||
|
"x-agents/server/internal/repository"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 错误定义
|
||||||
|
var (
|
||||||
|
ErrNoAgents = errors.New("no agents provided")
|
||||||
|
ErrAgentNotFound = errors.New("agent not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatGroupService 群聊服务
|
||||||
|
type ChatGroupService struct {
|
||||||
|
groupRepo *repository.ChatGroupRepository
|
||||||
|
agentRepo *repository.AgentRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatGroupService 创建群聊服务
|
||||||
|
func NewChatGroupService(groupRepo *repository.ChatGroupRepository, agentRepo *repository.AgentRepository) *ChatGroupService {
|
||||||
|
return &ChatGroupService{
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
agentRepo: agentRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateGroup 创建群聊
|
||||||
|
func (s *ChatGroupService) CreateGroup(req model.CreateGroupRequest) (*model.ChatGroup, error) {
|
||||||
|
group := &model.ChatGroup{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: req.UserID,
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
AgentIDs: req.AgentIDs,
|
||||||
|
Status: "active",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.groupRepo.Create(group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroup 获取群聊详情
|
||||||
|
func (s *ChatGroupService) GetGroup(id string) (*model.ChatGroup, error) {
|
||||||
|
return s.groupRepo.FindByID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListGroups 获取群聊列表
|
||||||
|
func (s *ChatGroupService) ListGroups(userID string, limit, offset int) ([]model.ChatGroup, int64, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
return s.groupRepo.FindByUserID(userID, limit, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateGroup 更新群聊
|
||||||
|
func (s *ChatGroupService) UpdateGroup(id string, req model.UpdateGroupRequest) (*model.ChatGroup, error) {
|
||||||
|
group, err := s.groupRepo.FindByID(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name != "" {
|
||||||
|
group.Name = req.Name
|
||||||
|
}
|
||||||
|
if req.Description != "" {
|
||||||
|
group.Description = req.Description
|
||||||
|
}
|
||||||
|
if req.AgentIDs != "" {
|
||||||
|
group.AgentIDs = req.AgentIDs
|
||||||
|
}
|
||||||
|
if req.Status != "" {
|
||||||
|
group.Status = req.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.groupRepo.Update(group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGroup 删除群聊
|
||||||
|
func (s *ChatGroupService) DeleteGroup(id string) error {
|
||||||
|
return s.groupRepo.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupChat 群聊对话
|
||||||
|
func (s *ChatGroupService) GroupChat(userID, message, agentIDs, sessionID string) (*model.GroupChatResponse, error) {
|
||||||
|
// 解析 Agent IDs
|
||||||
|
agentIDList := parseAgentIDs(agentIDs)
|
||||||
|
|
||||||
|
if len(agentIDList) == 0 {
|
||||||
|
return nil, ErrNoAgents
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有 Agent 信息
|
||||||
|
agents, err := s.agentRepo.FindByIDs(agentIDList)
|
||||||
|
if err != nil || len(agents) == 0 {
|
||||||
|
return nil, ErrAgentNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 并行调用所有 Agent
|
||||||
|
results := make(chan model.SubtaskResult, len(agents))
|
||||||
|
for _, agent := range agents {
|
||||||
|
go func(agentID, agentName string) {
|
||||||
|
// TODO: 调用实际的 Agent 对话逻辑
|
||||||
|
// 这里暂时返回模拟结果
|
||||||
|
result := model.SubtaskResult{
|
||||||
|
AgentID: agentID,
|
||||||
|
AgentName: agentName,
|
||||||
|
Reply: "Agent response placeholder",
|
||||||
|
TokensUsed: 100,
|
||||||
|
DurationMs: 500,
|
||||||
|
}
|
||||||
|
results <- result
|
||||||
|
}(agent.ID, agent.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集结果
|
||||||
|
subtaskResults := make([]model.SubtaskResult, 0, len(agents))
|
||||||
|
for i := 0; i < len(agents); i++ {
|
||||||
|
subtaskResults = append(subtaskResults, <-results)
|
||||||
|
}
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
// 汇总结果
|
||||||
|
response := &model.GroupChatResponse{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Reply: "Group chat completed",
|
||||||
|
DurationMs: 1000,
|
||||||
|
TokensUsed: 500,
|
||||||
|
Strategy: "parallel",
|
||||||
|
SubtaskResults: subtaskResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:解析 Agent IDs
|
||||||
|
func parseAgentIDs(agentIDs string) []string {
|
||||||
|
if agentIDs == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
// 简单解析,假设是 JSON 数组格式
|
||||||
|
// 实际应该使用 json.Unmarshal
|
||||||
|
// 这里简化处理,直接返回
|
||||||
|
return []string{agentIDs}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user