package enterprise import ( "context" "fmt" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/middleware" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) type Service struct { db *gorm.DB enterpriseStore *postgres.EnterpriseStore shopStore *postgres.ShopStore accountStore *postgres.AccountStore } func New(db *gorm.DB, enterpriseStore *postgres.EnterpriseStore, shopStore *postgres.ShopStore, accountStore *postgres.AccountStore) *Service { return &Service{ db: db, enterpriseStore: enterpriseStore, shopStore: shopStore, accountStore: accountStore, } } func (s *Service) Create(ctx context.Context, req *model.CreateEnterpriseReq) (*model.CreateEnterpriseResp, error) { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return nil, errors.New(errors.CodeUnauthorized, "未授权访问") } if req.EnterpriseCode != "" { existing, _ := s.enterpriseStore.GetByCode(ctx, req.EnterpriseCode) if existing != nil { return nil, errors.New(errors.CodeEnterpriseCodeExists, "企业编号已存在") } } existingAccount, _ := s.accountStore.GetByPhone(ctx, req.LoginPhone) if existingAccount != nil { return nil, errors.New(errors.CodePhoneExists, "手机号已被使用") } if req.OwnerShopID != nil { _, err := s.shopStore.GetByID(ctx, *req.OwnerShopID) if err != nil { return nil, errors.New(errors.CodeShopNotFound, "归属店铺不存在或无效") } } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("密码加密失败: %w", err) } var enterprise *model.Enterprise var account *model.Account err = s.db.Transaction(func(tx *gorm.DB) error { enterprise = &model.Enterprise{ EnterpriseName: req.EnterpriseName, EnterpriseCode: req.EnterpriseCode, OwnerShopID: req.OwnerShopID, LegalPerson: req.LegalPerson, ContactName: req.ContactName, ContactPhone: req.ContactPhone, BusinessLicense: req.BusinessLicense, Province: req.Province, City: req.City, District: req.District, Address: req.Address, Status: constants.StatusEnabled, } enterprise.Creator = currentUserID enterprise.Updater = currentUserID if err := tx.WithContext(ctx).Create(enterprise).Error; err != nil { return fmt.Errorf("创建企业失败: %w", err) } account = &model.Account{ Username: req.EnterpriseName, Phone: req.LoginPhone, Password: string(hashedPassword), UserType: constants.UserTypeEnterprise, EnterpriseID: &enterprise.ID, Status: constants.StatusEnabled, } account.Creator = currentUserID account.Updater = currentUserID if err := tx.WithContext(ctx).Create(account).Error; err != nil { return fmt.Errorf("创建企业账号失败: %w", err) } return nil }) if err != nil { return nil, err } ownerShopName := "" if enterprise.OwnerShopID != nil { if shop, err := s.shopStore.GetByID(ctx, *enterprise.OwnerShopID); err == nil { ownerShopName = shop.ShopName } } return &model.CreateEnterpriseResp{ Enterprise: model.EnterpriseItem{ ID: enterprise.ID, EnterpriseName: enterprise.EnterpriseName, EnterpriseCode: enterprise.EnterpriseCode, OwnerShopID: enterprise.OwnerShopID, OwnerShopName: ownerShopName, LegalPerson: enterprise.LegalPerson, ContactName: enterprise.ContactName, ContactPhone: enterprise.ContactPhone, LoginPhone: req.LoginPhone, BusinessLicense: enterprise.BusinessLicense, Province: enterprise.Province, City: enterprise.City, District: enterprise.District, Address: enterprise.Address, Status: enterprise.Status, StatusName: getStatusName(enterprise.Status), CreatedAt: enterprise.CreatedAt.Format("2006-01-02 15:04:05"), }, AccountID: account.ID, }, nil } // Update 更新企业信息 func (s *Service) Update(ctx context.Context, id uint, req *model.UpdateEnterpriseRequest) (*model.Enterprise, error) { // 获取当前用户 ID currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return nil, errors.New(errors.CodeUnauthorized, "未授权访问") } // 查询企业 enterprise, err := s.enterpriseStore.GetByID(ctx, id) if err != nil { return nil, errors.New(errors.CodeEnterpriseNotFound, "企业不存在") } // 检查企业编号唯一性(如果修改了编号) if req.EnterpriseCode != nil && *req.EnterpriseCode != enterprise.EnterpriseCode { existing, err := s.enterpriseStore.GetByCode(ctx, *req.EnterpriseCode) if err == nil && existing != nil && existing.ID != id { return nil, errors.New(errors.CodeEnterpriseCodeExists, "企业编号已存在") } enterprise.EnterpriseCode = *req.EnterpriseCode } // 更新字段 if req.EnterpriseName != nil { enterprise.EnterpriseName = *req.EnterpriseName } if req.LegalPerson != nil { enterprise.LegalPerson = *req.LegalPerson } if req.ContactName != nil { enterprise.ContactName = *req.ContactName } if req.ContactPhone != nil { enterprise.ContactPhone = *req.ContactPhone } if req.BusinessLicense != nil { enterprise.BusinessLicense = *req.BusinessLicense } if req.Province != nil { enterprise.Province = *req.Province } if req.City != nil { enterprise.City = *req.City } if req.District != nil { enterprise.District = *req.District } if req.Address != nil { enterprise.Address = *req.Address } enterprise.Updater = currentUserID if err := s.enterpriseStore.Update(ctx, enterprise); err != nil { return nil, err } return enterprise, nil } func (s *Service) UpdateStatus(ctx context.Context, id uint, status int) error { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return errors.New(errors.CodeUnauthorized, "未授权访问") } enterprise, err := s.enterpriseStore.GetByID(ctx, id) if err != nil { return errors.New(errors.CodeEnterpriseNotFound, "企业不存在") } return s.db.Transaction(func(tx *gorm.DB) error { enterprise.Status = status enterprise.Updater = currentUserID if err := tx.WithContext(ctx).Save(enterprise).Error; err != nil { return fmt.Errorf("更新企业状态失败: %w", err) } if err := tx.WithContext(ctx).Model(&model.Account{}). Where("enterprise_id = ?", id). Updates(map[string]interface{}{ "status": status, "updater": currentUserID, }).Error; err != nil { return fmt.Errorf("同步更新企业账号状态失败: %w", err) } return nil }) } func (s *Service) UpdatePassword(ctx context.Context, id uint, password string) error { currentUserID := middleware.GetUserIDFromContext(ctx) if currentUserID == 0 { return errors.New(errors.CodeUnauthorized, "未授权访问") } _, err := s.enterpriseStore.GetByID(ctx, id) if err != nil { return errors.New(errors.CodeEnterpriseNotFound, "企业不存在") } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("密码加密失败: %w", err) } return s.db.WithContext(ctx).Model(&model.Account{}). Where("enterprise_id = ?", id). Updates(map[string]interface{}{ "password": string(hashedPassword), "updater": currentUserID, }).Error } func (s *Service) GetByID(ctx context.Context, id uint) (*model.Enterprise, error) { enterprise, err := s.enterpriseStore.GetByID(ctx, id) if err != nil { return nil, errors.New(errors.CodeEnterpriseNotFound, "企业不存在") } return enterprise, nil } func (s *Service) List(ctx context.Context, req *model.EnterpriseListReq) (*model.EnterprisePageResult, error) { opts := &store.QueryOptions{ Page: req.Page, PageSize: req.PageSize, } if opts.Page == 0 { opts.Page = 1 } if opts.PageSize == 0 { opts.PageSize = constants.DefaultPageSize } filters := make(map[string]interface{}) if req.EnterpriseName != "" { filters["enterprise_name"] = req.EnterpriseName } if req.ContactPhone != "" { filters["contact_phone"] = req.ContactPhone } if req.OwnerShopID != nil { filters["owner_shop_id"] = *req.OwnerShopID } if req.Status != nil { filters["status"] = *req.Status } enterprises, total, err := s.enterpriseStore.List(ctx, opts, filters) if err != nil { return nil, fmt.Errorf("查询企业列表失败: %w", err) } enterpriseIDs := make([]uint, 0, len(enterprises)) shopIDs := make([]uint, 0) for _, e := range enterprises { enterpriseIDs = append(enterpriseIDs, e.ID) if e.OwnerShopID != nil { shopIDs = append(shopIDs, *e.OwnerShopID) } } accountMap := make(map[uint]string) if len(enterpriseIDs) > 0 { var accounts []model.Account s.db.WithContext(ctx).Where("enterprise_id IN ?", enterpriseIDs).Find(&accounts) for _, acc := range accounts { if acc.EnterpriseID != nil { accountMap[*acc.EnterpriseID] = acc.Phone } } } shopMap := make(map[uint]string) if len(shopIDs) > 0 { var shops []model.Shop s.db.WithContext(ctx).Where("id IN ?", shopIDs).Find(&shops) for _, shop := range shops { shopMap[shop.ID] = shop.ShopName } } items := make([]model.EnterpriseItem, 0, len(enterprises)) for _, e := range enterprises { ownerShopName := "" if e.OwnerShopID != nil { ownerShopName = shopMap[*e.OwnerShopID] } items = append(items, model.EnterpriseItem{ ID: e.ID, EnterpriseName: e.EnterpriseName, EnterpriseCode: e.EnterpriseCode, OwnerShopID: e.OwnerShopID, OwnerShopName: ownerShopName, LegalPerson: e.LegalPerson, ContactName: e.ContactName, ContactPhone: e.ContactPhone, LoginPhone: accountMap[e.ID], BusinessLicense: e.BusinessLicense, Province: e.Province, City: e.City, District: e.District, Address: e.Address, Status: e.Status, StatusName: getStatusName(e.Status), CreatedAt: e.CreatedAt.Format("2006-01-02 15:04:05"), }) } return &model.EnterprisePageResult{ Items: items, Total: total, Page: opts.Page, Size: opts.PageSize, }, nil } func getStatusName(status int) string { if status == constants.StatusEnabled { return "启用" } return "禁用" }