package postgres import ( "context" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) // EnterpriseStore 企业数据访问层 type EnterpriseStore struct { db *gorm.DB redis *redis.Client } // NewEnterpriseStore 创建企业 Store func NewEnterpriseStore(db *gorm.DB, redis *redis.Client) *EnterpriseStore { return &EnterpriseStore{ db: db, redis: redis, } } // Create 创建企业 func (s *EnterpriseStore) Create(ctx context.Context, enterprise *model.Enterprise) error { return s.db.WithContext(ctx).Create(enterprise).Error } // GetByID 根据 ID 获取企业 func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterprise, error) { var enterprise model.Enterprise if err := s.db.WithContext(ctx).First(&enterprise, id).Error; err != nil { return nil, err } return &enterprise, nil } // GetByCode 根据企业编号获取企业 func (s *EnterpriseStore) GetByCode(ctx context.Context, code string) (*model.Enterprise, error) { var enterprise model.Enterprise if err := s.db.WithContext(ctx).Where("enterprise_code = ?", code).First(&enterprise).Error; err != nil { return nil, err } return &enterprise, nil } // Update 更新企业 func (s *EnterpriseStore) Update(ctx context.Context, enterprise *model.Enterprise) error { return s.db.WithContext(ctx).Save(enterprise).Error } // Delete 软删除企业 func (s *EnterpriseStore) Delete(ctx context.Context, id uint) error { return s.db.WithContext(ctx).Delete(&model.Enterprise{}, id).Error } // List 查询企业列表 func (s *EnterpriseStore) List(ctx context.Context, opts *store.QueryOptions, filters map[string]interface{}) ([]*model.Enterprise, int64, error) { var enterprises []*model.Enterprise var total int64 query := s.db.WithContext(ctx).Model(&model.Enterprise{}) // 应用过滤条件 if enterpriseName, ok := filters["enterprise_name"].(string); ok && enterpriseName != "" { query = query.Where("enterprise_name LIKE ?", "%"+enterpriseName+"%") } if enterpriseCode, ok := filters["enterprise_code"].(string); ok && enterpriseCode != "" { query = query.Where("enterprise_code = ?", enterpriseCode) } if ownerShopID, ok := filters["owner_shop_id"].(uint); ok { query = query.Where("owner_shop_id = ?", ownerShopID) } if status, ok := filters["status"].(int); ok { query = query.Where("status = ?", status) } // 计算总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 分页 if opts == nil { opts = &store.QueryOptions{ Page: 1, PageSize: constants.DefaultPageSize, } } offset := (opts.Page - 1) * opts.PageSize query = query.Offset(offset).Limit(opts.PageSize) // 排序 if opts.OrderBy != "" { query = query.Order(opts.OrderBy) } else { query = query.Order("created_at DESC") } // 查询 if err := query.Find(&enterprises).Error; err != nil { return nil, 0, err } return enterprises, total, nil } // GetByOwnerShopID 根据归属店铺 ID 查询企业列表 func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint) ([]*model.Enterprise, error) { var enterprises []*model.Enterprise if err := s.db.WithContext(ctx).Where("owner_shop_id = ?", ownerShopID).Find(&enterprises).Error; err != nil { return nil, err } return enterprises, nil } // GetPlatformEnterprises 获取平台直属企业列表(owner_shop_id 为 NULL) func (s *EnterpriseStore) GetPlatformEnterprises(ctx context.Context) ([]*model.Enterprise, error) { var enterprises []*model.Enterprise if err := s.db.WithContext(ctx).Where("owner_shop_id IS NULL").Find(&enterprises).Error; err != nil { return nil, err } return enterprises, nil }