重构: 店铺套餐分配系统从加价模式改为返佣模式
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m18s

主要变更:
- 重构分配模型:从加价模式(pricing_mode/pricing_value)改为返佣模式(base_commission + tier_commission)
- 删除独立的 my_package 接口,统一到 /api/admin/packages(通过数据权限自动过滤)
- 新增批量分配和批量调价功能,支持事务和性能优化
- 新增配置版本管理,订单创建时锁定返佣配置
- 新增成本价历史记录,支持审计和纠纷处理
- 新增统计缓存系统(Redis + 异步任务),优化梯度返佣计算性能
- 删除冗余的梯度佣金独立 CRUD 接口(合并到分配配置中)
- 归档 3 个已完成的 OpenSpec changes 并同步 8 个新 capabilities 到 main specs

技术细节:
- 数据库迁移:000026_refactor_shop_package_allocation
- 新增 Store:AllocationConfigStore, PriceHistoryStore, CommissionStatsStore
- 新增 Service:BatchAllocationService, BatchPricingService, CommissionStatsService
- 新增异步任务:统计更新、定时同步、周期归档
- 测试覆盖:批量操作集成测试、梯度佣金 CRUD 清理验证

影响:
- API 变更:删除 4 个梯度 CRUD 接口(POST/GET/PUT/DELETE /:id/tiers)
- API 新增:批量分配、批量调价接口
- 数据模型:重构 shop_series_allocation 表结构
- 性能优化:批量操作使用 CreateInBatches,统计使用 Redis 缓存

相关文档:
- openspec/changes/archive/2026-01-28-refactor-shop-package-allocation/
- openspec/specs/agent-available-packages/
- openspec/specs/allocation-config-versioning/
- 等 8 个新 capability specs
This commit is contained in:
2026-01-28 17:11:55 +08:00
parent 23eb0307bb
commit 1da680a790
97 changed files with 6810 additions and 3622 deletions

View File

@@ -0,0 +1,98 @@
package commission_stats
import (
"context"
"fmt"
"time"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"gorm.io/gorm"
)
type Service struct {
statsStore *postgres.ShopSeriesCommissionStatsStore
}
func New(statsStore *postgres.ShopSeriesCommissionStatsStore) *Service {
return &Service{
statsStore: statsStore,
}
}
func (s *Service) GetCurrentStats(ctx context.Context, allocationID uint, periodType string) (*model.ShopSeriesCommissionStats, error) {
now := time.Now()
stats, err := s.statsStore.GetCurrent(ctx, allocationID, periodType, now)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "统计数据不存在")
}
return nil, fmt.Errorf("获取统计数据失败: %w", err)
}
return stats, nil
}
func (s *Service) UpdateStats(ctx context.Context, allocationID uint, periodType string, salesCount int64, salesAmount int64) error {
now := time.Now()
periodStart, periodEnd := calculatePeriod(now, periodType)
stats, err := s.statsStore.GetCurrent(ctx, allocationID, periodType, now)
if err != nil && err != gorm.ErrRecordNotFound {
return fmt.Errorf("查询统计数据失败: %w", err)
}
if stats == nil {
stats = &model.ShopSeriesCommissionStats{
AllocationID: allocationID,
PeriodType: periodType,
PeriodStart: periodStart,
PeriodEnd: periodEnd,
TotalSalesCount: salesCount,
TotalSalesAmount: salesAmount,
Status: "active",
LastUpdatedAt: now,
Version: 1,
}
return s.statsStore.Create(ctx, stats)
}
return s.statsStore.IncrementSales(ctx, stats.ID, salesCount, salesAmount, stats.Version)
}
func (s *Service) ArchiveCompletedPeriod(ctx context.Context, allocationID uint, periodType string) error {
now := time.Now()
stats, err := s.statsStore.GetCurrent(ctx, allocationID, periodType, now)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil
}
return fmt.Errorf("查询统计数据失败: %w", err)
}
return s.statsStore.CompletePeriod(ctx, stats.ID)
}
func calculatePeriod(now time.Time, periodType string) (time.Time, time.Time) {
var periodStart, periodEnd time.Time
switch periodType {
case "monthly":
periodStart = time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
periodEnd = periodStart.AddDate(0, 1, 0).Add(-time.Second)
case "quarterly":
quarter := (int(now.Month()) - 1) / 3
periodStart = time.Date(now.Year(), time.Month(quarter*3+1), 1, 0, 0, 0, 0, now.Location())
periodEnd = periodStart.AddDate(0, 3, 0).Add(-time.Second)
case "yearly":
periodStart = time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
periodEnd = periodStart.AddDate(1, 0, 0).Add(-time.Second)
default:
periodStart = time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
periodEnd = periodStart.AddDate(0, 1, 0).Add(-time.Second)
}
return periodStart, periodEnd
}

View File

@@ -1,306 +0,0 @@
package my_package
import (
"context"
"fmt"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
)
type Service struct {
seriesAllocationStore *postgres.ShopSeriesAllocationStore
packageAllocationStore *postgres.ShopPackageAllocationStore
packageSeriesStore *postgres.PackageSeriesStore
packageStore *postgres.PackageStore
shopStore *postgres.ShopStore
}
func New(
seriesAllocationStore *postgres.ShopSeriesAllocationStore,
packageAllocationStore *postgres.ShopPackageAllocationStore,
packageSeriesStore *postgres.PackageSeriesStore,
packageStore *postgres.PackageStore,
shopStore *postgres.ShopStore,
) *Service {
return &Service{
seriesAllocationStore: seriesAllocationStore,
packageAllocationStore: packageAllocationStore,
packageSeriesStore: packageSeriesStore,
packageStore: packageStore,
shopStore: shopStore,
}
}
func (s *Service) ListMyPackages(ctx context.Context, req *dto.MyPackageListRequest) ([]*dto.MyPackageResponse, int64, error) {
shopID := middleware.GetShopIDFromContext(ctx)
if shopID == 0 {
return nil, 0, errors.New(errors.CodeUnauthorized, "当前用户不属于任何店铺")
}
seriesAllocations, err := s.seriesAllocationStore.GetByShopID(ctx, shopID)
if err != nil {
return nil, 0, fmt.Errorf("获取系列分配失败: %w", err)
}
if len(seriesAllocations) == 0 {
return []*dto.MyPackageResponse{}, 0, nil
}
seriesIDs := make([]uint, 0, len(seriesAllocations))
for _, sa := range seriesAllocations {
seriesIDs = append(seriesIDs, sa.SeriesID)
}
opts := &store.QueryOptions{
Page: req.Page,
PageSize: req.PageSize,
OrderBy: "id DESC",
}
if opts.Page == 0 {
opts.Page = 1
}
if opts.PageSize == 0 {
opts.PageSize = constants.DefaultPageSize
}
filters := make(map[string]interface{})
filters["series_ids"] = seriesIDs
filters["status"] = constants.StatusEnabled
filters["shelf_status"] = 1
if req.SeriesID != nil {
found := false
for _, sid := range seriesIDs {
if sid == *req.SeriesID {
found = true
break
}
}
if !found {
return []*dto.MyPackageResponse{}, 0, nil
}
filters["series_id"] = *req.SeriesID
}
if req.PackageType != nil {
filters["package_type"] = *req.PackageType
}
packages, total, err := s.packageStore.List(ctx, opts, filters)
if err != nil {
return nil, 0, fmt.Errorf("查询套餐列表失败: %w", err)
}
packageOverrides, _ := s.packageAllocationStore.GetByShopID(ctx, shopID)
overrideMap := make(map[uint]*model.ShopPackageAllocation)
for _, po := range packageOverrides {
overrideMap[po.PackageID] = po
}
allocationMap := make(map[uint]*model.ShopSeriesAllocation)
for _, sa := range seriesAllocations {
allocationMap[sa.SeriesID] = sa
}
responses := make([]*dto.MyPackageResponse, len(packages))
for i, pkg := range packages {
series, _ := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID)
seriesName := ""
if series != nil {
seriesName = series.SeriesName
}
costPrice, priceSource := s.GetCostPrice(ctx, shopID, pkg, allocationMap, overrideMap)
responses[i] = &dto.MyPackageResponse{
ID: pkg.ID,
PackageCode: pkg.PackageCode,
PackageName: pkg.PackageName,
PackageType: pkg.PackageType,
SeriesID: pkg.SeriesID,
SeriesName: seriesName,
CostPrice: costPrice,
SuggestedRetailPrice: pkg.SuggestedRetailPrice,
ProfitMargin: pkg.SuggestedRetailPrice - costPrice,
PriceSource: priceSource,
Status: pkg.Status,
ShelfStatus: pkg.ShelfStatus,
}
}
return responses, total, nil
}
func (s *Service) GetMyPackage(ctx context.Context, packageID uint) (*dto.MyPackageDetailResponse, error) {
shopID := middleware.GetShopIDFromContext(ctx)
if shopID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "当前用户不属于任何店铺")
}
pkg, err := s.packageStore.GetByID(ctx, packageID)
if err != nil {
return nil, errors.New(errors.CodeNotFound, "套餐不存在")
}
seriesAllocation, err := s.seriesAllocationStore.GetByShopAndSeries(ctx, shopID, pkg.SeriesID)
if err != nil {
return nil, errors.New(errors.CodeForbidden, "您没有该套餐的销售权限")
}
series, _ := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID)
seriesName := ""
if series != nil {
seriesName = series.SeriesName
}
allocationMap := map[uint]*model.ShopSeriesAllocation{pkg.SeriesID: seriesAllocation}
packageOverride, _ := s.packageAllocationStore.GetByShopAndPackage(ctx, shopID, packageID)
overrideMap := make(map[uint]*model.ShopPackageAllocation)
if packageOverride != nil {
overrideMap[packageID] = packageOverride
}
costPrice, priceSource := s.GetCostPrice(ctx, shopID, pkg, allocationMap, overrideMap)
return &dto.MyPackageDetailResponse{
ID: pkg.ID,
PackageCode: pkg.PackageCode,
PackageName: pkg.PackageName,
PackageType: pkg.PackageType,
Description: "",
SeriesID: pkg.SeriesID,
SeriesName: seriesName,
CostPrice: costPrice,
SuggestedRetailPrice: pkg.SuggestedRetailPrice,
ProfitMargin: pkg.SuggestedRetailPrice - costPrice,
PriceSource: priceSource,
Status: pkg.Status,
ShelfStatus: pkg.ShelfStatus,
}, nil
}
func (s *Service) ListMySeriesAllocations(ctx context.Context, req *dto.MySeriesAllocationListRequest) ([]*dto.MySeriesAllocationResponse, int64, error) {
shopID := middleware.GetShopIDFromContext(ctx)
if shopID == 0 {
return nil, 0, errors.New(errors.CodeUnauthorized, "当前用户不属于任何店铺")
}
allocations, err := s.seriesAllocationStore.GetByShopID(ctx, shopID)
if err != nil {
return nil, 0, fmt.Errorf("获取系列分配失败: %w", err)
}
total := int64(len(allocations))
page := req.Page
pageSize := req.PageSize
if page == 0 {
page = 1
}
if pageSize == 0 {
pageSize = constants.DefaultPageSize
}
start := (page - 1) * pageSize
end := start + pageSize
if start >= int(total) {
return []*dto.MySeriesAllocationResponse{}, total, nil
}
if end > int(total) {
end = int(total)
}
allocations = allocations[start:end]
responses := make([]*dto.MySeriesAllocationResponse, len(allocations))
for i, a := range allocations {
series, _ := s.packageSeriesStore.GetByID(ctx, a.SeriesID)
seriesCode := ""
seriesName := ""
if series != nil {
seriesCode = series.SeriesCode
seriesName = series.SeriesName
}
allocatorShop, _ := s.shopStore.GetByID(ctx, a.AllocatorShopID)
allocatorShopName := ""
if allocatorShop != nil {
allocatorShopName = allocatorShop.ShopName
}
availableCount := 0
filters := map[string]interface{}{
"series_id": a.SeriesID,
"status": constants.StatusEnabled,
"shelf_status": 1,
}
packages, _, _ := s.packageStore.List(ctx, &store.QueryOptions{Page: 1, PageSize: 1000}, filters)
availableCount = len(packages)
responses[i] = &dto.MySeriesAllocationResponse{
ID: a.ID,
SeriesID: a.SeriesID,
SeriesCode: seriesCode,
SeriesName: seriesName,
PricingMode: a.PricingMode,
PricingValue: a.PricingValue,
AvailablePackageCount: availableCount,
AllocatorShopName: allocatorShopName,
Status: a.Status,
}
}
return responses, total, nil
}
func (s *Service) GetCostPrice(ctx context.Context, shopID uint, pkg *model.Package, allocationMap map[uint]*model.ShopSeriesAllocation, overrideMap map[uint]*model.ShopPackageAllocation) (int64, string) {
if override, ok := overrideMap[pkg.ID]; ok && override.Status == constants.StatusEnabled {
return override.CostPrice, dto.PriceSourcePackageOverride
}
allocation, ok := allocationMap[pkg.SeriesID]
if !ok {
return 0, ""
}
parentCostPrice := s.getParentCostPriceRecursive(ctx, allocation.AllocatorShopID, pkg)
costPrice := s.calculateCostPrice(parentCostPrice, allocation.PricingMode, allocation.PricingValue)
return costPrice, dto.PriceSourceSeriesPricing
}
func (s *Service) getParentCostPriceRecursive(ctx context.Context, shopID uint, pkg *model.Package) int64 {
shop, err := s.shopStore.GetByID(ctx, shopID)
if err != nil {
return pkg.SuggestedCostPrice
}
if shop.ParentID == nil || *shop.ParentID == 0 {
return pkg.SuggestedCostPrice
}
allocation, err := s.seriesAllocationStore.GetByShopAndSeries(ctx, shopID, pkg.SeriesID)
if err != nil {
return pkg.SuggestedCostPrice
}
parentCostPrice := s.getParentCostPriceRecursive(ctx, allocation.AllocatorShopID, pkg)
return s.calculateCostPrice(parentCostPrice, allocation.PricingMode, allocation.PricingValue)
}
func (s *Service) calculateCostPrice(parentCostPrice int64, pricingMode string, pricingValue int64) int64 {
switch pricingMode {
case model.PricingModeFixed:
return parentCostPrice + pricingValue
case model.PricingModePercent:
return parentCostPrice + (parentCostPrice * pricingValue / 1000)
default:
return parentCostPrice
}
}

View File

@@ -1,820 +0,0 @@
package my_package
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/tests/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_GetCostPrice_Priority(t *testing.T) {
tx := testutils.NewTestTransaction(t)
ctx := context.Background()
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
// 创建测试数据:套餐系列
series := &model.PackageSeries{
SeriesCode: "TEST_SERIES_001",
SeriesName: "测试系列",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series))
// 创建测试数据:套餐
pkg := &model.Package{
PackageCode: "TEST_PKG_001",
PackageName: "测试套餐",
SeriesID: series.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000, // 基础成本价50元
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg))
// 创建测试数据:上级店铺
allocatorShop := &model.Shop{
ShopName: "上级店铺",
ShopCode: "ALLOCATOR_001",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
// 创建测试数据:下级店铺
shop := &model.Shop{
ShopName: "下级店铺",
ShopCode: "SHOP_001",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 创建测试数据:系列分配(系列加价模式)
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000, // 固定加价10元
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
t.Run("套餐覆盖优先级最高", func(t *testing.T) {
// 创建套餐覆盖覆盖成本价80元
packageOverride := &model.ShopPackageAllocation{
ShopID: shop.ID,
PackageID: pkg.ID,
AllocationID: seriesAllocation.ID,
CostPrice: 8000,
Status: constants.StatusEnabled,
}
require.NoError(t, packageAllocationStore.Create(ctx, packageOverride))
allocationMap := map[uint]*model.ShopSeriesAllocation{series.ID: seriesAllocation}
overrideMap := map[uint]*model.ShopPackageAllocation{pkg.ID: packageOverride}
costPrice, priceSource := svc.GetCostPrice(ctx, shop.ID, pkg, allocationMap, overrideMap)
// 应该返回套餐覆盖的成本价
assert.Equal(t, int64(8000), costPrice)
assert.Equal(t, dto.PriceSourcePackageOverride, priceSource)
})
t.Run("套餐覆盖禁用时使用系列加价", func(t *testing.T) {
pkg2 := &model.Package{
PackageCode: "TEST_PKG_001_DISABLED",
PackageName: "测试套餐禁用",
SeriesID: series.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000,
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg2))
packageOverride := &model.ShopPackageAllocation{
ShopID: shop.ID,
PackageID: pkg2.ID,
AllocationID: seriesAllocation.ID,
CostPrice: 8000,
Status: constants.StatusDisabled,
}
allocationMap := map[uint]*model.ShopSeriesAllocation{series.ID: seriesAllocation}
overrideMap := map[uint]*model.ShopPackageAllocation{pkg2.ID: packageOverride}
costPrice, priceSource := svc.GetCostPrice(ctx, shop.ID, pkg2, allocationMap, overrideMap)
assert.Equal(t, int64(6000), costPrice)
assert.Equal(t, dto.PriceSourceSeriesPricing, priceSource)
})
t.Run("无套餐覆盖时使用系列加价", func(t *testing.T) {
allocationMap := map[uint]*model.ShopSeriesAllocation{series.ID: seriesAllocation}
overrideMap := make(map[uint]*model.ShopPackageAllocation)
costPrice, priceSource := svc.GetCostPrice(ctx, shop.ID, pkg, allocationMap, overrideMap)
// 应该返回系列加价的成本价5000 + 1000 = 6000
assert.Equal(t, int64(6000), costPrice)
assert.Equal(t, dto.PriceSourceSeriesPricing, priceSource)
})
t.Run("无系列分配时返回0", func(t *testing.T) {
allocationMap := make(map[uint]*model.ShopSeriesAllocation)
overrideMap := make(map[uint]*model.ShopPackageAllocation)
costPrice, priceSource := svc.GetCostPrice(ctx, shop.ID, pkg, allocationMap, overrideMap)
// 应该返回0和空的价格来源
assert.Equal(t, int64(0), costPrice)
assert.Equal(t, "", priceSource)
})
}
func TestService_calculateCostPrice(t *testing.T) {
tx := testutils.NewTestTransaction(t)
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
tests := []struct {
name string
parentCostPrice int64
pricingMode string
pricingValue int64
expectedCostPrice int64
description string
}{
{
name: "固定金额加价模式",
parentCostPrice: 5000, // 50元
pricingMode: model.PricingModeFixed,
pricingValue: 1000, // 加价10元
expectedCostPrice: 6000, // 60元
description: "固定加价5000 + 1000 = 6000",
},
{
name: "百分比加价模式",
parentCostPrice: 5000, // 50元
pricingMode: model.PricingModePercent,
pricingValue: 200, // 20%千分比200/1000 = 20%
expectedCostPrice: 6000, // 50 + 50*20% = 60元
description: "百分比加价5000 + (5000 * 200 / 1000) = 6000",
},
{
name: "百分比加价模式-10%",
parentCostPrice: 10000, // 100元
pricingMode: model.PricingModePercent,
pricingValue: 100, // 10%千分比100/1000 = 10%
expectedCostPrice: 11000, // 100 + 100*10% = 110元
description: "百分比加价10000 + (10000 * 100 / 1000) = 11000",
},
{
name: "未知加价模式返回原价",
parentCostPrice: 5000,
pricingMode: "unknown",
pricingValue: 1000,
expectedCostPrice: 5000, // 返回原价不变
description: "未知模式:返回 parentCostPrice 不变",
},
{
name: "零加价",
parentCostPrice: 5000,
pricingMode: model.PricingModeFixed,
pricingValue: 0,
expectedCostPrice: 5000,
description: "零加价5000 + 0 = 5000",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
costPrice := svc.calculateCostPrice(tt.parentCostPrice, tt.pricingMode, tt.pricingValue)
assert.Equal(t, tt.expectedCostPrice, costPrice, tt.description)
})
}
}
func TestService_ListMyPackages_Authorization(t *testing.T) {
tx := testutils.NewTestTransaction(t)
ctx := context.Background()
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
t.Run("店铺ID为0时返回错误", func(t *testing.T) {
// 创建不包含店铺ID的context
ctxWithoutShop := context.WithValue(ctx, constants.ContextKeyShopID, uint(0))
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
}
packages, total, err := svc.ListMyPackages(ctxWithoutShop, req)
// 应该返回错误
require.Error(t, err)
assert.Nil(t, packages)
assert.Equal(t, int64(0), total)
assert.Contains(t, err.Error(), "当前用户不属于任何店铺")
})
t.Run("无系列分配时返回空列表", func(t *testing.T) {
// 创建店铺
shop := &model.Shop{
ShopName: "测试店铺",
ShopCode: "SHOP_TEST_001",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 创建包含店铺ID的context
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
// 应该返回空列表,无错误
require.NoError(t, err)
assert.NotNil(t, packages)
assert.Equal(t, 0, len(packages))
assert.Equal(t, int64(0), total)
})
t.Run("有系列分配时返回套餐列表", func(t *testing.T) {
// 创建套餐系列
series := &model.PackageSeries{
SeriesCode: "TEST_SERIES_002",
SeriesName: "测试系列2",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series))
// 创建套餐
pkg := &model.Package{
PackageCode: "TEST_PKG_002",
PackageName: "测试套餐2",
SeriesID: series.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000,
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg))
// 创建上级店铺
allocatorShop := &model.Shop{
ShopName: "上级店铺2",
ShopCode: "ALLOCATOR_002",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
// 创建下级店铺
shop := &model.Shop{
ShopName: "下级店铺2",
ShopCode: "SHOP_002",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 创建系列分配
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
// 创建包含店铺ID的context
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
// 应该返回套餐列表
require.NoError(t, err)
assert.NotNil(t, packages)
assert.Equal(t, 1, len(packages))
assert.Equal(t, int64(1), total)
assert.Equal(t, pkg.ID, packages[0].ID)
assert.Equal(t, pkg.PackageName, packages[0].PackageName)
// 验证成本价计算5000 + 1000 = 6000
assert.Equal(t, int64(6000), packages[0].CostPrice)
assert.Equal(t, dto.PriceSourceSeriesPricing, packages[0].PriceSource)
})
t.Run("分页参数默认值", func(t *testing.T) {
series := &model.PackageSeries{
SeriesCode: "TEST_SERIES_PAGING",
SeriesName: "分页测试系列",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series))
for i := range 5 {
pkg := &model.Package{
PackageCode: "TEST_PKG_PAGING_" + string(byte('0'+byte(i))),
PackageName: "分页测试套餐_" + string(byte('0'+byte(i))),
SeriesID: series.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000,
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg))
}
allocatorShop := &model.Shop{
ShopName: "分页上级店铺",
ShopCode: "ALLOCATOR_PAGING",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
shop := &model.Shop{
ShopName: "分页下级店铺",
ShopCode: "SHOP_PAGING",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
req := &dto.MyPackageListRequest{}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
require.NoError(t, err)
assert.NotNil(t, packages)
assert.GreaterOrEqual(t, total, int64(5))
assert.LessOrEqual(t, len(packages), constants.DefaultPageSize)
})
}
func TestService_ListMyPackages_Filtering(t *testing.T) {
tx := testutils.NewTestTransaction(t)
ctx := context.Background()
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
// 创建两个套餐系列
series1 := &model.PackageSeries{
SeriesCode: "SERIES_FILTER_001",
SeriesName: "系列1",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series1))
series2 := &model.PackageSeries{
SeriesCode: "SERIES_FILTER_002",
SeriesName: "系列2",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series2))
// 创建不同类型的套餐
pkg1 := &model.Package{
PackageCode: "PKG_FILTER_001",
PackageName: "正式套餐1",
SeriesID: series1.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000,
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg1))
pkg2 := &model.Package{
PackageCode: "PKG_FILTER_002",
PackageName: "附加套餐1",
SeriesID: series2.ID,
PackageType: "addon",
DurationMonths: 1,
DataType: "real",
RealDataMB: 512,
DataAmountMB: 512,
Price: 4900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 2500,
SuggestedRetailPrice: 4900,
}
require.NoError(t, packageStore.Create(ctx, pkg2))
// 创建上级店铺
allocatorShop := &model.Shop{
ShopName: "上级店铺过滤",
ShopCode: "ALLOCATOR_FILTER",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
// 创建下级店铺
shop := &model.Shop{
ShopName: "下级店铺过滤",
ShopCode: "SHOP_FILTER",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 为两个系列都创建分配
for _, series := range []*model.PackageSeries{series1, series2} {
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
}
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
t.Run("按系列ID过滤", func(t *testing.T) {
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
SeriesID: &series1.ID,
}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
require.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Equal(t, 1, len(packages))
assert.Equal(t, pkg1.ID, packages[0].ID)
})
t.Run("按套餐类型过滤", func(t *testing.T) {
packageType := "addon"
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
PackageType: &packageType,
}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
require.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Equal(t, 1, len(packages))
assert.Equal(t, pkg2.ID, packages[0].ID)
})
t.Run("无效的系列ID返回空列表", func(t *testing.T) {
invalidSeriesID := uint(99999)
req := &dto.MyPackageListRequest{
Page: 1,
PageSize: 20,
SeriesID: &invalidSeriesID,
}
packages, total, err := svc.ListMyPackages(ctxWithShop, req)
require.NoError(t, err)
assert.Equal(t, int64(0), total)
assert.Equal(t, 0, len(packages))
})
}
func TestService_GetMyPackage(t *testing.T) {
tx := testutils.NewTestTransaction(t)
ctx := context.Background()
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
// 创建套餐系列
series := &model.PackageSeries{
SeriesCode: "DETAIL_SERIES",
SeriesName: "详情系列",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series))
// 创建套餐
pkg := &model.Package{
PackageCode: "DETAIL_PKG",
PackageName: "详情套餐",
SeriesID: series.ID,
PackageType: "formal",
DurationMonths: 1,
DataType: "real",
RealDataMB: 1024,
DataAmountMB: 1024,
Price: 9900,
Status: constants.StatusEnabled,
ShelfStatus: 1,
SuggestedCostPrice: 5000,
SuggestedRetailPrice: 9900,
}
require.NoError(t, packageStore.Create(ctx, pkg))
// 创建上级店铺
allocatorShop := &model.Shop{
ShopName: "上级店铺详情",
ShopCode: "ALLOCATOR_DETAIL",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
// 创建下级店铺
shop := &model.Shop{
ShopName: "下级店铺详情",
ShopCode: "SHOP_DETAIL",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 创建系列分配
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
t.Run("店铺ID为0时返回错误", func(t *testing.T) {
ctxWithoutShop := context.WithValue(ctx, constants.ContextKeyShopID, uint(0))
_, err := svc.GetMyPackage(ctxWithoutShop, pkg.ID)
require.Error(t, err)
assert.Contains(t, err.Error(), "当前用户不属于任何店铺")
})
t.Run("成功获取套餐详情", func(t *testing.T) {
detail, err := svc.GetMyPackage(ctxWithShop, pkg.ID)
require.NoError(t, err)
assert.NotNil(t, detail)
assert.Equal(t, pkg.ID, detail.ID)
assert.Equal(t, pkg.PackageName, detail.PackageName)
assert.Equal(t, series.SeriesName, detail.SeriesName)
// 验证成本价5000 + 1000 = 6000
assert.Equal(t, int64(6000), detail.CostPrice)
assert.Equal(t, dto.PriceSourceSeriesPricing, detail.PriceSource)
})
t.Run("无权限访问套餐时返回错误", func(t *testing.T) {
// 创建另一个没有系列分配的店铺
otherShop := &model.Shop{
ShopName: "其他店铺",
ShopCode: "OTHER_SHOP",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000002",
}
require.NoError(t, shopStore.Create(ctx, otherShop))
ctxWithOtherShop := context.WithValue(ctx, constants.ContextKeyShopID, otherShop.ID)
_, err := svc.GetMyPackage(ctxWithOtherShop, pkg.ID)
require.Error(t, err)
assert.Contains(t, err.Error(), "您没有该套餐的销售权限")
})
}
func TestService_ListMySeriesAllocations(t *testing.T) {
tx := testutils.NewTestTransaction(t)
ctx := context.Background()
seriesAllocationStore := postgres.NewShopSeriesAllocationStore(tx)
packageAllocationStore := postgres.NewShopPackageAllocationStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
shopStore := postgres.NewShopStore(tx, nil)
// 创建 Service
svc := New(seriesAllocationStore, packageAllocationStore, packageSeriesStore, packageStore, shopStore)
t.Run("店铺ID为0时返回错误", func(t *testing.T) {
ctxWithoutShop := context.WithValue(ctx, constants.ContextKeyShopID, uint(0))
req := &dto.MySeriesAllocationListRequest{
Page: 1,
PageSize: 20,
}
_, _, err := svc.ListMySeriesAllocations(ctxWithoutShop, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "当前用户不属于任何店铺")
})
t.Run("无系列分配时返回空列表", func(t *testing.T) {
shop := &model.Shop{
ShopName: "分配测试店铺",
ShopCode: "ALLOC_SHOP",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, shop))
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
req := &dto.MySeriesAllocationListRequest{
Page: 1,
PageSize: 20,
}
allocations, total, err := svc.ListMySeriesAllocations(ctxWithShop, req)
require.NoError(t, err)
assert.NotNil(t, allocations)
assert.Equal(t, 0, len(allocations))
assert.Equal(t, int64(0), total)
})
t.Run("成功列表系列分配", func(t *testing.T) {
// 创建套餐系列
series := &model.PackageSeries{
SeriesCode: "ALLOC_SERIES",
SeriesName: "分配系列",
Status: constants.StatusEnabled,
}
require.NoError(t, packageSeriesStore.Create(ctx, series))
// 创建上级店铺
allocatorShop := &model.Shop{
ShopName: "分配者店铺",
ShopCode: "ALLOCATOR_ALLOC",
Status: constants.StatusEnabled,
Level: 1,
ContactName: "联系人",
ContactPhone: "13800000000",
}
require.NoError(t, shopStore.Create(ctx, allocatorShop))
// 创建下级店铺
shop := &model.Shop{
ShopName: "被分配店铺",
ShopCode: "ALLOCATED_SHOP",
Status: constants.StatusEnabled,
Level: 2,
ParentID: &allocatorShop.ID,
ContactName: "联系人",
ContactPhone: "13800000001",
}
require.NoError(t, shopStore.Create(ctx, shop))
// 创建系列分配
seriesAllocation := &model.ShopSeriesAllocation{
ShopID: shop.ID,
SeriesID: series.ID,
AllocatorShopID: allocatorShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
Status: constants.StatusEnabled,
}
require.NoError(t, seriesAllocationStore.Create(ctx, seriesAllocation))
ctxWithShop := context.WithValue(ctx, constants.ContextKeyShopID, shop.ID)
req := &dto.MySeriesAllocationListRequest{
Page: 1,
PageSize: 20,
}
allocations, total, err := svc.ListMySeriesAllocations(ctxWithShop, req)
require.NoError(t, err)
assert.NotNil(t, allocations)
assert.Equal(t, 1, len(allocations))
assert.Equal(t, int64(1), total)
assert.Equal(t, series.SeriesName, allocations[0].SeriesName)
assert.Equal(t, allocatorShop.ShopName, allocations[0].AllocatorShopName)
})
}

View File

@@ -17,14 +17,26 @@ import (
)
type Service struct {
packageStore *postgres.PackageStore
packageSeriesStore *postgres.PackageSeriesStore
packageStore *postgres.PackageStore
packageSeriesStore *postgres.PackageSeriesStore
packageAllocationStore *postgres.ShopPackageAllocationStore
seriesAllocationStore *postgres.ShopSeriesAllocationStore
commissionTierStore *postgres.ShopSeriesCommissionTierStore
}
func New(packageStore *postgres.PackageStore, packageSeriesStore *postgres.PackageSeriesStore) *Service {
func New(
packageStore *postgres.PackageStore,
packageSeriesStore *postgres.PackageSeriesStore,
packageAllocationStore *postgres.ShopPackageAllocationStore,
seriesAllocationStore *postgres.ShopSeriesAllocationStore,
commissionTierStore *postgres.ShopSeriesCommissionTierStore,
) *Service {
return &Service{
packageStore: packageStore,
packageSeriesStore: packageSeriesStore,
packageStore: packageStore,
packageSeriesStore: packageSeriesStore,
packageAllocationStore: packageAllocationStore,
seriesAllocationStore: seriesAllocationStore,
commissionTierStore: commissionTierStore,
}
}
@@ -39,14 +51,16 @@ func (s *Service) Create(ctx context.Context, req *dto.CreatePackageRequest) (*d
return nil, errors.New(errors.CodeConflict, "套餐编码已存在")
}
var seriesName *string
if req.SeriesID != nil && *req.SeriesID > 0 {
_, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID)
series, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "套餐系列不存在")
}
return nil, fmt.Errorf("获取套餐系列失败: %w", err)
}
seriesName = &series.SeriesName
}
pkg := &model.Package{
@@ -85,7 +99,9 @@ func (s *Service) Create(ctx context.Context, req *dto.CreatePackageRequest) (*d
return nil, fmt.Errorf("创建套餐失败: %w", err)
}
return s.toResponse(pkg), nil
resp := s.toResponse(ctx, pkg)
resp.SeriesName = seriesName
return resp, nil
}
func (s *Service) Get(ctx context.Context, id uint) (*dto.PackageResponse, error) {
@@ -96,7 +112,16 @@ func (s *Service) Get(ctx context.Context, id uint) (*dto.PackageResponse, error
}
return nil, fmt.Errorf("获取套餐失败: %w", err)
}
return s.toResponse(pkg), nil
resp := s.toResponse(ctx, pkg)
// 查询系列名称
if pkg.SeriesID > 0 {
series, err := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID)
if err == nil {
resp.SeriesName = &series.SeriesName
}
}
return resp, nil
}
func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdatePackageRequest) (*dto.PackageResponse, error) {
@@ -113,8 +138,9 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdatePackageReq
return nil, fmt.Errorf("获取套餐失败: %w", err)
}
var seriesName *string
if req.SeriesID != nil && *req.SeriesID > 0 {
_, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID)
series, err := s.packageSeriesStore.GetByID(ctx, *req.SeriesID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "套餐系列不存在")
@@ -122,6 +148,13 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdatePackageReq
return nil, fmt.Errorf("获取套餐系列失败: %w", err)
}
pkg.SeriesID = *req.SeriesID
seriesName = &series.SeriesName
} else if pkg.SeriesID > 0 {
// 如果没有更新 SeriesID但现有套餐有 SeriesID则查询当前的系列名称
series, err := s.packageSeriesStore.GetByID(ctx, pkg.SeriesID)
if err == nil {
seriesName = &series.SeriesName
}
}
if req.PackageName != nil {
@@ -160,7 +193,9 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdatePackageReq
return nil, fmt.Errorf("更新套餐失败: %w", err)
}
return s.toResponse(pkg), nil
resp := s.toResponse(ctx, pkg)
resp.SeriesName = seriesName
return resp, nil
}
func (s *Service) Delete(ctx context.Context, id uint) error {
@@ -214,9 +249,40 @@ func (s *Service) List(ctx context.Context, req *dto.PackageListRequest) ([]*dto
return nil, 0, fmt.Errorf("查询套餐列表失败: %w", err)
}
// 收集所有唯一的 series_id
seriesIDMap := make(map[uint]bool)
for _, pkg := range packages {
if pkg.SeriesID > 0 {
seriesIDMap[pkg.SeriesID] = true
}
}
// 批量查询套餐系列
seriesMap := make(map[uint]string)
if len(seriesIDMap) > 0 {
seriesIDs := make([]uint, 0, len(seriesIDMap))
for id := range seriesIDMap {
seriesIDs = append(seriesIDs, id)
}
seriesList, err := s.packageSeriesStore.GetByIDs(ctx, seriesIDs)
if err != nil {
return nil, 0, fmt.Errorf("批量查询套餐系列失败: %w", err)
}
for _, series := range seriesList {
seriesMap[series.ID] = series.SeriesName
}
}
// 构建响应,填充系列名称
responses := make([]*dto.PackageResponse, len(packages))
for i, pkg := range packages {
responses[i] = s.toResponse(pkg)
resp := s.toResponse(ctx, pkg)
if pkg.SeriesID > 0 {
if seriesName, ok := seriesMap[pkg.SeriesID]; ok {
resp.SeriesName = &seriesName
}
}
responses[i] = resp
}
return responses, total, nil
@@ -278,12 +344,13 @@ func (s *Service) UpdateShelfStatus(ctx context.Context, id uint, shelfStatus in
return nil
}
func (s *Service) toResponse(pkg *model.Package) *dto.PackageResponse {
func (s *Service) toResponse(ctx context.Context, pkg *model.Package) *dto.PackageResponse {
var seriesID *uint
if pkg.SeriesID > 0 {
seriesID = &pkg.SeriesID
}
return &dto.PackageResponse{
resp := &dto.PackageResponse{
ID: pkg.ID,
PackageCode: pkg.PackageCode,
PackageName: pkg.PackageName,
@@ -302,4 +369,55 @@ func (s *Service) toResponse(pkg *model.Package) *dto.PackageResponse {
CreatedAt: pkg.CreatedAt.Format(time.RFC3339),
UpdatedAt: pkg.UpdatedAt.Format(time.RFC3339),
}
userType := middleware.GetUserTypeFromContext(ctx)
shopID := middleware.GetShopIDFromContext(ctx)
if userType == constants.UserTypeAgent && shopID > 0 {
allocation, err := s.packageAllocationStore.GetByShopAndPackage(ctx, shopID, pkg.ID)
if err == nil && allocation != nil {
resp.CostPrice = &allocation.CostPrice
profitMargin := pkg.SuggestedRetailPrice - allocation.CostPrice
resp.ProfitMargin = &profitMargin
commissionInfo := s.getCommissionInfo(ctx, allocation.AllocationID)
if commissionInfo != nil {
resp.CurrentCommissionRate = commissionInfo.CurrentRate
resp.TierInfo = commissionInfo
}
}
}
return resp
}
func (s *Service) getCommissionInfo(ctx context.Context, allocationID uint) *dto.CommissionTierInfo {
seriesAllocation, err := s.seriesAllocationStore.GetByID(ctx, allocationID)
if err != nil {
return nil
}
info := &dto.CommissionTierInfo{}
if seriesAllocation.BaseCommissionMode == constants.CommissionModeFixed {
info.CurrentRate = fmt.Sprintf("%.2f元/单", float64(seriesAllocation.BaseCommissionValue)/100)
} else {
info.CurrentRate = fmt.Sprintf("%.1f%%", float64(seriesAllocation.BaseCommissionValue)/10)
}
if seriesAllocation.EnableTierCommission {
tiers, err := s.commissionTierStore.ListByAllocationID(ctx, allocationID)
if err == nil && len(tiers) > 0 {
tier := tiers[0]
info.NextThreshold = &tier.ThresholdValue
if tier.CommissionMode == constants.CommissionModeFixed {
nextRate := fmt.Sprintf("%.2f元/单", float64(tier.CommissionValue)/100)
info.NextRate = nextRate
} else {
nextRate := fmt.Sprintf("%.1f%%", float64(tier.CommissionValue)/10)
info.NextRate = nextRate
}
}
}
return info
}

View File

@@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
@@ -24,7 +25,7 @@ func TestPackageService_Create(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -97,7 +98,7 @@ func TestPackageService_UpdateStatus(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -167,7 +168,7 @@ func TestPackageService_UpdateShelfStatus(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -254,7 +255,7 @@ func TestPackageService_Get(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -292,7 +293,7 @@ func TestPackageService_Update(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -341,7 +342,7 @@ func TestPackageService_Delete(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -376,7 +377,7 @@ func TestPackageService_List(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
@@ -454,3 +455,135 @@ func TestPackageService_List(t *testing.T) {
}
})
}
func TestPackageService_SeriesNameInResponse(t *testing.T) {
tx := testutils.NewTestTransaction(t)
packageStore := postgres.NewPackageStore(tx)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
svc := New(packageStore, packageSeriesStore, nil, nil, nil)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
// 创建套餐系列
series := &model.PackageSeries{
SeriesCode: fmt.Sprintf("SERIES_%d", time.Now().UnixNano()),
SeriesName: "测试套餐系列",
Description: "用于测试系列名称字段",
Status: constants.StatusEnabled,
}
series.Creator = 1
err := packageSeriesStore.Create(ctx, series)
require.NoError(t, err)
t.Run("创建套餐时返回系列名称", func(t *testing.T) {
req := &dto.CreatePackageRequest{
PackageCode: generateUniquePackageCode("PKG_SERIES"),
PackageName: "带系列的套餐",
SeriesID: &series.ID,
PackageType: "formal",
DurationMonths: 1,
Price: 9900,
}
resp, err := svc.Create(ctx, req)
require.NoError(t, err)
assert.NotNil(t, resp.SeriesName)
assert.Equal(t, series.SeriesName, *resp.SeriesName)
})
t.Run("获取套餐时返回系列名称", func(t *testing.T) {
// 先创建一个套餐
req := &dto.CreatePackageRequest{
PackageCode: generateUniquePackageCode("PKG_GET_SERIES"),
PackageName: "获取测试套餐",
SeriesID: &series.ID,
PackageType: "formal",
DurationMonths: 1,
Price: 9900,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
// 获取套餐
resp, err := svc.Get(ctx, created.ID)
require.NoError(t, err)
assert.NotNil(t, resp.SeriesName)
assert.Equal(t, series.SeriesName, *resp.SeriesName)
})
t.Run("更新套餐时返回系列名称", func(t *testing.T) {
// 先创建一个套餐
req := &dto.CreatePackageRequest{
PackageCode: generateUniquePackageCode("PKG_UPDATE_SERIES"),
PackageName: "更新测试套餐",
SeriesID: &series.ID,
PackageType: "formal",
DurationMonths: 1,
Price: 9900,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
// 更新套餐
newName := "更新后的套餐"
updateReq := &dto.UpdatePackageRequest{
PackageName: &newName,
}
resp, err := svc.Update(ctx, created.ID, updateReq)
require.NoError(t, err)
assert.NotNil(t, resp.SeriesName)
assert.Equal(t, series.SeriesName, *resp.SeriesName)
})
t.Run("列表查询时返回系列名称", func(t *testing.T) {
// 创建多个带系列的套餐
for i := 0; i < 3; i++ {
req := &dto.CreatePackageRequest{
PackageCode: generateUniquePackageCode(fmt.Sprintf("PKG_LIST_SERIES_%d", i)),
PackageName: fmt.Sprintf("列表测试套餐%d", i),
SeriesID: &series.ID,
PackageType: "formal",
DurationMonths: 1,
Price: 9900,
}
_, err := svc.Create(ctx, req)
require.NoError(t, err)
}
// 查询列表
listReq := &dto.PackageListRequest{
Page: 1,
PageSize: 10,
SeriesID: &series.ID,
}
resp, _, err := svc.List(ctx, listReq)
require.NoError(t, err)
assert.Greater(t, len(resp), 0)
// 验证所有套餐都有系列名称
for _, pkg := range resp {
if pkg.SeriesID != nil && *pkg.SeriesID == series.ID {
assert.NotNil(t, pkg.SeriesName)
assert.Equal(t, series.SeriesName, *pkg.SeriesName)
}
}
})
t.Run("没有系列的套餐SeriesName为空", func(t *testing.T) {
req := &dto.CreatePackageRequest{
PackageCode: generateUniquePackageCode("PKG_NO_SERIES"),
PackageName: "无系列套餐",
PackageType: "formal",
DurationMonths: 1,
Price: 9900,
}
resp, err := svc.Create(ctx, req)
require.NoError(t, err)
assert.Nil(t, resp.SeriesID)
assert.Nil(t, resp.SeriesName)
})
}

View File

@@ -18,6 +18,7 @@ import (
type Service struct {
packageAllocationStore *postgres.ShopPackageAllocationStore
seriesAllocationStore *postgres.ShopSeriesAllocationStore
priceHistoryStore *postgres.ShopPackageAllocationPriceHistoryStore
shopStore *postgres.ShopStore
packageStore *postgres.PackageStore
}
@@ -25,12 +26,14 @@ type Service struct {
func New(
packageAllocationStore *postgres.ShopPackageAllocationStore,
seriesAllocationStore *postgres.ShopSeriesAllocationStore,
priceHistoryStore *postgres.ShopPackageAllocationPriceHistoryStore,
shopStore *postgres.ShopStore,
packageStore *postgres.PackageStore,
) *Service {
return &Service{
packageAllocationStore: packageAllocationStore,
seriesAllocationStore: seriesAllocationStore,
priceHistoryStore: priceHistoryStore,
shopStore: shopStore,
packageStore: packageStore,
}
@@ -271,3 +274,76 @@ func (s *Service) buildResponse(ctx context.Context, a *model.ShopPackageAllocat
UpdatedAt: a.UpdatedAt.Format(time.RFC3339),
}, nil
}
func (s *Service) UpdateCostPrice(ctx context.Context, id uint, newCostPrice int64, changeReason string) (*dto.ShopPackageAllocationResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
allocation, err := s.packageAllocationStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "分配记录不存在")
}
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
if allocation.CostPrice == newCostPrice {
return nil, errors.New(errors.CodeInvalidParam, "新成本价与当前成本价相同")
}
oldCostPrice := allocation.CostPrice
now := time.Now()
priceHistory := &model.ShopPackageAllocationPriceHistory{
AllocationID: allocation.ID,
OldCostPrice: oldCostPrice,
NewCostPrice: newCostPrice,
ChangeReason: changeReason,
ChangedBy: currentUserID,
EffectiveFrom: now,
}
if err := s.priceHistoryStore.Create(ctx, priceHistory); err != nil {
return nil, fmt.Errorf("创建价格历史记录失败: %w", err)
}
allocation.CostPrice = newCostPrice
allocation.Updater = currentUserID
if err := s.packageAllocationStore.Update(ctx, allocation); err != nil {
return nil, fmt.Errorf("更新成本价失败: %w", err)
}
shop, _ := s.shopStore.GetByID(ctx, allocation.ShopID)
pkg, _ := s.packageStore.GetByID(ctx, allocation.PackageID)
shopName := ""
packageName := ""
packageCode := ""
if shop != nil {
shopName = shop.ShopName
}
if pkg != nil {
packageName = pkg.PackageName
packageCode = pkg.PackageCode
}
return s.buildResponse(ctx, allocation, shopName, packageName, packageCode)
}
func (s *Service) GetPriceHistory(ctx context.Context, allocationID uint) ([]*model.ShopPackageAllocationPriceHistory, error) {
_, err := s.packageAllocationStore.GetByID(ctx, allocationID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "分配记录不存在")
}
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
history, err := s.priceHistoryStore.ListByAllocation(ctx, allocationID)
if err != nil {
return nil, fmt.Errorf("获取价格历史失败: %w", err)
}
return history, nil
}

View File

@@ -0,0 +1,193 @@
package shop_package_batch_allocation
import (
"context"
"fmt"
"time"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"gorm.io/gorm"
)
type Service struct {
db *gorm.DB
packageStore *postgres.PackageStore
seriesAllocationStore *postgres.ShopSeriesAllocationStore
packageAllocationStore *postgres.ShopPackageAllocationStore
configStore *postgres.ShopSeriesAllocationConfigStore
commissionTierStore *postgres.ShopSeriesCommissionTierStore
commissionStatsStore *postgres.ShopSeriesCommissionStatsStore
shopStore *postgres.ShopStore
}
func New(
db *gorm.DB,
packageStore *postgres.PackageStore,
seriesAllocationStore *postgres.ShopSeriesAllocationStore,
packageAllocationStore *postgres.ShopPackageAllocationStore,
configStore *postgres.ShopSeriesAllocationConfigStore,
commissionTierStore *postgres.ShopSeriesCommissionTierStore,
commissionStatsStore *postgres.ShopSeriesCommissionStatsStore,
shopStore *postgres.ShopStore,
) *Service {
return &Service{
db: db,
packageStore: packageStore,
seriesAllocationStore: seriesAllocationStore,
packageAllocationStore: packageAllocationStore,
configStore: configStore,
commissionTierStore: commissionTierStore,
commissionStatsStore: commissionStatsStore,
shopStore: shopStore,
}
}
func (s *Service) BatchAllocate(ctx context.Context, req *dto.BatchAllocatePackagesRequest) error {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
userType := middleware.GetUserTypeFromContext(ctx)
allocatorShopID := middleware.GetShopIDFromContext(ctx)
if userType == constants.UserTypeAgent && allocatorShopID == 0 {
return errors.New(errors.CodeUnauthorized, "当前用户不属于任何店铺")
}
targetShop, err := s.shopStore.GetByID(ctx, req.ShopID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeNotFound, "目标店铺不存在")
}
return fmt.Errorf("获取目标店铺失败: %w", err)
}
if userType == constants.UserTypeAgent {
if targetShop.ParentID == nil || *targetShop.ParentID != allocatorShopID {
return errors.New(errors.CodeForbidden, "只能分配给直属下级店铺")
}
}
packages, err := s.getEnabledPackagesBySeries(ctx, req.SeriesID)
if err != nil {
return err
}
if len(packages) == 0 {
return errors.New(errors.CodeInvalidParam, "该系列下没有启用的套餐")
}
return s.db.Transaction(func(tx *gorm.DB) error {
seriesAllocation := &model.ShopSeriesAllocation{
BaseModel: model.BaseModel{Creator: currentUserID, Updater: currentUserID},
ShopID: req.ShopID,
SeriesID: req.SeriesID,
AllocatorShopID: allocatorShopID,
BaseCommissionMode: req.BaseCommission.Mode,
BaseCommissionValue: req.BaseCommission.Value,
EnableTierCommission: req.EnableTierCommission,
Status: constants.StatusEnabled,
}
if err := tx.Create(seriesAllocation).Error; err != nil {
return fmt.Errorf("创建系列分配失败: %w", err)
}
now := time.Now()
config := &model.ShopSeriesAllocationConfig{
AllocationID: seriesAllocation.ID,
Version: 1,
BaseCommissionMode: req.BaseCommission.Mode,
BaseCommissionValue: req.BaseCommission.Value,
EnableTierCommission: req.EnableTierCommission,
EffectiveFrom: now,
}
if err := tx.Create(config).Error; err != nil {
return fmt.Errorf("创建配置版本失败: %w", err)
}
packageAllocations := make([]*model.ShopPackageAllocation, 0, len(packages))
for _, pkg := range packages {
costPrice := pkg.SuggestedCostPrice
if req.PriceAdjustment != nil {
costPrice = s.calculateAdjustedPrice(pkg.SuggestedCostPrice, req.PriceAdjustment)
}
allocation := &model.ShopPackageAllocation{
BaseModel: model.BaseModel{Creator: currentUserID, Updater: currentUserID},
ShopID: req.ShopID,
PackageID: pkg.ID,
AllocationID: seriesAllocation.ID,
CostPrice: costPrice,
Status: constants.StatusEnabled,
}
packageAllocations = append(packageAllocations, allocation)
}
if err := tx.CreateInBatches(packageAllocations, 100).Error; err != nil {
return fmt.Errorf("批量创建套餐分配失败: %w", err)
}
if req.EnableTierCommission && req.TierConfig != nil {
if err := s.createCommissionTiers(tx, seriesAllocation.ID, req.TierConfig, currentUserID); err != nil {
return err
}
}
return nil
})
}
func (s *Service) getEnabledPackagesBySeries(ctx context.Context, seriesID uint) ([]*model.Package, error) {
filters := map[string]interface{}{
"series_id": seriesID,
"status": constants.StatusEnabled,
"shelf_status": 1,
}
packages, _, err := s.packageStore.List(ctx, nil, filters)
if err != nil {
return nil, fmt.Errorf("获取套餐列表失败: %w", err)
}
return packages, nil
}
func (s *Service) calculateAdjustedPrice(basePrice int64, adjustment *dto.PriceAdjustment) int64 {
if adjustment == nil {
return basePrice
}
if adjustment.Type == "fixed" {
return basePrice + adjustment.Value
}
return basePrice + (basePrice * adjustment.Value / 1000)
}
func (s *Service) createCommissionTiers(tx *gorm.DB, allocationID uint, config *dto.TierCommissionConfig, creatorID uint) error {
for _, tierReq := range config.Tiers {
tier := &model.ShopSeriesCommissionTier{
BaseModel: model.BaseModel{Creator: creatorID, Updater: creatorID},
AllocationID: allocationID,
PeriodType: config.PeriodType,
TierType: config.TierType,
ThresholdValue: tierReq.Threshold,
CommissionMode: tierReq.Mode,
CommissionValue: tierReq.Value,
}
if err := tx.Create(tier).Error; err != nil {
return fmt.Errorf("创建佣金梯度失败: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,129 @@
package shop_package_batch_pricing
import (
"context"
"fmt"
"time"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"gorm.io/gorm"
)
type Service struct {
db *gorm.DB
packageAllocationStore *postgres.ShopPackageAllocationStore
priceHistoryStore *postgres.ShopPackageAllocationPriceHistoryStore
shopStore *postgres.ShopStore
}
func New(
db *gorm.DB,
packageAllocationStore *postgres.ShopPackageAllocationStore,
priceHistoryStore *postgres.ShopPackageAllocationPriceHistoryStore,
shopStore *postgres.ShopStore,
) *Service {
return &Service{
db: db,
packageAllocationStore: packageAllocationStore,
priceHistoryStore: priceHistoryStore,
shopStore: shopStore,
}
}
func (s *Service) BatchUpdatePricing(ctx context.Context, req *dto.BatchUpdateCostPriceRequest) (*dto.BatchUpdateCostPriceResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
userType := middleware.GetUserTypeFromContext(ctx)
shopID := middleware.GetShopIDFromContext(ctx)
if userType == constants.UserTypeAgent && shopID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "当前用户不属于任何店铺")
}
filters := map[string]interface{}{
"shop_id": req.ShopID,
"status": constants.StatusEnabled,
}
if req.SeriesID != nil {
filters["series_id"] = *req.SeriesID
}
allocations, _, err := s.packageAllocationStore.List(ctx, nil, filters)
if err != nil {
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
if len(allocations) == 0 {
return nil, errors.New(errors.CodeInvalidParam, "没有找到符合条件的分配记录")
}
updatedCount := 0
now := time.Now()
affectedIDs := make([]uint, 0)
err = s.db.Transaction(func(tx *gorm.DB) error {
for _, allocation := range allocations {
oldPrice := allocation.CostPrice
newPrice := s.calculateAdjustedPrice(oldPrice, &req.PriceAdjustment)
if newPrice == oldPrice {
continue
}
history := &model.ShopPackageAllocationPriceHistory{
AllocationID: allocation.ID,
OldCostPrice: oldPrice,
NewCostPrice: newPrice,
ChangeReason: req.ChangeReason,
ChangedBy: currentUserID,
EffectiveFrom: now,
}
if err := tx.Create(history).Error; err != nil {
return fmt.Errorf("创建价格历史失败: %w", err)
}
allocation.CostPrice = newPrice
allocation.Updater = currentUserID
if err := tx.Save(allocation).Error; err != nil {
return fmt.Errorf("更新成本价失败: %w", err)
}
affectedIDs = append(affectedIDs, allocation.ID)
updatedCount++
}
return nil
})
if err != nil {
return nil, err
}
return &dto.BatchUpdateCostPriceResponse{
UpdatedCount: updatedCount,
AffectedIDs: affectedIDs,
}, nil
}
func (s *Service) calculateAdjustedPrice(basePrice int64, adjustment *dto.PriceAdjustment) int64 {
if adjustment == nil {
return basePrice
}
if adjustment.Type == "fixed" {
return basePrice + adjustment.Value
}
return basePrice + (basePrice * adjustment.Value / 1000)
}

View File

@@ -18,6 +18,7 @@ import (
type Service struct {
allocationStore *postgres.ShopSeriesAllocationStore
tierStore *postgres.ShopSeriesCommissionTierStore
configStore *postgres.ShopSeriesAllocationConfigStore
shopStore *postgres.ShopStore
packageSeriesStore *postgres.PackageSeriesStore
packageStore *postgres.PackageStore
@@ -26,6 +27,7 @@ type Service struct {
func New(
allocationStore *postgres.ShopSeriesAllocationStore,
tierStore *postgres.ShopSeriesCommissionTierStore,
configStore *postgres.ShopSeriesAllocationConfigStore,
shopStore *postgres.ShopStore,
packageSeriesStore *postgres.PackageSeriesStore,
packageStore *postgres.PackageStore,
@@ -33,6 +35,7 @@ func New(
return &Service{
allocationStore: allocationStore,
tierStore: tierStore,
configStore: configStore,
shopStore: shopStore,
packageSeriesStore: packageSeriesStore,
packageStore: packageStore,
@@ -97,15 +100,13 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateShopSeriesAllocatio
}
allocation := &model.ShopSeriesAllocation{
ShopID: req.ShopID,
SeriesID: req.SeriesID,
AllocatorShopID: allocatorShopID,
PricingMode: req.PricingMode,
PricingValue: req.PricingValue,
OneTimeCommissionTrigger: req.OneTimeCommissionTrigger,
OneTimeCommissionThreshold: req.OneTimeCommissionThreshold,
OneTimeCommissionAmount: req.OneTimeCommissionAmount,
Status: constants.StatusEnabled,
ShopID: req.ShopID,
SeriesID: req.SeriesID,
AllocatorShopID: allocatorShopID,
BaseCommissionMode: req.BaseCommission.Mode,
BaseCommissionValue: req.BaseCommission.Value,
EnableTierCommission: req.EnableTierCommission,
Status: constants.StatusEnabled,
}
allocation.Creator = currentUserID
@@ -154,23 +155,29 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateShopSeries
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
if req.PricingMode != nil {
allocation.PricingMode = *req.PricingMode
configChanged := false
if req.BaseCommission != nil {
if allocation.BaseCommissionMode != req.BaseCommission.Mode ||
allocation.BaseCommissionValue != req.BaseCommission.Value {
configChanged = true
}
allocation.BaseCommissionMode = req.BaseCommission.Mode
allocation.BaseCommissionValue = req.BaseCommission.Value
}
if req.PricingValue != nil {
allocation.PricingValue = *req.PricingValue
}
if req.OneTimeCommissionTrigger != nil {
allocation.OneTimeCommissionTrigger = *req.OneTimeCommissionTrigger
}
if req.OneTimeCommissionThreshold != nil {
allocation.OneTimeCommissionThreshold = *req.OneTimeCommissionThreshold
}
if req.OneTimeCommissionAmount != nil {
allocation.OneTimeCommissionAmount = *req.OneTimeCommissionAmount
if req.EnableTierCommission != nil {
if allocation.EnableTierCommission != *req.EnableTierCommission {
configChanged = true
}
allocation.EnableTierCommission = *req.EnableTierCommission
}
allocation.Updater = currentUserID
if configChanged {
if err := s.createNewConfigVersion(ctx, allocation); err != nil {
return nil, fmt.Errorf("创建配置版本失败: %w", err)
}
}
if err := s.allocationStore.Update(ctx, allocation); err != nil {
return nil, fmt.Errorf("更新分配失败: %w", err)
}
@@ -306,177 +313,7 @@ func (s *Service) GetParentCostPrice(ctx context.Context, shopID, packageID uint
return pkg.SuggestedCostPrice, nil
}
allocation, err := s.allocationStore.GetByShopAndSeries(ctx, shopID, pkg.SeriesID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return 0, errors.New(errors.CodeNotFound, "未找到分配记录")
}
return 0, fmt.Errorf("获取分配记录失败: %w", err)
}
parentCostPrice, err := s.GetParentCostPrice(ctx, allocation.AllocatorShopID, packageID)
if err != nil {
return 0, err
}
return s.CalculateCostPrice(parentCostPrice, allocation.PricingMode, allocation.PricingValue), nil
}
func (s *Service) CalculateCostPrice(parentCostPrice int64, pricingMode string, pricingValue int64) int64 {
switch pricingMode {
case model.PricingModeFixed:
return parentCostPrice + pricingValue
case model.PricingModePercent:
return parentCostPrice + (parentCostPrice * pricingValue / 1000)
default:
return parentCostPrice
}
}
func (s *Service) AddTier(ctx context.Context, allocationID uint, req *dto.CreateCommissionTierRequest) (*dto.CommissionTierResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
_, err := s.allocationStore.GetByID(ctx, allocationID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "分配记录不存在")
}
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
if req.PeriodType == model.PeriodTypeCustom {
if req.PeriodStartDate == nil || req.PeriodEndDate == nil {
return nil, errors.New(errors.CodeInvalidParam, "自定义周期必须指定开始和结束日期")
}
}
tier := &model.ShopSeriesCommissionTier{
AllocationID: allocationID,
TierType: req.TierType,
PeriodType: req.PeriodType,
ThresholdValue: req.ThresholdValue,
CommissionAmount: req.CommissionAmount,
}
tier.Creator = currentUserID
if req.PeriodStartDate != nil {
t, err := time.Parse("2006-01-02", *req.PeriodStartDate)
if err != nil {
return nil, errors.New(errors.CodeInvalidParam, "开始日期格式无效")
}
tier.PeriodStartDate = &t
}
if req.PeriodEndDate != nil {
t, err := time.Parse("2006-01-02", *req.PeriodEndDate)
if err != nil {
return nil, errors.New(errors.CodeInvalidParam, "结束日期格式无效")
}
tier.PeriodEndDate = &t
}
if err := s.tierStore.Create(ctx, tier); err != nil {
return nil, fmt.Errorf("创建梯度配置失败: %w", err)
}
return s.buildTierResponse(tier), nil
}
func (s *Service) UpdateTier(ctx context.Context, allocationID, tierID uint, req *dto.UpdateCommissionTierRequest) (*dto.CommissionTierResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
tier, err := s.tierStore.GetByID(ctx, tierID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "梯度配置不存在")
}
return nil, fmt.Errorf("获取梯度配置失败: %w", err)
}
if tier.AllocationID != allocationID {
return nil, errors.New(errors.CodeForbidden, "梯度配置不属于该分配")
}
if req.TierType != nil {
tier.TierType = *req.TierType
}
if req.PeriodType != nil {
tier.PeriodType = *req.PeriodType
}
if req.ThresholdValue != nil {
tier.ThresholdValue = *req.ThresholdValue
}
if req.CommissionAmount != nil {
tier.CommissionAmount = *req.CommissionAmount
}
if req.PeriodStartDate != nil {
t, err := time.Parse("2006-01-02", *req.PeriodStartDate)
if err != nil {
return nil, errors.New(errors.CodeInvalidParam, "开始日期格式无效")
}
tier.PeriodStartDate = &t
}
if req.PeriodEndDate != nil {
t, err := time.Parse("2006-01-02", *req.PeriodEndDate)
if err != nil {
return nil, errors.New(errors.CodeInvalidParam, "结束日期格式无效")
}
tier.PeriodEndDate = &t
}
tier.Updater = currentUserID
if err := s.tierStore.Update(ctx, tier); err != nil {
return nil, fmt.Errorf("更新梯度配置失败: %w", err)
}
return s.buildTierResponse(tier), nil
}
func (s *Service) DeleteTier(ctx context.Context, allocationID, tierID uint) error {
tier, err := s.tierStore.GetByID(ctx, tierID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeNotFound, "梯度配置不存在")
}
return fmt.Errorf("获取梯度配置失败: %w", err)
}
if tier.AllocationID != allocationID {
return errors.New(errors.CodeForbidden, "梯度配置不属于该分配")
}
if err := s.tierStore.Delete(ctx, tierID); err != nil {
return fmt.Errorf("删除梯度配置失败: %w", err)
}
return nil
}
func (s *Service) ListTiers(ctx context.Context, allocationID uint) ([]*dto.CommissionTierResponse, error) {
_, err := s.allocationStore.GetByID(ctx, allocationID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "分配记录不存在")
}
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
tiers, err := s.tierStore.ListByAllocationID(ctx, allocationID)
if err != nil {
return nil, fmt.Errorf("查询梯度配置失败: %w", err)
}
responses := make([]*dto.CommissionTierResponse, len(tiers))
for i, t := range tiers {
responses[i] = s.buildTierResponse(t)
}
return responses, nil
return 0, errors.New(errors.CodeInvalidParam, "自动计算成本价功能已移除,请手动设置成本价")
}
func (s *Service) buildResponse(ctx context.Context, a *model.ShopSeriesAllocation, shopName, seriesName string) (*dto.ShopSeriesAllocationResponse, error) {
@@ -486,46 +323,78 @@ func (s *Service) buildResponse(ctx context.Context, a *model.ShopSeriesAllocati
allocatorShopName = allocatorShop.ShopName
}
var calculatedCostPrice int64 = 0
return &dto.ShopSeriesAllocationResponse{
ID: a.ID,
ShopID: a.ShopID,
ShopName: shopName,
SeriesID: a.SeriesID,
SeriesName: seriesName,
AllocatorShopID: a.AllocatorShopID,
AllocatorShopName: allocatorShopName,
PricingMode: a.PricingMode,
PricingValue: a.PricingValue,
CalculatedCostPrice: calculatedCostPrice,
OneTimeCommissionTrigger: a.OneTimeCommissionTrigger,
OneTimeCommissionThreshold: a.OneTimeCommissionThreshold,
OneTimeCommissionAmount: a.OneTimeCommissionAmount,
Status: a.Status,
CreatedAt: a.CreatedAt.Format(time.RFC3339),
UpdatedAt: a.UpdatedAt.Format(time.RFC3339),
ID: a.ID,
ShopID: a.ShopID,
ShopName: shopName,
SeriesID: a.SeriesID,
SeriesName: seriesName,
AllocatorShopID: a.AllocatorShopID,
AllocatorShopName: allocatorShopName,
BaseCommission: dto.BaseCommissionConfig{
Mode: a.BaseCommissionMode,
Value: a.BaseCommissionValue,
},
EnableTierCommission: a.EnableTierCommission,
Status: a.Status,
CreatedAt: a.CreatedAt.Format(time.RFC3339),
UpdatedAt: a.UpdatedAt.Format(time.RFC3339),
}, nil
}
func (s *Service) buildTierResponse(t *model.ShopSeriesCommissionTier) *dto.CommissionTierResponse {
resp := &dto.CommissionTierResponse{
ID: t.ID,
AllocationID: t.AllocationID,
TierType: t.TierType,
PeriodType: t.PeriodType,
ThresholdValue: t.ThresholdValue,
CommissionAmount: t.CommissionAmount,
CreatedAt: t.CreatedAt.Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.Format(time.RFC3339),
func (s *Service) createNewConfigVersion(ctx context.Context, allocation *model.ShopSeriesAllocation) error {
now := time.Now()
if err := s.configStore.InvalidateCurrent(ctx, allocation.ID, now); err != nil {
return fmt.Errorf("失效当前配置版本失败: %w", err)
}
if t.PeriodStartDate != nil {
resp.PeriodStartDate = t.PeriodStartDate.Format("2006-01-02")
}
if t.PeriodEndDate != nil {
resp.PeriodEndDate = t.PeriodEndDate.Format("2006-01-02")
latestVersion, err := s.configStore.GetLatestVersion(ctx, allocation.ID)
newVersion := 1
if err == nil && latestVersion != nil {
newVersion = latestVersion.Version + 1
}
return resp
newConfig := &model.ShopSeriesAllocationConfig{
AllocationID: allocation.ID,
Version: newVersion,
BaseCommissionMode: allocation.BaseCommissionMode,
BaseCommissionValue: allocation.BaseCommissionValue,
EnableTierCommission: allocation.EnableTierCommission,
EffectiveFrom: now,
}
if err := s.configStore.Create(ctx, newConfig); err != nil {
return fmt.Errorf("创建新配置版本失败: %w", err)
}
return nil
}
func (s *Service) GetEffectiveConfig(ctx context.Context, allocationID uint, at time.Time) (*model.ShopSeriesAllocationConfig, error) {
config, err := s.configStore.GetEffective(ctx, allocationID, at)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "未找到生效的配置版本")
}
return nil, fmt.Errorf("获取生效配置失败: %w", err)
}
return config, nil
}
func (s *Service) ListConfigVersions(ctx context.Context, allocationID uint) ([]*model.ShopSeriesAllocationConfig, error) {
_, err := s.allocationStore.GetByID(ctx, allocationID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeNotFound, "分配记录不存在")
}
return nil, fmt.Errorf("获取分配记录失败: %w", err)
}
configs, err := s.configStore.List(ctx, allocationID)
if err != nil {
return nil, fmt.Errorf("获取配置版本列表失败: %w", err)
}
return configs, nil
}

View File

@@ -1,595 +0,0 @@
package shop_series_allocation
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/tests/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func createTestService(t *testing.T) (*Service, *postgres.ShopSeriesAllocationStore, *postgres.ShopStore, *postgres.PackageSeriesStore, *postgres.PackageStore, *postgres.ShopSeriesCommissionTierStore) {
tx := testutils.NewTestTransaction(t)
rdb := testutils.GetTestRedis(t)
testutils.CleanTestRedisKeys(t, rdb)
allocationStore := postgres.NewShopSeriesAllocationStore(tx)
tierStore := postgres.NewShopSeriesCommissionTierStore(tx)
shopStore := postgres.NewShopStore(tx, rdb)
packageSeriesStore := postgres.NewPackageSeriesStore(tx)
packageStore := postgres.NewPackageStore(tx)
svc := New(allocationStore, tierStore, shopStore, packageSeriesStore, packageStore)
return svc, allocationStore, shopStore, packageSeriesStore, packageStore, tierStore
}
func createContextWithUser(userID uint, userType int, shopID uint) context.Context {
ctx := context.Background()
info := &middleware.UserContextInfo{
UserID: userID,
UserType: userType,
ShopID: shopID,
}
return middleware.SetUserContext(ctx, info)
}
func createTestShop(t *testing.T, store *postgres.ShopStore, ctx context.Context, shopName string, parentID *uint) *model.Shop {
shop := &model.Shop{
ShopName: shopName,
ShopCode: shopName,
ParentID: parentID,
Status: constants.StatusEnabled,
}
shop.Creator = 1
err := store.Create(ctx, shop)
require.NoError(t, err)
return shop
}
func createTestSeries(t *testing.T, store *postgres.PackageSeriesStore, ctx context.Context, seriesName string) *model.PackageSeries {
series := &model.PackageSeries{
SeriesName: seriesName,
SeriesCode: seriesName,
Status: constants.StatusEnabled,
}
series.Creator = 1
err := store.Create(ctx, series)
require.NoError(t, err)
return series
}
func TestService_CalculateCostPrice(t *testing.T) {
svc, _, _, _, _, _ := createTestService(t)
tests := []struct {
name string
parentCostPrice int64
pricingMode string
pricingValue int64
expectedCostPrice int64
description string
}{
{
name: "固定加价模式10000 + 500 = 10500",
parentCostPrice: 10000,
pricingMode: model.PricingModeFixed,
pricingValue: 500,
expectedCostPrice: 10500,
description: "固定金额加价",
},
{
name: "百分比加价模式10000 + 10000*100/1000 = 11000",
parentCostPrice: 10000,
pricingMode: model.PricingModePercent,
pricingValue: 100,
expectedCostPrice: 11000,
description: "百分比加价100 = 10%",
},
{
name: "百分比加价模式5000 + 5000*50/1000 = 5250",
parentCostPrice: 5000,
pricingMode: model.PricingModePercent,
pricingValue: 50,
expectedCostPrice: 5250,
description: "百分比加价50 = 5%",
},
{
name: "未知加价模式:返回原价",
parentCostPrice: 10000,
pricingMode: "unknown",
pricingValue: 500,
expectedCostPrice: 10000,
description: "未知加价模式返回原价",
},
{
name: "固定加价为010000 + 0 = 10000",
parentCostPrice: 10000,
pricingMode: model.PricingModeFixed,
pricingValue: 0,
expectedCostPrice: 10000,
description: "固定加价为0",
},
{
name: "百分比加价为010000 + 0 = 10000",
parentCostPrice: 10000,
pricingMode: model.PricingModePercent,
pricingValue: 0,
expectedCostPrice: 10000,
description: "百分比加价为0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.CalculateCostPrice(tt.parentCostPrice, tt.pricingMode, tt.pricingValue)
assert.Equal(t, tt.expectedCostPrice, result, tt.description)
})
}
}
func TestService_Create_Validation(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop := createTestShop(t, shopStore, ctx, "二级代理", &parentShop.ID)
unrelatedShop := createTestShop(t, shopStore, ctx, "无关店铺", nil)
series := createTestSeries(t, seriesStore, ctx, "测试系列")
t.Run("未授权访问:无用户上下文", func(t *testing.T) {
emptyCtx := context.Background()
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop.ID,
SeriesID: series.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
}
_, err := svc.Create(emptyCtx, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeUnauthorized, appErr.Code)
})
t.Run("代理账号无店铺上下文", func(t *testing.T) {
ctxWithoutShop := createContextWithUser(1, constants.UserTypeAgent, 0)
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop.ID,
SeriesID: series.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
}
_, err := svc.Create(ctxWithoutShop, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeUnauthorized, appErr.Code)
})
t.Run("分配给非直属下级店铺", func(t *testing.T) {
ctxParent := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: unrelatedShop.ID,
SeriesID: series.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
}
_, err := svc.Create(ctxParent, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeForbidden, appErr.Code)
})
t.Run("代理账号无该系列分配权限", func(t *testing.T) {
ctxParent := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
series2 := createTestSeries(t, seriesStore, ctx, "测试系列2")
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop.ID,
SeriesID: series2.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
}
_, err := svc.Create(ctxParent, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeForbidden, appErr.Code)
})
t.Run("重复分配:同一店铺和系列已分配", func(t *testing.T) {
series3 := createTestSeries(t, seriesStore, ctx, "测试系列3")
childShop2 := createTestShop(t, shopStore, ctx, "二级代理2", &parentShop.ID)
ctxParent := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
parentAllocation := &model.ShopSeriesAllocation{
ShopID: parentShop.ID,
SeriesID: series3.ID,
AllocatorShopID: 0,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
parentAllocation.Creator = 1
err := allocationStore.Create(ctx, parentAllocation)
require.NoError(t, err)
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop2.ID,
SeriesID: series3.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
}
resp1, err := svc.Create(ctxParent, req)
require.NoError(t, err)
assert.NotNil(t, resp1)
_, err = svc.Create(ctxParent, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeConflict, appErr.Code)
})
t.Run("成功创建分配:代理有该系列权限", func(t *testing.T) {
series4 := createTestSeries(t, seriesStore, ctx, "测试系列4")
childShop3 := createTestShop(t, shopStore, ctx, "二级代理3", &parentShop.ID)
ctxParent := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
parentAllocation := &model.ShopSeriesAllocation{
ShopID: parentShop.ID,
SeriesID: series4.ID,
AllocatorShopID: 0,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
parentAllocation.Creator = 1
err := allocationStore.Create(ctx, parentAllocation)
require.NoError(t, err)
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop3.ID,
SeriesID: series4.ID,
PricingMode: model.PricingModePercent,
PricingValue: 100,
}
resp, err := svc.Create(ctxParent, req)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, childShop3.ID, resp.ShopID)
assert.Equal(t, series4.ID, resp.SeriesID)
assert.Equal(t, model.PricingModePercent, resp.PricingMode)
assert.Equal(t, int64(100), resp.PricingValue)
})
t.Run("平台用户需要有店铺上下文才能分配", func(t *testing.T) {
series5 := createTestSeries(t, seriesStore, ctx, "测试系列5")
childShop4 := createTestShop(t, shopStore, ctx, "二级代理4", &parentShop.ID)
ctxPlatform := createContextWithUser(2, constants.UserTypePlatform, 0)
req := &dto.CreateShopSeriesAllocationRequest{
ShopID: childShop4.ID,
SeriesID: series5.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 1000,
}
_, err := svc.Create(ctxPlatform, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeForbidden, appErr.Code)
})
}
func TestService_Delete_WithDependency(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop := createTestShop(t, shopStore, ctx, "二级代理", &parentShop.ID)
_ = createTestShop(t, shopStore, ctx, "三级代理", &childShop.ID)
series := createTestSeries(t, seriesStore, ctx, "测试系列")
t.Run("删除无依赖的分配成功", func(t *testing.T) {
allocation := &model.ShopSeriesAllocation{
ShopID: childShop.ID,
SeriesID: series.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation.Creator = 1
err := allocationStore.Create(ctx, allocation)
require.NoError(t, err)
err = svc.Delete(ctx, allocation.ID)
require.NoError(t, err)
_, err = allocationStore.GetByID(ctx, allocation.ID)
require.Error(t, err)
assert.Equal(t, gorm.ErrRecordNotFound, err)
})
t.Run("删除分配成功(无依赖关系)", func(t *testing.T) {
series2 := createTestSeries(t, seriesStore, ctx, "测试系列2")
allocation1 := &model.ShopSeriesAllocation{
ShopID: childShop.ID,
SeriesID: series2.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation1.Creator = 1
err := allocationStore.Create(ctx, allocation1)
require.NoError(t, err)
err = svc.Delete(ctx, allocation1.ID)
require.NoError(t, err)
_, err = allocationStore.GetByID(ctx, allocation1.ID)
require.Error(t, err)
assert.Equal(t, gorm.ErrRecordNotFound, err)
})
t.Run("删除不存在的分配返回错误", func(t *testing.T) {
err := svc.Delete(ctx, 99999)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeNotFound, appErr.Code)
})
}
func TestService_Get(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop := createTestShop(t, shopStore, ctx, "二级代理", &parentShop.ID)
series := createTestSeries(t, seriesStore, ctx, "测试系列")
allocation := &model.ShopSeriesAllocation{
ShopID: childShop.ID,
SeriesID: series.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation.Creator = 1
err := allocationStore.Create(ctx, allocation)
require.NoError(t, err)
t.Run("获取存在的分配", func(t *testing.T) {
resp, err := svc.Get(ctx, allocation.ID)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, allocation.ID, resp.ID)
assert.Equal(t, childShop.ID, resp.ShopID)
assert.Equal(t, series.ID, resp.SeriesID)
})
t.Run("获取不存在的分配", func(t *testing.T) {
_, err := svc.Get(ctx, 99999)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeNotFound, appErr.Code)
})
}
func TestService_Update(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop := createTestShop(t, shopStore, ctx, "二级代理", &parentShop.ID)
series := createTestSeries(t, seriesStore, ctx, "测试系列")
allocation := &model.ShopSeriesAllocation{
ShopID: childShop.ID,
SeriesID: series.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation.Creator = 1
err := allocationStore.Create(ctx, allocation)
require.NoError(t, err)
t.Run("更新加价模式和加价值", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
newMode := model.PricingModePercent
newValue := int64(100)
req := &dto.UpdateShopSeriesAllocationRequest{
PricingMode: &newMode,
PricingValue: &newValue,
}
resp, err := svc.Update(ctxWithUser, allocation.ID, req)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, model.PricingModePercent, resp.PricingMode)
assert.Equal(t, int64(100), resp.PricingValue)
})
t.Run("更新不存在的分配", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
newMode := model.PricingModeFixed
req := &dto.UpdateShopSeriesAllocationRequest{
PricingMode: &newMode,
}
_, err := svc.Update(ctxWithUser, 99999, req)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeNotFound, appErr.Code)
})
}
func TestService_UpdateStatus(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop := createTestShop(t, shopStore, ctx, "二级代理", &parentShop.ID)
series := createTestSeries(t, seriesStore, ctx, "测试系列")
allocation := &model.ShopSeriesAllocation{
ShopID: childShop.ID,
SeriesID: series.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation.Creator = 1
err := allocationStore.Create(ctx, allocation)
require.NoError(t, err)
t.Run("禁用分配", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
err := svc.UpdateStatus(ctxWithUser, allocation.ID, constants.StatusDisabled)
require.NoError(t, err)
updated, err := allocationStore.GetByID(ctx, allocation.ID)
require.NoError(t, err)
assert.Equal(t, constants.StatusDisabled, updated.Status)
})
t.Run("启用分配", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
err := svc.UpdateStatus(ctxWithUser, allocation.ID, constants.StatusEnabled)
require.NoError(t, err)
updated, err := allocationStore.GetByID(ctx, allocation.ID)
require.NoError(t, err)
assert.Equal(t, constants.StatusEnabled, updated.Status)
})
t.Run("更新不存在的分配状态", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
err := svc.UpdateStatus(ctxWithUser, 99999, constants.StatusDisabled)
require.Error(t, err)
appErr := err.(*errors.AppError)
assert.Equal(t, errors.CodeNotFound, appErr.Code)
})
}
func TestService_List(t *testing.T) {
svc, allocationStore, shopStore, seriesStore, _, _ := createTestService(t)
ctx := context.Background()
parentShop := createTestShop(t, shopStore, ctx, "一级代理", nil)
childShop1 := createTestShop(t, shopStore, ctx, "二级代理1", &parentShop.ID)
childShop2 := createTestShop(t, shopStore, ctx, "二级代理2", &parentShop.ID)
series1 := createTestSeries(t, seriesStore, ctx, "测试系列1")
series2 := createTestSeries(t, seriesStore, ctx, "测试系列2")
allocation1 := &model.ShopSeriesAllocation{
ShopID: childShop1.ID,
SeriesID: series1.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModeFixed,
PricingValue: 500,
Status: constants.StatusEnabled,
}
allocation1.Creator = 1
err := allocationStore.Create(ctx, allocation1)
require.NoError(t, err)
allocation2 := &model.ShopSeriesAllocation{
ShopID: childShop2.ID,
SeriesID: series2.ID,
AllocatorShopID: parentShop.ID,
PricingMode: model.PricingModePercent,
PricingValue: 100,
Status: constants.StatusEnabled,
}
allocation2.Creator = 1
err = allocationStore.Create(ctx, allocation2)
require.NoError(t, err)
t.Run("查询所有分配", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
req := &dto.ShopSeriesAllocationListRequest{
Page: 1,
PageSize: 20,
}
resp, total, err := svc.List(ctxWithUser, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(2))
assert.GreaterOrEqual(t, len(resp), 2)
})
t.Run("按店铺ID过滤", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
req := &dto.ShopSeriesAllocationListRequest{
Page: 1,
PageSize: 20,
ShopID: &childShop1.ID,
}
resp, total, err := svc.List(ctxWithUser, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(1))
for _, a := range resp {
assert.Equal(t, childShop1.ID, a.ShopID)
}
})
t.Run("按系列ID过滤", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
req := &dto.ShopSeriesAllocationListRequest{
Page: 1,
PageSize: 20,
SeriesID: &series1.ID,
}
resp, total, err := svc.List(ctxWithUser, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(1))
for _, a := range resp {
assert.Equal(t, series1.ID, a.SeriesID)
}
})
t.Run("按状态过滤", func(t *testing.T) {
ctxWithUser := createContextWithUser(1, constants.UserTypeAgent, parentShop.ID)
status := constants.StatusEnabled
req := &dto.ShopSeriesAllocationListRequest{
Page: 1,
PageSize: 20,
Status: &status,
}
resp, total, err := svc.List(ctxWithUser, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(2))
for _, a := range resp {
assert.Equal(t, constants.StatusEnabled, a.Status)
}
})
}