package openapi import ( "encoding/json" "os" "path/filepath" "reflect" "strconv" "strings" "github.com/swaggest/openapi-go/openapi3" "gopkg.in/yaml.v3" ) // Generator OpenAPI 文档生成器 type Generator struct { Reflector *openapi3.Reflector } // NewGenerator 创建一个新的生成器实例 func NewGenerator(title, version string) *Generator { reflector := openapi3.Reflector{} reflector.Spec = &openapi3.Spec{ Openapi: "3.0.3", Info: openapi3.Info{ Title: title, Version: version, }, } g := &Generator{Reflector: &reflector} g.addBearerAuth() return g } // addBearerAuth 添加 Bearer Token 认证定义 func (g *Generator) addBearerAuth() { bearerFormat := "JWT" g.Reflector.Spec.ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem( "BearerAuth", openapi3.SecuritySchemeOrRef{ SecurityScheme: &openapi3.SecurityScheme{ HTTPSecurityScheme: &openapi3.HTTPSecurityScheme{ Scheme: "bearer", BearerFormat: &bearerFormat, }, }, }, ) g.addErrorResponseSchema() } // addErrorResponseSchema 添加错误响应 Schema 定义 func (g *Generator) addErrorResponseSchema() { objectType := openapi3.SchemaType("object") integerType := openapi3.SchemaType("integer") stringType := openapi3.SchemaType("string") dateTimeFormat := "date-time" errorSchema := openapi3.SchemaOrRef{ Schema: &openapi3.Schema{ Type: &objectType, Properties: map[string]openapi3.SchemaOrRef{ "code": { Schema: &openapi3.Schema{ Type: &integerType, Description: ptrString("错误码"), }, }, "message": { Schema: &openapi3.Schema{ Type: &stringType, Description: ptrString("错误消息"), }, }, "timestamp": { Schema: &openapi3.Schema{ Type: &stringType, Format: &dateTimeFormat, Description: ptrString("时间戳"), }, }, }, Required: []string{"code", "message", "timestamp"}, }, } g.Reflector.Spec.ComponentsEns().SchemasEns().WithMapOfSchemaOrRefValuesItem("ErrorResponse", errorSchema) } func ptrString(s string) *string { return &s } // FileUploadField 定义文件上传字段 type FileUploadField struct { Name string Description string Required bool } // AddOperation 向 OpenAPI 规范中添加一个操作 // 参数: // - method: HTTP 方法(GET, POST, PUT, DELETE 等) // - path: API 路径 // - summary: 操作摘要 // - description: 详细说明,支持 Markdown 语法(可为空) // - input: 请求参数结构体(可为 nil) // - output: 响应结构体(可为 nil) // - tags: 标签列表 // - requiresAuth: 是否需要认证 func (g *Generator) AddOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, tags ...string) { op := openapi3.Operation{ Summary: &summary, Tags: tags, } if description != "" { op.Description = &description } // 反射输入 (请求参数/Body) if input != nil { // SetRequest 根据结构体标签自动检测 Body、Query 或 Path 参数 if err := g.Reflector.SetRequest(&op, input, method); err != nil { panic(err) // 生成过程中出错直接 panic,以便快速发现问题 } } // 反射输出 (响应 Body) if output != nil { if err := g.Reflector.SetJSONResponse(&op, output, 200); err != nil { panic(err) } } // 添加认证要求 if requiresAuth { g.addSecurityRequirement(&op) } // 添加标准错误响应 g.addStandardErrorResponses(&op, requiresAuth) // 将操作添加到规范中 if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil { panic(err) } } // AddMultipartOperation 添加支持文件上传的 multipart/form-data 操作 func (g *Generator) AddMultipartOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, fileFields []FileUploadField, tags ...string) { op := openapi3.Operation{ Summary: &summary, Tags: tags, } if description != "" { op.Description = &description } objectType := openapi3.SchemaType("object") stringType := openapi3.SchemaType("string") integerType := openapi3.SchemaType("integer") binaryFormat := "binary" properties := make(map[string]openapi3.SchemaOrRef) var requiredFields []string for _, f := range fileFields { properties[f.Name] = openapi3.SchemaOrRef{ Schema: &openapi3.Schema{ Type: &stringType, Format: &binaryFormat, Description: ptrString(f.Description), }, } if f.Required { requiredFields = append(requiredFields, f.Name) } } if input != nil { formFields := parseFormFields(input) for _, field := range formFields { var schemaType *openapi3.SchemaType switch field.Type { case "integer": schemaType = &integerType default: schemaType = &stringType } schema := &openapi3.Schema{ Type: schemaType, Description: ptrString(field.Description), } if field.Min != nil { schema.Minimum = field.Min } if field.MaxLength != nil { schema.MaxLength = field.MaxLength } properties[field.Name] = openapi3.SchemaOrRef{Schema: schema} if field.Required { requiredFields = append(requiredFields, field.Name) } } } op.RequestBody = &openapi3.RequestBodyOrRef{ RequestBody: &openapi3.RequestBody{ Required: ptrBool(true), Content: map[string]openapi3.MediaType{ "multipart/form-data": { Schema: &openapi3.SchemaOrRef{ Schema: &openapi3.Schema{ Type: &objectType, Properties: properties, Required: requiredFields, }, }, }, }, }, } if output != nil { if err := g.Reflector.SetJSONResponse(&op, output, 200); err != nil { panic(err) } } if requiresAuth { g.addSecurityRequirement(&op) } g.addStandardErrorResponses(&op, requiresAuth) if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil { panic(err) } } func ptrBool(b bool) *bool { return &b } type formFieldInfo struct { Name string Type string Description string Required bool Min *float64 MaxLength *int64 } func parseFormFields(input interface{}) []formFieldInfo { var fields []formFieldInfo t := reflect.TypeOf(input) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return fields } for i := 0; i < t.NumField(); i++ { field := t.Field(i) formTag := field.Tag.Get("form") if formTag == "" || formTag == "-" { continue } info := formFieldInfo{ Name: formTag, Description: field.Tag.Get("description"), } switch field.Type.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: info.Type = "integer" default: info.Type = "string" } validateTag := field.Tag.Get("validate") if strings.Contains(validateTag, "required") { info.Required = true } if minStr := field.Tag.Get("minimum"); minStr != "" { if min, err := strconv.ParseFloat(minStr, 64); err == nil { info.Min = &min } } if maxLenStr := field.Tag.Get("maxLength"); maxLenStr != "" { if maxLen, err := strconv.ParseInt(maxLenStr, 10, 64); err == nil { info.MaxLength = &maxLen } } fields = append(fields, info) } return fields } // addSecurityRequirement 为操作添加认证要求 func (g *Generator) addSecurityRequirement(op *openapi3.Operation) { op.Security = []map[string][]string{ {"BearerAuth": {}}, } } // addStandardErrorResponses 添加标准错误响应 func (g *Generator) addStandardErrorResponses(op *openapi3.Operation, requiresAuth bool) { if op.Responses.MapOfResponseOrRefValues == nil { op.Responses.MapOfResponseOrRefValues = make(map[string]openapi3.ResponseOrRef) } // 400 Bad Request - 所有端点都可能返回 desc400 := "请求参数错误" op.Responses.MapOfResponseOrRefValues["400"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc400, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } // 401 Unauthorized - 仅认证端点返回 if requiresAuth { desc401 := "未认证或认证已过期" op.Responses.MapOfResponseOrRefValues["401"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc401, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } // 403 Forbidden - 仅认证端点返回 desc403 := "无权访问" op.Responses.MapOfResponseOrRefValues["403"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc403, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } } // 500 Internal Server Error - 所有端点都可能返回 desc500 := "服务器内部错误" op.Responses.MapOfResponseOrRefValues["500"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc500, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } } // Save 将规范导出为 YAML 文件 func (g *Generator) Save(filename string) error { // 确保目录存在 dir := filepath.Dir(filename) if err := os.MkdirAll(dir, 0755); err != nil { return err } // 安全的方法:MarshalJSON -> Unmarshal -> MarshalYAML // 这确保了我们遵守 openapi3 库中定义的 `json` 标签 jsonBytes, err := g.Reflector.Spec.MarshalJSON() if err != nil { return err } var obj interface{} if err := json.Unmarshal(jsonBytes, &obj); err != nil { return err } yamlBytes, err := yaml.Marshal(obj) if err != nil { return err } return os.WriteFile(filename, yamlBytes, 0644) }