Files
junhong_cmp_fiber/pkg/database/postgres.go
2025-12-15 14:37:34 +08:00

179 lines
4.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"context"
"fmt"
"time"
"github.com/break/junhong_cmp_fiber/pkg/config"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"go.uber.org/zap"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitPostgreSQL 初始化 PostgreSQL 数据库连接
func InitPostgreSQL(cfg *config.DatabaseConfig, log *zap.Logger) (*gorm.DB, error) {
// 构建 DSN (数据源名称)
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
cfg.Host,
cfg.Port,
cfg.User,
cfg.Password,
cfg.DBName,
cfg.SSLMode,
)
// 配置 GORM
gormConfig := &gorm.Config{
// 使用自定义日志器(集成 Zap
Logger: newGormLogger(log),
// 禁用自动创建表(使用迁移脚本管理)
DisableAutomaticPing: false,
SkipDefaultTransaction: true, // 提高性能,手动管理事务
PrepareStmt: true, // 预编译语句
}
// 连接数据库
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
log.Error("PostgreSQL 连接失败",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.String("dbname", cfg.DBName),
zap.Error(err))
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
// 获取底层 SQL DB 对象
sqlDB, err := db.DB()
if err != nil {
log.Error("获取 SQL DB 失败", zap.Error(err))
return nil, fmt.Errorf("failed to get SQL DB: %w", err)
}
// 配置连接池
maxOpenConns := cfg.MaxOpenConns
if maxOpenConns <= 0 {
maxOpenConns = constants.DefaultMaxOpenConns
}
maxIdleConns := cfg.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = constants.DefaultMaxIdleConns
}
connMaxLifetime := cfg.ConnMaxLifetime
if connMaxLifetime <= 0 {
connMaxLifetime = constants.DefaultConnMaxLifetime
}
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
// 验证连接
if err := sqlDB.Ping(); err != nil {
log.Error("PostgreSQL Ping 失败", zap.Error(err))
return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err)
}
log.Info("PostgreSQL 连接成功",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.String("dbname", cfg.DBName),
zap.Int("max_open_conns", maxOpenConns),
zap.Int("max_idle_conns", maxIdleConns),
zap.Duration("conn_max_lifetime", connMaxLifetime))
// db.AutoMigrate(
// &model.Account{},
// &model.Role{},
// &model.Permission{},
// &model.RolePermission{},
// )
return db, nil
}
// gormLogger 自定义 GORM 日志器,集成 Zap
type gormLogger struct {
zap *zap.Logger
slowQueryThreshold time.Duration
ignoreRecordNotFound bool
logLevel logger.LogLevel
}
// newGormLogger 创建新的 GORM 日志器
func newGormLogger(log *zap.Logger) logger.Interface {
return &gormLogger{
zap: log,
slowQueryThreshold: constants.SlowQueryThreshold,
ignoreRecordNotFound: true,
logLevel: logger.Info,
}
}
// LogMode 设置日志级别
func (l *gormLogger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.logLevel = level
return &newLogger
}
// Info 记录 Info 级别日志
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= logger.Info {
l.zap.Sugar().Infof(msg, data...)
}
}
// Warn 记录 Warn 级别日志
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= logger.Warn {
l.zap.Sugar().Warnf(msg, data...)
}
}
// Error 记录 Error 级别日志
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= logger.Error {
l.zap.Sugar().Errorf(msg, data...)
}
}
// Trace 记录 SQL 查询日志
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.logLevel <= logger.Silent {
return
}
elapsed := time.Since(begin)
sql, rows := fc()
switch {
case err != nil && l.logLevel >= logger.Error && (!l.ignoreRecordNotFound || err != gorm.ErrRecordNotFound):
// 查询错误
l.zap.Error("SQL 查询失败",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
zap.Error(err))
case elapsed > l.slowQueryThreshold && l.logLevel >= logger.Warn:
// 慢查询
l.zap.Warn("慢查询检测",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed),
zap.Duration("threshold", l.slowQueryThreshold))
case l.logLevel >= logger.Info:
// 正常查询
l.zap.Debug("SQL 查询",
zap.String("sql", sql),
zap.Int64("rows", rows),
zap.Duration("elapsed", elapsed))
}
}