package storage import ( "context" "fmt" "io" "os" "path/filepath" "strings" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/break/junhong_cmp_fiber/pkg/config" ) type S3Provider struct { client *s3.S3 uploader *s3manager.Uploader bucket string tempDir string } func NewS3Provider(cfg *config.StorageConfig) (*S3Provider, error) { if cfg.S3.Endpoint == "" || cfg.S3.Bucket == "" { return nil, fmt.Errorf("S3 配置不完整:endpoint 和 bucket 必填") } if cfg.S3.AccessKeyID == "" || cfg.S3.SecretAccessKey == "" { return nil, fmt.Errorf("S3 凭证未配置:access_key_id 和 secret_access_key 必填") } sess, err := session.NewSession(&aws.Config{ Endpoint: aws.String(cfg.S3.Endpoint), Region: aws.String(cfg.S3.Region), Credentials: credentials.NewStaticCredentials(cfg.S3.AccessKeyID, cfg.S3.SecretAccessKey, ""), DisableSSL: aws.Bool(!cfg.S3.UseSSL), S3ForcePathStyle: aws.Bool(cfg.S3.PathStyle), }) if err != nil { return nil, fmt.Errorf("创建 S3 session 失败: %w", err) } tempDir := cfg.TempDir if tempDir == "" { tempDir = "/tmp/junhong-storage" } return &S3Provider{ client: s3.New(sess), uploader: s3manager.NewUploader(sess), bucket: cfg.S3.Bucket, tempDir: tempDir, }, nil } func (p *S3Provider) Upload(ctx context.Context, key string, reader io.Reader, contentType string) error { input := &s3manager.UploadInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), Body: reader, ContentType: aws.String(contentType), } _, err := p.uploader.UploadWithContext(ctx, input) if err != nil { return fmt.Errorf("上传文件失败: %w", err) } return nil } func (p *S3Provider) Download(ctx context.Context, key string, writer io.Writer) error { input := &s3.GetObjectInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), } result, err := p.client.GetObjectWithContext(ctx, input) if err != nil { if strings.Contains(err.Error(), "NoSuchKey") { return fmt.Errorf("文件不存在: %s", key) } return fmt.Errorf("下载文件失败: %w", err) } defer result.Body.Close() _, err = io.Copy(writer, result.Body) if err != nil { return fmt.Errorf("写入文件内容失败: %w", err) } return nil } func (p *S3Provider) DownloadToTemp(ctx context.Context, key string) (string, func(), error) { ext := filepath.Ext(key) if ext == "" { ext = ".tmp" } tempFile, err := os.CreateTemp(p.tempDir, "download-*"+ext) if err != nil { return "", nil, fmt.Errorf("创建临时文件失败: %w", err) } tempPath := tempFile.Name() cleanup := func() { _ = os.Remove(tempPath) } if err := p.Download(ctx, key, tempFile); err != nil { tempFile.Close() cleanup() return "", nil, err } if err := tempFile.Close(); err != nil { cleanup() return "", nil, fmt.Errorf("关闭临时文件失败: %w", err) } return tempPath, cleanup, nil } func (p *S3Provider) Delete(ctx context.Context, key string) error { input := &s3.DeleteObjectInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), } _, err := p.client.DeleteObjectWithContext(ctx, input) if err != nil { return fmt.Errorf("删除文件失败: %w", err) } return nil } func (p *S3Provider) Exists(ctx context.Context, key string) (bool, error) { input := &s3.HeadObjectInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), } _, err := p.client.HeadObjectWithContext(ctx, input) if err != nil { if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") { return false, nil } return false, fmt.Errorf("检查文件存在性失败: %w", err) } return true, nil } func (p *S3Provider) GetUploadURL(ctx context.Context, key string, contentType string, expires time.Duration) (string, error) { input := &s3.PutObjectInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), ContentType: aws.String(contentType), } req, _ := p.client.PutObjectRequest(input) url, err := req.Presign(expires) if err != nil { return "", fmt.Errorf("生成上传预签名 URL 失败: %w", err) } return url, nil } func (p *S3Provider) GetDownloadURL(ctx context.Context, key string, expires time.Duration) (string, error) { input := &s3.GetObjectInput{ Bucket: aws.String(p.bucket), Key: aws.String(key), } req, _ := p.client.GetObjectRequest(input) url, err := req.Presign(expires) if err != nil { return "", fmt.Errorf("生成下载预签名 URL 失败: %w", err) } return url, nil }