package unit import ( "context" "testing" "time" "github.com/bytedance/sonic" "github.com/hibiken/asynq" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/break/junhong_cmp_fiber/pkg/constants" ) // MockEmailPayload 邮件任务载荷(测试用) type MockEmailPayload struct { RequestID string `json:"request_id"` To string `json:"to"` Subject string `json:"subject"` Body string `json:"body"` CC []string `json:"cc,omitempty"` } // TestHandlerIdempotency 测试处理器幂等性逻辑 func TestHandlerIdempotency(t *testing.T) { redisClient := redis.NewClient(&redis.Options{ Addr: "localhost:6379", }) defer redisClient.Close() ctx := context.Background() redisClient.FlushDB(ctx) requestID := "test-req-001" lockKey := constants.RedisTaskLockKey(requestID) // 测试场景1: 第一次执行(未加锁) t.Run("First Execution - Should Acquire Lock", func(t *testing.T) { result, err := redisClient.SetNX(ctx, lockKey, "1", 24*time.Hour).Result() require.NoError(t, err) assert.True(t, result, "第一次执行应该成功获取锁") }) // 测试场景2: 重复执行(已加锁) t.Run("Duplicate Execution - Should Skip", func(t *testing.T) { result, err := redisClient.SetNX(ctx, lockKey, "1", 24*time.Hour).Result() require.NoError(t, err) assert.False(t, result, "重复执行应该跳过(锁已存在)") }) // 清理 redisClient.Del(ctx, lockKey) } // TestHandlerErrorHandling 测试处理器错误处理 func TestHandlerErrorHandling(t *testing.T) { tests := []struct { name string payload MockEmailPayload shouldError bool errorMsg string }{ { name: "Valid Payload", payload: MockEmailPayload{ RequestID: "valid-001", To: "test@example.com", Subject: "Test", Body: "Test Body", }, shouldError: false, }, { name: "Missing RequestID", payload: MockEmailPayload{ RequestID: "", To: "test@example.com", Subject: "Test", Body: "Test Body", }, shouldError: true, errorMsg: "request_id 不能为空", }, { name: "Missing To", payload: MockEmailPayload{ RequestID: "test-002", To: "", Subject: "Test", Body: "Test Body", }, shouldError: true, errorMsg: "收件人不能为空", }, { name: "Invalid Email Format", payload: MockEmailPayload{ RequestID: "test-003", To: "invalid-email", Subject: "Test", Body: "Test Body", }, shouldError: true, errorMsg: "邮箱格式无效", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 验证载荷 err := validateEmailPayload(&tt.payload) if tt.shouldError { require.Error(t, err) assert.Contains(t, err.Error(), tt.errorMsg) } else { require.NoError(t, err) } }) } } // validateEmailPayload 验证邮件载荷(模拟实际处理器中的验证逻辑) func validateEmailPayload(payload *MockEmailPayload) error { if payload.RequestID == "" { return asynq.SkipRetry // 参数错误不重试 } if payload.To == "" { return asynq.SkipRetry } // 简单的邮箱格式验证 if payload.To != "" && !contains(payload.To, "@") { return asynq.SkipRetry } return nil } func contains(s, substr string) bool { for i := 0; i < len(s)-len(substr)+1; i++ { if s[i:i+len(substr)] == substr { return true } } return false } // TestHandlerRetryLogic 测试重试逻辑 func TestHandlerRetryLogic(t *testing.T) { tests := []struct { name string error error shouldRetry bool }{ { name: "Retryable Error - Network Issue", error: assert.AnError, shouldRetry: true, }, { name: "Non-Retryable Error - Invalid Params", error: asynq.SkipRetry, shouldRetry: false, }, { name: "No Error", error: nil, shouldRetry: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { shouldRetry := tt.error != nil && tt.error != asynq.SkipRetry assert.Equal(t, tt.shouldRetry, shouldRetry) }) } } // TestPayloadDeserialization 测试载荷反序列化 func TestPayloadDeserialization(t *testing.T) { tests := []struct { name string jsonPayload string expectError bool }{ { name: "Valid JSON", jsonPayload: `{"request_id":"test-001","to":"test@example.com","subject":"Test","body":"Body"}`, expectError: false, }, { name: "Invalid JSON", jsonPayload: `{invalid json}`, expectError: true, }, { name: "Empty JSON", jsonPayload: `{}`, expectError: false, // JSON 解析成功,但验证会失败 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var payload MockEmailPayload err := sonic.Unmarshal([]byte(tt.jsonPayload), &payload) if tt.expectError { require.Error(t, err) } else { require.NoError(t, err) } }) } } // TestTaskStatusTransition 测试任务状态转换 func TestTaskStatusTransition(t *testing.T) { redisClient := redis.NewClient(&redis.Options{ Addr: "localhost:6379", }) defer redisClient.Close() ctx := context.Background() redisClient.FlushDB(ctx) taskID := "task-transition-001" statusKey := constants.RedisTaskStatusKey(taskID) // 状态转换序列 transitions := []struct { status string valid bool }{ {"pending", true}, {"processing", true}, {"completed", true}, {"failed", false}, // completed 后不应该转到 failed } currentStatus := "" for _, tr := range transitions { t.Run("Transition to "+tr.status, func(t *testing.T) { // 检查状态转换是否合法 if isValidTransition(currentStatus, tr.status) == tr.valid { err := redisClient.Set(ctx, statusKey, tr.status, 7*24*time.Hour).Err() require.NoError(t, err) currentStatus = tr.status } else { // 不合法的转换应该被拒绝 assert.False(t, tr.valid) } }) } } // isValidTransition 检查状态转换是否合法 func isValidTransition(from, to string) bool { validTransitions := map[string][]string{ "": {"pending"}, "pending": {"processing", "failed"}, "processing": {"completed", "failed"}, "completed": {}, // 终态 "failed": {}, // 终态 } allowed, exists := validTransitions[from] if !exists { return false } for _, valid := range allowed { if valid == to { return true } } return false } // TestConcurrentTaskExecution 测试并发任务执行 func TestConcurrentTaskExecution(t *testing.T) { redisClient := redis.NewClient(&redis.Options{ Addr: "localhost:6379", }) defer redisClient.Close() ctx := context.Background() redisClient.FlushDB(ctx) // 模拟多个并发任务尝试获取同一个锁 requestID := "concurrent-test-001" lockKey := constants.RedisTaskLockKey(requestID) concurrency := 10 successCount := 0 done := make(chan bool, concurrency) // 并发执行 for i := 0; i < concurrency; i++ { go func() { result, err := redisClient.SetNX(ctx, lockKey, "1", 24*time.Hour).Result() if err == nil && result { successCount++ } done <- true }() } // 等待所有 goroutine 完成 for i := 0; i < concurrency; i++ { <-done } // 验证只有一个成功获取锁 assert.Equal(t, 1, successCount, "只有一个任务应该成功获取锁") } // TestTaskTimeout 测试任务超时处理 func TestTaskTimeout(t *testing.T) { tests := []struct { name string taskDuration time.Duration timeout time.Duration shouldTimeout bool }{ { name: "Normal Execution", taskDuration: 100 * time.Millisecond, timeout: 1 * time.Second, shouldTimeout: false, }, { name: "Timeout Execution", taskDuration: 2 * time.Second, timeout: 500 * time.Millisecond, shouldTimeout: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() // 模拟任务执行 done := make(chan bool) go func() { time.Sleep(tt.taskDuration) done <- true }() select { case <-done: assert.False(t, tt.shouldTimeout, "任务应该正常完成") case <-ctx.Done(): assert.True(t, tt.shouldTimeout, "任务应该超时") } }) } } // TestLockExpiration 测试锁过期机制 func TestLockExpiration(t *testing.T) { redisClient := redis.NewClient(&redis.Options{ Addr: "localhost:6379", }) defer redisClient.Close() ctx := context.Background() redisClient.FlushDB(ctx) requestID := "expiration-test-001" lockKey := constants.RedisTaskLockKey(requestID) // 设置短 TTL 的锁 result, err := redisClient.SetNX(ctx, lockKey, "1", 100*time.Millisecond).Result() require.NoError(t, err) assert.True(t, result) // 等待锁过期 time.Sleep(200 * time.Millisecond) // 验证锁已过期,可以重新获取 result, err = redisClient.SetNX(ctx, lockKey, "1", 24*time.Hour).Result() require.NoError(t, err) assert.True(t, result, "锁过期后应该可以重新获取") }