2026-03-13 14:31:42 +08:00
|
|
|
package repository
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"x-agents/server/internal/model"
|
|
|
|
|
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type ChatRepository struct {
|
|
|
|
|
db *gorm.DB
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func NewChatRepository(db *gorm.DB) *ChatRepository {
|
|
|
|
|
return &ChatRepository{db: db}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Session CRUD
|
|
|
|
|
|
|
|
|
|
// CreateSession 创建会话
|
|
|
|
|
func (r *ChatRepository) CreateSession(session *model.ChatSession) error {
|
|
|
|
|
return r.db.Create(session).Error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetSessionByID 根据ID获取会话
|
|
|
|
|
func (r *ChatRepository) GetSessionByID(id string) (*model.ChatSession, error) {
|
|
|
|
|
var session model.ChatSession
|
|
|
|
|
err := r.db.Where("id = ?", id).First(&session).Error
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return &session, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetSessionsByUserID 获取用户的所有会话
|
|
|
|
|
func (r *ChatRepository) GetSessionsByUserID(userID string, limit, offset int) ([]model.ChatSession, int64, error) {
|
|
|
|
|
var sessions []model.ChatSession
|
|
|
|
|
var total int64
|
|
|
|
|
|
|
|
|
|
query := r.db.Model(&model.ChatSession{}).Where("user_id = ?", userID)
|
|
|
|
|
query.Count(&total)
|
|
|
|
|
|
|
|
|
|
err := query.Order("updated_at DESC").Limit(limit).Offset(offset).Find(&sessions).Error
|
|
|
|
|
return sessions, total, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// UpdateSession 更新会话
|
|
|
|
|
func (r *ChatRepository) UpdateSession(session *model.ChatSession) error {
|
|
|
|
|
return r.db.Save(session).Error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteSession 删除会话
|
|
|
|
|
func (r *ChatRepository) DeleteSession(id string) error {
|
|
|
|
|
// 先删除会话下的所有消息
|
|
|
|
|
r.db.Where("session_id = ?", id).Delete(&model.ChatMessage{})
|
|
|
|
|
// 再删除会话
|
|
|
|
|
return r.db.Delete(&model.ChatSession{}, "id = ?", id).Error
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-15 19:49:27 +08:00
|
|
|
// DeleteSessionsByAgentID 删除智能体的所有会话
|
|
|
|
|
func (r *ChatRepository) DeleteSessionsByAgentID(agentID string) error {
|
|
|
|
|
// 先查询该智能体的所有会话
|
|
|
|
|
var sessions []model.ChatSession
|
|
|
|
|
if err := r.db.Where("agent_id = ?", agentID).Find(&sessions).Error; err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
// 删除每个会话下的所有消息
|
|
|
|
|
for _, session := range sessions {
|
|
|
|
|
r.db.Where("session_id = ?", session.ID).Delete(&model.ChatMessage{})
|
|
|
|
|
}
|
|
|
|
|
// 再删除所有会话
|
|
|
|
|
return r.db.Where("agent_id = ?", agentID).Delete(&model.ChatSession{}).Error
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-13 14:31:42 +08:00
|
|
|
// Message CRUD
|
|
|
|
|
|
|
|
|
|
// CreateMessage 创建消息
|
|
|
|
|
func (r *ChatRepository) CreateMessage(message *model.ChatMessage) error {
|
|
|
|
|
return r.db.Create(message).Error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetMessagesBySessionID 获取会话的所有消息
|
|
|
|
|
func (r *ChatRepository) GetMessagesBySessionID(sessionID string, limit, offset int) ([]model.ChatMessage, int64, error) {
|
|
|
|
|
var messages []model.ChatMessage
|
|
|
|
|
var total int64
|
|
|
|
|
|
|
|
|
|
query := r.db.Model(&model.ChatMessage{}).Where("session_id = ?", sessionID)
|
|
|
|
|
query.Count(&total)
|
|
|
|
|
|
|
|
|
|
err := query.Order("created_at ASC").Limit(limit).Offset(offset).Find(&messages).Error
|
|
|
|
|
return messages, total, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteMessage 删除消息
|
|
|
|
|
func (r *ChatRepository) DeleteMessage(id string) error {
|
|
|
|
|
return r.db.Delete(&model.ChatMessage{}, "id = ?", id).Error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// DeleteMessagesBySessionID 删除会话的所有消息
|
|
|
|
|
func (r *ChatRepository) DeleteMessagesBySessionID(sessionID string) error {
|
|
|
|
|
return r.db.Where("session_id = ?", sessionID).Delete(&model.ChatMessage{}).Error
|
|
|
|
|
}
|