package carrier import ( "context" "testing" "github.com/break/junhong_cmp_fiber/internal/model/dto" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/break/junhong_cmp_fiber/tests/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCarrierService_Create(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) t.Run("创建成功", func(t *testing.T) { req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_CMCC_001", CarrierName: "中国移动-服务测试", CarrierType: constants.CarrierTypeCMCC, Description: "服务层测试", } resp, err := svc.Create(ctx, req) require.NoError(t, err) assert.NotZero(t, resp.ID) assert.Equal(t, req.CarrierCode, resp.CarrierCode) assert.Equal(t, req.CarrierName, resp.CarrierName) assert.Equal(t, constants.StatusEnabled, resp.Status) }) t.Run("编码重复失败", func(t *testing.T) { req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_CMCC_001", CarrierName: "中国移动-重复", CarrierType: constants.CarrierTypeCMCC, } _, err := svc.Create(ctx, req) require.Error(t, err) appErr, ok := err.(*errors.AppError) require.True(t, ok) assert.Equal(t, errors.CodeCarrierCodeExists, appErr.Code) }) t.Run("未授权失败", func(t *testing.T) { req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_CMCC_002", CarrierName: "未授权测试", CarrierType: constants.CarrierTypeCMCC, } _, err := svc.Create(context.Background(), req) require.Error(t, err) }) } func TestCarrierService_Get(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_GET_001", CarrierName: "查询测试", CarrierType: constants.CarrierTypeCUCC, } created, err := svc.Create(ctx, req) require.NoError(t, err) t.Run("查询存在的运营商", func(t *testing.T) { resp, err := svc.Get(ctx, created.ID) require.NoError(t, err) assert.Equal(t, created.CarrierCode, resp.CarrierCode) }) t.Run("查询不存在的运营商", func(t *testing.T) { _, err := svc.Get(ctx, 99999) require.Error(t, err) appErr, ok := err.(*errors.AppError) require.True(t, ok) assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code) }) } func TestCarrierService_Update(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_UPD_001", CarrierName: "更新测试", CarrierType: constants.CarrierTypeCTCC, } created, err := svc.Create(ctx, req) require.NoError(t, err) t.Run("更新成功", func(t *testing.T) { newName := "更新后的名称" newDesc := "更新后的描述" updateReq := &dto.UpdateCarrierRequest{ CarrierName: &newName, Description: &newDesc, } resp, err := svc.Update(ctx, created.ID, updateReq) require.NoError(t, err) assert.Equal(t, newName, resp.CarrierName) assert.Equal(t, newDesc, resp.Description) }) t.Run("更新不存在的运营商", func(t *testing.T) { newName := "test" updateReq := &dto.UpdateCarrierRequest{ CarrierName: &newName, } _, err := svc.Update(ctx, 99999, updateReq) require.Error(t, err) appErr, ok := err.(*errors.AppError) require.True(t, ok) assert.Equal(t, errors.CodeCarrierNotFound, appErr.Code) }) } func TestCarrierService_Delete(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_DEL_001", CarrierName: "删除测试", CarrierType: constants.CarrierTypeCBN, } created, err := svc.Create(ctx, req) require.NoError(t, err) t.Run("删除成功", func(t *testing.T) { err := svc.Delete(ctx, created.ID) require.NoError(t, err) _, err = svc.Get(ctx, created.ID) require.Error(t, err) }) t.Run("删除不存在的运营商", func(t *testing.T) { err := svc.Delete(ctx, 99999) require.Error(t, err) }) } func TestCarrierService_List(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) carriers := []dto.CreateCarrierRequest{ {CarrierCode: "SVC_LIST_001", CarrierName: "移动列表", CarrierType: constants.CarrierTypeCMCC}, {CarrierCode: "SVC_LIST_002", CarrierName: "联通列表", CarrierType: constants.CarrierTypeCUCC}, {CarrierCode: "SVC_LIST_003", CarrierName: "电信列表", CarrierType: constants.CarrierTypeCTCC}, } for _, c := range carriers { _, err := svc.Create(ctx, &c) require.NoError(t, err) } t.Run("查询列表", func(t *testing.T) { req := &dto.CarrierListRequest{ Page: 1, PageSize: 20, } result, total, err := svc.List(ctx, req) require.NoError(t, err) assert.GreaterOrEqual(t, total, int64(3)) assert.GreaterOrEqual(t, len(result), 3) }) t.Run("按类型过滤", func(t *testing.T) { carrierType := constants.CarrierTypeCMCC req := &dto.CarrierListRequest{ Page: 1, PageSize: 20, CarrierType: &carrierType, } result, total, err := svc.List(ctx, req) require.NoError(t, err) assert.GreaterOrEqual(t, total, int64(1)) for _, c := range result { assert.Equal(t, constants.CarrierTypeCMCC, c.CarrierType) } }) } func TestCarrierService_UpdateStatus(t *testing.T) { tx := testutils.NewTestTransaction(t) store := postgres.NewCarrierStore(tx) svc := New(store) ctx := middleware.SetUserContext(context.Background(), &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, }) req := &dto.CreateCarrierRequest{ CarrierCode: "SVC_STATUS_001", CarrierName: "状态测试", CarrierType: constants.CarrierTypeCMCC, } created, err := svc.Create(ctx, req) require.NoError(t, err) assert.Equal(t, constants.StatusEnabled, created.Status) t.Run("禁用运营商", func(t *testing.T) { err := svc.UpdateStatus(ctx, created.ID, constants.StatusDisabled) require.NoError(t, err) updated, err := svc.Get(ctx, created.ID) require.NoError(t, err) assert.Equal(t, constants.StatusDisabled, updated.Status) }) t.Run("启用运营商", func(t *testing.T) { err := svc.UpdateStatus(ctx, created.ID, constants.StatusEnabled) require.NoError(t, err) updated, err := svc.Get(ctx, created.ID) require.NoError(t, err) assert.Equal(t, constants.StatusEnabled, updated.Status) }) t.Run("更新不存在的运营商状态", func(t *testing.T) { err := svc.UpdateStatus(ctx, 99999, 1) require.Error(t, err) }) }