feat: 完成B端认证系统和商户管理模块测试补全

主要变更:
- 新增B端认证系统(后台+H5):登录、登出、Token刷新、密码修改
- 完善商户管理和商户账号管理功能
- 补全单元测试(ShopService: 72.5%, ShopAccountService: 79.8%)
- 新增集成测试(商户管理+商户账号管理)
- 归档OpenSpec提案(add-shop-account-management, implement-b-end-auth-system)
- 完善文档(使用指南、API文档、认证架构说明)

测试统计:
- 13个测试套件,37个测试用例,100%通过率
- 平均覆盖率76.2%,达标

OpenSpec验证:通过(strict模式)
This commit is contained in:
2026-01-15 18:15:17 +08:00
parent 7ccd3d146c
commit 18f35f3ef4
64 changed files with 11875 additions and 242 deletions

View File

@@ -14,6 +14,7 @@ type Dependencies struct {
DB *gorm.DB // PostgreSQL 数据库连接
Redis *redis.Client // Redis 客户端
Logger *zap.Logger // 应用日志器
JWTManager *auth.JWTManager // JWT 管理器
JWTManager *auth.JWTManager // JWT 管理器(个人客户认证)
TokenManager *auth.TokenManager // Token 管理器后台和H5认证
VerificationService *verification.Service // 验证码服务
}

View File

@@ -3,15 +3,22 @@ package bootstrap
import (
"github.com/break/junhong_cmp_fiber/internal/handler/admin"
"github.com/break/junhong_cmp_fiber/internal/handler/app"
"github.com/break/junhong_cmp_fiber/internal/handler/h5"
"github.com/go-playground/validator/v10"
)
// initHandlers 初始化所有 Handler 实例
func initHandlers(svc *services, deps *Dependencies) *Handlers {
validate := validator.New()
return &Handlers{
Account: admin.NewAccountHandler(svc.Account),
Role: admin.NewRoleHandler(svc.Role),
Permission: admin.NewPermissionHandler(svc.Permission),
PersonalCustomer: app.NewPersonalCustomerHandler(svc.PersonalCustomer, deps.Logger),
// TODO: 新增 Handler 在此初始化
Shop: admin.NewShopHandler(svc.Shop),
ShopAccount: admin.NewShopAccountHandler(svc.ShopAccount),
AdminAuth: admin.NewAuthHandler(svc.Auth, validate),
H5Auth: h5.NewAuthHandler(svc.Auth, validate),
}
}

View File

@@ -1,9 +1,16 @@
package bootstrap
import (
"context"
"time"
"github.com/break/junhong_cmp_fiber/internal/middleware"
"github.com/break/junhong_cmp_fiber/pkg/auth"
pkgauth "github.com/break/junhong_cmp_fiber/pkg/auth"
"github.com/break/junhong_cmp_fiber/pkg/config"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
pkgmiddleware "github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/gofiber/fiber/v2"
)
// initMiddlewares 初始化所有中间件
@@ -12,12 +19,76 @@ func initMiddlewares(deps *Dependencies) *Middlewares {
cfg := config.Get()
// 创建 JWT Manager
jwtManager := auth.NewJWTManager(cfg.JWT.SecretKey, cfg.JWT.TokenDuration)
jwtManager := pkgauth.NewJWTManager(cfg.JWT.SecretKey, cfg.JWT.TokenDuration)
// 创建个人客户认证中间件
personalAuthMiddleware := middleware.NewPersonalAuthMiddleware(jwtManager, deps.Logger)
// 创建 Token Manager用于后台和H5认证
accessTTL := time.Duration(cfg.JWT.AccessTokenTTL) * time.Second
refreshTTL := time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second
tokenManager := pkgauth.NewTokenManager(deps.Redis, accessTTL, refreshTTL)
// 创建后台认证中间件
adminAuthMiddleware := createAdminAuthMiddleware(tokenManager)
// 创建H5认证中间件
h5AuthMiddleware := createH5AuthMiddleware(tokenManager)
return &Middlewares{
PersonalAuth: personalAuthMiddleware,
AdminAuth: adminAuthMiddleware,
H5Auth: h5AuthMiddleware,
}
}
func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler {
return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{
TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) {
tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token)
if err != nil {
return nil, errors.New(errors.CodeInvalidToken, "认证令牌无效或已过期")
}
// 检查用户类型:后台允许 SuperAdmin(1), Platform(2), Agent(3)
if tokenInfo.UserType != constants.UserTypeSuperAdmin &&
tokenInfo.UserType != constants.UserTypePlatform &&
tokenInfo.UserType != constants.UserTypeAgent {
return nil, errors.New(errors.CodeForbidden, "权限不足")
}
return &pkgmiddleware.UserContextInfo{
UserID: tokenInfo.UserID,
UserType: tokenInfo.UserType,
ShopID: tokenInfo.ShopID,
EnterpriseID: tokenInfo.EnterpriseID,
}, nil
},
SkipPaths: []string{"/api/admin/login", "/api/admin/refresh-token"},
})
}
func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler {
return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{
TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) {
tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token)
if err != nil {
return nil, errors.New(errors.CodeInvalidToken, "认证令牌无效或已过期")
}
// 检查用户类型H5 允许 Agent(3), Enterprise(4)
if tokenInfo.UserType != constants.UserTypeAgent &&
tokenInfo.UserType != constants.UserTypeEnterprise {
return nil, errors.New(errors.CodeForbidden, "权限不足")
}
return &pkgmiddleware.UserContextInfo{
UserID: tokenInfo.UserID,
UserType: tokenInfo.UserType,
ShopID: tokenInfo.ShopID,
EnterpriseID: tokenInfo.EnterpriseID,
}, nil
},
SkipPaths: []string{"/api/h5/login", "/api/h5/refresh-token"},
})
}

View File

@@ -2,9 +2,12 @@ package bootstrap
import (
accountSvc "github.com/break/junhong_cmp_fiber/internal/service/account"
authSvc "github.com/break/junhong_cmp_fiber/internal/service/auth"
permissionSvc "github.com/break/junhong_cmp_fiber/internal/service/permission"
personalCustomerSvc "github.com/break/junhong_cmp_fiber/internal/service/personal_customer"
roleSvc "github.com/break/junhong_cmp_fiber/internal/service/role"
shopSvc "github.com/break/junhong_cmp_fiber/internal/service/shop"
shopAccountSvc "github.com/break/junhong_cmp_fiber/internal/service/shop_account"
)
// services 封装所有 Service 实例
@@ -14,7 +17,9 @@ type services struct {
Role *roleSvc.Service
Permission *permissionSvc.Service
PersonalCustomer *personalCustomerSvc.Service
// TODO: 新增 Service 在此添加字段
Shop *shopSvc.Service
ShopAccount *shopAccountSvc.Service
Auth *authSvc.Service
}
// initServices 初始化所有 Service 实例
@@ -24,6 +29,8 @@ func initServices(s *stores, deps *Dependencies) *services {
Role: roleSvc.New(s.Role, s.Permission, s.RolePermission),
Permission: permissionSvc.New(s.Permission),
PersonalCustomer: personalCustomerSvc.NewService(s.PersonalCustomer, s.PersonalCustomerPhone, deps.VerificationService, deps.JWTManager, deps.Logger),
// TODO: 新增 Service 在此初始化
Shop: shopSvc.New(s.Shop, s.Account),
ShopAccount: shopAccountSvc.New(s.Account, s.Shop),
Auth: authSvc.New(s.Account, s.AccountRole, s.RolePermission, s.Permission, deps.TokenManager, deps.Logger),
}
}

View File

@@ -3,7 +3,9 @@ package bootstrap
import (
"github.com/break/junhong_cmp_fiber/internal/handler/admin"
"github.com/break/junhong_cmp_fiber/internal/handler/app"
"github.com/break/junhong_cmp_fiber/internal/handler/h5"
"github.com/break/junhong_cmp_fiber/internal/middleware"
"github.com/gofiber/fiber/v2"
)
// Handlers 封装所有 HTTP 处理器
@@ -13,12 +15,17 @@ type Handlers struct {
Role *admin.RoleHandler
Permission *admin.PermissionHandler
PersonalCustomer *app.PersonalCustomerHandler
// TODO: 新增 Handler 在此添加字段
Shop *admin.ShopHandler
ShopAccount *admin.ShopAccountHandler
AdminAuth *admin.AuthHandler
H5Auth *h5.AuthHandler
}
// Middlewares 封装所有中间件
// 用于路由注册
type Middlewares struct {
PersonalAuth *middleware.PersonalAuthMiddleware
AdminAuth func(*fiber.Ctx) error
H5Auth func(*fiber.Ctx) error
// TODO: 新增 Middleware 在此添加字段
}

View File

@@ -0,0 +1,143 @@
package admin
import (
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/service/auth"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
)
// AuthHandler 后台认证处理器
type AuthHandler struct {
authService *auth.Service
validator *validator.Validate
}
// NewAuthHandler 创建后台认证处理器
func NewAuthHandler(authService *auth.Service, validator *validator.Validate) *AuthHandler {
return &AuthHandler{
authService: authService,
validator: validator,
}
}
// Login 后台登录
func (h *AuthHandler) Login(c *fiber.Ctx) error {
var req model.LoginRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
clientIP := c.IP()
ctx := c.UserContext()
resp, err := h.authService.Login(ctx, &req, clientIP)
if err != nil {
return err
}
return response.Success(c, resp)
}
// Logout 后台登出
func (h *AuthHandler) Logout(c *fiber.Ctx) error {
auth := c.Get("Authorization")
accessToken := ""
if len(auth) > 7 && auth[:7] == "Bearer " {
accessToken = auth[7:]
}
refreshToken := ""
var req model.RefreshTokenRequest
if err := c.BodyParser(&req); err == nil {
refreshToken = req.RefreshToken
}
ctx := c.UserContext()
if err := h.authService.Logout(ctx, accessToken, refreshToken); err != nil {
return err
}
return response.Success(c, nil)
}
// RefreshToken 刷新访问令牌
func (h *AuthHandler) RefreshToken(c *fiber.Ctx) error {
var req model.RefreshTokenRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
ctx := c.UserContext()
newAccessToken, err := h.authService.RefreshToken(ctx, req.RefreshToken)
if err != nil {
return err
}
resp := &model.RefreshTokenResponse{
AccessToken: newAccessToken,
ExpiresIn: 86400,
}
return response.Success(c, resp)
}
// GetMe 获取当前用户信息
func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
userID := middleware.GetUserIDFromContext(c.UserContext())
if userID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
ctx := c.UserContext()
userInfo, permissions, err := h.authService.GetCurrentUser(ctx, userID)
if err != nil {
return err
}
data := map[string]interface{}{
"user": userInfo,
"permissions": permissions,
}
return response.Success(c, data)
}
// ChangePassword 修改密码
func (h *AuthHandler) ChangePassword(c *fiber.Ctx) error {
userID := middleware.GetUserIDFromContext(c.UserContext())
if userID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
var req model.ChangePasswordRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
ctx := c.UserContext()
if err := h.authService.ChangePassword(ctx, userID, req.OldPassword, req.NewPassword); err != nil {
return err
}
return response.Success(c, nil)
}

View File

@@ -0,0 +1,80 @@
package admin
import (
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/break/junhong_cmp_fiber/internal/model"
shopService "github.com/break/junhong_cmp_fiber/internal/service/shop"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/response"
)
type ShopHandler struct {
service *shopService.Service
}
func NewShopHandler(service *shopService.Service) *ShopHandler {
return &ShopHandler{service: service}
}
func (h *ShopHandler) List(c *fiber.Ctx) error {
var req model.ShopListRequest
if err := c.QueryParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
shops, total, err := h.service.ListShopResponses(c.UserContext(), &req)
if err != nil {
return err
}
return response.SuccessWithPagination(c, shops, total, req.Page, req.PageSize)
}
func (h *ShopHandler) Create(c *fiber.Ctx) error {
var req model.CreateShopRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
shop, err := h.service.Create(c.UserContext(), &req)
if err != nil {
return err
}
return response.Success(c, shop)
}
func (h *ShopHandler) Update(c *fiber.Ctx) error {
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
if err != nil {
return errors.New(errors.CodeInvalidParam, "无效的店铺 ID")
}
var req model.UpdateShopRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
shop, err := h.service.Update(c.UserContext(), uint(id), &req)
if err != nil {
return err
}
return response.Success(c, shop)
}
func (h *ShopHandler) Delete(c *fiber.Ctx) error {
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
if err != nil {
return errors.New(errors.CodeInvalidParam, "无效的店铺 ID")
}
if err := h.service.Delete(c.UserContext(), uint(id)); err != nil {
return err
}
return response.Success(c, nil)
}

View File

@@ -0,0 +1,103 @@
package admin
import (
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/break/junhong_cmp_fiber/internal/model"
shopAccountService "github.com/break/junhong_cmp_fiber/internal/service/shop_account"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/response"
)
type ShopAccountHandler struct {
service *shopAccountService.Service
}
func NewShopAccountHandler(service *shopAccountService.Service) *ShopAccountHandler {
return &ShopAccountHandler{service: service}
}
func (h *ShopAccountHandler) List(c *fiber.Ctx) error {
var req model.ShopAccountListRequest
if err := c.QueryParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
accounts, total, err := h.service.List(c.UserContext(), &req)
if err != nil {
return err
}
return response.SuccessWithPagination(c, accounts, total, req.Page, req.PageSize)
}
func (h *ShopAccountHandler) Create(c *fiber.Ctx) error {
var req model.CreateShopAccountRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
account, err := h.service.Create(c.UserContext(), &req)
if err != nil {
return err
}
return response.Success(c, account)
}
func (h *ShopAccountHandler) Update(c *fiber.Ctx) error {
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
if err != nil {
return errors.New(errors.CodeInvalidParam, "无效的账号 ID")
}
var req model.UpdateShopAccountRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
account, err := h.service.Update(c.UserContext(), uint(id), &req)
if err != nil {
return err
}
return response.Success(c, account)
}
func (h *ShopAccountHandler) UpdatePassword(c *fiber.Ctx) error {
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
if err != nil {
return errors.New(errors.CodeInvalidParam, "无效的账号 ID")
}
var req model.UpdateShopAccountPasswordRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.service.UpdatePassword(c.UserContext(), uint(id), &req); err != nil {
return err
}
return response.Success(c, nil)
}
func (h *ShopAccountHandler) UpdateStatus(c *fiber.Ctx) error {
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
if err != nil {
return errors.New(errors.CodeInvalidParam, "无效的账号 ID")
}
var req model.UpdateShopAccountStatusRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.service.UpdateStatus(c.UserContext(), uint(id), &req); err != nil {
return err
}
return response.Success(c, nil)
}

143
internal/handler/h5/auth.go Normal file
View File

@@ -0,0 +1,143 @@
package h5
import (
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/service/auth"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
)
// AuthHandler H5认证处理器
type AuthHandler struct {
authService *auth.Service
validator *validator.Validate
}
// NewAuthHandler 创建H5认证处理器
func NewAuthHandler(authService *auth.Service, validator *validator.Validate) *AuthHandler {
return &AuthHandler{
authService: authService,
validator: validator,
}
}
// Login H5登录
func (h *AuthHandler) Login(c *fiber.Ctx) error {
var req model.LoginRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
clientIP := c.IP()
ctx := c.UserContext()
resp, err := h.authService.Login(ctx, &req, clientIP)
if err != nil {
return err
}
return response.Success(c, resp)
}
// Logout H5登出
func (h *AuthHandler) Logout(c *fiber.Ctx) error {
auth := c.Get("Authorization")
accessToken := ""
if len(auth) > 7 && auth[:7] == "Bearer " {
accessToken = auth[7:]
}
refreshToken := ""
var req model.RefreshTokenRequest
if err := c.BodyParser(&req); err == nil {
refreshToken = req.RefreshToken
}
ctx := c.UserContext()
if err := h.authService.Logout(ctx, accessToken, refreshToken); err != nil {
return err
}
return response.Success(c, nil)
}
// RefreshToken 刷新访问令牌
func (h *AuthHandler) RefreshToken(c *fiber.Ctx) error {
var req model.RefreshTokenRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
ctx := c.UserContext()
newAccessToken, err := h.authService.RefreshToken(ctx, req.RefreshToken)
if err != nil {
return err
}
resp := &model.RefreshTokenResponse{
AccessToken: newAccessToken,
ExpiresIn: 86400,
}
return response.Success(c, resp)
}
// GetMe 获取当前用户信息
func (h *AuthHandler) GetMe(c *fiber.Ctx) error {
userID := middleware.GetUserIDFromContext(c.UserContext())
if userID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
ctx := c.UserContext()
userInfo, permissions, err := h.authService.GetCurrentUser(ctx, userID)
if err != nil {
return err
}
data := map[string]interface{}{
"user": userInfo,
"permissions": permissions,
}
return response.Success(c, data)
}
// ChangePassword 修改密码
func (h *AuthHandler) ChangePassword(c *fiber.Ctx) error {
userID := middleware.GetUserIDFromContext(c.UserContext())
if userID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
var req model.ChangePasswordRequest
if err := c.BodyParser(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
}
if err := h.validator.Struct(&req); err != nil {
return errors.New(errors.CodeInvalidParam, "参数验证失败: "+err.Error())
}
ctx := c.UserContext()
if err := h.authService.ChangePassword(ctx, userID, req.OldPassword, req.NewPassword); err != nil {
return err
}
return response.Success(c, nil)
}

View File

@@ -0,0 +1,41 @@
package model
type LoginRequest struct {
Username string `json:"username" validate:"required"`
Password string `json:"password" validate:"required"`
Device string `json:"device" validate:"omitempty,oneof=web h5 mobile"`
}
type LoginResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
User UserInfo `json:"user"`
Permissions []string `json:"permissions"`
}
type UserInfo struct {
ID uint `json:"id"`
Username string `json:"username"`
Phone string `json:"phone"`
UserType int `json:"user_type"`
UserTypeName string `json:"user_type_name"`
ShopID uint `json:"shop_id,omitempty"`
ShopName string `json:"shop_name,omitempty"`
EnterpriseID uint `json:"enterprise_id,omitempty"`
EnterpriseName string `json:"enterprise_name,omitempty"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" validate:"required"`
}
type RefreshTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
}
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=32"`
}

View File

@@ -0,0 +1,48 @@
package model
// ShopAccountListRequest 代理商账号列表查询请求
type ShopAccountListRequest struct {
Page int `json:"page" query:"page" validate:"omitempty,min=1"` // 页码
PageSize int `json:"page_size" query:"page_size" validate:"omitempty,min=1,max=100"` // 每页数量
ShopID *uint `json:"shop_id" query:"shop_id" validate:"omitempty,min=1"` // 店铺ID过滤
Username string `json:"username" query:"username" validate:"omitempty,max=50"` // 用户名(模糊查询)
Phone string `json:"phone" query:"phone" validate:"omitempty,len=11"` // 手机号(精确查询)
Status *int `json:"status" query:"status" validate:"omitempty,oneof=0 1"` // 状态
}
// CreateShopAccountRequest 创建代理商账号请求
type CreateShopAccountRequest struct {
ShopID uint `json:"shop_id" validate:"required,min=1"` // 店铺ID
Username string `json:"username" validate:"required,min=3,max=50"` // 用户名
Phone string `json:"phone" validate:"required,len=11"` // 手机号
Password string `json:"password" validate:"required,min=8,max=32"` // 密码
}
// UpdateShopAccountRequest 更新代理商账号请求
type UpdateShopAccountRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"` // 用户名
// 注意:不包含 phone 和 password按照业务规则不允许修改
}
// UpdateShopAccountPasswordRequest 修改代理商账号密码请求(管理员重置)
type UpdateShopAccountPasswordRequest struct {
NewPassword string `json:"new_password" validate:"required,min=8,max=32"` // 新密码
}
// UpdateShopAccountStatusRequest 修改代理商账号状态请求
type UpdateShopAccountStatusRequest struct {
Status int `json:"status" validate:"required,oneof=0 1"` // 状态0=禁用 1=启用)
}
// ShopAccountResponse 代理商账号响应
type ShopAccountResponse struct {
ID uint `json:"id"`
ShopID uint `json:"shop_id"`
ShopName string `json:"shop_name,omitempty"` // 关联查询时填充
Username string `json:"username"`
Phone string `json:"phone"`
UserType int `json:"user_type"`
Status int `json:"status"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}

View File

@@ -1,28 +1,39 @@
package model
// CreateShopRequest 创建店铺请求
type CreateShopRequest struct {
ShopName string `json:"shop_name" validate:"required"` // 店铺名称
ShopCode string `json:"shop_code"` // 店铺编号
ParentID *uint `json:"parent_id"` // 上级店铺ID
ContactName string `json:"contact_name"` // 联系人姓名
ContactPhone string `json:"contact_phone" validate:"omitempty"` // 联系人电话
Province string `json:"province"` // 省份
City string `json:"city"` // 城市
District string `json:"district"` // 区县
Address string `json:"address"` // 详细地址
type ShopListRequest struct {
Page int `json:"page" query:"page" validate:"omitempty,min=1"`
PageSize int `json:"page_size" query:"page_size" validate:"omitempty,min=1,max=100"`
ShopName string `json:"shop_name" query:"shop_name" validate:"omitempty,max=100"`
ShopCode string `json:"shop_code" query:"shop_code" validate:"omitempty,max=50"`
ParentID *uint `json:"parent_id" query:"parent_id" validate:"omitempty,min=1"`
Level *int `json:"level" query:"level" validate:"omitempty,min=1,max=7"`
Status *int `json:"status" query:"status" validate:"omitempty,oneof=0 1"`
}
type CreateShopRequest struct {
ShopName string `json:"shop_name" validate:"required,min=1,max=100"`
ShopCode string `json:"shop_code" validate:"required,min=1,max=50"`
ParentID *uint `json:"parent_id" validate:"omitempty,min=1"`
ContactName string `json:"contact_name" validate:"omitempty,max=50"`
ContactPhone string `json:"contact_phone" validate:"omitempty,len=11"`
Province string `json:"province" validate:"omitempty,max=50"`
City string `json:"city" validate:"omitempty,max=50"`
District string `json:"district" validate:"omitempty,max=50"`
Address string `json:"address" validate:"omitempty,max=255"`
InitPassword string `json:"init_password" validate:"required,min=8,max=32"`
InitUsername string `json:"init_username" validate:"required,min=3,max=50"`
InitPhone string `json:"init_phone" validate:"required,len=11"`
}
// UpdateShopRequest 更新店铺请求
type UpdateShopRequest struct {
ShopName *string `json:"shop_name"` // 店铺名称
ShopCode *string `json:"shop_code"` // 店铺编号
ContactName *string `json:"contact_name"` // 联系人姓名
ContactPhone *string `json:"contact_phone"` // 联系人电话
Province *string `json:"province"` // 省份
City *string `json:"city"` // 城市
District *string `json:"district"` // 区县
Address *string `json:"address"` // 详细地址
ShopName string `json:"shop_name" validate:"required,min=1,max=100"`
ContactName string `json:"contact_name" validate:"omitempty,max=50"`
ContactPhone string `json:"contact_phone" validate:"omitempty,len=11"`
Province string `json:"province" validate:"omitempty,max=50"`
City string `json:"city" validate:"omitempty,max=50"`
District string `json:"district" validate:"omitempty,max=50"`
Address string `json:"address" validate:"omitempty,max=255"`
Status int `json:"status" validate:"required,oneof=0 1"`
}
// ShopResponse 店铺响应

View File

@@ -19,6 +19,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.CreateAccountRequest),
Output: new(model.AccountResponse),
Auth: true,
})
Register(accounts, doc, groupPath, "GET", "", h.List, RouteSpec{
@@ -26,6 +27,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.AccountListRequest),
Output: new(model.AccountPageResult),
Auth: true,
})
Register(accounts, doc, groupPath, "GET", "/:id", h.Get, RouteSpec{
@@ -33,6 +35,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.IDReq),
Output: new(model.AccountResponse),
Auth: true,
})
Register(accounts, doc, groupPath, "PUT", "/:id", h.Update, RouteSpec{
@@ -40,6 +43,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.UpdateAccountParams),
Output: new(model.AccountResponse),
Auth: true,
})
Register(accounts, doc, groupPath, "DELETE", "/:id", h.Delete, RouteSpec{
@@ -47,6 +51,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.IDReq),
Output: nil,
Auth: true,
})
// 账号-角色关联
@@ -62,6 +67,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.IDReq),
Output: new([]model.Role),
Auth: true,
})
Register(accounts, doc, groupPath, "DELETE", "/:account_id/roles/:role_id", h.RemoveRole, RouteSpec{
@@ -69,6 +75,7 @@ func registerAccountRoutes(api fiber.Router, h *admin.AccountHandler, doc *opena
Tags: []string{"账号相关"},
Input: new(model.RemoveRoleParams),
Output: nil,
Auth: true,
})
registerPlatformAccountRoutes(api, h, doc, basePath)
@@ -83,6 +90,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.PlatformAccountListRequest),
Output: new(model.AccountPageResult),
Auth: true,
})
Register(platformAccounts, doc, groupPath, "POST", "", h.Create, RouteSpec{
@@ -90,6 +98,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.CreateAccountRequest),
Output: new(model.AccountResponse),
Auth: true,
})
Register(platformAccounts, doc, groupPath, "GET", "/:id", h.Get, RouteSpec{
@@ -97,6 +106,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.IDReq),
Output: new(model.AccountResponse),
Auth: true,
})
Register(platformAccounts, doc, groupPath, "PUT", "/:id", h.Update, RouteSpec{
@@ -104,6 +114,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.UpdateAccountParams),
Output: new(model.AccountResponse),
Auth: true,
})
Register(platformAccounts, doc, groupPath, "DELETE", "/:id", h.Delete, RouteSpec{
@@ -111,6 +122,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.IDReq),
Output: nil,
Auth: true,
})
Register(platformAccounts, doc, groupPath, "PUT", "/:id/password", h.UpdatePassword, RouteSpec{
@@ -118,6 +130,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.UpdatePasswordParams),
Output: nil,
Auth: true,
})
Register(platformAccounts, doc, groupPath, "PUT", "/:id/status", h.UpdateStatus, RouteSpec{
@@ -125,6 +138,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.UpdateStatusParams),
Output: nil,
Auth: true,
})
Register(platformAccounts, doc, groupPath, "POST", "/:id/roles", h.AssignRoles, RouteSpec{
@@ -132,6 +146,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.AssignRolesParams),
Output: nil,
Auth: true,
})
Register(platformAccounts, doc, groupPath, "GET", "/:id/roles", h.GetRoles, RouteSpec{
@@ -139,6 +154,7 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.IDReq),
Output: new([]model.Role),
Auth: true,
})
Register(platformAccounts, doc, groupPath, "DELETE", "/:account_id/roles/:role_id", h.RemoveRole, RouteSpec{
@@ -146,5 +162,6 @@ func registerPlatformAccountRoutes(api fiber.Router, h *admin.AccountHandler, do
Tags: []string{"平台账号"},
Input: new(model.RemoveRoleParams),
Output: nil,
Auth: true,
})
}

View File

@@ -4,19 +4,83 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/openapi"
)
// RegisterAdminRoutes 注册管理后台相关路由
func RegisterAdminRoutes(router fiber.Router, handlers *bootstrap.Handlers, doc *openapi.Generator, basePath string) {
func RegisterAdminRoutes(router fiber.Router, handlers *bootstrap.Handlers, middlewares *bootstrap.Middlewares, doc *openapi.Generator, basePath string) {
if handlers.AdminAuth != nil {
registerAdminAuthRoutes(router, handlers.AdminAuth, middlewares.AdminAuth, doc, basePath)
}
authGroup := router.Group("", middlewares.AdminAuth)
if handlers.Account != nil {
registerAccountRoutes(router, handlers.Account, doc, basePath)
registerAccountRoutes(authGroup, handlers.Account, doc, basePath)
}
if handlers.Role != nil {
registerRoleRoutes(router, handlers.Role, doc, basePath)
registerRoleRoutes(authGroup, handlers.Role, doc, basePath)
}
if handlers.Permission != nil {
registerPermissionRoutes(router, handlers.Permission, doc, basePath)
registerPermissionRoutes(authGroup, handlers.Permission, doc, basePath)
}
if handlers.Shop != nil {
registerShopRoutes(authGroup, handlers.Shop, doc, basePath)
}
if handlers.ShopAccount != nil {
registerShopAccountRoutes(authGroup, handlers.ShopAccount, doc, basePath)
}
// TODO: Task routes?
}
func registerAdminAuthRoutes(router fiber.Router, handler interface{}, authMiddleware fiber.Handler, doc *openapi.Generator, basePath string) {
h := handler.(interface {
Login(c *fiber.Ctx) error
Logout(c *fiber.Ctx) error
RefreshToken(c *fiber.Ctx) error
GetMe(c *fiber.Ctx) error
ChangePassword(c *fiber.Ctx) error
})
Register(router, doc, basePath, "POST", "/login", h.Login, RouteSpec{
Summary: "后台登录",
Tags: []string{"认证"},
Input: new(model.LoginRequest),
Output: new(model.LoginResponse),
Auth: false,
})
Register(router, doc, basePath, "POST", "/refresh-token", h.RefreshToken, RouteSpec{
Summary: "刷新 Token",
Tags: []string{"认证"},
Input: new(model.RefreshTokenRequest),
Output: new(model.RefreshTokenResponse),
Auth: false,
})
authGroup := router.Group("", authMiddleware)
Register(authGroup, doc, basePath, "POST", "/logout", h.Logout, RouteSpec{
Summary: "登出",
Tags: []string{"认证"},
Input: nil,
Output: nil,
Auth: true,
})
Register(authGroup, doc, basePath, "GET", "/me", h.GetMe, RouteSpec{
Summary: "获取当前用户信息",
Tags: []string{"认证"},
Input: nil,
Output: new(model.UserInfo),
Auth: true,
})
Register(authGroup, doc, basePath, "PUT", "/password", h.ChangePassword, RouteSpec{
Summary: "修改密码",
Tags: []string{"认证"},
Input: new(model.ChangePasswordRequest),
Output: nil,
Auth: true,
})
}

68
internal/routes/h5.go Normal file
View File

@@ -0,0 +1,68 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/openapi"
)
// RegisterH5Routes 注册H5相关路由
func RegisterH5Routes(router fiber.Router, handlers *bootstrap.Handlers, middlewares *bootstrap.Middlewares, doc *openapi.Generator, basePath string) {
if handlers.H5Auth != nil {
registerH5AuthRoutes(router, handlers.H5Auth, middlewares.H5Auth, doc, basePath)
}
}
func registerH5AuthRoutes(router fiber.Router, handler interface{}, authMiddleware fiber.Handler, doc *openapi.Generator, basePath string) {
h := handler.(interface {
Login(c *fiber.Ctx) error
Logout(c *fiber.Ctx) error
RefreshToken(c *fiber.Ctx) error
GetMe(c *fiber.Ctx) error
ChangePassword(c *fiber.Ctx) error
})
Register(router, doc, basePath, "POST", "/login", h.Login, RouteSpec{
Summary: "H5 登录",
Tags: []string{"H5 认证"},
Input: new(model.LoginRequest),
Output: new(model.LoginResponse),
Auth: false,
})
Register(router, doc, basePath, "POST", "/refresh-token", h.RefreshToken, RouteSpec{
Summary: "刷新 Token",
Tags: []string{"H5 认证"},
Input: new(model.RefreshTokenRequest),
Output: new(model.RefreshTokenResponse),
Auth: false,
})
authGroup := router.Group("", authMiddleware)
Register(authGroup, doc, basePath, "POST", "/logout", h.Logout, RouteSpec{
Summary: "登出",
Tags: []string{"H5 认证"},
Input: nil,
Output: nil,
Auth: true,
})
Register(authGroup, doc, basePath, "GET", "/me", h.GetMe, RouteSpec{
Summary: "获取当前用户信息",
Tags: []string{"H5 认证"},
Input: nil,
Output: new(model.UserInfo),
Auth: true,
})
Register(authGroup, doc, basePath, "PUT", "/password", h.ChangePassword, RouteSpec{
Summary: "修改密码",
Tags: []string{"H5 认证"},
Input: new(model.ChangePasswordRequest),
Output: nil,
Auth: true,
})
}

View File

@@ -19,6 +19,7 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: new(model.CreatePermissionRequest),
Output: new(model.PermissionResponse),
Auth: true,
})
Register(permissions, doc, groupPath, "GET", "", h.List, RouteSpec{
@@ -26,6 +27,7 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: new(model.PermissionListRequest),
Output: new(model.PermissionPageResult),
Auth: true,
})
Register(permissions, doc, groupPath, "GET", "/tree", h.GetTree, RouteSpec{
@@ -33,6 +35,7 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: nil, // 无参数或 Query 参数
Output: new([]*model.PermissionTreeNode),
Auth: true,
})
Register(permissions, doc, groupPath, "GET", "/:id", h.Get, RouteSpec{
@@ -40,6 +43,7 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: new(model.IDReq),
Output: new(model.PermissionResponse),
Auth: true,
})
Register(permissions, doc, groupPath, "PUT", "/:id", h.Update, RouteSpec{
@@ -47,6 +51,7 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: new(model.UpdatePermissionParams),
Output: new(model.PermissionResponse),
Auth: true,
})
Register(permissions, doc, groupPath, "DELETE", "/:id", h.Delete, RouteSpec{
@@ -54,5 +59,6 @@ func registerPermissionRoutes(api fiber.Router, h *admin.PermissionHandler, doc
Tags: []string{"权限"},
Input: new(model.IDReq),
Output: nil,
Auth: true,
})
}

View File

@@ -28,16 +28,11 @@ var pathParamRegex = regexp.MustCompile(`/:([a-zA-Z0-9_]+)`)
// handler: Fiber Handler
// spec: 文档元数据
func Register(router fiber.Router, doc *openapi.Generator, basePath, method, path string, handler fiber.Handler, spec RouteSpec) {
// 1. 注册实际的 Fiber 路由
router.Add(method, path, handler)
// 2. 注册文档 (如果 doc 不为空 - 也就是在生成文档模式下)
if doc != nil {
// 简单的路径拼接
fullPath := basePath + path
// 将 Fiber 路由参数格式 /:id 转换为 OpenAPI 格式 /{id}
openapiPath := pathParamRegex.ReplaceAllString(fullPath, "/{$1}")
doc.AddOperation(method, openapiPath, spec.Summary, spec.Input, spec.Output, spec.Tags...)
doc.AddOperation(method, openapiPath, spec.Summary, spec.Input, spec.Output, spec.Auth, spec.Tags...)
}
}

View File

@@ -19,6 +19,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.CreateRoleRequest),
Output: new(model.RoleResponse),
Auth: true,
})
Register(roles, doc, groupPath, "GET", "", h.List, RouteSpec{
@@ -26,6 +27,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.RoleListRequest),
Output: new(model.RolePageResult),
Auth: true,
})
Register(roles, doc, groupPath, "GET", "/:id", h.Get, RouteSpec{
@@ -33,6 +35,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.IDReq),
Output: new(model.RoleResponse),
Auth: true,
})
Register(roles, doc, groupPath, "PUT", "/:id", h.Update, RouteSpec{
@@ -40,6 +43,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.UpdateRoleParams),
Output: new(model.RoleResponse),
Auth: true,
})
Register(roles, doc, groupPath, "PUT", "/:id/status", h.UpdateStatus, RouteSpec{
@@ -47,6 +51,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.UpdateRoleStatusParams),
Output: nil,
Auth: true,
})
Register(roles, doc, groupPath, "DELETE", "/:id", h.Delete, RouteSpec{
@@ -54,6 +59,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.IDReq),
Output: nil,
Auth: true,
})
// 角色-权限关联
@@ -62,6 +68,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.AssignPermissionsParams),
Output: nil,
Auth: true,
})
Register(roles, doc, groupPath, "GET", "/:id/permissions", h.GetPermissions, RouteSpec{
@@ -69,6 +76,7 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.IDReq),
Output: new([]model.Permission),
Auth: true,
})
Register(roles, doc, groupPath, "DELETE", "/:role_id/permissions/:perm_id", h.RemovePermission, RouteSpec{
@@ -76,5 +84,6 @@ func registerRoleRoutes(api fiber.Router, h *admin.RoleHandler, doc *openapi.Gen
Tags: []string{"角色"},
Input: new(model.RemovePermissionParams),
Output: nil,
Auth: true,
})
}

View File

@@ -14,11 +14,15 @@ func RegisterRoutes(app *fiber.App, handlers *bootstrap.Handlers, middlewares *b
// 2. Admin 域 (挂载在 /api/admin)
adminGroup := app.Group("/api/admin")
RegisterAdminRoutes(adminGroup, handlers, nil, "/api/admin")
RegisterAdminRoutes(adminGroup, handlers, middlewares, nil, "/api/admin")
// 任务相关路由 (归属于 Admin 域)
registerTaskRoutes(adminGroup)
// 3. 个人客户路由 (挂载在 /api/c/v1)
// 3. H5 域 (挂载在 /api/h5)
h5Group := app.Group("/api/h5")
RegisterH5Routes(h5Group, handlers, middlewares, nil, "/api/h5")
// 4. 个人客户路由 (挂载在 /api/c/v1)
RegisterPersonalCustomerRoutes(app, handlers, middlewares.PersonalAuth)
}

23
internal/routes/shop.go Normal file
View File

@@ -0,0 +1,23 @@
package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/break/junhong_cmp_fiber/internal/handler/admin"
"github.com/break/junhong_cmp_fiber/pkg/openapi"
)
func registerShopRoutes(router fiber.Router, handler *admin.ShopHandler, doc *openapi.Generator, basePath string) {
router.Get("/shops", handler.List)
router.Post("/shops", handler.Create)
router.Put("/shops/:id", handler.Update)
router.Delete("/shops/:id", handler.Delete)
}
func registerShopAccountRoutes(router fiber.Router, handler *admin.ShopAccountHandler, doc *openapi.Generator, basePath string) {
router.Get("/shop-accounts", handler.List)
router.Post("/shop-accounts", handler.Create)
router.Put("/shop-accounts/:id", handler.Update)
router.Put("/shop-accounts/:id/password", handler.UpdatePassword)
router.Put("/shop-accounts/:id/status", handler.UpdateStatus)
}

View File

@@ -0,0 +1,260 @@
package auth
import (
"context"
"fmt"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/auth"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type Service struct {
accountStore *postgres.AccountStore
accountRoleStore *postgres.AccountRoleStore
rolePermStore *postgres.RolePermissionStore
permissionStore *postgres.PermissionStore
tokenManager *auth.TokenManager
logger *zap.Logger
}
func New(
accountStore *postgres.AccountStore,
accountRoleStore *postgres.AccountRoleStore,
rolePermStore *postgres.RolePermissionStore,
permissionStore *postgres.PermissionStore,
tokenManager *auth.TokenManager,
logger *zap.Logger,
) *Service {
return &Service{
accountStore: accountStore,
accountRoleStore: accountRoleStore,
rolePermStore: rolePermStore,
permissionStore: permissionStore,
tokenManager: tokenManager,
logger: logger,
}
}
func (s *Service) Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) {
ctx = pkgGorm.SkipDataPermission(ctx)
account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username)
if err != nil {
if err == gorm.ErrRecordNotFound {
s.logger.Warn("登录失败:用户名不存在", zap.String("username", req.Username), zap.String("ip", clientIP))
return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误")
}
return nil, errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err))
}
if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(req.Password)); err != nil {
s.logger.Warn("登录失败:密码错误", zap.String("username", req.Username), zap.String("ip", clientIP))
return nil, errors.New(errors.CodeInvalidCredentials, "用户名或密码错误")
}
if account.Status != 1 {
s.logger.Warn("登录失败:账号已禁用", zap.String("username", req.Username), zap.Uint("user_id", account.ID))
return nil, errors.New(errors.CodeAccountDisabled, "账号已禁用")
}
device := req.Device
if device == "" {
device = "web"
}
var shopID, enterpriseID uint
if account.ShopID != nil {
shopID = *account.ShopID
}
if account.EnterpriseID != nil {
enterpriseID = *account.EnterpriseID
}
tokenInfo := &auth.TokenInfo{
UserID: account.ID,
UserType: account.UserType,
ShopID: shopID,
EnterpriseID: enterpriseID,
Username: account.Username,
Device: device,
IP: clientIP,
}
accessToken, refreshToken, err := s.tokenManager.GenerateTokenPair(ctx, tokenInfo)
if err != nil {
return nil, err
}
permissions, err := s.getUserPermissions(ctx, account.ID)
if err != nil {
s.logger.Error("查询用户权限失败", zap.Uint("user_id", account.ID), zap.Error(err))
permissions = []string{}
}
userInfo := s.buildUserInfo(account)
s.logger.Info("用户登录成功",
zap.Uint("user_id", account.ID),
zap.String("username", account.Username),
zap.String("device", device),
zap.String("ip", clientIP),
)
return &model.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int64(constants.DefaultAccessTokenTTL.Seconds()),
User: userInfo,
Permissions: permissions,
}, nil
}
func (s *Service) Logout(ctx context.Context, accessToken, refreshToken string) error {
if err := s.tokenManager.RevokeToken(ctx, accessToken); err != nil {
return err
}
if refreshToken != "" {
if err := s.tokenManager.RevokeToken(ctx, refreshToken); err != nil {
s.logger.Warn("撤销 refresh token 失败", zap.Error(err))
}
}
return nil
}
func (s *Service) RefreshToken(ctx context.Context, refreshToken string) (string, error) {
return s.tokenManager.RefreshAccessToken(ctx, refreshToken)
}
func (s *Service) GetCurrentUser(ctx context.Context, userID uint) (*model.UserInfo, []string, error) {
account, err := s.accountStore.GetByID(ctx, userID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil, errors.New(errors.CodeAccountNotFound, "账号不存在")
}
return nil, nil, errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err))
}
permissions, err := s.getUserPermissions(ctx, userID)
if err != nil {
s.logger.Error("查询用户权限失败", zap.Uint("user_id", userID), zap.Error(err))
permissions = []string{}
}
userInfo := s.buildUserInfo(account)
return &userInfo, permissions, nil
}
func (s *Service) ChangePassword(ctx context.Context, userID uint, oldPassword, newPassword string) error {
account, err := s.accountStore.GetByID(ctx, userID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeAccountNotFound, "账号不存在")
}
return errors.New(errors.CodeDatabaseError, fmt.Sprintf("查询账号失败: %v", err))
}
if err := bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(oldPassword)); err != nil {
return errors.New(errors.CodeInvalidOldPassword, "旧密码错误")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
if err := s.accountStore.UpdatePassword(ctx, userID, string(hashedPassword), userID); err != nil {
return errors.New(errors.CodeDatabaseError, fmt.Sprintf("更新密码失败: %v", err))
}
if err := s.tokenManager.RevokeAllUserTokens(ctx, userID); err != nil {
s.logger.Warn("撤销用户所有 token 失败", zap.Uint("user_id", userID), zap.Error(err))
}
s.logger.Info("用户修改密码成功", zap.Uint("user_id", userID))
return nil
}
func (s *Service) getUserPermissions(ctx context.Context, userID uint) ([]string, error) {
accountRoles, err := s.accountRoleStore.GetByAccountID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get account roles: %w", err)
}
if len(accountRoles) == 0 {
return []string{}, nil
}
roleIDs := make([]uint, 0, len(accountRoles))
for _, ar := range accountRoles {
roleIDs = append(roleIDs, ar.RoleID)
}
permIDs, err := s.rolePermStore.GetPermIDsByRoleIDs(ctx, roleIDs)
if err != nil {
return nil, fmt.Errorf("failed to get permission IDs: %w", err)
}
if len(permIDs) == 0 {
return []string{}, nil
}
permissions, err := s.permissionStore.GetByIDs(ctx, permIDs)
if err != nil {
return nil, fmt.Errorf("failed to get permissions: %w", err)
}
permCodes := make([]string, 0, len(permissions))
for _, perm := range permissions {
permCodes = append(permCodes, perm.PermCode)
}
return permCodes, nil
}
func (s *Service) buildUserInfo(account *model.Account) model.UserInfo {
userTypeName := s.getUserTypeName(account.UserType)
var shopID, enterpriseID uint
if account.ShopID != nil {
shopID = *account.ShopID
}
if account.EnterpriseID != nil {
enterpriseID = *account.EnterpriseID
}
return model.UserInfo{
ID: account.ID,
Username: account.Username,
Phone: account.Phone,
UserType: account.UserType,
UserTypeName: userTypeName,
ShopID: shopID,
EnterpriseID: enterpriseID,
}
}
func (s *Service) getUserTypeName(userType int) string {
switch userType {
case constants.UserTypeSuperAdmin:
return "超级管理员"
case constants.UserTypePlatform:
return "平台用户"
case constants.UserTypeAgent:
return "代理账号"
case constants.UserTypeEnterprise:
return "企业账号"
default:
return "未知"
}
}

View File

@@ -1,9 +1,8 @@
// Package shop 提供店铺管理的业务逻辑服务
// 包含店铺创建、查询、更新、删除等功能
package shop
import (
"context"
"fmt"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store"
@@ -11,55 +10,55 @@ import (
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// Service 店铺业务服务
type Service struct {
shopStore *postgres.ShopStore
shopStore *postgres.ShopStore
accountStore *postgres.AccountStore
}
// New 创建店铺服务
func New(shopStore *postgres.ShopStore) *Service {
func New(shopStore *postgres.ShopStore, accountStore *postgres.AccountStore) *Service {
return &Service{
shopStore: shopStore,
shopStore: shopStore,
accountStore: accountStore,
}
}
// Create 创建店铺
func (s *Service) Create(ctx context.Context, req *model.CreateShopRequest) (*model.Shop, error) {
// 获取当前用户 ID
func (s *Service) Create(ctx context.Context, req *model.CreateShopRequest) (*model.ShopResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
// 检查店铺编号唯一性
if req.ShopCode != "" {
existing, err := s.shopStore.GetByCode(ctx, req.ShopCode)
if err == nil && existing != nil {
return nil, errors.New(errors.CodeShopCodeExists, "店铺编号已存在")
}
existing, err := s.shopStore.GetByCode(ctx, req.ShopCode)
if err == nil && existing != nil {
return nil, errors.New(errors.CodeShopCodeExists, "店铺编号已存在")
}
// 计算层级
level := 1
if req.ParentID != nil {
// 验证上级店铺存在
parent, err := s.shopStore.GetByID(ctx, *req.ParentID)
if err != nil {
return nil, errors.New(errors.CodeInvalidParentID, "上级店铺不存在或无效")
}
// 计算新店铺的层级
level = parent.Level + 1
// 校验层级不超过最大值
if level > constants.MaxShopLevel {
if level > constants.ShopMaxLevel {
return nil, errors.New(errors.CodeShopLevelExceeded, "店铺层级不能超过 7 级")
}
}
// 创建店铺
existingAccount, err := s.accountStore.GetByUsername(ctx, req.InitUsername)
if err == nil && existingAccount != nil {
return nil, errors.New(errors.CodeUsernameExists, "初始账号用户名已存在")
}
existingAccount, err = s.accountStore.GetByPhone(ctx, req.InitPhone)
if err == nil && existingAccount != nil {
return nil, errors.New(errors.CodePhoneExists, "初始账号手机号已存在")
}
shop := &model.Shop{
ShopName: req.ShopName,
ShopCode: req.ShopCode,
@@ -71,71 +70,94 @@ func (s *Service) Create(ctx context.Context, req *model.CreateShopRequest) (*mo
City: req.City,
District: req.District,
Address: req.Address,
Status: constants.StatusEnabled,
Status: constants.ShopStatusEnabled,
}
shop.Creator = currentUserID
shop.Updater = currentUserID
if err := s.shopStore.Create(ctx, shop); err != nil {
return nil, err
return nil, fmt.Errorf("创建店铺失败: %w", err)
}
return shop, nil
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.InitPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("密码哈希失败: %w", err)
}
account := &model.Account{
Username: req.InitUsername,
Phone: req.InitPhone,
Password: string(hashedPassword),
UserType: constants.UserTypeAgent,
ShopID: &shop.ID,
Status: constants.StatusEnabled,
}
account.Creator = currentUserID
account.Updater = currentUserID
if err := s.accountStore.Create(ctx, account); err != nil {
return nil, fmt.Errorf("创建初始账号失败: %w", err)
}
return &model.ShopResponse{
ID: shop.ID,
ShopName: shop.ShopName,
ShopCode: shop.ShopCode,
ParentID: shop.ParentID,
Level: shop.Level,
ContactName: shop.ContactName,
ContactPhone: shop.ContactPhone,
Province: shop.Province,
City: shop.City,
District: shop.District,
Address: shop.Address,
Status: shop.Status,
CreatedAt: shop.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: shop.UpdatedAt.Format("2006-01-02 15:04:05"),
}, nil
}
// Update 更新店铺信息
func (s *Service) Update(ctx context.Context, id uint, req *model.UpdateShopRequest) (*model.Shop, error) {
// 获取当前用户 ID
func (s *Service) Update(ctx context.Context, id uint, req *model.UpdateShopRequest) (*model.ShopResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
// 查询店铺
shop, err := s.shopStore.GetByID(ctx, id)
if err != nil {
return nil, errors.New(errors.CodeShopNotFound, "店铺不存在")
}
// 检查店铺编号唯一性(如果修改了编号)
if req.ShopCode != nil && *req.ShopCode != shop.ShopCode {
existing, err := s.shopStore.GetByCode(ctx, *req.ShopCode)
if err == nil && existing != nil && existing.ID != id {
return nil, errors.New(errors.CodeShopCodeExists, "店铺编号已存在")
}
shop.ShopCode = *req.ShopCode
}
// 更新字段
if req.ShopName != nil {
shop.ShopName = *req.ShopName
}
if req.ContactName != nil {
shop.ContactName = *req.ContactName
}
if req.ContactPhone != nil {
shop.ContactPhone = *req.ContactPhone
}
if req.Province != nil {
shop.Province = *req.Province
}
if req.City != nil {
shop.City = *req.City
}
if req.District != nil {
shop.District = *req.District
}
if req.Address != nil {
shop.Address = *req.Address
}
shop.ShopName = req.ShopName
shop.ContactName = req.ContactName
shop.ContactPhone = req.ContactPhone
shop.Province = req.Province
shop.City = req.City
shop.District = req.District
shop.Address = req.Address
shop.Status = req.Status
shop.Updater = currentUserID
if err := s.shopStore.Update(ctx, shop); err != nil {
return nil, err
}
return shop, nil
return &model.ShopResponse{
ID: shop.ID,
ShopName: shop.ShopName,
ShopCode: shop.ShopCode,
ParentID: shop.ParentID,
Level: shop.Level,
ContactName: shop.ContactName,
ContactPhone: shop.ContactPhone,
Province: shop.Province,
City: shop.City,
District: shop.District,
Address: shop.Address,
Status: shop.Status,
CreatedAt: shop.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: shop.UpdatedAt.Format("2006-01-02 15:04:05"),
}, nil
}
// Disable 禁用店铺
@@ -189,11 +211,104 @@ func (s *Service) GetByID(ctx context.Context, id uint) (*model.Shop, error) {
return shop, nil
}
// List 查询店铺列表
func (s *Service) ListShopResponses(ctx context.Context, req *model.ShopListRequest) ([]*model.ShopResponse, int64, error) {
opts := &store.QueryOptions{
Page: req.Page,
PageSize: req.PageSize,
OrderBy: "created_at DESC",
}
if opts.Page == 0 {
opts.Page = 1
}
if opts.PageSize == 0 {
opts.PageSize = constants.DefaultPageSize
}
filters := make(map[string]interface{})
if req.ShopName != "" {
filters["shop_name"] = req.ShopName
}
if req.ShopCode != "" {
filters["shop_code"] = req.ShopCode
}
if req.ParentID != nil {
filters["parent_id"] = *req.ParentID
}
if req.Level != nil {
filters["level"] = *req.Level
}
if req.Status != nil {
filters["status"] = *req.Status
}
shops, total, err := s.shopStore.List(ctx, opts, filters)
if err != nil {
return nil, 0, fmt.Errorf("查询店铺列表失败: %w", err)
}
responses := make([]*model.ShopResponse, 0, len(shops))
for _, shop := range shops {
responses = append(responses, &model.ShopResponse{
ID: shop.ID,
ShopName: shop.ShopName,
ShopCode: shop.ShopCode,
ParentID: shop.ParentID,
Level: shop.Level,
ContactName: shop.ContactName,
ContactPhone: shop.ContactPhone,
Province: shop.Province,
City: shop.City,
District: shop.District,
Address: shop.Address,
Status: shop.Status,
CreatedAt: shop.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: shop.UpdatedAt.Format("2006-01-02 15:04:05"),
})
}
return responses, total, nil
}
func (s *Service) List(ctx context.Context, opts *store.QueryOptions, filters map[string]interface{}) ([]*model.Shop, int64, error) {
return s.shopStore.List(ctx, opts, filters)
}
func (s *Service) Delete(ctx context.Context, id uint) error {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
shop, err := s.shopStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeShopNotFound, "店铺不存在")
}
return fmt.Errorf("获取店铺失败: %w", err)
}
accounts, err := s.accountStore.GetByShopID(ctx, shop.ID)
if err != nil {
return fmt.Errorf("查询店铺账号失败: %w", err)
}
if len(accounts) > 0 {
accountIDs := make([]uint, 0, len(accounts))
for _, account := range accounts {
accountIDs = append(accountIDs, account.ID)
}
if err := s.accountStore.BulkUpdateStatus(ctx, accountIDs, constants.StatusDisabled, currentUserID); err != nil {
return fmt.Errorf("禁用店铺账号失败: %w", err)
}
}
if err := s.shopStore.Delete(ctx, id); err != nil {
return fmt.Errorf("删除店铺失败: %w", err)
}
return nil
}
// GetSubordinateShopIDs 获取下级店铺 ID 列表(包含自己)
func (s *Service) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) {
return s.shopStore.GetSubordinateShopIDs(ctx, shopID)

View File

@@ -0,0 +1,265 @@
package shop_account
import (
"context"
"fmt"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type Service struct {
accountStore *postgres.AccountStore
shopStore *postgres.ShopStore
}
func New(accountStore *postgres.AccountStore, shopStore *postgres.ShopStore) *Service {
return &Service{
accountStore: accountStore,
shopStore: shopStore,
}
}
func (s *Service) List(ctx context.Context, req *model.ShopAccountListRequest) ([]*model.ShopAccountResponse, int64, error) {
opts := &store.QueryOptions{
Page: req.Page,
PageSize: req.PageSize,
OrderBy: "created_at DESC",
}
if opts.Page == 0 {
opts.Page = 1
}
if opts.PageSize == 0 {
opts.PageSize = constants.DefaultPageSize
}
filters := make(map[string]interface{})
filters["user_type"] = constants.UserTypeAgent
if req.Username != "" {
filters["username"] = req.Username
}
if req.Phone != "" {
filters["phone"] = req.Phone
}
if req.Status != nil {
filters["status"] = *req.Status
}
var accounts []*model.Account
var total int64
var err error
if req.ShopID != nil {
accounts, total, err = s.accountStore.ListByShopID(ctx, *req.ShopID, opts, filters)
} else {
filters["user_type"] = constants.UserTypeAgent
accounts, total, err = s.accountStore.List(ctx, opts, filters)
}
if err != nil {
return nil, 0, fmt.Errorf("查询代理商账号列表失败: %w", err)
}
shopMap := make(map[uint]string)
for _, account := range accounts {
if account.ShopID != nil {
if _, exists := shopMap[*account.ShopID]; !exists {
shop, err := s.shopStore.GetByID(ctx, *account.ShopID)
if err == nil {
shopMap[*account.ShopID] = shop.ShopName
}
}
}
}
responses := make([]*model.ShopAccountResponse, 0, len(accounts))
for _, account := range accounts {
resp := &model.ShopAccountResponse{
ID: account.ID,
Username: account.Username,
Phone: account.Phone,
UserType: account.UserType,
Status: account.Status,
CreatedAt: account.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: account.UpdatedAt.Format("2006-01-02 15:04:05"),
}
if account.ShopID != nil {
resp.ShopID = *account.ShopID
if shopName, ok := shopMap[*account.ShopID]; ok {
resp.ShopName = shopName
}
}
responses = append(responses, resp)
}
return responses, total, nil
}
func (s *Service) Create(ctx context.Context, req *model.CreateShopAccountRequest) (*model.ShopAccountResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
shop, err := s.shopStore.GetByID(ctx, req.ShopID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeShopNotFound, "店铺不存在")
}
return nil, fmt.Errorf("获取店铺失败: %w", err)
}
existing, err := s.accountStore.GetByUsername(ctx, req.Username)
if err == nil && existing != nil {
return nil, errors.New(errors.CodeUsernameExists, "用户名已存在")
}
existing, err = s.accountStore.GetByPhone(ctx, req.Phone)
if err == nil && existing != nil {
return nil, errors.New(errors.CodePhoneExists, "手机号已存在")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("密码哈希失败: %w", err)
}
account := &model.Account{
Username: req.Username,
Phone: req.Phone,
Password: string(hashedPassword),
UserType: constants.UserTypeAgent,
ShopID: &req.ShopID,
Status: constants.StatusEnabled,
}
account.Creator = currentUserID
account.Updater = currentUserID
if err := s.accountStore.Create(ctx, account); err != nil {
return nil, fmt.Errorf("创建代理商账号失败: %w", err)
}
return &model.ShopAccountResponse{
ID: account.ID,
ShopID: *account.ShopID,
ShopName: shop.ShopName,
Username: account.Username,
Phone: account.Phone,
UserType: account.UserType,
Status: account.Status,
CreatedAt: account.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: account.UpdatedAt.Format("2006-01-02 15:04:05"),
}, nil
}
func (s *Service) Update(ctx context.Context, id uint, req *model.UpdateShopAccountRequest) (*model.ShopAccountResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
account, err := s.accountStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeAccountNotFound, "账号不存在")
}
return nil, fmt.Errorf("获取账号失败: %w", err)
}
if account.UserType != constants.UserTypeAgent {
return nil, errors.New(errors.CodeInvalidParam, "只能更新代理商账号")
}
existingAccount, err := s.accountStore.GetByUsername(ctx, req.Username)
if err == nil && existingAccount != nil && existingAccount.ID != id {
return nil, errors.New(errors.CodeUsernameExists, "用户名已存在")
}
account.Username = req.Username
account.Updater = currentUserID
if err := s.accountStore.Update(ctx, account); err != nil {
return nil, fmt.Errorf("更新代理商账号失败: %w", err)
}
var shopName string
if account.ShopID != nil {
shop, err := s.shopStore.GetByID(ctx, *account.ShopID)
if err == nil {
shopName = shop.ShopName
}
}
return &model.ShopAccountResponse{
ID: account.ID,
ShopID: *account.ShopID,
ShopName: shopName,
Username: account.Username,
Phone: account.Phone,
UserType: account.UserType,
Status: account.Status,
CreatedAt: account.CreatedAt.Format("2006-01-02 15:04:05"),
UpdatedAt: account.UpdatedAt.Format("2006-01-02 15:04:05"),
}, nil
}
func (s *Service) UpdatePassword(ctx context.Context, id uint, req *model.UpdateShopAccountPasswordRequest) error {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
account, err := s.accountStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeAccountNotFound, "账号不存在")
}
return fmt.Errorf("获取账号失败: %w", err)
}
if account.UserType != constants.UserTypeAgent {
return errors.New(errors.CodeInvalidParam, "只能更新代理商账号密码")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("密码哈希失败: %w", err)
}
if err := s.accountStore.UpdatePassword(ctx, id, string(hashedPassword), currentUserID); err != nil {
return fmt.Errorf("更新密码失败: %w", err)
}
return nil
}
func (s *Service) UpdateStatus(ctx context.Context, id uint, req *model.UpdateShopAccountStatusRequest) error {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
account, err := s.accountStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeAccountNotFound, "账号不存在")
}
return fmt.Errorf("获取账号失败: %w", err)
}
if account.UserType != constants.UserTypeAgent {
return errors.New(errors.CodeInvalidParam, "只能更新代理商账号状态")
}
if err := s.accountStore.UpdateStatus(ctx, id, req.Status, currentUserID); err != nil {
return fmt.Errorf("更新账号状态失败: %w", err)
}
return nil
}

View File

@@ -56,6 +56,15 @@ func (s *AccountStore) GetByPhone(ctx context.Context, phone string) (*model.Acc
return &account, nil
}
// GetByUsernameOrPhone 根据用户名或手机号获取账号
func (s *AccountStore) GetByUsernameOrPhone(ctx context.Context, identifier string) (*model.Account, error) {
var account model.Account
if err := s.db.WithContext(ctx).Where("username = ? OR phone = ?", identifier, identifier).First(&account).Error; err != nil {
return nil, err
}
return &account, nil
}
// GetByShopID 根据店铺 ID 查询账号列表
func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.Account, error) {
var accounts []*model.Account
@@ -197,3 +206,52 @@ func (s *AccountStore) UpdateStatus(ctx context.Context, id uint, status int, up
"updater": updater,
}).Error
}
// BulkUpdateStatus 批量更新账号状态
func (s *AccountStore) BulkUpdateStatus(ctx context.Context, ids []uint, status int, updater uint) error {
return s.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Updates(map[string]interface{}{
"status": status,
"updater": updater,
}).Error
}
// ListByShopID 按店铺ID分页查询账号列表
func (s *AccountStore) ListByShopID(ctx context.Context, shopID uint, opts *store.QueryOptions, filters map[string]interface{}) ([]*model.Account, int64, error) {
var accounts []*model.Account
var total int64
query := s.db.WithContext(ctx).Model(&model.Account{}).Where("shop_id = ?", shopID)
if username, ok := filters["username"].(string); ok && username != "" {
query = query.Where("username LIKE ?", "%"+username+"%")
}
if phone, ok := filters["phone"].(string); ok && phone != "" {
query = query.Where("phone LIKE ?", "%"+phone+"%")
}
if status, ok := filters["status"].(int); ok {
query = query.Where("status = ?", status)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if opts == nil {
opts = store.DefaultQueryOptions()
}
offset := (opts.Page - 1) * opts.PageSize
query = query.Offset(offset).Limit(opts.PageSize)
if opts.OrderBy != "" {
query = query.Order(opts.OrderBy)
}
if err := query.Find(&accounts).Error; err != nil {
return nil, 0, err
}
return accounts, total, nil
}