package auth import ( "context" "fmt" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/auth" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type Service struct { accountStore *postgres.AccountStore accountRoleStore *postgres.AccountRoleStore rolePermStore *postgres.RolePermissionStore permissionStore *postgres.PermissionStore tokenManager *auth.TokenManager logger *zap.Logger } func New( accountStore *postgres.AccountStore, accountRoleStore *postgres.AccountRoleStore, rolePermStore *postgres.RolePermissionStore, permissionStore *postgres.PermissionStore, tokenManager *auth.TokenManager, logger *zap.Logger, ) *Service { return &Service{ accountStore: accountStore, accountRoleStore: accountRoleStore, rolePermStore: rolePermStore, permissionStore: permissionStore, tokenManager: tokenManager, logger: logger, } } func (s *Service) Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) { ctx = pkgGorm.SkipDataPermission(ctx) account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username) if err != nil { if err == gorm.ErrRecordNotFound { s.logger.Warn("登录失败:用户名不存在", zap.String("username", req.Username), zap.String("ip", clientIP)) return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误") } return nil, errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err)) } if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(req.Password)); err != nil { s.logger.Warn("登录失败:密码错误", zap.String("username", req.Username), zap.String("ip", clientIP)) return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误") } if account.Status != 1 { s.logger.Warn("登录失败:账号已禁用", zap.String("username", req.Username), zap.Uint("user_id", account.ID)) return nil, errors.New(errors.CodeAccountDisabled, "账号已禁用") } device := req.Device if device == "" { device = "web" } var shopID, enterpriseID uint if account.ShopID != nil { shopID = *account.ShopID } if account.EnterpriseID != nil { enterpriseID = *account.EnterpriseID } tokenInfo := &auth.TokenInfo{ UserID: account.ID, UserType: account.UserType, ShopID: shopID, EnterpriseID: enterpriseID, Username: account.Username, Device: device, IP: clientIP, } accessToken, refreshToken, err := s.tokenManager.GenerateTokenPair(ctx, tokenInfo) if err != nil { return nil, err } permissions, err := s.getUserPermissions(ctx, account.ID) if err != nil { s.logger.Error("查询用户权限失败", zap.Uint("user_id", account.ID), zap.Error(err)) permissions = []string{} } userInfo := s.buildUserInfo(account) s.logger.Info("用户登录成功", zap.Uint("user_id", account.ID), zap.String("username", account.Username), zap.String("device", device), zap.String("ip", clientIP), ) return &model.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: int64(constants.DefaultAccessTokenTTL.Seconds()), User: userInfo, Permissions: permissions, }, nil } func (s *Service) Logout(ctx context.Context, accessToken, refreshToken string) error { if err := s.tokenManager.RevokeToken(ctx, accessToken); err != nil { return err } if refreshToken != "" { if err := s.tokenManager.RevokeToken(ctx, refreshToken); err != nil { s.logger.Warn("撤销 refresh token 失败", zap.Error(err)) } } return nil } func (s *Service) RefreshToken(ctx context.Context, refreshToken string) (string, error) { return s.tokenManager.RefreshAccessToken(ctx, refreshToken) } func (s *Service) GetCurrentUser(ctx context.Context, userID uint) (*model.UserInfo, []string, error) { account, err := s.accountStore.GetByID(ctx, userID) if err != nil { if err == gorm.ErrRecordNotFound { return nil, nil, errors.New(errors.CodeAccountNotFound, "账号不存在") } return nil, nil, errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err)) } permissions, err := s.getUserPermissions(ctx, userID) if err != nil { s.logger.Error("查询用户权限失败", zap.Uint("user_id", userID), zap.Error(err)) permissions = []string{} } userInfo := s.buildUserInfo(account) return &userInfo, permissions, nil } func (s *Service) ChangePassword(ctx context.Context, userID uint, oldPassword, newPassword string) error { account, err := s.accountStore.GetByID(ctx, userID) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeAccountNotFound, "账号不存在") } return errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err)) } if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(oldPassword)); err != nil { return errors.New(errors.CodeInvalidOldPassword, "旧密码错误") } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } if err := s.accountStore.UpdatePassword(ctx, userID, string(hashedPassword), userID); err != nil { return errors.New(errors.CodeDatabaseError, fmt.Sprintf("更新密码失败: %v", err)) } if err := s.tokenManager.RevokeAllUserTokens(ctx, userID); err != nil { s.logger.Warn("撤销用户所有 token 失败", zap.Uint("user_id", userID), zap.Error(err)) } s.logger.Info("用户修改密码成功", zap.Uint("user_id", userID)) return nil } func (s *Service) getUserPermissions(ctx context.Context, userID uint) ([]string, error) { accountRoles, err := s.accountRoleStore.GetByAccountID(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to get account roles: %w", err) } if len(accountRoles) == 0 { return []string{}, nil } roleIDs := make([]uint, 0, len(accountRoles)) for _, ar := range accountRoles { roleIDs = append(roleIDs, ar.RoleID) } permIDs, err := s.rolePermStore.GetPermIDsByRoleIDs(ctx, roleIDs) if err != nil { return nil, fmt.Errorf("failed to get permission IDs: %w", err) } if len(permIDs) == 0 { return []string{}, nil } permissions, err := s.permissionStore.GetByIDs(ctx, permIDs) if err != nil { return nil, fmt.Errorf("failed to get permissions: %w", err) } permCodes := make([]string, 0, len(permissions)) for _, perm := range permissions { permCodes = append(permCodes, perm.PermCode) } return permCodes, nil } func (s *Service) buildUserInfo(account *model.Account) model.UserInfo { userTypeName := s.getUserTypeName(account.UserType) var shopID, enterpriseID uint if account.ShopID != nil { shopID = *account.ShopID } if account.EnterpriseID != nil { enterpriseID = *account.EnterpriseID } return model.UserInfo{ ID: account.ID, Username: account.Username, Phone: account.Phone, UserType: account.UserType, UserTypeName: userTypeName, ShopID: shopID, EnterpriseID: enterpriseID, } } func (s *Service) getUserTypeName(userType int) string { switch userType { case constants.UserTypeSuperAdmin: return "超级管理员" case constants.UserTypePlatform: return "平台用户" case constants.UserTypeAgent: return "代理账号" case constants.UserTypeEnterprise: return "企业账号" default: return "未知" } }