package database import ( "context" "fmt" "time" "github.com/break/junhong_cmp_fiber/pkg/config" "github.com/break/junhong_cmp_fiber/pkg/constants" "go.uber.org/zap" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // InitPostgreSQL 初始化 PostgreSQL 数据库连接 func InitPostgreSQL(cfg *config.DatabaseConfig, log *zap.Logger) (*gorm.DB, error) { // 构建 DSN (数据源名称) dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, ) // 配置 GORM gormConfig := &gorm.Config{ // 使用自定义日志器(集成 Zap) Logger: newGormLogger(log), // 禁用自动创建表(使用迁移脚本管理) DisableAutomaticPing: false, SkipDefaultTransaction: true, // 提高性能,手动管理事务 PrepareStmt: true, // 预编译语句 } // 连接数据库 db, err := gorm.Open(postgres.Open(dsn), gormConfig) if err != nil { log.Error("PostgreSQL 连接失败", zap.String("host", cfg.Host), zap.Int("port", cfg.Port), zap.String("dbname", cfg.DBName), zap.Error(err)) return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } // 获取底层 SQL DB 对象 sqlDB, err := db.DB() if err != nil { log.Error("获取 SQL DB 失败", zap.Error(err)) return nil, fmt.Errorf("failed to get SQL DB: %w", err) } // 配置连接池 maxOpenConns := cfg.MaxOpenConns if maxOpenConns <= 0 { maxOpenConns = constants.DefaultMaxOpenConns } maxIdleConns := cfg.MaxIdleConns if maxIdleConns <= 0 { maxIdleConns = constants.DefaultMaxIdleConns } connMaxLifetime := cfg.ConnMaxLifetime if connMaxLifetime <= 0 { connMaxLifetime = constants.DefaultConnMaxLifetime } sqlDB.SetMaxOpenConns(maxOpenConns) sqlDB.SetMaxIdleConns(maxIdleConns) sqlDB.SetConnMaxLifetime(connMaxLifetime) // 验证连接 if err := sqlDB.Ping(); err != nil { log.Error("PostgreSQL Ping 失败", zap.Error(err)) return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err) } log.Info("PostgreSQL 连接成功", zap.String("host", cfg.Host), zap.Int("port", cfg.Port), zap.String("dbname", cfg.DBName), zap.Int("max_open_conns", maxOpenConns), zap.Int("max_idle_conns", maxIdleConns), zap.Duration("conn_max_lifetime", connMaxLifetime)) // db.AutoMigrate( // &model.Account{}, // &model.Role{}, // &model.Permission{}, // &model.RolePermission{}, // ) return db, nil } // gormLogger 自定义 GORM 日志器,集成 Zap type gormLogger struct { zap *zap.Logger slowQueryThreshold time.Duration ignoreRecordNotFound bool logLevel logger.LogLevel } // newGormLogger 创建新的 GORM 日志器 func newGormLogger(log *zap.Logger) logger.Interface { return &gormLogger{ zap: log, slowQueryThreshold: constants.SlowQueryThreshold, ignoreRecordNotFound: true, logLevel: logger.Info, } } // LogMode 设置日志级别 func (l *gormLogger) LogMode(level logger.LogLevel) logger.Interface { newLogger := *l newLogger.logLevel = level return &newLogger } // Info 记录 Info 级别日志 func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= logger.Info { l.zap.Sugar().Infof(msg, data...) } } // Warn 记录 Warn 级别日志 func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= logger.Warn { l.zap.Sugar().Warnf(msg, data...) } } // Error 记录 Error 级别日志 func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= logger.Error { l.zap.Sugar().Errorf(msg, data...) } } // Trace 记录 SQL 查询日志 func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.logLevel <= logger.Silent { return } elapsed := time.Since(begin) sql, rows := fc() switch { case err != nil && l.logLevel >= logger.Error && (!l.ignoreRecordNotFound || err != gorm.ErrRecordNotFound): // 查询错误 l.zap.Error("SQL 查询失败", zap.String("sql", sql), zap.Int64("rows", rows), zap.Duration("elapsed", elapsed), zap.Error(err)) case elapsed > l.slowQueryThreshold && l.logLevel >= logger.Warn: // 慢查询 l.zap.Warn("慢查询检测", zap.String("sql", sql), zap.Int64("rows", rows), zap.Duration("elapsed", elapsed), zap.Duration("threshold", l.slowQueryThreshold)) case l.logLevel >= logger.Info: // 正常查询 l.zap.Debug("SQL 查询", zap.String("sql", sql), zap.Int64("rows", rows), zap.Duration("elapsed", elapsed)) } }