feat: 实现运营商模块重构,添加冗余字段优化查询性能
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m16s

主要变更:
- 新增 Carrier CRUD API(创建、列表、详情、更新、删除、状态更新)
- IotCard/IotCardImportTask 添加 carrier_type/carrier_name 冗余字段
- 移除 Carrier 表的 channel_name/channel_code 字段
- 查询时直接使用冗余字段,避免 JOIN Carrier 表
- 添加数据库迁移脚本(000021-000023)
- 添加单元测试和集成测试
- 同步更新 OpenAPI 文档和 specs
This commit is contained in:
2026-01-27 12:18:19 +08:00
parent 5a179ba16b
commit d104d297ca
42 changed files with 2431 additions and 122 deletions

View File

@@ -0,0 +1,182 @@
package carrier
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
)
type Service struct {
carrierStore *postgres.CarrierStore
}
func New(carrierStore *postgres.CarrierStore) *Service {
return &Service{carrierStore: carrierStore}
}
func (s *Service) Create(ctx context.Context, req *dto.CreateCarrierRequest) (*dto.CarrierResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
existing, _ := s.carrierStore.GetByCode(ctx, req.CarrierCode)
if existing != nil {
return nil, errors.New(errors.CodeCarrierCodeExists, "运营商编码已存在")
}
carrier := &model.Carrier{
CarrierCode: req.CarrierCode,
CarrierName: req.CarrierName,
CarrierType: req.CarrierType,
Description: req.Description,
Status: constants.StatusEnabled,
}
carrier.Creator = currentUserID
if err := s.carrierStore.Create(ctx, carrier); err != nil {
return nil, fmt.Errorf("创建运营商失败: %w", err)
}
return s.toResponse(carrier), nil
}
func (s *Service) Get(ctx context.Context, id uint) (*dto.CarrierResponse, error) {
carrier, err := s.carrierStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeCarrierNotFound, "运营商不存在")
}
return nil, fmt.Errorf("获取运营商失败: %w", err)
}
return s.toResponse(carrier), nil
}
func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateCarrierRequest) (*dto.CarrierResponse, error) {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return nil, errors.New(errors.CodeUnauthorized, "未授权访问")
}
carrier, err := s.carrierStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeCarrierNotFound, "运营商不存在")
}
return nil, fmt.Errorf("获取运营商失败: %w", err)
}
if req.CarrierName != nil {
carrier.CarrierName = *req.CarrierName
}
if req.Description != nil {
carrier.Description = *req.Description
}
carrier.Updater = currentUserID
if err := s.carrierStore.Update(ctx, carrier); err != nil {
return nil, fmt.Errorf("更新运营商失败: %w", err)
}
return s.toResponse(carrier), nil
}
func (s *Service) Delete(ctx context.Context, id uint) error {
_, err := s.carrierStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeCarrierNotFound, "运营商不存在")
}
return fmt.Errorf("获取运营商失败: %w", err)
}
if err := s.carrierStore.Delete(ctx, id); err != nil {
return fmt.Errorf("删除运营商失败: %w", err)
}
return nil
}
func (s *Service) List(ctx context.Context, req *dto.CarrierListRequest) ([]*dto.CarrierResponse, int64, error) {
opts := &store.QueryOptions{
Page: req.Page,
PageSize: req.PageSize,
OrderBy: "id DESC",
}
if opts.Page == 0 {
opts.Page = 1
}
if opts.PageSize == 0 {
opts.PageSize = constants.DefaultPageSize
}
filters := make(map[string]interface{})
if req.CarrierType != nil {
filters["carrier_type"] = *req.CarrierType
}
if req.CarrierName != nil {
filters["carrier_name"] = *req.CarrierName
}
if req.Status != nil {
filters["status"] = *req.Status
}
carriers, total, err := s.carrierStore.List(ctx, opts, filters)
if err != nil {
return nil, 0, fmt.Errorf("查询运营商列表失败: %w", err)
}
responses := make([]*dto.CarrierResponse, len(carriers))
for i, c := range carriers {
responses[i] = s.toResponse(c)
}
return responses, total, nil
}
func (s *Service) UpdateStatus(ctx context.Context, id uint, status int) error {
currentUserID := middleware.GetUserIDFromContext(ctx)
if currentUserID == 0 {
return errors.New(errors.CodeUnauthorized, "未授权访问")
}
carrier, err := s.carrierStore.GetByID(ctx, id)
if err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeCarrierNotFound, "运营商不存在")
}
return fmt.Errorf("获取运营商失败: %w", err)
}
carrier.Status = status
carrier.Updater = currentUserID
if err := s.carrierStore.Update(ctx, carrier); err != nil {
return fmt.Errorf("更新运营商状态失败: %w", err)
}
return nil
}
func (s *Service) toResponse(c *model.Carrier) *dto.CarrierResponse {
return &dto.CarrierResponse{
ID: c.ID,
CarrierCode: c.CarrierCode,
CarrierName: c.CarrierName,
CarrierType: c.CarrierType,
Description: c.Description,
Status: c.Status,
CreatedAt: c.CreatedAt.Format(time.RFC3339),
UpdatedAt: c.UpdatedAt.Format(time.RFC3339),
}
}

View File

@@ -0,0 +1,268 @@
package carrier
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/internal/model/dto"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/tests/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCarrierService_Create(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
t.Run("创建成功", func(t *testing.T) {
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_CMCC_001",
CarrierName: "中国移动-服务测试",
CarrierType: constants.CarrierTypeCMCC,
Description: "服务层测试",
}
resp, err := svc.Create(ctx, req)
require.NoError(t, err)
assert.NotZero(t, resp.ID)
assert.Equal(t, req.CarrierCode, resp.CarrierCode)
assert.Equal(t, req.CarrierName, resp.CarrierName)
assert.Equal(t, constants.StatusEnabled, resp.Status)
})
t.Run("编码重复失败", func(t *testing.T) {
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_CMCC_001",
CarrierName: "中国移动-重复",
CarrierType: constants.CarrierTypeCMCC,
}
_, err := svc.Create(ctx, req)
require.Error(t, err)
appErr, ok := err.(*errors.AppError)
require.True(t, ok)
assert.Equal(t, errors.CodeCarrierCodeExists, appErr.Code)
})
t.Run("未授权失败", func(t *testing.T) {
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_CMCC_002",
CarrierName: "未授权测试",
CarrierType: constants.CarrierTypeCMCC,
}
_, err := svc.Create(context.Background(), req)
require.Error(t, err)
})
}
func TestCarrierService_Get(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_GET_001",
CarrierName: "查询测试",
CarrierType: constants.CarrierTypeCUCC,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
t.Run("查询存在的运营商", func(t *testing.T) {
resp, err := svc.Get(ctx, created.ID)
require.NoError(t, err)
assert.Equal(t, created.CarrierCode, resp.CarrierCode)
})
t.Run("查询不存在的运营商", func(t *testing.T) {
_, err := svc.Get(ctx, 99999)
require.Error(t, err)
appErr, ok := err.(*errors.AppError)
require.True(t, ok)
assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code)
})
}
func TestCarrierService_Update(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_UPD_001",
CarrierName: "更新测试",
CarrierType: constants.CarrierTypeCTCC,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
t.Run("更新成功", func(t *testing.T) {
newName := "更新后的名称"
newDesc := "更新后的描述"
updateReq := &dto.UpdateCarrierRequest{
CarrierName: &newName,
Description: &newDesc,
}
resp, err := svc.Update(ctx, created.ID, updateReq)
require.NoError(t, err)
assert.Equal(t, newName, resp.CarrierName)
assert.Equal(t, newDesc, resp.Description)
})
t.Run("更新不存在的运营商", func(t *testing.T) {
newName := "test"
updateReq := &dto.UpdateCarrierRequest{
CarrierName: &newName,
}
_, err := svc.Update(ctx, 99999, updateReq)
require.Error(t, err)
appErr, ok := err.(*errors.AppError)
require.True(t, ok)
assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code)
})
}
func TestCarrierService_Delete(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_DEL_001",
CarrierName: "删除测试",
CarrierType: constants.CarrierTypeCBN,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
t.Run("删除成功", func(t *testing.T) {
err := svc.Delete(ctx, created.ID)
require.NoError(t, err)
_, err = svc.Get(ctx, created.ID)
require.Error(t, err)
})
t.Run("删除不存在的运营商", func(t *testing.T) {
err := svc.Delete(ctx, 99999)
require.Error(t, err)
})
}
func TestCarrierService_List(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
carriers := []dto.CreateCarrierRequest{
{CarrierCode: "SVC_LIST_001", CarrierName: "移动列表", CarrierType: constants.CarrierTypeCMCC},
{CarrierCode: "SVC_LIST_002", CarrierName: "联通列表", CarrierType: constants.CarrierTypeCUCC},
{CarrierCode: "SVC_LIST_003", CarrierName: "电信列表", CarrierType: constants.CarrierTypeCTCC},
}
for _, c := range carriers {
_, err := svc.Create(ctx, &c)
require.NoError(t, err)
}
t.Run("查询列表", func(t *testing.T) {
req := &dto.CarrierListRequest{
Page: 1,
PageSize: 20,
}
result, total, err := svc.List(ctx, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(3))
assert.GreaterOrEqual(t, len(result), 3)
})
t.Run("按类型过滤", func(t *testing.T) {
carrierType := constants.CarrierTypeCMCC
req := &dto.CarrierListRequest{
Page: 1,
PageSize: 20,
CarrierType: &carrierType,
}
result, total, err := svc.List(ctx, req)
require.NoError(t, err)
assert.GreaterOrEqual(t, total, int64(1))
for _, c := range result {
assert.Equal(t, constants.CarrierTypeCMCC, c.CarrierType)
}
})
}
func TestCarrierService_UpdateStatus(t *testing.T) {
tx := testutils.NewTestTransaction(t)
store := postgres.NewCarrierStore(tx)
svc := New(store)
ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
})
req := &dto.CreateCarrierRequest{
CarrierCode: "SVC_STATUS_001",
CarrierName: "状态测试",
CarrierType: constants.CarrierTypeCMCC,
}
created, err := svc.Create(ctx, req)
require.NoError(t, err)
assert.Equal(t, constants.StatusEnabled, created.Status)
t.Run("禁用运营商", func(t *testing.T) {
err := svc.UpdateStatus(ctx, created.ID, constants.StatusDisabled)
require.NoError(t, err)
updated, err := svc.Get(ctx, created.ID)
require.NoError(t, err)
assert.Equal(t, constants.StatusDisabled, updated.Status)
})
t.Run("启用运营商", func(t *testing.T) {
err := svc.UpdateStatus(ctx, created.ID, constants.StatusEnabled)
require.NoError(t, err)
updated, err := svc.Get(ctx, created.ID)
require.NoError(t, err)
assert.Equal(t, constants.StatusEnabled, updated.Status)
})
t.Run("更新不存在的运营商状态", func(t *testing.T) {
err := svc.UpdateStatus(ctx, 99999, 1)
require.Error(t, err)
})
}

View File

@@ -41,8 +41,6 @@ func (s *Service) ListBindings(ctx context.Context, deviceID uint) (*dto.ListDev
cardMap[card.ID] = card
}
carrierMap := s.loadCarrierData(ctx, cards)
responses := make([]*dto.DeviceCardBindingResponse, 0, len(bindings))
for _, binding := range bindings {
card := cardMap[binding.IotCardID]
@@ -56,7 +54,7 @@ func (s *Service) ListBindings(ctx context.Context, deviceID uint) (*dto.ListDev
IotCardID: binding.IotCardID,
ICCID: card.ICCID,
MSISDN: card.MSISDN,
CarrierName: carrierMap[card.CarrierID],
CarrierName: card.CarrierName, // 直接使用 IotCard 的冗余字段
Status: card.Status,
BindTime: binding.BindTime,
}
@@ -147,26 +145,3 @@ func (s *Service) UnbindCard(ctx context.Context, deviceID uint, cardID uint) (*
Message: "解绑成功",
}, nil
}
func (s *Service) loadCarrierData(ctx context.Context, cards []*model.IotCard) map[uint]string {
carrierIDs := make([]uint, 0)
carrierIDSet := make(map[uint]bool)
for _, card := range cards {
if card.CarrierID > 0 && !carrierIDSet[card.CarrierID] {
carrierIDs = append(carrierIDs, card.CarrierID)
carrierIDSet[card.CarrierID] = true
}
}
carrierMap := make(map[uint]string)
if len(carrierIDs) > 0 {
var carriers []model.Carrier
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
for _, c := range carriers {
carrierMap[c.ID] = c.CarrierName
}
}
return carrierMap
}

View File

@@ -88,11 +88,11 @@ func (s *Service) ListStandalone(ctx context.Context, req *dto.ListStandaloneIot
return nil, err
}
carrierMap, shopMap := s.loadRelatedData(ctx, cards)
shopMap := s.loadShopNames(ctx, cards)
list := make([]*dto.StandaloneIotCardResponse, 0, len(cards))
for _, card := range cards {
item := s.toStandaloneResponse(card, carrierMap, shopMap)
item := s.toStandaloneResponse(card, shopMap)
list = append(list, item)
}
@@ -120,40 +120,25 @@ func (s *Service) GetByICCID(ctx context.Context, iccid string) (*dto.IotCardDet
return nil, err
}
carrierMap, shopMap := s.loadRelatedData(ctx, []*model.IotCard{card})
standaloneResp := s.toStandaloneResponse(card, carrierMap, shopMap)
shopMap := s.loadShopNames(ctx, []*model.IotCard{card})
standaloneResp := s.toStandaloneResponse(card, shopMap)
return &dto.IotCardDetailResponse{
StandaloneIotCardResponse: *standaloneResp,
}, nil
}
func (s *Service) loadRelatedData(ctx context.Context, cards []*model.IotCard) (map[uint]string, map[uint]string) {
carrierIDs := make([]uint, 0)
func (s *Service) loadShopNames(ctx context.Context, cards []*model.IotCard) map[uint]string {
shopIDs := make([]uint, 0)
carrierIDSet := make(map[uint]bool)
shopIDSet := make(map[uint]bool)
for _, card := range cards {
if card.CarrierID > 0 && !carrierIDSet[card.CarrierID] {
carrierIDs = append(carrierIDs, card.CarrierID)
carrierIDSet[card.CarrierID] = true
}
if card.ShopID != nil && *card.ShopID > 0 && !shopIDSet[*card.ShopID] {
shopIDs = append(shopIDs, *card.ShopID)
shopIDSet[*card.ShopID] = true
}
}
carrierMap := make(map[uint]string)
if len(carrierIDs) > 0 {
var carriers []model.Carrier
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
for _, c := range carriers {
carrierMap[c.ID] = c.CarrierName
}
}
shopMap := make(map[uint]string)
if len(shopIDs) > 0 {
var shops []model.Shop
@@ -163,17 +148,18 @@ func (s *Service) loadRelatedData(ctx context.Context, cards []*model.IotCard) (
}
}
return carrierMap, shopMap
return shopMap
}
func (s *Service) toStandaloneResponse(card *model.IotCard, carrierMap map[uint]string, shopMap map[uint]string) *dto.StandaloneIotCardResponse {
func (s *Service) toStandaloneResponse(card *model.IotCard, shopMap map[uint]string) *dto.StandaloneIotCardResponse {
resp := &dto.StandaloneIotCardResponse{
ID: card.ID,
ICCID: card.ICCID,
CardType: card.CardType,
CardCategory: card.CardCategory,
CarrierID: card.CarrierID,
CarrierName: carrierMap[card.CarrierID],
CarrierType: card.CarrierType,
CarrierName: card.CarrierName,
IMSI: card.IMSI,
MSISDN: card.MSISDN,
BatchNo: card.BatchNo,

View File

@@ -76,6 +76,7 @@ func (s *Service) CreateImportTask(ctx context.Context, req *dto.ImportIotCardRe
Status: model.ImportTaskStatusPending,
CarrierID: req.CarrierID,
CarrierType: carrier.CarrierType,
CarrierName: carrier.CarrierName,
BatchNo: req.BatchNo,
FileName: fileName,
StorageKey: req.FileKey,
@@ -138,11 +139,9 @@ func (s *Service) List(ctx context.Context, req *dto.ListImportTaskRequest) (*dt
return nil, err
}
carrierMap := s.loadCarriers(ctx, tasks)
list := make([]*dto.ImportTaskResponse, 0, len(tasks))
for _, task := range tasks {
list = append(list, s.toTaskResponse(task, carrierMap))
list = append(list, s.toTaskResponse(task))
}
totalPages := int(total) / pageSize
@@ -165,14 +164,8 @@ func (s *Service) GetByID(ctx context.Context, id uint) (*dto.ImportTaskDetailRe
return nil, errors.New(errors.CodeNotFound, "导入任务不存在")
}
carrierMap := make(map[uint]string)
var carrier model.Carrier
if s.db.WithContext(ctx).First(&carrier, task.CarrierID).Error == nil {
carrierMap[carrier.ID] = carrier.CarrierName
}
resp := &dto.ImportTaskDetailResponse{
ImportTaskResponse: *s.toTaskResponse(task, carrierMap),
ImportTaskResponse: *s.toTaskResponse(task),
SkippedItems: make([]*dto.ImportResultItemDTO, 0),
FailedItems: make([]*dto.ImportResultItemDTO, 0),
}
@@ -198,28 +191,7 @@ func (s *Service) GetByID(ctx context.Context, id uint) (*dto.ImportTaskDetailRe
return resp, nil
}
func (s *Service) loadCarriers(ctx context.Context, tasks []*model.IotCardImportTask) map[uint]string {
carrierIDs := make([]uint, 0)
carrierIDSet := make(map[uint]bool)
for _, task := range tasks {
if task.CarrierID > 0 && !carrierIDSet[task.CarrierID] {
carrierIDs = append(carrierIDs, task.CarrierID)
carrierIDSet[task.CarrierID] = true
}
}
carrierMap := make(map[uint]string)
if len(carrierIDs) > 0 {
var carriers []model.Carrier
s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers)
for _, c := range carriers {
carrierMap[c.ID] = c.CarrierName
}
}
return carrierMap
}
func (s *Service) toTaskResponse(task *model.IotCardImportTask, carrierMap map[uint]string) *dto.ImportTaskResponse {
func (s *Service) toTaskResponse(task *model.IotCardImportTask) *dto.ImportTaskResponse {
var startedAt, completedAt *time.Time
if task.StartedAt != nil {
startedAt = task.StartedAt
@@ -234,7 +206,8 @@ func (s *Service) toTaskResponse(task *model.IotCardImportTask, carrierMap map[u
Status: task.Status,
StatusText: getStatusText(task.Status),
CarrierID: task.CarrierID,
CarrierName: carrierMap[task.CarrierID],
CarrierType: task.CarrierType,
CarrierName: task.CarrierName,
BatchNo: task.BatchNo,
FileName: task.FileName,
TotalCount: task.TotalCount,