feat: 实现运营商模块重构,添加冗余字段优化查询性能
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m16s
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m16s
主要变更: - 新增 Carrier CRUD API(创建、列表、详情、更新、删除、状态更新) - IotCard/IotCardImportTask 添加 carrier_type/carrier_name 冗余字段 - 移除 Carrier 表的 channel_name/channel_code 字段 - 查询时直接使用冗余字段,避免 JOIN Carrier 表 - 添加数据库迁移脚本(000021-000023) - 添加单元测试和集成测试 - 同步更新 OpenAPI 文档和 specs
This commit is contained in:
@@ -33,5 +33,6 @@ func initHandlers(svc *services, deps *Dependencies) *Handlers {
|
||||
DeviceImport: admin.NewDeviceImportHandler(svc.DeviceImport),
|
||||
AssetAllocationRecord: admin.NewAssetAllocationRecordHandler(svc.AssetAllocationRecord),
|
||||
Storage: admin.NewStorageHandler(deps.StorageService),
|
||||
Carrier: admin.NewCarrierHandler(svc.Carrier),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
accountSvc "github.com/break/junhong_cmp_fiber/internal/service/account"
|
||||
assetAllocationRecordSvc "github.com/break/junhong_cmp_fiber/internal/service/asset_allocation_record"
|
||||
authSvc "github.com/break/junhong_cmp_fiber/internal/service/auth"
|
||||
carrierSvc "github.com/break/junhong_cmp_fiber/internal/service/carrier"
|
||||
commissionWithdrawalSvc "github.com/break/junhong_cmp_fiber/internal/service/commission_withdrawal"
|
||||
commissionWithdrawalSettingSvc "github.com/break/junhong_cmp_fiber/internal/service/commission_withdrawal_setting"
|
||||
customerAccountSvc "github.com/break/junhong_cmp_fiber/internal/service/customer_account"
|
||||
@@ -43,6 +44,7 @@ type services struct {
|
||||
Device *deviceSvc.Service
|
||||
DeviceImport *deviceImportSvc.Service
|
||||
AssetAllocationRecord *assetAllocationRecordSvc.Service
|
||||
Carrier *carrierSvc.Service
|
||||
}
|
||||
|
||||
func initServices(s *stores, deps *Dependencies) *services {
|
||||
@@ -67,5 +69,6 @@ func initServices(s *stores, deps *Dependencies) *services {
|
||||
Device: deviceSvc.New(deps.DB, s.Device, s.DeviceSimBinding, s.IotCard, s.Shop, s.AssetAllocationRecord),
|
||||
DeviceImport: deviceImportSvc.New(deps.DB, s.DeviceImportTask, deps.QueueClient),
|
||||
AssetAllocationRecord: assetAllocationRecordSvc.New(deps.DB, s.AssetAllocationRecord, s.Shop, s.Account),
|
||||
Carrier: carrierSvc.New(s.Carrier),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ type stores struct {
|
||||
DeviceSimBinding *postgres.DeviceSimBindingStore
|
||||
DeviceImportTask *postgres.DeviceImportTaskStore
|
||||
AssetAllocationRecord *postgres.AssetAllocationRecordStore
|
||||
Carrier *postgres.CarrierStore
|
||||
}
|
||||
|
||||
func initStores(deps *Dependencies) *stores {
|
||||
@@ -51,5 +52,6 @@ func initStores(deps *Dependencies) *stores {
|
||||
DeviceSimBinding: postgres.NewDeviceSimBindingStore(deps.DB, deps.Redis),
|
||||
DeviceImportTask: postgres.NewDeviceImportTaskStore(deps.DB, deps.Redis),
|
||||
AssetAllocationRecord: postgres.NewAssetAllocationRecordStore(deps.DB, deps.Redis),
|
||||
Carrier: postgres.NewCarrierStore(deps.DB),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ type Handlers struct {
|
||||
DeviceImport *admin.DeviceImportHandler
|
||||
AssetAllocationRecord *admin.AssetAllocationRecordHandler
|
||||
Storage *admin.StorageHandler
|
||||
Carrier *admin.CarrierHandler
|
||||
}
|
||||
|
||||
// Middlewares 封装所有中间件
|
||||
|
||||
112
internal/handler/admin/carrier.go
Normal file
112
internal/handler/admin/carrier.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model/dto"
|
||||
carrierService "github.com/break/junhong_cmp_fiber/internal/service/carrier"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/response"
|
||||
)
|
||||
|
||||
type CarrierHandler struct {
|
||||
service *carrierService.Service
|
||||
}
|
||||
|
||||
func NewCarrierHandler(service *carrierService.Service) *CarrierHandler {
|
||||
return &CarrierHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) List(c *fiber.Ctx) error {
|
||||
var req dto.CarrierListRequest
|
||||
if err := c.QueryParser(&req); err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
|
||||
}
|
||||
|
||||
carriers, total, err := h.service.List(c.UserContext(), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.SuccessWithPagination(c, carriers, total, req.Page, req.PageSize)
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) Create(c *fiber.Ctx) error {
|
||||
var req dto.CreateCarrierRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
|
||||
}
|
||||
|
||||
carrier, err := h.service.Create(c.UserContext(), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.Success(c, carrier)
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) Get(c *fiber.Ctx) error {
|
||||
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "无效的运营商 ID")
|
||||
}
|
||||
|
||||
carrier, err := h.service.Get(c.UserContext(), uint(id))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.Success(c, carrier)
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) Update(c *fiber.Ctx) error {
|
||||
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "无效的运营商 ID")
|
||||
}
|
||||
|
||||
var req dto.UpdateCarrierRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
|
||||
}
|
||||
|
||||
carrier, err := h.service.Update(c.UserContext(), uint(id), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.Success(c, carrier)
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) Delete(c *fiber.Ctx) error {
|
||||
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "无效的运营商 ID")
|
||||
}
|
||||
|
||||
if err := h.service.Delete(c.UserContext(), uint(id)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.Success(c, nil)
|
||||
}
|
||||
|
||||
func (h *CarrierHandler) UpdateStatus(c *fiber.Ctx) error {
|
||||
id, err := strconv.ParseUint(c.Params("id"), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "无效的运营商 ID")
|
||||
}
|
||||
|
||||
var req dto.UpdateCarrierStatusRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return errors.New(errors.CodeInvalidParam, "请求参数解析失败")
|
||||
}
|
||||
|
||||
if err := h.service.UpdateStatus(c.UserContext(), uint(id), req.Status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return response.Success(c, nil)
|
||||
}
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
type Carrier struct {
|
||||
gorm.Model
|
||||
BaseModel `gorm:"embedded"`
|
||||
CarrierCode string `gorm:"column:carrier_code;type:varchar(50);uniqueIndex:idx_carrier_code,where:deleted_at IS NULL;not null;comment:运营商编码(CMCC/CUCC/CTCC)" json:"carrier_code"`
|
||||
CarrierName string `gorm:"column:carrier_name;type:varchar(100);not null;comment:运营商名称(中国移动/中国联通/中国电信)" json:"carrier_name"`
|
||||
CarrierType string `gorm:"column:carrier_type;type:varchar(20);not null;default:'CMCC';uniqueIndex:idx_carrier_type_channel,priority:1,where:deleted_at IS NULL;comment:运营商类型" json:"carrier_type"`
|
||||
ChannelName *string `gorm:"column:channel_name;type:varchar(100);comment:渠道名称" json:"channel_name,omitempty"`
|
||||
ChannelCode *string `gorm:"column:channel_code;type:varchar(50);uniqueIndex:idx_carrier_type_channel,priority:2,where:deleted_at IS NULL;comment:渠道编码" json:"channel_code,omitempty"`
|
||||
Description string `gorm:"column:description;type:varchar(500);comment:运营商描述" json:"description"`
|
||||
Status int `gorm:"column:status;type:int;default:1;comment:状态 1-启用 2-禁用" json:"status"`
|
||||
CarrierCode string `gorm:"column:carrier_code;type:varchar(50);uniqueIndex:idx_carrier_code,where:deleted_at IS NULL;not null;comment:运营商编码" json:"carrier_code"`
|
||||
CarrierName string `gorm:"column:carrier_name;type:varchar(100);not null;comment:运营商名称" json:"carrier_name"`
|
||||
CarrierType string `gorm:"column:carrier_type;type:varchar(20);not null;default:'CMCC';comment:运营商类型(CMCC/CUCC/CTCC/CBN)" json:"carrier_type"`
|
||||
Description string `gorm:"column:description;type:varchar(500);comment:运营商描述" json:"description"`
|
||||
Status int `gorm:"column:status;type:int;default:1;comment:状态 1-启用 0-禁用" json:"status"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
|
||||
54
internal/model/dto/carrier_dto.go
Normal file
54
internal/model/dto/carrier_dto.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package dto
|
||||
|
||||
type CreateCarrierRequest struct {
|
||||
CarrierCode string `json:"carrier_code" validate:"required,min=1,max=50" required:"true" minLength:"1" maxLength:"50" description:"运营商编码"`
|
||||
CarrierName string `json:"carrier_name" validate:"required,min=1,max=100" required:"true" minLength:"1" maxLength:"100" description:"运营商名称"`
|
||||
CarrierType string `json:"carrier_type" validate:"required,oneof=CMCC CUCC CTCC CBN" required:"true" description:"运营商类型 (CMCC:中国移动, CUCC:中国联通, CTCC:中国电信, CBN:中国广电)"`
|
||||
Description string `json:"description" validate:"omitempty,max=500" maxLength:"500" description:"运营商描述"`
|
||||
}
|
||||
|
||||
type UpdateCarrierRequest struct {
|
||||
CarrierName *string `json:"carrier_name" validate:"omitempty,min=1,max=100" minLength:"1" maxLength:"100" description:"运营商名称"`
|
||||
Description *string `json:"description" validate:"omitempty,max=500" maxLength:"500" description:"运营商描述"`
|
||||
}
|
||||
|
||||
type CarrierListRequest struct {
|
||||
Page int `json:"page" query:"page" validate:"omitempty,min=1" minimum:"1" description:"页码"`
|
||||
PageSize int `json:"page_size" query:"page_size" validate:"omitempty,min=1,max=100" minimum:"1" maximum:"100" description:"每页数量"`
|
||||
CarrierType *string `json:"carrier_type" query:"carrier_type" validate:"omitempty,oneof=CMCC CUCC CTCC CBN" description:"运营商类型 (CMCC:中国移动, CUCC:中国联通, CTCC:中国电信, CBN:中国广电)"`
|
||||
CarrierName *string `json:"carrier_name" query:"carrier_name" validate:"omitempty,max=100" maxLength:"100" description:"运营商名称(模糊搜索)"`
|
||||
Status *int `json:"status" query:"status" validate:"omitempty,oneof=0 1" description:"状态 (1:启用, 0:禁用)"`
|
||||
}
|
||||
|
||||
type UpdateCarrierStatusRequest struct {
|
||||
Status int `json:"status" validate:"required,oneof=0 1" required:"true" description:"状态 (1:启用, 0:禁用)"`
|
||||
}
|
||||
|
||||
type CarrierResponse struct {
|
||||
ID uint `json:"id" description:"运营商ID"`
|
||||
CarrierCode string `json:"carrier_code" description:"运营商编码"`
|
||||
CarrierName string `json:"carrier_name" description:"运营商名称"`
|
||||
CarrierType string `json:"carrier_type" description:"运营商类型 (CMCC:中国移动, CUCC:中国联通, CTCC:中国电信, CBN:中国广电)"`
|
||||
Description string `json:"description" description:"运营商描述"`
|
||||
Status int `json:"status" description:"状态 (1:启用, 0:禁用)"`
|
||||
CreatedAt string `json:"created_at" description:"创建时间"`
|
||||
UpdatedAt string `json:"updated_at" description:"更新时间"`
|
||||
}
|
||||
|
||||
type UpdateCarrierParams struct {
|
||||
IDReq
|
||||
UpdateCarrierRequest
|
||||
}
|
||||
|
||||
type UpdateCarrierStatusParams struct {
|
||||
IDReq
|
||||
UpdateCarrierStatusRequest
|
||||
}
|
||||
|
||||
type CarrierPageResult struct {
|
||||
List []*CarrierResponse `json:"list" description:"运营商列表"`
|
||||
Total int64 `json:"total" description:"总数"`
|
||||
Page int `json:"page" description:"当前页"`
|
||||
PageSize int `json:"page_size" description:"每页数量"`
|
||||
TotalPages int `json:"total_pages" description:"总页数"`
|
||||
}
|
||||
@@ -24,6 +24,7 @@ type StandaloneIotCardResponse struct {
|
||||
CardType string `json:"card_type" description:"卡类型"`
|
||||
CardCategory string `json:"card_category" description:"卡业务类型 (normal:普通卡, industry:行业卡)"`
|
||||
CarrierID uint `json:"carrier_id" description:"运营商ID"`
|
||||
CarrierType string `json:"carrier_type,omitempty" description:"运营商类型 (CMCC:中国移动, CUCC:中国联通, CTCC:中国电信, CBN:中国广电)"`
|
||||
CarrierName string `json:"carrier_name,omitempty" description:"运营商名称"`
|
||||
IMSI string `json:"imsi,omitempty" description:"IMSI"`
|
||||
MSISDN string `json:"msisdn,omitempty" description:"卡接入号"`
|
||||
@@ -79,6 +80,7 @@ type ImportTaskResponse struct {
|
||||
Status int `json:"status" description:"任务状态 (1:待处理, 2:处理中, 3:已完成, 4:失败)"`
|
||||
StatusText string `json:"status_text" description:"任务状态文本"`
|
||||
CarrierID uint `json:"carrier_id" description:"运营商ID"`
|
||||
CarrierType string `json:"carrier_type,omitempty" description:"运营商类型 (CMCC:中国移动, CUCC:中国联通, CTCC:中国电信, CBN:中国广电)"`
|
||||
CarrierName string `json:"carrier_name,omitempty" description:"运营商名称"`
|
||||
BatchNo string `json:"batch_no,omitempty" description:"批次号"`
|
||||
FileName string `json:"file_name,omitempty" description:"文件名"`
|
||||
|
||||
@@ -16,6 +16,8 @@ type IotCard struct {
|
||||
CardType string `gorm:"column:card_type;type:varchar(50);not null;comment:卡类型" json:"card_type"`
|
||||
CardCategory string `gorm:"column:card_category;type:varchar(20);default:'normal';not null;comment:卡业务类型 normal-普通卡 industry-行业卡" json:"card_category"`
|
||||
CarrierID uint `gorm:"column:carrier_id;index;not null;comment:运营商ID" json:"carrier_id"`
|
||||
CarrierType string `gorm:"column:carrier_type;type:varchar(20);comment:运营商类型(CMCC/CUCC/CTCC/CBN),导入时快照" json:"carrier_type"`
|
||||
CarrierName string `gorm:"column:carrier_name;type:varchar(100);comment:运营商名称,导入时快照" json:"carrier_name"`
|
||||
IMSI string `gorm:"column:imsi;type:varchar(50);comment:IMSI" json:"imsi"`
|
||||
MSISDN string `gorm:"column:msisdn;type:varchar(20);comment:MSISDN(手机号码)" json:"msisdn"`
|
||||
BatchNo string `gorm:"column:batch_no;type:varchar(100);comment:批次号" json:"batch_no"`
|
||||
|
||||
@@ -15,6 +15,7 @@ type IotCardImportTask struct {
|
||||
Status int `gorm:"column:status;type:int;default:1;not null;comment:任务状态 1-待处理 2-处理中 3-已完成 4-失败" json:"status"`
|
||||
CarrierID uint `gorm:"column:carrier_id;index;not null;comment:运营商ID" json:"carrier_id"`
|
||||
CarrierType string `gorm:"column:carrier_type;type:varchar(20);not null;comment:运营商类型(CMCC/CUCC/CTCC/CBN)" json:"carrier_type"`
|
||||
CarrierName string `gorm:"column:carrier_name;type:varchar(100);comment:运营商名称,创建任务时快照" json:"carrier_name"`
|
||||
BatchNo string `gorm:"column:batch_no;type:varchar(100);comment:批次号" json:"batch_no"`
|
||||
FileName string `gorm:"column:file_name;type:varchar(255);comment:原始文件名" json:"file_name"`
|
||||
TotalCount int `gorm:"column:total_count;type:int;default:0;not null;comment:总数" json:"total_count"`
|
||||
|
||||
@@ -67,6 +67,9 @@ func RegisterAdminRoutes(router fiber.Router, handlers *bootstrap.Handlers, midd
|
||||
if handlers.Storage != nil {
|
||||
registerStorageRoutes(authGroup, handlers.Storage, doc, basePath)
|
||||
}
|
||||
if handlers.Carrier != nil {
|
||||
registerCarrierRoutes(authGroup, handlers.Carrier, doc, basePath)
|
||||
}
|
||||
}
|
||||
|
||||
func registerAdminAuthRoutes(router fiber.Router, handler interface{}, authMiddleware fiber.Handler, doc *openapi.Generator, basePath string) {
|
||||
|
||||
62
internal/routes/carrier.go
Normal file
62
internal/routes/carrier.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/handler/admin"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model/dto"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/openapi"
|
||||
)
|
||||
|
||||
func registerCarrierRoutes(router fiber.Router, handler *admin.CarrierHandler, doc *openapi.Generator, basePath string) {
|
||||
carriers := router.Group("/carriers")
|
||||
groupPath := basePath + "/carriers"
|
||||
|
||||
Register(carriers, doc, groupPath, "GET", "", handler.List, RouteSpec{
|
||||
Summary: "运营商列表",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.CarrierListRequest),
|
||||
Output: new(dto.CarrierPageResult),
|
||||
Auth: true,
|
||||
})
|
||||
|
||||
Register(carriers, doc, groupPath, "POST", "", handler.Create, RouteSpec{
|
||||
Summary: "创建运营商",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.CreateCarrierRequest),
|
||||
Output: new(dto.CarrierResponse),
|
||||
Auth: true,
|
||||
})
|
||||
|
||||
Register(carriers, doc, groupPath, "GET", "/:id", handler.Get, RouteSpec{
|
||||
Summary: "获取运营商详情",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.IDReq),
|
||||
Output: new(dto.CarrierResponse),
|
||||
Auth: true,
|
||||
})
|
||||
|
||||
Register(carriers, doc, groupPath, "PUT", "/:id", handler.Update, RouteSpec{
|
||||
Summary: "更新运营商",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.UpdateCarrierParams),
|
||||
Output: new(dto.CarrierResponse),
|
||||
Auth: true,
|
||||
})
|
||||
|
||||
Register(carriers, doc, groupPath, "DELETE", "/:id", handler.Delete, RouteSpec{
|
||||
Summary: "删除运营商",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.IDReq),
|
||||
Output: nil,
|
||||
Auth: true,
|
||||
})
|
||||
|
||||
Register(carriers, doc, groupPath, "PUT", "/:id/status", handler.UpdateStatus, RouteSpec{
|
||||
Summary: "更新运营商状态",
|
||||
Tags: []string{"运营商管理"},
|
||||
Input: new(dto.UpdateCarrierStatusParams),
|
||||
Output: nil,
|
||||
Auth: true,
|
||||
})
|
||||
}
|
||||
182
internal/service/carrier/service.go
Normal file
182
internal/service/carrier/service.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package carrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model/dto"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
carrierStore *postgres.CarrierStore
|
||||
}
|
||||
|
||||
func New(carrierStore *postgres.CarrierStore) *Service {
|
||||
return &Service{carrierStore: carrierStore}
|
||||
}
|
||||
|
||||
func (s *Service) Create(ctx context.Context, req *dto.CreateCarrierRequest) (*dto.CarrierResponse, error) {
|
||||
currentUserID := middleware.GetUserIDFromContext(ctx)
|
||||
if currentUserID == 0 {
|
||||
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
|
||||
}
|
||||
|
||||
existing, _ := s.carrierStore.GetByCode(ctx, req.CarrierCode)
|
||||
if existing != nil {
|
||||
return nil, errors.New(errors.CodeCarrierCodeExists, "运营商编码已存在")
|
||||
}
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: req.CarrierCode,
|
||||
CarrierName: req.CarrierName,
|
||||
CarrierType: req.CarrierType,
|
||||
Description: req.Description,
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
carrier.Creator = currentUserID
|
||||
|
||||
if err := s.carrierStore.Create(ctx, carrier); err != nil {
|
||||
return nil, fmt.Errorf("创建运营商失败: %w", err)
|
||||
}
|
||||
|
||||
return s.toResponse(carrier), nil
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, id uint) (*dto.CarrierResponse, error) {
|
||||
carrier, err := s.carrierStore.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, errors.New(errors.CodeCarrierNotFound, "运营商不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取运营商失败: %w", err)
|
||||
}
|
||||
return s.toResponse(carrier), nil
|
||||
}
|
||||
|
||||
func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateCarrierRequest) (*dto.CarrierResponse, error) {
|
||||
currentUserID := middleware.GetUserIDFromContext(ctx)
|
||||
if currentUserID == 0 {
|
||||
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
|
||||
}
|
||||
|
||||
carrier, err := s.carrierStore.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, errors.New(errors.CodeCarrierNotFound, "运营商不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取运营商失败: %w", err)
|
||||
}
|
||||
|
||||
if req.CarrierName != nil {
|
||||
carrier.CarrierName = *req.CarrierName
|
||||
}
|
||||
if req.Description != nil {
|
||||
carrier.Description = *req.Description
|
||||
}
|
||||
carrier.Updater = currentUserID
|
||||
|
||||
if err := s.carrierStore.Update(ctx, carrier); err != nil {
|
||||
return nil, fmt.Errorf("更新运营商失败: %w", err)
|
||||
}
|
||||
|
||||
return s.toResponse(carrier), nil
|
||||
}
|
||||
|
||||
func (s *Service) Delete(ctx context.Context, id uint) error {
|
||||
_, err := s.carrierStore.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return errors.New(errors.CodeCarrierNotFound, "运营商不存在")
|
||||
}
|
||||
return fmt.Errorf("获取运营商失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.carrierStore.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("删除运营商失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) List(ctx context.Context, req *dto.CarrierListRequest) ([]*dto.CarrierResponse, int64, error) {
|
||||
opts := &store.QueryOptions{
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
OrderBy: "id DESC",
|
||||
}
|
||||
if opts.Page == 0 {
|
||||
opts.Page = 1
|
||||
}
|
||||
if opts.PageSize == 0 {
|
||||
opts.PageSize = constants.DefaultPageSize
|
||||
}
|
||||
|
||||
filters := make(map[string]interface{})
|
||||
if req.CarrierType != nil {
|
||||
filters["carrier_type"] = *req.CarrierType
|
||||
}
|
||||
if req.CarrierName != nil {
|
||||
filters["carrier_name"] = *req.CarrierName
|
||||
}
|
||||
if req.Status != nil {
|
||||
filters["status"] = *req.Status
|
||||
}
|
||||
|
||||
carriers, total, err := s.carrierStore.List(ctx, opts, filters)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询运营商列表失败: %w", err)
|
||||
}
|
||||
|
||||
responses := make([]*dto.CarrierResponse, len(carriers))
|
||||
for i, c := range carriers {
|
||||
responses[i] = s.toResponse(c)
|
||||
}
|
||||
|
||||
return responses, total, nil
|
||||
}
|
||||
|
||||
func (s *Service) UpdateStatus(ctx context.Context, id uint, status int) error {
|
||||
currentUserID := middleware.GetUserIDFromContext(ctx)
|
||||
if currentUserID == 0 {
|
||||
return errors.New(errors.CodeUnauthorized, "未授权访问")
|
||||
}
|
||||
|
||||
carrier, err := s.carrierStore.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return errors.New(errors.CodeCarrierNotFound, "运营商不存在")
|
||||
}
|
||||
return fmt.Errorf("获取运营商失败: %w", err)
|
||||
}
|
||||
|
||||
carrier.Status = status
|
||||
carrier.Updater = currentUserID
|
||||
|
||||
if err := s.carrierStore.Update(ctx, carrier); err != nil {
|
||||
return fmt.Errorf("更新运营商状态失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) toResponse(c *model.Carrier) *dto.CarrierResponse {
|
||||
return &dto.CarrierResponse{
|
||||
ID: c.ID,
|
||||
CarrierCode: c.CarrierCode,
|
||||
CarrierName: c.CarrierName,
|
||||
CarrierType: c.CarrierType,
|
||||
Description: c.Description,
|
||||
Status: c.Status,
|
||||
CreatedAt: c.CreatedAt.Format(time.RFC3339),
|
||||
UpdatedAt: c.UpdatedAt.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
268
internal/service/carrier/service_test.go
Normal file
268
internal/service/carrier/service_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package carrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model/dto"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/tests/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCarrierService_Create(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
t.Run("创建成功", func(t *testing.T) {
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_CMCC_001",
|
||||
CarrierName: "中国移动-服务测试",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
Description: "服务层测试",
|
||||
}
|
||||
|
||||
resp, err := svc.Create(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, resp.ID)
|
||||
assert.Equal(t, req.CarrierCode, resp.CarrierCode)
|
||||
assert.Equal(t, req.CarrierName, resp.CarrierName)
|
||||
assert.Equal(t, constants.StatusEnabled, resp.Status)
|
||||
})
|
||||
|
||||
t.Run("编码重复失败", func(t *testing.T) {
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_CMCC_001",
|
||||
CarrierName: "中国移动-重复",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
}
|
||||
|
||||
_, err := svc.Create(ctx, req)
|
||||
require.Error(t, err)
|
||||
appErr, ok := err.(*errors.AppError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, errors.CodeCarrierCodeExists, appErr.Code)
|
||||
})
|
||||
|
||||
t.Run("未授权失败", func(t *testing.T) {
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_CMCC_002",
|
||||
CarrierName: "未授权测试",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
}
|
||||
|
||||
_, err := svc.Create(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierService_Get(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_GET_001",
|
||||
CarrierName: "查询测试",
|
||||
CarrierType: constants.CarrierTypeCUCC,
|
||||
}
|
||||
created, err := svc.Create(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("查询存在的运营商", func(t *testing.T) {
|
||||
resp, err := svc.Get(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.CarrierCode, resp.CarrierCode)
|
||||
})
|
||||
|
||||
t.Run("查询不存在的运营商", func(t *testing.T) {
|
||||
_, err := svc.Get(ctx, 99999)
|
||||
require.Error(t, err)
|
||||
appErr, ok := err.(*errors.AppError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierService_Update(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_UPD_001",
|
||||
CarrierName: "更新测试",
|
||||
CarrierType: constants.CarrierTypeCTCC,
|
||||
}
|
||||
created, err := svc.Create(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("更新成功", func(t *testing.T) {
|
||||
newName := "更新后的名称"
|
||||
newDesc := "更新后的描述"
|
||||
updateReq := &dto.UpdateCarrierRequest{
|
||||
CarrierName: &newName,
|
||||
Description: &newDesc,
|
||||
}
|
||||
|
||||
resp, err := svc.Update(ctx, created.ID, updateReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newName, resp.CarrierName)
|
||||
assert.Equal(t, newDesc, resp.Description)
|
||||
})
|
||||
|
||||
t.Run("更新不存在的运营商", func(t *testing.T) {
|
||||
newName := "test"
|
||||
updateReq := &dto.UpdateCarrierRequest{
|
||||
CarrierName: &newName,
|
||||
}
|
||||
|
||||
_, err := svc.Update(ctx, 99999, updateReq)
|
||||
require.Error(t, err)
|
||||
appErr, ok := err.(*errors.AppError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierService_Delete(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_DEL_001",
|
||||
CarrierName: "删除测试",
|
||||
CarrierType: constants.CarrierTypeCBN,
|
||||
}
|
||||
created, err := svc.Create(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("删除成功", func(t *testing.T) {
|
||||
err := svc.Delete(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.Get(ctx, created.ID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("删除不存在的运营商", func(t *testing.T) {
|
||||
err := svc.Delete(ctx, 99999)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierService_List(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
carriers := []dto.CreateCarrierRequest{
|
||||
{CarrierCode: "SVC_LIST_001", CarrierName: "移动列表", CarrierType: constants.CarrierTypeCMCC},
|
||||
{CarrierCode: "SVC_LIST_002", CarrierName: "联通列表", CarrierType: constants.CarrierTypeCUCC},
|
||||
{CarrierCode: "SVC_LIST_003", CarrierName: "电信列表", CarrierType: constants.CarrierTypeCTCC},
|
||||
}
|
||||
for _, c := range carriers {
|
||||
_, err := svc.Create(ctx, &c)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("查询列表", func(t *testing.T) {
|
||||
req := &dto.CarrierListRequest{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
result, total, err := svc.List(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(3))
|
||||
assert.GreaterOrEqual(t, len(result), 3)
|
||||
})
|
||||
|
||||
t.Run("按类型过滤", func(t *testing.T) {
|
||||
carrierType := constants.CarrierTypeCMCC
|
||||
req := &dto.CarrierListRequest{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
CarrierType: &carrierType,
|
||||
}
|
||||
result, total, err := svc.List(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(1))
|
||||
for _, c := range result {
|
||||
assert.Equal(t, constants.CarrierTypeCMCC, c.CarrierType)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierService_UpdateStatus(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
store := postgres.NewCarrierStore(tx)
|
||||
svc := New(store)
|
||||
|
||||
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
})
|
||||
|
||||
req := &dto.CreateCarrierRequest{
|
||||
CarrierCode: "SVC_STATUS_001",
|
||||
CarrierName: "状态测试",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
}
|
||||
created, err := svc.Create(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, constants.StatusEnabled, created.Status)
|
||||
|
||||
t.Run("禁用运营商", func(t *testing.T) {
|
||||
err := svc.UpdateStatus(ctx, created.ID, constants.StatusDisabled)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, constants.StatusDisabled, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("启用运营商", func(t *testing.T) {
|
||||
err := svc.UpdateStatus(ctx, created.ID, constants.StatusEnabled)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, constants.StatusEnabled, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("更新不存在的运营商状态", func(t *testing.T) {
|
||||
err := svc.UpdateStatus(ctx, 99999, 1)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -41,8 +41,6 @@ func (s *Service) ListBindings(ctx context.Context, deviceID uint) (*dto.ListDev
|
||||
cardMap[card.ID] = card
|
||||
}
|
||||
|
||||
carrierMap := s.loadCarrierData(ctx, cards)
|
||||
|
||||
responses := make([]*dto.DeviceCardBindingResponse, 0, len(bindings))
|
||||
for _, binding := range bindings {
|
||||
card := cardMap[binding.IotCardID]
|
||||
@@ -56,7 +54,7 @@ func (s *Service) ListBindings(ctx context.Context, deviceID uint) (*dto.ListDev
|
||||
IotCardID: binding.IotCardID,
|
||||
ICCID: card.ICCID,
|
||||
MSISDN: card.MSISDN,
|
||||
CarrierName: carrierMap[card.CarrierID],
|
||||
CarrierName: card.CarrierName, // 直接使用 IotCard 的冗余字段
|
||||
Status: card.Status,
|
||||
BindTime: binding.BindTime,
|
||||
}
|
||||
@@ -147,26 +145,3 @@ func (s *Service) UnbindCard(ctx context.Context, deviceID uint, cardID uint) (*
|
||||
Message: "解绑成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadCarrierData(ctx context.Context, cards []*model.IotCard) map[uint]string {
|
||||
carrierIDs := make([]uint, 0)
|
||||
carrierIDSet := make(map[uint]bool)
|
||||
|
||||
for _, card := range cards {
|
||||
if card.CarrierID > 0 && !carrierIDSet[card.CarrierID] {
|
||||
carrierIDs = append(carrierIDs, card.CarrierID)
|
||||
carrierIDSet[card.CarrierID] = true
|
||||
}
|
||||
}
|
||||
|
||||
carrierMap := make(map[uint]string)
|
||||
if len(carrierIDs) > 0 {
|
||||
var carriers []model.Carrier
|
||||
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
|
||||
for _, c := range carriers {
|
||||
carrierMap[c.ID] = c.CarrierName
|
||||
}
|
||||
}
|
||||
|
||||
return carrierMap
|
||||
}
|
||||
|
||||
@@ -88,11 +88,11 @@ func (s *Service) ListStandalone(ctx context.Context, req *dto.ListStandaloneIot
|
||||
return nil, err
|
||||
}
|
||||
|
||||
carrierMap, shopMap := s.loadRelatedData(ctx, cards)
|
||||
shopMap := s.loadShopNames(ctx, cards)
|
||||
|
||||
list := make([]*dto.StandaloneIotCardResponse, 0, len(cards))
|
||||
for _, card := range cards {
|
||||
item := s.toStandaloneResponse(card, carrierMap, shopMap)
|
||||
item := s.toStandaloneResponse(card, shopMap)
|
||||
list = append(list, item)
|
||||
}
|
||||
|
||||
@@ -120,40 +120,25 @@ func (s *Service) GetByICCID(ctx context.Context, iccid string) (*dto.IotCardDet
|
||||
return nil, err
|
||||
}
|
||||
|
||||
carrierMap, shopMap := s.loadRelatedData(ctx, []*model.IotCard{card})
|
||||
standaloneResp := s.toStandaloneResponse(card, carrierMap, shopMap)
|
||||
shopMap := s.loadShopNames(ctx, []*model.IotCard{card})
|
||||
standaloneResp := s.toStandaloneResponse(card, shopMap)
|
||||
|
||||
return &dto.IotCardDetailResponse{
|
||||
StandaloneIotCardResponse: *standaloneResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadRelatedData(ctx context.Context, cards []*model.IotCard) (map[uint]string, map[uint]string) {
|
||||
carrierIDs := make([]uint, 0)
|
||||
func (s *Service) loadShopNames(ctx context.Context, cards []*model.IotCard) map[uint]string {
|
||||
shopIDs := make([]uint, 0)
|
||||
carrierIDSet := make(map[uint]bool)
|
||||
shopIDSet := make(map[uint]bool)
|
||||
|
||||
for _, card := range cards {
|
||||
if card.CarrierID > 0 && !carrierIDSet[card.CarrierID] {
|
||||
carrierIDs = append(carrierIDs, card.CarrierID)
|
||||
carrierIDSet[card.CarrierID] = true
|
||||
}
|
||||
if card.ShopID != nil && *card.ShopID > 0 && !shopIDSet[*card.ShopID] {
|
||||
shopIDs = append(shopIDs, *card.ShopID)
|
||||
shopIDSet[*card.ShopID] = true
|
||||
}
|
||||
}
|
||||
|
||||
carrierMap := make(map[uint]string)
|
||||
if len(carrierIDs) > 0 {
|
||||
var carriers []model.Carrier
|
||||
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
|
||||
for _, c := range carriers {
|
||||
carrierMap[c.ID] = c.CarrierName
|
||||
}
|
||||
}
|
||||
|
||||
shopMap := make(map[uint]string)
|
||||
if len(shopIDs) > 0 {
|
||||
var shops []model.Shop
|
||||
@@ -163,17 +148,18 @@ func (s *Service) loadRelatedData(ctx context.Context, cards []*model.IotCard) (
|
||||
}
|
||||
}
|
||||
|
||||
return carrierMap, shopMap
|
||||
return shopMap
|
||||
}
|
||||
|
||||
func (s *Service) toStandaloneResponse(card *model.IotCard, carrierMap map[uint]string, shopMap map[uint]string) *dto.StandaloneIotCardResponse {
|
||||
func (s *Service) toStandaloneResponse(card *model.IotCard, shopMap map[uint]string) *dto.StandaloneIotCardResponse {
|
||||
resp := &dto.StandaloneIotCardResponse{
|
||||
ID: card.ID,
|
||||
ICCID: card.ICCID,
|
||||
CardType: card.CardType,
|
||||
CardCategory: card.CardCategory,
|
||||
CarrierID: card.CarrierID,
|
||||
CarrierName: carrierMap[card.CarrierID],
|
||||
CarrierType: card.CarrierType,
|
||||
CarrierName: card.CarrierName,
|
||||
IMSI: card.IMSI,
|
||||
MSISDN: card.MSISDN,
|
||||
BatchNo: card.BatchNo,
|
||||
|
||||
@@ -76,6 +76,7 @@ func (s *Service) CreateImportTask(ctx context.Context, req *dto.ImportIotCardRe
|
||||
Status: model.ImportTaskStatusPending,
|
||||
CarrierID: req.CarrierID,
|
||||
CarrierType: carrier.CarrierType,
|
||||
CarrierName: carrier.CarrierName,
|
||||
BatchNo: req.BatchNo,
|
||||
FileName: fileName,
|
||||
StorageKey: req.FileKey,
|
||||
@@ -138,11 +139,9 @@ func (s *Service) List(ctx context.Context, req *dto.ListImportTaskRequest) (*dt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
carrierMap := s.loadCarriers(ctx, tasks)
|
||||
|
||||
list := make([]*dto.ImportTaskResponse, 0, len(tasks))
|
||||
for _, task := range tasks {
|
||||
list = append(list, s.toTaskResponse(task, carrierMap))
|
||||
list = append(list, s.toTaskResponse(task))
|
||||
}
|
||||
|
||||
totalPages := int(total) / pageSize
|
||||
@@ -165,14 +164,8 @@ func (s *Service) GetByID(ctx context.Context, id uint) (*dto.ImportTaskDetailRe
|
||||
return nil, errors.New(errors.CodeNotFound, "导入任务不存在")
|
||||
}
|
||||
|
||||
carrierMap := make(map[uint]string)
|
||||
var carrier model.Carrier
|
||||
if s.db.WithContext(ctx).First(&carrier, task.CarrierID).Error == nil {
|
||||
carrierMap[carrier.ID] = carrier.CarrierName
|
||||
}
|
||||
|
||||
resp := &dto.ImportTaskDetailResponse{
|
||||
ImportTaskResponse: *s.toTaskResponse(task, carrierMap),
|
||||
ImportTaskResponse: *s.toTaskResponse(task),
|
||||
SkippedItems: make([]*dto.ImportResultItemDTO, 0),
|
||||
FailedItems: make([]*dto.ImportResultItemDTO, 0),
|
||||
}
|
||||
@@ -198,28 +191,7 @@ func (s *Service) GetByID(ctx context.Context, id uint) (*dto.ImportTaskDetailRe
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadCarriers(ctx context.Context, tasks []*model.IotCardImportTask) map[uint]string {
|
||||
carrierIDs := make([]uint, 0)
|
||||
carrierIDSet := make(map[uint]bool)
|
||||
for _, task := range tasks {
|
||||
if task.CarrierID > 0 && !carrierIDSet[task.CarrierID] {
|
||||
carrierIDs = append(carrierIDs, task.CarrierID)
|
||||
carrierIDSet[task.CarrierID] = true
|
||||
}
|
||||
}
|
||||
|
||||
carrierMap := make(map[uint]string)
|
||||
if len(carrierIDs) > 0 {
|
||||
var carriers []model.Carrier
|
||||
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
|
||||
for _, c := range carriers {
|
||||
carrierMap[c.ID] = c.CarrierName
|
||||
}
|
||||
}
|
||||
return carrierMap
|
||||
}
|
||||
|
||||
func (s *Service) toTaskResponse(task *model.IotCardImportTask, carrierMap map[uint]string) *dto.ImportTaskResponse {
|
||||
func (s *Service) toTaskResponse(task *model.IotCardImportTask) *dto.ImportTaskResponse {
|
||||
var startedAt, completedAt *time.Time
|
||||
if task.StartedAt != nil {
|
||||
startedAt = task.StartedAt
|
||||
@@ -234,7 +206,8 @@ func (s *Service) toTaskResponse(task *model.IotCardImportTask, carrierMap map[u
|
||||
Status: task.Status,
|
||||
StatusText: getStatusText(task.Status),
|
||||
CarrierID: task.CarrierID,
|
||||
CarrierName: carrierMap[task.CarrierID],
|
||||
CarrierType: task.CarrierType,
|
||||
CarrierName: task.CarrierName,
|
||||
BatchNo: task.BatchNo,
|
||||
FileName: task.FileName,
|
||||
TotalCount: task.TotalCount,
|
||||
|
||||
83
internal/store/postgres/carrier_store.go
Normal file
83
internal/store/postgres/carrier_store.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store"
|
||||
)
|
||||
|
||||
type CarrierStore struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewCarrierStore(db *gorm.DB) *CarrierStore {
|
||||
return &CarrierStore{db: db}
|
||||
}
|
||||
|
||||
func (s *CarrierStore) Create(ctx context.Context, carrier *model.Carrier) error {
|
||||
return s.db.WithContext(ctx).Create(carrier).Error
|
||||
}
|
||||
|
||||
func (s *CarrierStore) GetByID(ctx context.Context, id uint) (*model.Carrier, error) {
|
||||
var carrier model.Carrier
|
||||
if err := s.db.WithContext(ctx).First(&carrier, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &carrier, nil
|
||||
}
|
||||
|
||||
func (s *CarrierStore) GetByCode(ctx context.Context, code string) (*model.Carrier, error) {
|
||||
var carrier model.Carrier
|
||||
if err := s.db.WithContext(ctx).Where("carrier_code = ?", code).First(&carrier).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &carrier, nil
|
||||
}
|
||||
|
||||
func (s *CarrierStore) Update(ctx context.Context, carrier *model.Carrier) error {
|
||||
return s.db.WithContext(ctx).Save(carrier).Error
|
||||
}
|
||||
|
||||
func (s *CarrierStore) Delete(ctx context.Context, id uint) error {
|
||||
return s.db.WithContext(ctx).Delete(&model.Carrier{}, id).Error
|
||||
}
|
||||
|
||||
func (s *CarrierStore) List(ctx context.Context, opts *store.QueryOptions, filters map[string]interface{}) ([]*model.Carrier, int64, error) {
|
||||
var carriers []*model.Carrier
|
||||
var total int64
|
||||
|
||||
query := s.db.WithContext(ctx).Model(&model.Carrier{})
|
||||
|
||||
if carrierType, ok := filters["carrier_type"].(string); ok && carrierType != "" {
|
||||
query = query.Where("carrier_type = ?", carrierType)
|
||||
}
|
||||
if carrierName, ok := filters["carrier_name"].(string); ok && carrierName != "" {
|
||||
query = query.Where("carrier_name LIKE ?", "%"+carrierName+"%")
|
||||
}
|
||||
if status, ok := filters["status"]; ok {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if opts == nil {
|
||||
opts = store.DefaultQueryOptions()
|
||||
}
|
||||
offset := (opts.Page - 1) * opts.PageSize
|
||||
query = query.Offset(offset).Limit(opts.PageSize)
|
||||
|
||||
if opts.OrderBy != "" {
|
||||
query = query.Order(opts.OrderBy)
|
||||
}
|
||||
|
||||
if err := query.Find(&carriers).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return carriers, total, nil
|
||||
}
|
||||
204
internal/store/postgres/carrier_store_test.go
Normal file
204
internal/store/postgres/carrier_store_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/tests/testutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCarrierStore_Create(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: "CMCC_TEST_001",
|
||||
CarrierName: "中国移动测试",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
Description: "测试运营商",
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
|
||||
err := s.Create(ctx, carrier)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, carrier.ID)
|
||||
}
|
||||
|
||||
func TestCarrierStore_GetByID(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: "CUCC_TEST_001",
|
||||
CarrierName: "中国联通测试",
|
||||
CarrierType: constants.CarrierTypeCUCC,
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
require.NoError(t, s.Create(ctx, carrier))
|
||||
|
||||
t.Run("查询存在的运营商", func(t *testing.T) {
|
||||
result, err := s.GetByID(ctx, carrier.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, carrier.CarrierCode, result.CarrierCode)
|
||||
assert.Equal(t, carrier.CarrierName, result.CarrierName)
|
||||
})
|
||||
|
||||
t.Run("查询不存在的运营商", func(t *testing.T) {
|
||||
_, err := s.GetByID(ctx, 99999)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierStore_GetByCode(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: "CTCC_TEST_001",
|
||||
CarrierName: "中国电信测试",
|
||||
CarrierType: constants.CarrierTypeCTCC,
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
require.NoError(t, s.Create(ctx, carrier))
|
||||
|
||||
t.Run("查询存在的编码", func(t *testing.T) {
|
||||
result, err := s.GetByCode(ctx, "CTCC_TEST_001")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, carrier.ID, result.ID)
|
||||
})
|
||||
|
||||
t.Run("查询不存在的编码", func(t *testing.T) {
|
||||
_, err := s.GetByCode(ctx, "NOT_EXISTS")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCarrierStore_Update(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: "CBN_TEST_001",
|
||||
CarrierName: "中国广电测试",
|
||||
CarrierType: constants.CarrierTypeCBN,
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
require.NoError(t, s.Create(ctx, carrier))
|
||||
|
||||
carrier.CarrierName = "中国广电测试-更新"
|
||||
carrier.Description = "更新后的描述"
|
||||
err := s.Update(ctx, carrier)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.GetByID(ctx, carrier.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "中国广电测试-更新", updated.CarrierName)
|
||||
assert.Equal(t, "更新后的描述", updated.Description)
|
||||
}
|
||||
|
||||
func TestCarrierStore_Delete(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carrier := &model.Carrier{
|
||||
CarrierCode: "DEL_TEST_001",
|
||||
CarrierName: "待删除运营商",
|
||||
CarrierType: constants.CarrierTypeCMCC,
|
||||
Status: constants.StatusEnabled,
|
||||
}
|
||||
require.NoError(t, s.Create(ctx, carrier))
|
||||
|
||||
err := s.Delete(ctx, carrier.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.GetByID(ctx, carrier.ID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCarrierStore_List(t *testing.T) {
|
||||
tx := testutils.NewTestTransaction(t)
|
||||
s := NewCarrierStore(tx)
|
||||
ctx := context.Background()
|
||||
|
||||
carriers := []*model.Carrier{
|
||||
{CarrierCode: "LIST_001", CarrierName: "移动1", CarrierType: constants.CarrierTypeCMCC, Status: constants.StatusEnabled},
|
||||
{CarrierCode: "LIST_002", CarrierName: "联通1", CarrierType: constants.CarrierTypeCUCC, Status: constants.StatusEnabled},
|
||||
{CarrierCode: "LIST_003", CarrierName: "电信1", CarrierType: constants.CarrierTypeCTCC, Status: constants.StatusEnabled},
|
||||
}
|
||||
for _, c := range carriers {
|
||||
require.NoError(t, s.Create(ctx, c))
|
||||
}
|
||||
// 显式更新第三个 carrier 为禁用状态(GORM 不会写入零值)
|
||||
carriers[2].Status = constants.StatusDisabled
|
||||
require.NoError(t, s.Update(ctx, carriers[2]))
|
||||
|
||||
t.Run("查询所有运营商", func(t *testing.T) {
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 20}, nil)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(3))
|
||||
assert.GreaterOrEqual(t, len(result), 3)
|
||||
})
|
||||
|
||||
t.Run("按类型过滤", func(t *testing.T) {
|
||||
filters := map[string]interface{}{"carrier_type": constants.CarrierTypeCMCC}
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 20}, filters)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(1))
|
||||
for _, c := range result {
|
||||
assert.Equal(t, constants.CarrierTypeCMCC, c.CarrierType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("按名称模糊搜索", func(t *testing.T) {
|
||||
filters := map[string]interface{}{"carrier_name": "联通"}
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 20}, filters)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(1))
|
||||
for _, c := range result {
|
||||
assert.Contains(t, c.CarrierName, "联通")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("按状态过滤-禁用", func(t *testing.T) {
|
||||
filters := map[string]interface{}{"status": constants.StatusDisabled}
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 20}, filters)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(1))
|
||||
for _, c := range result {
|
||||
assert.Equal(t, constants.StatusDisabled, c.Status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("按状态过滤-启用", func(t *testing.T) {
|
||||
filters := map[string]interface{}{"status": constants.StatusEnabled}
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 20}, filters)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(2))
|
||||
for _, c := range result {
|
||||
assert.Equal(t, constants.StatusEnabled, c.Status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("分页查询", func(t *testing.T) {
|
||||
result, total, err := s.List(ctx, &store.QueryOptions{Page: 1, PageSize: 2}, nil)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, total, int64(3))
|
||||
assert.LessOrEqual(t, len(result), 2)
|
||||
})
|
||||
|
||||
t.Run("默认分页选项", func(t *testing.T) {
|
||||
result, _, err := s.List(ctx, nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user