package openapi import ( "encoding/json" "os" "path/filepath" "github.com/swaggest/openapi-go/openapi3" "gopkg.in/yaml.v3" ) // Generator OpenAPI 文档生成器 type Generator struct { Reflector *openapi3.Reflector } // NewGenerator 创建一个新的生成器实例 func NewGenerator(title, version string) *Generator { reflector := openapi3.Reflector{} reflector.Spec = &openapi3.Spec{ Openapi: "3.0.3", Info: openapi3.Info{ Title: title, Version: version, }, } g := &Generator{Reflector: &reflector} g.addBearerAuth() return g } // addBearerAuth 添加 Bearer Token 认证定义 func (g *Generator) addBearerAuth() { bearerFormat := "JWT" g.Reflector.Spec.ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem( "BearerAuth", openapi3.SecuritySchemeOrRef{ SecurityScheme: &openapi3.SecurityScheme{ HTTPSecurityScheme: &openapi3.HTTPSecurityScheme{ Scheme: "bearer", BearerFormat: &bearerFormat, }, }, }, ) g.addErrorResponseSchema() } // addErrorResponseSchema 添加错误响应 Schema 定义 func (g *Generator) addErrorResponseSchema() { objectType := openapi3.SchemaType("object") integerType := openapi3.SchemaType("integer") stringType := openapi3.SchemaType("string") dateTimeFormat := "date-time" errorSchema := openapi3.SchemaOrRef{ Schema: &openapi3.Schema{ Type: &objectType, Properties: map[string]openapi3.SchemaOrRef{ "code": { Schema: &openapi3.Schema{ Type: &integerType, Description: ptrString("错误码"), }, }, "message": { Schema: &openapi3.Schema{ Type: &stringType, Description: ptrString("错误消息"), }, }, "timestamp": { Schema: &openapi3.Schema{ Type: &stringType, Format: &dateTimeFormat, Description: ptrString("时间戳"), }, }, }, Required: []string{"code", "message", "timestamp"}, }, } g.Reflector.Spec.ComponentsEns().SchemasEns().WithMapOfSchemaOrRefValuesItem("ErrorResponse", errorSchema) } // ptrString 返回字符串指针 func ptrString(s string) *string { return &s } // AddOperation 向 OpenAPI 规范中添加一个操作 // 参数: // - method: HTTP 方法(GET, POST, PUT, DELETE 等) // - path: API 路径 // - summary: 操作摘要 // - input: 请求参数结构体(可为 nil) // - output: 响应结构体(可为 nil) // - tags: 标签列表 // - requiresAuth: 是否需要认证 func (g *Generator) AddOperation(method, path, summary string, input interface{}, output interface{}, requiresAuth bool, tags ...string) { op := openapi3.Operation{ Summary: &summary, Tags: tags, } // 反射输入 (请求参数/Body) if input != nil { // SetRequest 根据结构体标签自动检测 Body、Query 或 Path 参数 if err := g.Reflector.SetRequest(&op, input, method); err != nil { panic(err) // 生成过程中出错直接 panic,以便快速发现问题 } } // 反射输出 (响应 Body) if output != nil { if err := g.Reflector.SetJSONResponse(&op, output, 200); err != nil { panic(err) } } // 添加认证要求 if requiresAuth { g.addSecurityRequirement(&op) } // 添加标准错误响应 g.addStandardErrorResponses(&op, requiresAuth) // 将操作添加到规范中 if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil { panic(err) } } // addSecurityRequirement 为操作添加认证要求 func (g *Generator) addSecurityRequirement(op *openapi3.Operation) { op.Security = []map[string][]string{ {"BearerAuth": {}}, } } // addStandardErrorResponses 添加标准错误响应 func (g *Generator) addStandardErrorResponses(op *openapi3.Operation, requiresAuth bool) { if op.Responses.MapOfResponseOrRefValues == nil { op.Responses.MapOfResponseOrRefValues = make(map[string]openapi3.ResponseOrRef) } // 400 Bad Request - 所有端点都可能返回 desc400 := "请求参数错误" op.Responses.MapOfResponseOrRefValues["400"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc400, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } // 401 Unauthorized - 仅认证端点返回 if requiresAuth { desc401 := "未认证或认证已过期" op.Responses.MapOfResponseOrRefValues["401"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc401, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } // 403 Forbidden - 仅认证端点返回 desc403 := "无权访问" op.Responses.MapOfResponseOrRefValues["403"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc403, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } } // 500 Internal Server Error - 所有端点都可能返回 desc500 := "服务器内部错误" op.Responses.MapOfResponseOrRefValues["500"] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Description: desc500, Content: map[string]openapi3.MediaType{ "application/json": { Schema: &openapi3.SchemaOrRef{ SchemaReference: &openapi3.SchemaReference{ Ref: "#/components/schemas/ErrorResponse", }, }, }, }, }, } } // Save 将规范导出为 YAML 文件 func (g *Generator) Save(filename string) error { // 确保目录存在 dir := filepath.Dir(filename) if err := os.MkdirAll(dir, 0755); err != nil { return err } // 安全的方法:MarshalJSON -> Unmarshal -> MarshalYAML // 这确保了我们遵守 openapi3 库中定义的 `json` 标签 jsonBytes, err := g.Reflector.Spec.MarshalJSON() if err != nil { return err } var obj interface{} if err := json.Unmarshal(jsonBytes, &obj); err != nil { return err } yamlBytes, err := yaml.Marshal(obj) if err != nil { return err } return os.WriteFile(filename, yamlBytes, 0644) }