refactor: align framework cleanup with new bootstrap flow
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/response"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/validator"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -52,7 +52,16 @@ func setupAuthTestApp(t *testing.T, rdb *redis.Client) *fiber.App {
|
||||
|
||||
// Add authentication middleware
|
||||
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
|
||||
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
|
||||
app.Use(middleware.Auth(middleware.AuthConfig{
|
||||
TokenValidator: func(token string) (uint, int, uint, error) {
|
||||
_, err := tokenValidator.Validate(token)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
// 测试中简化处理:userID 设为 1,userType 设为普通用户
|
||||
return 1, 0, 0, nil
|
||||
},
|
||||
}))
|
||||
|
||||
// Add protected test routes
|
||||
app.Get("/api/v1/test", func(c *fiber.Ctx) error {
|
||||
@@ -342,14 +351,23 @@ func TestKeyAuthMiddleware_UserIDPropagation(t *testing.T) {
|
||||
|
||||
// Add authentication middleware
|
||||
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
|
||||
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
|
||||
app.Use(middleware.Auth(middleware.AuthConfig{
|
||||
TokenValidator: func(token string) (uint, int, uint, error) {
|
||||
_, err := tokenValidator.Validate(token)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
// 测试中简化处理:userID 设为 1,userType 设为普通用户
|
||||
return 1, 0, 0, nil
|
||||
},
|
||||
}))
|
||||
|
||||
// Add test route that checks user ID
|
||||
var capturedUserID string
|
||||
var capturedUserID uint
|
||||
app.Get("/api/v1/check-user", func(c *fiber.Ctx) error {
|
||||
userID, ok := c.Locals(constants.ContextKeyUserID).(string)
|
||||
userID, ok := c.Locals(constants.ContextKeyUserID).(uint)
|
||||
if !ok {
|
||||
return response.Error(c, 500, errors.CodeInternalError, "User ID not found in context")
|
||||
return errors.New(errors.CodeInternalError, "User ID not found in context")
|
||||
}
|
||||
capturedUserID = userID
|
||||
return response.Success(c, fiber.Map{
|
||||
|
||||
Reference in New Issue
Block a user