diff --git a/internal/store/postgres/recharge_store.go b/internal/store/postgres/recharge_store.go new file mode 100644 index 0000000..130ab1d --- /dev/null +++ b/internal/store/postgres/recharge_store.go @@ -0,0 +1,166 @@ +package postgres + +import ( + "context" + "time" + + "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +type RechargeStore struct { + db *gorm.DB + redis *redis.Client +} + +// NewRechargeStore 创建充值订单 Store 实例 +func NewRechargeStore(db *gorm.DB, redis *redis.Client) *RechargeStore { + return &RechargeStore{ + db: db, + redis: redis, + } +} + +// Create 创建充值订单 +func (s *RechargeStore) Create(ctx context.Context, recharge *model.RechargeRecord) error { + return s.db.WithContext(ctx).Create(recharge).Error +} + +// GetByRechargeNo 根据充值订单号查询充值订单 +// 不存在时返回 nil, nil +func (s *RechargeStore) GetByRechargeNo(ctx context.Context, rechargeNo string) (*model.RechargeRecord, error) { + var recharge model.RechargeRecord + err := s.db.WithContext(ctx).Where("recharge_no = ?", rechargeNo).First(&recharge).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } + return nil, err + } + return &recharge, nil +} + +// GetByID 根据 ID 查询充值订单 +func (s *RechargeStore) GetByID(ctx context.Context, id uint) (*model.RechargeRecord, error) { + var recharge model.RechargeRecord + if err := s.db.WithContext(ctx).First(&recharge, id).Error; err != nil { + return nil, err + } + return &recharge, nil +} + +// ListRechargeParams 充值订单列表查询参数 +type ListRechargeParams struct { + Page int // 页码(从 1 开始) + PageSize int // 每页数量 + UserID *uint // 用户 ID 筛选 + WalletID *uint // 钱包 ID 筛选 + Status *int // 状态筛选 + StartTime *time.Time // 开始时间 + EndTime *time.Time // 结束时间 +} + +// List 查询充值订单列表(支持分页和筛选) +func (s *RechargeStore) List(ctx context.Context, params *ListRechargeParams) ([]*model.RechargeRecord, int64, error) { + var recharges []*model.RechargeRecord + var total int64 + + query := s.db.WithContext(ctx).Model(&model.RechargeRecord{}) + + // 应用筛选条件 + if params.UserID != nil { + query = query.Where("user_id = ?", *params.UserID) + } + if params.WalletID != nil { + query = query.Where("wallet_id = ?", *params.WalletID) + } + if params.Status != nil { + query = query.Where("status = ?", *params.Status) + } + if params.StartTime != nil { + query = query.Where("created_at >= ?", *params.StartTime) + } + if params.EndTime != nil { + query = query.Where("created_at <= ?", *params.EndTime) + } + + // 统计总数 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // 分页查询 + page := params.Page + if page < 1 { + page = 1 + } + pageSize := params.PageSize + if pageSize < 1 { + pageSize = 20 + } + + offset := (page - 1) * pageSize + if err := query.Order("id DESC").Offset(offset).Limit(pageSize).Find(&recharges).Error; err != nil { + return nil, 0, err + } + + return recharges, total, nil +} + +// UpdateStatus 更新充值订单状态(支持乐观锁检查) +// oldStatus: 原状态(用于乐观锁检查,传 nil 则跳过检查) +// newStatus: 新状态 +// paidAt: 支付时间(状态变为已支付时传入) +// completedAt: 完成时间(状态变为已完成时传入) +func (s *RechargeStore) UpdateStatus(ctx context.Context, id uint, oldStatus *int, newStatus int, paidAt *time.Time, completedAt *time.Time) error { + updates := map[string]interface{}{ + "status": newStatus, + } + if paidAt != nil { + updates["paid_at"] = paidAt + } + if completedAt != nil { + updates["completed_at"] = completedAt + } + + query := s.db.WithContext(ctx).Model(&model.RechargeRecord{}).Where("id = ?", id) + + // 乐观锁检查 + if oldStatus != nil { + query = query.Where("status = ?", *oldStatus) + } + + result := query.Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil +} + +// UpdatePaymentInfo 更新支付信息 +func (s *RechargeStore) UpdatePaymentInfo(ctx context.Context, id uint, paymentChannel *string, paymentTransactionID *string) error { + updates := map[string]interface{}{} + if paymentChannel != nil { + updates["payment_channel"] = paymentChannel + } + if paymentTransactionID != nil { + updates["payment_transaction_id"] = paymentTransactionID + } + + if len(updates) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Model(&model.RechargeRecord{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil +} diff --git a/internal/store/postgres/recharge_store_test.go b/internal/store/postgres/recharge_store_test.go new file mode 100644 index 0000000..30ee2d1 --- /dev/null +++ b/internal/store/postgres/recharge_store_test.go @@ -0,0 +1,395 @@ +package postgres + +import ( + "context" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/break/junhong_cmp_fiber/tests/testutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRechargeStore_Create(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + recharge := &model.RechargeRecord{ + UserID: 100, + WalletID: 200, + RechargeNo: "RCH20260131120000000001", + Amount: 10000, + PaymentMethod: "wechat", + Status: 1, // 待支付 + } + + err := s.Create(ctx, recharge) + require.NoError(t, err) + assert.NotZero(t, recharge.ID) + assert.NotZero(t, recharge.CreatedAt) + assert.NotZero(t, recharge.UpdatedAt) +} + +func TestRechargeStore_GetByRechargeNo(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + rechargeNo := "RCH20260131120000000002" + recharge := &model.RechargeRecord{ + UserID: 101, + WalletID: 201, + RechargeNo: rechargeNo, + Amount: 20000, + PaymentMethod: "alipay", + Status: 1, + } + require.NoError(t, s.Create(ctx, recharge)) + + t.Run("查询存在的充值订单", func(t *testing.T) { + result, err := s.GetByRechargeNo(ctx, rechargeNo) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, recharge.ID, result.ID) + assert.Equal(t, recharge.UserID, result.UserID) + assert.Equal(t, recharge.Amount, result.Amount) + }) + + t.Run("查询不存在的充值订单返回 nil", func(t *testing.T) { + result, err := s.GetByRechargeNo(ctx, "NOT_EXISTS_RECHARGE_NO") + require.NoError(t, err) + assert.Nil(t, result) + }) +} + +func TestRechargeStore_GetByID(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + recharge := &model.RechargeRecord{ + UserID: 102, + WalletID: 202, + RechargeNo: "RCH20260131120000000003", + Amount: 30000, + PaymentMethod: "wechat", + Status: 2, // 已支付 + } + require.NoError(t, s.Create(ctx, recharge)) + + t.Run("查询存在的充值订单", func(t *testing.T) { + result, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + assert.Equal(t, recharge.RechargeNo, result.RechargeNo) + assert.Equal(t, recharge.Status, result.Status) + }) + + t.Run("查询不存在的充值订单", func(t *testing.T) { + _, err := s.GetByID(ctx, 99999) + require.Error(t, err) + }) +} + +func TestRechargeStore_List(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + // 创建测试数据 + now := time.Now() + yesterday := now.Add(-24 * time.Hour) + tomorrow := now.Add(24 * time.Hour) + + recharges := []*model.RechargeRecord{ + {UserID: 200, WalletID: 300, RechargeNo: "RCH20260131120000000010", Amount: 10000, PaymentMethod: "wechat", Status: 1}, + {UserID: 200, WalletID: 300, RechargeNo: "RCH20260131120000000011", Amount: 20000, PaymentMethod: "alipay", Status: 2}, + {UserID: 201, WalletID: 301, RechargeNo: "RCH20260131120000000012", Amount: 30000, PaymentMethod: "wechat", Status: 3}, + {UserID: 201, WalletID: 302, RechargeNo: "RCH20260131120000000013", Amount: 40000, PaymentMethod: "alipay", Status: 1}, + } + for _, r := range recharges { + require.NoError(t, s.Create(ctx, r)) + } + + t.Run("查询所有充值订单", func(t *testing.T) { + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(4)) + assert.GreaterOrEqual(t, len(result), 4) + }) + + t.Run("按用户 ID 筛选", func(t *testing.T) { + userID := uint(200) + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + UserID: &userID, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(2)) + for _, r := range result { + assert.Equal(t, uint(200), r.UserID) + } + }) + + t.Run("按钱包 ID 筛选", func(t *testing.T) { + walletID := uint(300) + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + WalletID: &walletID, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(2)) + for _, r := range result { + assert.Equal(t, uint(300), r.WalletID) + } + }) + + t.Run("按状态筛选", func(t *testing.T) { + status := 1 // 待支付 + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + Status: &status, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(2)) + for _, r := range result { + assert.Equal(t, 1, r.Status) + } + }) + + t.Run("按时间范围筛选", func(t *testing.T) { + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + StartTime: &yesterday, + EndTime: &tomorrow, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(4)) + for _, r := range result { + assert.True(t, r.CreatedAt.After(yesterday) || r.CreatedAt.Equal(yesterday)) + assert.True(t, r.CreatedAt.Before(tomorrow) || r.CreatedAt.Equal(tomorrow)) + } + }) + + t.Run("组合筛选条件", func(t *testing.T) { + userID := uint(201) + status := 1 + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + UserID: &userID, + Status: &status, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(1)) + for _, r := range result { + assert.Equal(t, uint(201), r.UserID) + assert.Equal(t, 1, r.Status) + } + }) + + t.Run("分页查询", func(t *testing.T) { + params := &ListRechargeParams{ + Page: 1, + PageSize: 2, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(4)) + assert.LessOrEqual(t, len(result), 2) + }) + + t.Run("默认分页参数", func(t *testing.T) { + params := &ListRechargeParams{ + Page: 0, // 无效值,应使用默认值 1 + PageSize: 0, // 无效值,应使用默认值 20 + } + result, _, err := s.List(ctx, params) + require.NoError(t, err) + assert.NotNil(t, result) + }) + + t.Run("按 ID 降序排列", func(t *testing.T) { + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + } + result, _, err := s.List(ctx, params) + require.NoError(t, err) + require.GreaterOrEqual(t, len(result), 2) + // 验证降序排列 + for i := 0; i < len(result)-1; i++ { + assert.GreaterOrEqual(t, result[i].ID, result[i+1].ID) + } + }) +} + +func TestRechargeStore_UpdateStatus(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + recharge := &model.RechargeRecord{ + UserID: 300, + WalletID: 400, + RechargeNo: "RCH20260131120000000020", + Amount: 50000, + PaymentMethod: "wechat", + Status: 1, // 待支付 + } + require.NoError(t, s.Create(ctx, recharge)) + + t.Run("更新状态为已支付(无乐观锁)", func(t *testing.T) { + now := time.Now() + err := s.UpdateStatus(ctx, recharge.ID, nil, 2, &now, nil) + require.NoError(t, err) + + updated, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + assert.Equal(t, 2, updated.Status) + assert.NotNil(t, updated.PaidAt) + }) + + t.Run("更新状态为已完成(带乐观锁)", func(t *testing.T) { + oldStatus := 2 + now := time.Now() + err := s.UpdateStatus(ctx, recharge.ID, &oldStatus, 3, nil, &now) + require.NoError(t, err) + + updated, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + assert.Equal(t, 3, updated.Status) + assert.NotNil(t, updated.CompletedAt) + }) + + t.Run("乐观锁检查失败", func(t *testing.T) { + oldStatus := 1 // 当前状态是 3,不是 1 + err := s.UpdateStatus(ctx, recharge.ID, &oldStatus, 4, nil, nil) + require.Error(t, err) + }) + + t.Run("更新不存在的充值订单", func(t *testing.T) { + err := s.UpdateStatus(ctx, 99999, nil, 2, nil, nil) + require.Error(t, err) + }) +} + +func TestRechargeStore_UpdatePaymentInfo(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + recharge := &model.RechargeRecord{ + UserID: 400, + WalletID: 500, + RechargeNo: "RCH20260131120000000030", + Amount: 60000, + PaymentMethod: "wechat", + Status: 1, + } + require.NoError(t, s.Create(ctx, recharge)) + + t.Run("更新支付渠道和交易号", func(t *testing.T) { + channel := "wechat_jsapi" + transactionID := "WX1234567890" + err := s.UpdatePaymentInfo(ctx, recharge.ID, &channel, &transactionID) + require.NoError(t, err) + + updated, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + require.NotNil(t, updated.PaymentChannel) + assert.Equal(t, "wechat_jsapi", *updated.PaymentChannel) + require.NotNil(t, updated.PaymentTransactionID) + assert.Equal(t, "WX1234567890", *updated.PaymentTransactionID) + }) + + t.Run("只更新支付渠道", func(t *testing.T) { + channel := "alipay_h5" + err := s.UpdatePaymentInfo(ctx, recharge.ID, &channel, nil) + require.NoError(t, err) + + updated, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + require.NotNil(t, updated.PaymentChannel) + assert.Equal(t, "alipay_h5", *updated.PaymentChannel) + }) + + t.Run("只更新交易号", func(t *testing.T) { + transactionID := "ALI9876543210" + err := s.UpdatePaymentInfo(ctx, recharge.ID, nil, &transactionID) + require.NoError(t, err) + + updated, err := s.GetByID(ctx, recharge.ID) + require.NoError(t, err) + require.NotNil(t, updated.PaymentTransactionID) + assert.Equal(t, "ALI9876543210", *updated.PaymentTransactionID) + }) + + t.Run("不更新任何字段", func(t *testing.T) { + err := s.UpdatePaymentInfo(ctx, recharge.ID, nil, nil) + require.NoError(t, err) + }) + + t.Run("更新不存在的充值订单", func(t *testing.T) { + channel := "test_channel" + err := s.UpdatePaymentInfo(ctx, 99999, &channel, nil) + require.Error(t, err) + }) +} + +func TestRechargeStore_ConcurrentOperations(t *testing.T) { + tx := testutils.NewTestTransaction(t) + rdb := testutils.GetTestRedis(t) + testutils.CleanTestRedisKeys(t, rdb) + s := NewRechargeStore(tx, rdb) + ctx := context.Background() + + // 创建多个充值订单 + for i := 0; i < 10; i++ { + recharge := &model.RechargeRecord{ + UserID: uint(500 + i), + WalletID: uint(600 + i), + RechargeNo: "RCH20260131120000000040" + string(rune('0'+i)), + Amount: int64(10000 * (i + 1)), + PaymentMethod: "wechat", + Status: 1, + } + require.NoError(t, s.Create(ctx, recharge)) + } + + // 验证查询 + params := &ListRechargeParams{ + Page: 1, + PageSize: 20, + } + result, total, err := s.List(ctx, params) + require.NoError(t, err) + assert.GreaterOrEqual(t, total, int64(10)) + assert.GreaterOrEqual(t, len(result), 10) +}