Files
2025-11-17 13:32:54 +08:00

178 lines
5.2 KiB
Go

package repository
import (
"dianshang/internal/model"
"time"
"gorm.io/gorm"
)
// BannerRepository 轮播图仓储
type BannerRepository struct {
db *gorm.DB
}
// NewBannerRepository 创建轮播图仓储
func NewBannerRepository(db *gorm.DB) *BannerRepository {
return &BannerRepository{db: db}
}
// GetActiveBanners 获取有效的轮播图
func (r *BannerRepository) GetActiveBanners() ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("status = ?", 1).Order("sort ASC").Find(&banners).Error
return banners, err
}
// GetBannerList 获取轮播图列表(分页)
func (r *BannerRepository) GetBannerList(page, pageSize int, status *int) ([]model.Banner, int64, error) {
var banners []model.Banner
var total int64
query := r.db.Model(&model.Banner{})
// 状态筛选
if status != nil {
query = query.Where("status = ?", *status)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err := query.Order("sort ASC, created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&banners).Error
return banners, total, err
}
// GetBannerByID 根据ID获取轮播图
func (r *BannerRepository) GetBannerByID(id uint) (*model.Banner, error) {
var banner model.Banner
err := r.db.First(&banner, id).Error
if err != nil {
return nil, err
}
return &banner, nil
}
// CreateBanner 创建轮播图
func (r *BannerRepository) CreateBanner(banner *model.Banner) error {
return r.db.Create(banner).Error
}
// UpdateBanner 更新轮播图
func (r *BannerRepository) UpdateBanner(banner *model.Banner) error {
return r.db.Save(banner).Error
}
// DeleteBanner 删除轮播图
func (r *BannerRepository) DeleteBanner(id uint) error {
return r.db.Delete(&model.Banner{}, id).Error
}
// BatchDeleteBanners 批量删除轮播图
func (r *BannerRepository) BatchDeleteBanners(ids []uint) error {
return r.db.Delete(&model.Banner{}, ids).Error
}
// UpdateBannerStatus 更新轮播图状态
func (r *BannerRepository) UpdateBannerStatus(id uint, status int) error {
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("status", status).Error
}
// BatchUpdateBannerStatus 批量更新轮播图状态
func (r *BannerRepository) BatchUpdateBannerStatus(ids []uint, status int) error {
return r.db.Model(&model.Banner{}).Where("id IN ?", ids).Update("status", status).Error
}
// UpdateBannerSort 更新轮播图排序
func (r *BannerRepository) UpdateBannerSort(id uint, sort int) error {
return r.db.Model(&model.Banner{}).Where("id = ?", id).Update("sort", sort).Error
}
// BatchUpdateBannerSort 批量更新轮播图排序
func (r *BannerRepository) BatchUpdateBannerSort(sortData []map[string]interface{}) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for _, data := range sortData {
if err := tx.Model(&model.Banner{}).
Where("id = ?", data["id"]).
Update("sort", data["sort"]).Error; err != nil {
return err
}
}
return nil
})
}
// GetBannersByDateRange 根据日期范围获取轮播图
func (r *BannerRepository) GetBannersByDateRange(startDate, endDate time.Time) ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("created_at BETWEEN ? AND ?", startDate, endDate).
Order("sort ASC, created_at DESC").
Find(&banners).Error
return banners, err
}
// GetBannersByStatus 根据状态获取轮播图
func (r *BannerRepository) GetBannersByStatus(status int) ([]model.Banner, error) {
var banners []model.Banner
err := r.db.Where("status = ?", status).
Order("sort ASC, created_at DESC").
Find(&banners).Error
return banners, err
}
// GetBannerCount 获取轮播图总数
func (r *BannerRepository) GetBannerCount() (int64, error) {
var count int64
err := r.db.Model(&model.Banner{}).Count(&count).Error
return count, err
}
// GetBannerCountByStatus 根据状态获取轮播图数量
func (r *BannerRepository) GetBannerCountByStatus(status int) (int64, error) {
var count int64
err := r.db.Model(&model.Banner{}).Where("status = ?", status).Count(&count).Error
return count, err
}
// GetMaxSort 获取最大排序值
func (r *BannerRepository) GetMaxSort() (int, error) {
var maxSort int
err := r.db.Model(&model.Banner{}).Select("COALESCE(MAX(sort), 0)").Scan(&maxSort).Error
return maxSort, err
}
// CheckBannerExists 检查轮播图是否存在
func (r *BannerRepository) CheckBannerExists(id uint) (bool, error) {
var count int64
err := r.db.Model(&model.Banner{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// GetExpiredBanners 获取过期的轮播图
func (r *BannerRepository) GetExpiredBanners() ([]model.Banner, error) {
var banners []model.Banner
now := time.Now()
err := r.db.Where("end_time IS NOT NULL AND end_time < ? AND status = 1", now).
Find(&banners).Error
return banners, err
}
// GetActiveBannersWithTimeRange 获取在时间范围内有效的轮播图
func (r *BannerRepository) GetActiveBannersWithTimeRange() ([]model.Banner, error) {
var banners []model.Banner
now := time.Now()
err := r.db.Where("status = 1").
Where("(start_time IS NULL OR start_time <= ?)"+
" AND (end_time IS NULL OR end_time >= ?)", now, now).
Order("sort ASC").
Find(&banners).Error
return banners, err
}