package auth import ( "context" "sort" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model/dto" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/auth" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type Service struct { accountStore *postgres.AccountStore accountRoleStore *postgres.AccountRoleStore rolePermStore *postgres.RolePermissionStore permissionStore *postgres.PermissionStore tokenManager *auth.TokenManager logger *zap.Logger } func New( accountStore *postgres.AccountStore, accountRoleStore *postgres.AccountRoleStore, rolePermStore *postgres.RolePermissionStore, permissionStore *postgres.PermissionStore, tokenManager *auth.TokenManager, logger *zap.Logger, ) *Service { return &Service{ accountStore: accountStore, accountRoleStore: accountRoleStore, rolePermStore: rolePermStore, permissionStore: permissionStore, tokenManager: tokenManager, logger: logger, } } func (s *Service) Login(ctx context.Context, req *dto.LoginRequest, clientIP string) (*dto.LoginResponse, error) { ctx = pkgGorm.SkipDataPermission(ctx) account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username) if err != nil { if err == gorm.ErrRecordNotFound { s.logger.Warn("登录失败:用户名不存在", zap.String("username", req.Username), zap.String("ip", clientIP)) return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误") } return nil, errors.Wrap(errors.CodeInternalError, err, "查询账号失败") } if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(req.Password)); err != nil { s.logger.Warn("登录失败:密码错误", zap.String("username", req.Username), zap.String("ip", clientIP)) return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误") } if account.Status != 1 { s.logger.Warn("登录失败:账号已禁用", zap.String("username", req.Username), zap.Uint("user_id", account.ID)) return nil, errors.New(errors.CodeAccountDisabled, "账号已禁用") } device := req.Device if device == "" { device = "web" } var shopID, enterpriseID uint if account.ShopID != nil { shopID = *account.ShopID } if account.EnterpriseID != nil { enterpriseID = *account.EnterpriseID } tokenInfo := &auth.TokenInfo{ UserID: account.ID, UserType: account.UserType, ShopID: shopID, EnterpriseID: enterpriseID, Username: account.Username, Device: device, IP: clientIP, } accessToken, refreshToken, err := s.tokenManager.GenerateTokenPair(ctx, tokenInfo) if err != nil { return nil, err } permissions, menus, buttons, err := s.getUserPermissionsAndMenus(ctx, account.ID, account.UserType, device) if err != nil { s.logger.Error("查询用户权限失败", zap.Uint("user_id", account.ID), zap.Error(err)) permissions = []string{} menus = []dto.MenuNode{} buttons = []string{} } userInfo := s.buildUserInfo(account) s.logger.Info("用户登录成功", zap.Uint("user_id", account.ID), zap.String("username", account.Username), zap.String("device", device), zap.String("ip", clientIP), ) return &dto.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: int64(constants.DefaultAccessTokenTTL.Seconds()), User: userInfo, Permissions: permissions, Menus: menus, Buttons: buttons, }, nil } func (s *Service) Logout(ctx context.Context, accessToken, refreshToken string) error { if err := s.tokenManager.RevokeToken(ctx, accessToken); err != nil { return err } if refreshToken != "" { if err := s.tokenManager.RevokeToken(ctx, refreshToken); err != nil { s.logger.Warn("撤销 refresh token 失败", zap.Error(err)) } } return nil } func (s *Service) RefreshToken(ctx context.Context, refreshToken string) (string, error) { return s.tokenManager.RefreshAccessToken(ctx, refreshToken) } func (s *Service) GetCurrentUser(ctx context.Context, userID uint) (*dto.UserInfo, []string, error) { account, err := s.accountStore.GetByID(ctx, userID) if err != nil { if err == gorm.ErrRecordNotFound { return nil, nil, errors.New(errors.CodeAccountNotFound, "账号不存在") } return nil, nil, errors.Wrap(errors.CodeInternalError, err, "查询账号失败") } permissions, err := s.getUserPermissions(ctx, userID) if err != nil { s.logger.Error("查询用户权限失败", zap.Uint("user_id", userID), zap.Error(err)) permissions = []string{} } userInfo := s.buildUserInfo(account) return &userInfo, permissions, nil } func (s *Service) ChangePassword(ctx context.Context, userID uint, oldPassword, newPassword string) error { account, err := s.accountStore.GetByID(ctx, userID) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeAccountNotFound, "账号不存在") } return errors.Wrap(errors.CodeInternalError, err, "查询账号失败") } if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(oldPassword)); err != nil { return errors.New(errors.CodeInvalidOldPassword, "旧密码错误") } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return errors.Wrap(errors.CodeInternalError, err, "密码加密失败") } if err := s.accountStore.UpdatePassword(ctx, userID, string(hashedPassword), userID); err != nil { return errors.Wrap(errors.CodeInternalError, err, "更新密码失败") } if err := s.tokenManager.RevokeAllUserTokens(ctx, userID); err != nil { s.logger.Warn("撤销用户所有 token 失败", zap.Uint("user_id", userID), zap.Error(err)) } s.logger.Info("用户修改密码成功", zap.Uint("user_id", userID)) return nil } func (s *Service) getUserPermissions(ctx context.Context, userID uint) ([]string, error) { accountRoles, err := s.accountRoleStore.GetByAccountID(ctx, userID) if err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询用户角色失败") } if len(accountRoles) == 0 { return []string{}, nil } roleIDs := make([]uint, 0, len(accountRoles)) for _, ar := range accountRoles { roleIDs = append(roleIDs, ar.RoleID) } permIDs, err := s.rolePermStore.GetPermIDsByRoleIDs(ctx, roleIDs) if err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询角色权限失败") } if len(permIDs) == 0 { return []string{}, nil } permissions, err := s.permissionStore.GetByIDs(ctx, permIDs) if err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询权限详情失败") } permCodes := make([]string, 0, len(permissions)) for _, perm := range permissions { permCodes = append(permCodes, perm.PermCode) } return permCodes, nil } func (s *Service) buildUserInfo(account *model.Account) dto.UserInfo { userTypeName := s.getUserTypeName(account.UserType) var shopID, enterpriseID uint if account.ShopID != nil { shopID = *account.ShopID } if account.EnterpriseID != nil { enterpriseID = *account.EnterpriseID } return dto.UserInfo{ ID: account.ID, Username: account.Username, Phone: account.Phone, UserType: account.UserType, UserTypeName: userTypeName, ShopID: shopID, EnterpriseID: enterpriseID, } } func (s *Service) getUserTypeName(userType int) string { switch userType { case constants.UserTypeSuperAdmin: return "超级管理员" case constants.UserTypePlatform: return "平台用户" case constants.UserTypeAgent: return "代理账号" case constants.UserTypeEnterprise: return "企业账号" default: return "未知" } } func (s *Service) getUserPermissionsAndMenus(ctx context.Context, userID uint, userType int, device string) ([]string, []dto.MenuNode, []string, error) { if userType == constants.UserTypeSuperAdmin { return s.getAllPermissionsForSuperAdmin(ctx, device) } accountRoles, err := s.accountRoleStore.GetByAccountID(ctx, userID) if err != nil { return nil, nil, nil, errors.Wrap(errors.CodeInternalError, err, "查询用户角色失败") } if len(accountRoles) == 0 { return []string{}, []dto.MenuNode{}, []string{}, nil } roleIDs := make([]uint, 0, len(accountRoles)) for _, ar := range accountRoles { roleIDs = append(roleIDs, ar.RoleID) } permIDs, err := s.rolePermStore.GetPermIDsByRoleIDs(ctx, roleIDs) if err != nil { return nil, nil, nil, errors.Wrap(errors.CodeInternalError, err, "查询角色权限失败") } if len(permIDs) == 0 { return []string{}, []dto.MenuNode{}, []string{}, nil } permissions, err := s.permissionStore.GetByIDs(ctx, permIDs) if err != nil { return nil, nil, nil, errors.Wrap(errors.CodeInternalError, err, "查询权限详情失败") } return s.classifyPermissions(permissions, device) } func (s *Service) getAllPermissionsForSuperAdmin(ctx context.Context, device string) ([]string, []dto.MenuNode, []string, error) { permissions, err := s.permissionStore.GetAll(ctx, nil) if err != nil { return nil, nil, nil, errors.Wrap(errors.CodeInternalError, err, "查询所有权限失败") } return s.classifyPermissions(permissions, device) } func (s *Service) classifyPermissions(permissions []*model.Permission, device string) ([]string, []dto.MenuNode, []string, error) { var menuPerms []*model.Permission var buttonCodes []string var allCodes []string for _, perm := range permissions { if perm.Status != constants.StatusEnabled { continue } if perm.Platform != constants.PlatformAll && perm.Platform != device { continue } allCodes = append(allCodes, perm.PermCode) if perm.PermType == constants.PermissionTypeMenu { menuPerms = append(menuPerms, perm) } else if perm.PermType == constants.PermissionTypeButton { buttonCodes = append(buttonCodes, perm.PermCode) } } menuTree := s.buildMenuTree(menuPerms) return allCodes, menuTree, buttonCodes, nil } func (s *Service) buildMenuTree(permissions []*model.Permission) []dto.MenuNode { if len(permissions) == 0 { return []dto.MenuNode{} } permMap := make(map[uint]*model.Permission) for _, p := range permissions { permMap[p.ID] = p } var roots []dto.MenuNode for _, p := range permissions { if p.ParentID == nil || *p.ParentID == 0 { roots = append(roots, s.buildNode(p, permMap)) } else if _, ok := permMap[*p.ParentID]; !ok { s.logger.Warn("检测到孤儿节点", zap.Uint("child_id", p.ID), zap.String("perm_code", p.PermCode), zap.Uint("parent_id", *p.ParentID), ) roots = append(roots, s.buildNode(p, permMap)) } } s.sortMenuNodes(roots) return roots } func (s *Service) buildNode(perm *model.Permission, permMap map[uint]*model.Permission) dto.MenuNode { node := dto.MenuNode{ ID: perm.ID, PermCode: perm.PermCode, Name: perm.PermName, URL: perm.URL, Sort: perm.Sort, Children: []dto.MenuNode{}, } for _, p := range permMap { if p.ParentID != nil && *p.ParentID == perm.ID { node.Children = append(node.Children, s.buildNode(p, permMap)) } } return node } func (s *Service) sortMenuNodes(nodes []dto.MenuNode) { sort.Slice(nodes, func(i, j int) bool { return nodes[i].Sort < nodes[j].Sort }) for i := range nodes { if len(nodes[i].Children) > 0 { s.sortMenuNodes(nodes[i].Children) } } }