package postgres import ( "context" "time" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) type DeviceStore struct { db *gorm.DB redis *redis.Client } func NewDeviceStore(db *gorm.DB, redis *redis.Client) *DeviceStore { return &DeviceStore{ db: db, redis: redis, } } func (s *DeviceStore) Create(ctx context.Context, device *model.Device) error { return s.db.WithContext(ctx).Create(device).Error } func (s *DeviceStore) CreateBatch(ctx context.Context, devices []*model.Device) error { if len(devices) == 0 { return nil } return s.db.WithContext(ctx).CreateInBatches(devices, 100).Error } func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, error) { var device model.Device if err := s.db.WithContext(ctx).First(&device, id).Error; err != nil { return nil, err } return &device, nil } func (s *DeviceStore) GetByDeviceNo(ctx context.Context, deviceNo string) (*model.Device, error) { var device model.Device if err := s.db.WithContext(ctx).Where("device_no = ?", deviceNo).First(&device).Error; err != nil { return nil, err } return &device, nil } func (s *DeviceStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.Device, error) { var devices []*model.Device if len(ids) == 0 { return devices, nil } if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&devices).Error; err != nil { return nil, err } return devices, nil } func (s *DeviceStore) Update(ctx context.Context, device *model.Device) error { return s.db.WithContext(ctx).Save(device).Error } func (s *DeviceStore) Delete(ctx context.Context, id uint) error { return s.db.WithContext(ctx).Delete(&model.Device{}, id).Error } func (s *DeviceStore) List(ctx context.Context, opts *store.QueryOptions, filters map[string]any) ([]*model.Device, int64, error) { var devices []*model.Device var total int64 query := s.db.WithContext(ctx).Model(&model.Device{}) if deviceNo, ok := filters["device_no"].(string); ok && deviceNo != "" { query = query.Where("device_no LIKE ?", "%"+deviceNo+"%") } if deviceName, ok := filters["device_name"].(string); ok && deviceName != "" { query = query.Where("device_name LIKE ?", "%"+deviceName+"%") } if status, ok := filters["status"].(int); ok && status > 0 { query = query.Where("status = ?", status) } if shopID, ok := filters["shop_id"].(*uint); ok { if shopID == nil { query = query.Where("shop_id IS NULL") } else { query = query.Where("shop_id = ?", *shopID) } } if batchNo, ok := filters["batch_no"].(string); ok && batchNo != "" { query = query.Where("batch_no = ?", batchNo) } if deviceType, ok := filters["device_type"].(string); ok && deviceType != "" { query = query.Where("device_type = ?", deviceType) } if manufacturer, ok := filters["manufacturer"].(string); ok && manufacturer != "" { query = query.Where("manufacturer LIKE ?", "%"+manufacturer+"%") } if createdAtStart, ok := filters["created_at_start"].(time.Time); ok && !createdAtStart.IsZero() { query = query.Where("created_at >= ?", createdAtStart) } if createdAtEnd, ok := filters["created_at_end"].(time.Time); ok && !createdAtEnd.IsZero() { query = query.Where("created_at <= ?", createdAtEnd) } if seriesID, ok := filters["series_id"].(uint); ok && seriesID > 0 { query = query.Where("series_id = ?", seriesID) } if err := query.Count(&total).Error; err != nil { return nil, 0, err } if opts == nil { opts = &store.QueryOptions{ Page: 1, PageSize: constants.DefaultPageSize, } } offset := (opts.Page - 1) * opts.PageSize query = query.Offset(offset).Limit(opts.PageSize) if opts.OrderBy != "" { query = query.Order(opts.OrderBy) } else { query = query.Order("created_at DESC") } if err := query.Find(&devices).Error; err != nil { return nil, 0, err } return devices, total, nil } func (s *DeviceStore) UpdateShopID(ctx context.Context, id uint, shopID *uint) error { return s.db.WithContext(ctx).Model(&model.Device{}).Where("id = ?", id).Update("shop_id", shopID).Error } func (s *DeviceStore) BatchUpdateShopIDAndStatus(ctx context.Context, ids []uint, shopID *uint, status int) error { if len(ids) == 0 { return nil } updates := map[string]any{ "shop_id": shopID, "status": status, "updated_at": time.Now(), } return s.db.WithContext(ctx).Model(&model.Device{}).Where("id IN ?", ids).Updates(updates).Error } func (s *DeviceStore) ExistsByDeviceNoBatch(ctx context.Context, deviceNos []string) (map[string]bool, error) { result := make(map[string]bool) if len(deviceNos) == 0 { return result, nil } var existingDevices []struct { DeviceNo string } if err := s.db.WithContext(ctx).Model(&model.Device{}). Select("device_no"). Where("device_no IN ?", deviceNos). Find(&existingDevices).Error; err != nil { return nil, err } for _, d := range existingDevices { result[d.DeviceNo] = true } return result, nil } func (s *DeviceStore) GetByDeviceNos(ctx context.Context, deviceNos []string) ([]*model.Device, error) { var devices []*model.Device if len(deviceNos) == 0 { return devices, nil } if err := s.db.WithContext(ctx).Where("device_no IN ?", deviceNos).Find(&devices).Error; err != nil { return nil, err } return devices, nil } // BatchUpdateSeriesID 批量更新设备的套餐系列ID func (s *DeviceStore) BatchUpdateSeriesID(ctx context.Context, deviceIDs []uint, seriesID *uint) error { if len(deviceIDs) == 0 { return nil } return s.db.WithContext(ctx).Model(&model.Device{}). Where("id IN ?", deviceIDs). Update("series_id", seriesID).Error } // ListBySeriesID 根据套餐系列ID查询设备列表 func (s *DeviceStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.Device, error) { var devices []*model.Device if err := s.db.WithContext(ctx).Where("series_id = ?", seriesID).Find(&devices).Error; err != nil { return nil, err } return devices, nil } func (s *DeviceStore) UpdateRechargeTrackingFields(ctx context.Context, deviceID uint, accumulatedJSON, triggeredJSON string) error { return s.db.WithContext(ctx).Model(&model.Device{}). Where("id = ?", deviceID). Updates(map[string]interface{}{ "accumulated_recharge_by_series": accumulatedJSON, "first_recharge_triggered_by_series": triggeredJSON, }).Error }