package packagepkg import ( "context" "time" "gorm.io/gorm" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model/dto" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/middleware" ) type Service struct { packageStore *postgres.PackageStore packageSeriesStore *postgres.PackageSeriesStore packageAllocationStore *postgres.ShopPackageAllocationStore } func New( packageStore *postgres.PackageStore, packageSeriesStore *postgres.PackageSeriesStore, packageAllocationStore *postgres.ShopPackageAllocationStore, ) *Service { return &Service{ packageStore: packageStore, packageSeriesStore: packageSeriesStore, packageAllocationStore: packageAllocationStore, } } func (s *Service) Create(ctx context.Context, req *dto.CreatePackageRequest) (*dto.PackageResponse, error) { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return nil, errors.New(errors.CodeUnauthorized, "未授权访问") } existing, _ := s.packageStore.GetByCode(ctx, req.PackageCode) if existing != nil { return nil, errors.New(errors.CodeConflict, "套餐编码已存在") } // 校验虚流量配置:启用时虚流量必须 > 0 且 ≤ 真流量 if req.EnableVirtualData { if req.VirtualDataMB == nil || *req.VirtualDataMB <= 0 { return nil, errors.New(errors.CodeInvalidParam, "启用虚流量时,虚流量额度必须大于0") } realDataMB := int64(0) if req.RealDataMB != nil { realDataMB = *req.RealDataMB } if *req.VirtualDataMB > realDataMB { return nil, errors.New(errors.CodeInvalidParam, "虚流量额度不能大于真流量额度") } } var seriesName *string if req.SeriesID != nil && *req.SeriesID > 0 { series, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID) if err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.New(errors.CodeNotFound, "套餐系列不存在") } return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐系列失败") } seriesName = &series.SeriesName } pkg := &model.Package{ PackageCode: req.PackageCode, PackageName: req.PackageName, PackageType: req.PackageType, DurationMonths: req.DurationMonths, CostPrice: req.CostPrice, EnableVirtualData: req.EnableVirtualData, Status: constants.StatusEnabled, ShelfStatus: 2, } if req.SeriesID != nil { pkg.SeriesID = *req.SeriesID } if req.RealDataMB != nil { pkg.RealDataMB = *req.RealDataMB } if req.VirtualDataMB != nil { pkg.VirtualDataMB = *req.VirtualDataMB } if req.SuggestedRetailPrice != nil { pkg.SuggestedRetailPrice = *req.SuggestedRetailPrice } pkg.Creator = currentUserID if err := s.packageStore.Create(ctx, pkg); err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "创建套餐失败") } resp := s.toResponse(ctx, pkg) resp.SeriesName = seriesName return resp, nil } func (s *Service) Get(ctx context.Context, id uint) (*dto.PackageResponse, error) { pkg, err := s.packageStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.New(errors.CodeNotFound, "套餐不存在") } return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐失败") } resp := s.toResponse(ctx, pkg) // 查询系列名称 if pkg.SeriesID > 0 { series, err := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID) if err == nil { resp.SeriesName = &series.SeriesName } } return resp, nil } func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdatePackageRequest) (*dto.PackageResponse, error) { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return nil, errors.New(errors.CodeUnauthorized, "未授权访问") } pkg, err := s.packageStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.New(errors.CodeNotFound, "套餐不存在") } return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐失败") } var seriesName *string if req.SeriesID != nil && *req.SeriesID > 0 { series, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID) if err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.New(errors.CodeNotFound, "套餐系列不存在") } return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐系列失败") } pkg.SeriesID = *req.SeriesID seriesName = &series.SeriesName } else if pkg.SeriesID > 0 { series, err := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID) if err == nil { seriesName = &series.SeriesName } } if req.PackageName != nil { pkg.PackageName = *req.PackageName } if req.PackageType != nil { pkg.PackageType = *req.PackageType } if req.DurationMonths != nil { pkg.DurationMonths = *req.DurationMonths } if req.RealDataMB != nil { pkg.RealDataMB = *req.RealDataMB } if req.VirtualDataMB != nil { pkg.VirtualDataMB = *req.VirtualDataMB } if req.EnableVirtualData != nil { pkg.EnableVirtualData = *req.EnableVirtualData } if req.CostPrice != nil { pkg.CostPrice = *req.CostPrice } if req.SuggestedRetailPrice != nil { pkg.SuggestedRetailPrice = *req.SuggestedRetailPrice } // 校验虚流量配置 if pkg.EnableVirtualData { if pkg.VirtualDataMB <= 0 { return nil, errors.New(errors.CodeInvalidParam, "启用虚流量时,虚流量额度必须大于0") } if pkg.VirtualDataMB > pkg.RealDataMB { return nil, errors.New(errors.CodeInvalidParam, "虚流量额度不能大于真流量额度") } } pkg.Updater = currentUserID if err := s.packageStore.Update(ctx, pkg); err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "更新套餐失败") } resp := s.toResponse(ctx, pkg) resp.SeriesName = seriesName return resp, nil } func (s *Service) Delete(ctx context.Context, id uint) error { _, err := s.packageStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeNotFound, "套餐不存在") } return errors.Wrap(errors.CodeInternalError, err, "获取套餐失败") } if err := s.packageStore.Delete(ctx, id); err != nil { return errors.Wrap(errors.CodeInternalError, err, "删除套餐失败") } return nil } func (s *Service) List(ctx context.Context, req *dto.PackageListRequest) ([]*dto.PackageResponse, int64, error) { opts := &store.QueryOptions{ Page: req.Page, PageSize: req.PageSize, OrderBy: "id DESC", } if opts.Page == 0 { opts.Page = 1 } if opts.PageSize == 0 { opts.PageSize = constants.DefaultPageSize } filters := make(map[string]interface{}) if req.PackageName != nil { filters["package_name"] = *req.PackageName } if req.SeriesID != nil { filters["series_id"] = *req.SeriesID } if req.Status != nil { filters["status"] = *req.Status } if req.ShelfStatus != nil { filters["shelf_status"] = *req.ShelfStatus } if req.PackageType != nil { filters["package_type"] = *req.PackageType } packages, total, err := s.packageStore.List(ctx, opts, filters) if err != nil { return nil, 0, errors.Wrap(errors.CodeInternalError, err, "查询套餐列表失败") } // 收集所有唯一的 series_id 和 package_id seriesIDMap := make(map[uint]bool) packageIDs := make([]uint, len(packages)) for i, pkg := range packages { packageIDs[i] = pkg.ID if pkg.SeriesID > 0 { seriesIDMap[pkg.SeriesID] = true } } // 批量查询套餐系列 seriesMap := make(map[uint]string) if len(seriesIDMap) > 0 { seriesIDs := make([]uint, 0, len(seriesIDMap)) for id := range seriesIDMap { seriesIDs = append(seriesIDs, id) } seriesList, err := s.packageSeriesStore.GetByIDs(ctx, seriesIDs) if err != nil { return nil, 0, errors.Wrap(errors.CodeInternalError, err, "批量查询套餐系列失败") } for _, series := range seriesList { seriesMap[series.ID] = series.SeriesName } } userType := middleware.GetUserTypeFromContext(ctx) shopID := middleware.GetShopIDFromContext(ctx) var allocationMap map[uint]*model.ShopPackageAllocation if userType == constants.UserTypeAgent && shopID > 0 && len(packageIDs) > 0 { allocationMap = s.batchGetAllocationsForShop(ctx, shopID, packageIDs) } responses := make([]*dto.PackageResponse, len(packages)) for i, pkg := range packages { resp := s.toResponseWithAllocation(pkg, allocationMap) if pkg.SeriesID > 0 { if seriesName, ok := seriesMap[pkg.SeriesID]; ok { resp.SeriesName = &seriesName } } responses[i] = resp } return responses, total, nil } func (s *Service) UpdateStatus(ctx context.Context, id uint, status int) error { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return errors.New(errors.CodeUnauthorized, "未授权访问") } pkg, err := s.packageStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeNotFound, "套餐不存在") } return errors.Wrap(errors.CodeInternalError, err, "获取套餐失败") } pkg.Status = status pkg.Updater = currentUserID if status == constants.StatusDisabled { pkg.ShelfStatus = 2 } if err := s.packageStore.Update(ctx, pkg); err != nil { return errors.Wrap(errors.CodeInternalError, err, "更新套餐状态失败") } return nil } func (s *Service) UpdateShelfStatus(ctx context.Context, id uint, shelfStatus int) error { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return errors.New(errors.CodeUnauthorized, "未授权访问") } pkg, err := s.packageStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeNotFound, "套餐不存在") } return errors.Wrap(errors.CodeInternalError, err, "获取套餐失败") } if shelfStatus == 1 && pkg.Status == constants.StatusDisabled { return errors.New(errors.CodeInvalidStatus, "禁用的套餐不能上架,请先启用") } pkg.ShelfStatus = shelfStatus pkg.Updater = currentUserID if err := s.packageStore.Update(ctx, pkg); err != nil { return errors.Wrap(errors.CodeInternalError, err, "更新套餐上架状态失败") } return nil } func (s *Service) toResponse(ctx context.Context, pkg *model.Package) *dto.PackageResponse { var seriesID *uint if pkg.SeriesID > 0 { seriesID = &pkg.SeriesID } resp := &dto.PackageResponse{ ID: pkg.ID, PackageCode: pkg.PackageCode, PackageName: pkg.PackageName, SeriesID: seriesID, PackageType: pkg.PackageType, DurationMonths: pkg.DurationMonths, RealDataMB: pkg.RealDataMB, VirtualDataMB: pkg.VirtualDataMB, EnableVirtualData: pkg.EnableVirtualData, CostPrice: pkg.CostPrice, SuggestedRetailPrice: pkg.SuggestedRetailPrice, Status: pkg.Status, ShelfStatus: pkg.ShelfStatus, CreatedAt: pkg.CreatedAt.Format(time.RFC3339), UpdatedAt: pkg.UpdatedAt.Format(time.RFC3339), } userType := middleware.GetUserTypeFromContext(ctx) shopID := middleware.GetShopIDFromContext(ctx) if userType == constants.UserTypeAgent && shopID > 0 { allocation, err := s.packageAllocationStore.GetByShopAndPackage(ctx, shopID, pkg.ID) if err == nil && allocation != nil { resp.CostPrice = allocation.CostPrice profitMargin := pkg.SuggestedRetailPrice - allocation.CostPrice resp.ProfitMargin = &profitMargin } } return resp } func (s *Service) batchGetAllocationsForShop(ctx context.Context, shopID uint, packageIDs []uint) map[uint]*model.ShopPackageAllocation { allocationMap := make(map[uint]*model.ShopPackageAllocation) allocations, err := s.packageAllocationStore.GetByShopAndPackages(ctx, shopID, packageIDs) if err != nil || len(allocations) == 0 { return allocationMap } for _, alloc := range allocations { allocationMap[alloc.PackageID] = alloc } return allocationMap } func (s *Service) toResponseWithAllocation(pkg *model.Package, allocationMap map[uint]*model.ShopPackageAllocation) *dto.PackageResponse { var seriesID *uint if pkg.SeriesID > 0 { seriesID = &pkg.SeriesID } resp := &dto.PackageResponse{ ID: pkg.ID, PackageCode: pkg.PackageCode, PackageName: pkg.PackageName, SeriesID: seriesID, PackageType: pkg.PackageType, DurationMonths: pkg.DurationMonths, RealDataMB: pkg.RealDataMB, VirtualDataMB: pkg.VirtualDataMB, EnableVirtualData: pkg.EnableVirtualData, CostPrice: pkg.CostPrice, SuggestedRetailPrice: pkg.SuggestedRetailPrice, Status: pkg.Status, ShelfStatus: pkg.ShelfStatus, CreatedAt: pkg.CreatedAt.Format(time.RFC3339), UpdatedAt: pkg.UpdatedAt.Format(time.RFC3339), } if allocationMap != nil { if allocation, ok := allocationMap[pkg.ID]; ok { resp.CostPrice = allocation.CostPrice profitMargin := pkg.SuggestedRetailPrice - allocation.CostPrice resp.ProfitMargin = &profitMargin } } return resp }