From d4f1294d61441e49e665af3f07a7ba4a944a0fc7 Mon Sep 17 00:00:00 2001 From: EwenQuim Date: Wed, 11 Dec 2024 01:29:59 +0100 Subject: [PATCH] Decorrelates openapi registration from server --- examples/petstore/lib/server_test.go | 2 +- .../lib/testdata/doc/openapi.golden.json | 2 +- mux.go | 2 +- mux_test.go | 16 ++-- openapi.go | 83 ++++++++++++++----- openapi_operations.go | 2 +- openapi_test.go | 2 +- option.go | 8 +- option_test.go | 16 ++-- server.go | 17 ++-- server_test.go | 40 ++++----- 11 files changed, 115 insertions(+), 75 deletions(-) diff --git a/examples/petstore/lib/server_test.go b/examples/petstore/lib/server_test.go index 0d990920..9260b76e 100644 --- a/examples/petstore/lib/server_test.go +++ b/examples/petstore/lib/server_test.go @@ -21,7 +21,7 @@ func TestPetstoreOpenAPIGeneration(t *testing.T) { ) server.OutputOpenAPISpec() - err := server.OpenApiSpec.Validate(context.Background()) + err := server.OpenAPIzer.OpenAPIDescription().Validate(context.Background()) require.NoError(t, err) generatedSpec, err := os.ReadFile("testdata/doc/openapi.json") diff --git a/examples/petstore/lib/testdata/doc/openapi.golden.json b/examples/petstore/lib/testdata/doc/openapi.golden.json index bed2ff64..16eaede4 100644 --- a/examples/petstore/lib/testdata/doc/openapi.golden.json +++ b/examples/petstore/lib/testdata/doc/openapi.golden.json @@ -1,5 +1,4 @@ { - "openapi": "3.1.0", "components": { "schemas": { "HTTPError": { @@ -149,6 +148,7 @@ "title": "OpenAPI", "version": "0.0.1" }, + "openapi": "3.1.0", "paths": { "/pets/": { "get": { diff --git a/mux.go b/mux.go index 2fcccb0e..4e208370 100644 --- a/mux.go +++ b/mux.go @@ -112,7 +112,7 @@ func Register[T, B any](s *Server, route Route[T, B], controller http.Handler, o route.Path = s.basePath + route.Path var err error - route.Operation, err = RegisterOpenAPIOperation(s, route) + route.Operation, err = RegisterOpenAPIOperation(s.OpenAPIzer, route) if err != nil { slog.Warn("error documenting openapi operation", "error", err) } diff --git a/mux_test.go b/mux_test.go index 7d5f5b48..f63c8b96 100644 --- a/mux_test.go +++ b/mux_test.go @@ -488,8 +488,8 @@ func TestHideOpenapiRoutes(t *testing.T) { Get(s, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil }) require.Equal(t, s.DisableOpenapi, true) - require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil) - require.True(t, s.OpenApiSpec.Paths.Find("/test") == nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test") == nil) }) t.Run("hide group", func(t *testing.T) { @@ -500,8 +500,8 @@ func TestHideOpenapiRoutes(t *testing.T) { Get(g, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil }) require.Equal(t, g.DisableOpenapi, true) - require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil) - require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil) }) t.Run("hide group but not other group", func(t *testing.T) { @@ -514,8 +514,8 @@ func TestHideOpenapiRoutes(t *testing.T) { require.Equal(t, true, g.DisableOpenapi) require.Equal(t, false, g2.DisableOpenapi) - require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil) - require.True(t, s.OpenApiSpec.Paths.Find("/group2/test") != nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group2/test") != nil) }) t.Run("hide group but show sub group", func(t *testing.T) { @@ -527,8 +527,8 @@ func TestHideOpenapiRoutes(t *testing.T) { Get(g2, "/test", func(ctx *ContextNoBody) (string, error) { return "test", nil }) require.Equal(t, true, g.DisableOpenapi) - require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil) - require.True(t, s.OpenApiSpec.Paths.Find("/group/sub/test") != nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil) + require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/sub/test") != nil) }) } diff --git a/openapi.go b/openapi.go index dc40087f..a1c1f44d 100644 --- a/openapi.go +++ b/openapi.go @@ -16,8 +16,45 @@ import ( "strings" "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3gen" ) +type OpenAPIzer interface { + OpenAPIDescription() *openapi3.T + Generator() *openapi3gen.Generator + GlobalOpenAPIResponses() *[]openAPIError +} + +func NewSpec() *Spec { + desc := NewOpenApiSpec() + return &Spec{ + description: &desc, + generator: openapi3gen.NewGenerator(), + globalOpenAPIResponses: &[]openAPIError{}, + } +} + +// Holds the OpenAPI OpenAPIDescription (OAD) and OpenAPI capabilities. +type Spec struct { + description *openapi3.T + generator *openapi3gen.Generator + globalOpenAPIResponses *[]openAPIError +} + +func (d *Spec) OpenAPIDescription() *openapi3.T { + return d.description +} + +func (d *Spec) Generator() *openapi3gen.Generator { + return d.generator +} + +func (d *Spec) GlobalOpenAPIResponses() *[]openAPIError { + return d.globalOpenAPIResponses +} + +var _ OpenAPIzer = &Spec{} + func NewOpenApiSpec() openapi3.T { info := &openapi3.Info{ Title: "OpenAPI", @@ -53,11 +90,11 @@ func (s *Server) Show() *Server { } func declareAllTagsFromOperations(s *Server) { - for _, pathItem := range s.OpenApiSpec.Paths.Map() { + for _, pathItem := range s.OpenAPIzer.OpenAPIDescription().Paths.Map() { for _, op := range pathItem.Operations() { for _, tag := range op.Tags { - if s.OpenApiSpec.Tags.Get(tag) == nil { - s.OpenApiSpec.Tags = append(s.OpenApiSpec.Tags, &openapi3.Tag{ + if s.OpenAPIzer.OpenAPIDescription().Tags.Get(tag) == nil { + s.OpenAPIzer.OpenAPIDescription().Tags = append(s.OpenAPIzer.OpenAPIDescription().Tags, &openapi3.Tag{ Name: tag, }) } @@ -73,7 +110,7 @@ func (s *Server) OutputOpenAPISpec() openapi3.T { declareAllTagsFromOperations(s) // Validate - err := s.OpenApiSpec.Validate(context.Background()) + err := s.OpenAPIzer.OpenAPIDescription().Validate(context.Background()) if err != nil { slog.Error("Error validating spec", "error", err) } @@ -95,14 +132,14 @@ func (s *Server) OutputOpenAPISpec() openapi3.T { } } - return s.OpenApiSpec + return *s.OpenAPIzer.OpenAPIDescription() } func (s *Server) marshalSpec() ([]byte, error) { if s.OpenAPIConfig.PrettyFormatJson { - return json.MarshalIndent(s.OpenApiSpec, "", " ") + return json.MarshalIndent(s.OpenAPIzer.OpenAPIDescription(), "", " ") } - return json.Marshal(s.OpenApiSpec) + return json.Marshal(s.OpenAPIzer.OpenAPIDescription()) } func (s *Server) saveOpenAPIToFile(jsonSpecLocalPath string, jsonSpec []byte) error { @@ -164,7 +201,7 @@ func validateSwaggerUrl(swaggerUrl string) bool { } // RegisterOpenAPIOperation registers an OpenAPI operation. -func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3.Operation, error) { +func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*openapi3.Operation, error) { if route.Operation == nil { route.Operation = openapi3.NewOperation() } @@ -184,7 +221,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3 } // Response - globals - for _, openAPIGlobalResponse := range s.globalOpenAPIResponses { + for _, openAPIGlobalResponse := range *s.GlobalOpenAPIResponses() { addResponseIfNotSet(s, route.Operation, openAPIGlobalResponse.Code, openAPIGlobalResponse.Description, openAPIGlobalResponse.ErrorType) } @@ -228,7 +265,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3 } } - s.OpenApiSpec.AddOperation(route.Path, route.Method, route.Operation) + s.OpenAPIDescription().AddOperation(route.Path, route.Method, route.Operation) return route.Operation, nil } @@ -247,10 +284,10 @@ type SchemaTag struct { Name string } -func SchemaTagFromType(s *Server, v any) SchemaTag { +func SchemaTagFromType(s OpenAPIzer, v any) SchemaTag { if v == nil { // ensure we add unknown-interface to our schemas - schema := s.getOrCreateSchema("unknown-interface", struct{}{}) + schema := getOrCreateSchema(s, "unknown-interface", struct{}{}) return SchemaTag{ Name: "unknown-interface", SchemaRef: openapi3.SchemaRef{ @@ -270,7 +307,7 @@ func SchemaTagFromType(s *Server, v any) SchemaTag { // If the type is a slice or array type it will dive into the type as well as // build and openapi3.Schema where Type is array and Ref is set to the proper // components Schema -func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag { +func dive(s OpenAPIzer, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag { if maxDepth == 0 { return SchemaTag{ Name: "default", @@ -297,7 +334,7 @@ func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag { return dive(s, t.Field(0).Type, tag, maxDepth-1) } tag.Ref = "#/components/schemas/" + tag.Name - tag.Value = s.getOrCreateSchema(tag.Name, reflect.New(t).Interface()) + tag.Value = getOrCreateSchema(s, tag.Name, reflect.New(t).Interface()) return tag } @@ -305,18 +342,18 @@ func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag { // getOrCreateSchema is used to get a schema from the OpenAPI spec. // If the schema does not exist, it will create a new schema and add it to the OpenAPI spec. -func (s *Server) getOrCreateSchema(key string, v any) *openapi3.Schema { - schemaRef, ok := s.OpenApiSpec.Components.Schemas[key] +func getOrCreateSchema(s OpenAPIzer, key string, v any) *openapi3.Schema { + schemaRef, ok := s.OpenAPIDescription().Components.Schemas[key] if !ok { - schemaRef = s.createSchema(key, v) + schemaRef = createSchema(s, key, v) } return schemaRef.Value } // createSchema is used to create a new schema and add it to the OpenAPI spec. // Relies on the openapi3gen package to generate the schema, and adds custom struct tags. -func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef { - schemaRef, err := s.openAPIGenerator.NewSchemaRefForValue(v, s.OpenApiSpec.Components.Schemas) +func createSchema(s OpenAPIzer, key string, v any) *openapi3.SchemaRef { + schemaRef, err := s.Generator().NewSchemaRefForValue(v, s.OpenAPIDescription().Components.Schemas) if err != nil { slog.Error("Error generating schema", "key", key, "error", err) } @@ -327,9 +364,9 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef { schemaRef.Value.Description = descriptionable.Description() } - s.parseStructTags(reflect.TypeOf(v), schemaRef) + parseStructTags(reflect.TypeOf(v), schemaRef) - s.OpenApiSpec.Components.Schemas[key] = schemaRef + s.OpenAPIDescription().Components.Schemas[key] = schemaRef return schemaRef } @@ -346,7 +383,7 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef { // - min=1 => minLength=1 (for strings) // - max=100 => max=100 (for integers) // - max=100 => maxLength=100 (for strings) -func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) { +func parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) { if t.Kind() == reflect.Ptr { t = t.Elem() } @@ -360,7 +397,7 @@ func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) if field.Anonymous { fieldType := field.Type - s.parseStructTags(fieldType, schemaRef) + parseStructTags(fieldType, schemaRef) continue } diff --git a/openapi_operations.go b/openapi_operations.go index 10e20073..1bd5c4d5 100644 --- a/openapi_operations.go +++ b/openapi_operations.go @@ -62,7 +62,7 @@ func (r Route[ResponseBody, RequestBody]) Param(paramType ParamType, name, descr } // Registers a response for the route, only if error for this code is not already set. -func addResponseIfNotSet(s *Server, operation *openapi3.Operation, code int, description string, errorType ...any) { +func addResponseIfNotSet(s OpenAPIzer, operation *openapi3.Operation, code int, description string, errorType ...any) { var responseSchema SchemaTag if len(errorType) > 0 { diff --git a/openapi_test.go b/openapi_test.go index 89667f92..616c5a4c 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -545,7 +545,7 @@ func TestDeclareCustom200Response(t *testing.T) { w.Write([]byte("PNG image")) }, optionReturnsPNG) - openAPIResponse := s.OpenApiSpec.Paths.Find("/image").Get.Responses.Value("200") + openAPIResponse := s.OpenAPIzer.OpenAPIDescription().Paths.Find("/image").Get.Responses.Value("200") require.Nil(t, openAPIResponse.Value.Content.Get("application/json")) require.NotNil(t, openAPIResponse.Value.Content.Get("image/png")) require.Equal(t, "Generated image", *openAPIResponse.Value.Description) diff --git a/option.go b/option.go index f9eed7f9..0f263209 100644 --- a/option.go +++ b/option.go @@ -302,9 +302,9 @@ func OptionAddError(code int, description string, errorType ...any) func(*BaseRo } if len(errorType) > 0 { - responseSchema = SchemaTagFromType(r.mainRouter, errorType[0]) + responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, errorType[0]) } else { - responseSchema = SchemaTagFromType(r.mainRouter, HTTPError{}) + responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, HTTPError{}) } content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"}) @@ -374,14 +374,14 @@ func OptionDefaultStatusCode(defaultStatusCode int) func(*BaseRoute) { // }) func OptionSecurity(securityRequirements ...openapi3.SecurityRequirement) func(*BaseRoute) { return func(r *BaseRoute) { - if r.mainRouter.OpenApiSpec.Components == nil { + if r.mainRouter.OpenAPIzer.OpenAPIDescription().Components == nil { panic("zero security schemes have been registered with the server") } // Validate the security scheme exists in components for _, req := range securityRequirements { for schemeName := range req { - if _, exists := r.mainRouter.OpenApiSpec.Components.SecuritySchemes[schemeName]; !exists { + if _, exists := r.mainRouter.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes[schemeName]; !exists { panic(fmt.Sprintf("security scheme '%s' not defined in components", schemeName)) } } diff --git a/option_test.go b/option_test.go index 6586e98a..1df04900 100644 --- a/option_test.go +++ b/option_test.go @@ -299,8 +299,8 @@ func TestPath(t *testing.T) { fuego.Get(s, "/test/{id}", helloWorld) - require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name) - require.Equal(t, "", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description) + require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name) + require.Equal(t, "", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description) }) t.Run("Declare explicitly an existing path parameter for the route", func(t *testing.T) { @@ -310,9 +310,9 @@ func TestPath(t *testing.T) { fuego.OptionPath("id", "some id", param.Example("123", "123"), param.Nullable()), ) - require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name) - require.Equal(t, "some id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description) - require.Equal(t, true, s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required") + require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name) + require.Equal(t, "some id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description) + require.Equal(t, true, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required") }) t.Run("Declare explicitly a non-existing path parameter for the route panics", func(t *testing.T) { @@ -353,7 +353,7 @@ func TestRequestContentType(t *testing.T) { require.NotNil(t, content.Get("application/json")) require.Nil(t, content.Get("application/xml")) require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref) - _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + _, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"] require.False(t, ok) }) @@ -369,7 +369,7 @@ func TestRequestContentType(t *testing.T) { require.Nil(t, content.Get("application/xml")) require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref) require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref) - _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + _, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"] require.False(t, ok) }) @@ -385,7 +385,7 @@ func TestRequestContentType(t *testing.T) { require.Nil(t, content.Get("application/xml")) require.NotNil(t, content.Get("my/content-type")) require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref) - _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + _, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"] require.False(t, ok) }) } diff --git a/server.go b/server.go index a0215c23..535a0bae 100644 --- a/server.go +++ b/server.go @@ -61,7 +61,8 @@ type Server struct { globalOpenAPIResponses []openAPIError // Global error responses - OpenApiSpec openapi3.T // OpenAPI spec generated by the server + // OpenAPIzer handles the OpenAPI spec generation. + OpenAPIzer Security Security @@ -105,10 +106,10 @@ func NewServer(options ...func(*Server)) *Server { WriteTimeout: 30 * time.Second, IdleTimeout: 30 * time.Second, }, - Mux: http.NewServeMux(), - OpenApiSpec: NewOpenApiSpec(), + Mux: http.NewServeMux(), OpenAPIConfig: defaultOpenAPIConfig, + OpenAPIzer: NewSpec(), openAPIGenerator: openapi3gen.NewGenerator( openapi3gen.UseAllExportedFields(), @@ -137,7 +138,7 @@ func NewServer(options ...func(*Server)) *Server { option(s) } - s.OpenApiSpec.Servers = append(s.OpenApiSpec.Servers, &openapi3.Server{ + s.OpenAPIzer.OpenAPIDescription().Servers = append(s.OpenAPIzer.OpenAPIDescription().Servers, &openapi3.Server{ URL: "http://" + s.Addr, Description: "local server", }) @@ -214,6 +215,8 @@ func WithGlobalResponseTypes(code int, description string, errorType ...any) fun errorType = append(errorType, HTTPError{}) return func(c *Server) { c.globalOpenAPIResponses = append(c.globalOpenAPIResponses, openAPIError{code, description, errorType[0]}) + truc := c.OpenAPIzer.GlobalOpenAPIResponses() + *truc = append(*truc, openAPIError{code, description, errorType[0]}) } } @@ -234,11 +237,11 @@ func WithGlobalResponseTypes(code int, description string, errorType ...any) fun // ) func WithSecurity(schemes openapi3.SecuritySchemes) func(*Server) { return func(s *Server) { - if s.OpenApiSpec.Components.SecuritySchemes == nil { - s.OpenApiSpec.Components.SecuritySchemes = openapi3.SecuritySchemes{} + if s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes == nil { + s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes = openapi3.SecuritySchemes{} } for name, scheme := range schemes { - s.OpenApiSpec.Components.SecuritySchemes[name] = scheme + s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes[name] = scheme } } } diff --git a/server_test.go b/server_test.go index 51d75764..75801820 100644 --- a/server_test.go +++ b/server_test.go @@ -334,7 +334,7 @@ func TestWithRequestContentType(t *testing.T) { require.NotNil(t, content.Get("application/xml")) require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref) require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/xml").Schema.Ref) - _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + _, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"] require.False(t, ok) }) } @@ -503,10 +503,10 @@ func TestWithSecurity(t *testing.T) { }), ) - require.NotNil(t, s.OpenApiSpec.Components.SecuritySchemes) - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "bearerAuth") + require.NotNil(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes) + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "bearerAuth") - scheme := s.OpenApiSpec.Components.SecuritySchemes["bearerAuth"].Value + scheme := s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes["bearerAuth"].Value require.Equal(t, "http", scheme.Type) require.Equal(t, "bearer", scheme.Scheme) require.Equal(t, "JWT", scheme.BearerFormat) @@ -530,12 +530,12 @@ func TestWithSecurity(t *testing.T) { }), ) - require.NotNil(t, s.OpenApiSpec.Components.SecuritySchemes) - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "bearerAuth") - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "apiKey") + require.NotNil(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes) + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "bearerAuth") + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "apiKey") - bearerScheme := s.OpenApiSpec.Components.SecuritySchemes["bearerAuth"].Value - apiKeyScheme := s.OpenApiSpec.Components.SecuritySchemes["apiKey"].Value + bearerScheme := s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes["bearerAuth"].Value + apiKeyScheme := s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes["apiKey"].Value require.Equal(t, "http", bearerScheme.Type) require.Equal(t, "bearer", bearerScheme.Scheme) @@ -559,16 +559,16 @@ func TestWithSecurity(t *testing.T) { ) // Add another security scheme to the existing server - s.OpenApiSpec.Components.SecuritySchemes["oauth2"] = &openapi3.SecuritySchemeRef{ + s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes["oauth2"] = &openapi3.SecuritySchemeRef{ Value: openapi3.NewOIDCSecurityScheme("https://example.com/.well-known/openid-configuration"). WithType("oauth2"), } - require.NotNil(t, s.OpenApiSpec.Components.SecuritySchemes) - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "bearerAuth") - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "oauth2") + require.NotNil(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes) + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "bearerAuth") + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "oauth2") - oauth2Scheme := s.OpenApiSpec.Components.SecuritySchemes["oauth2"].Value + oauth2Scheme := s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes["oauth2"].Value require.Equal(t, "oauth2", oauth2Scheme.Type) require.Equal(t, "https://example.com/.well-known/openid-configuration", oauth2Scheme.OpenIdConnectUrl) }) @@ -593,14 +593,14 @@ func TestWithSecurity(t *testing.T) { }), ) - require.NotNil(t, s.OpenApiSpec.Components.SecuritySchemes) - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "bearerAuth") - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "apiKey") + require.NotNil(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes) + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "bearerAuth") + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "apiKey") }) t.Run("initialize security schemes if nil", func(t *testing.T) { s := NewServer() - s.OpenApiSpec.Components.SecuritySchemes = nil + s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes = nil s = NewServer( WithSecurity(openapi3.SecuritySchemes{ @@ -613,7 +613,7 @@ func TestWithSecurity(t *testing.T) { }), ) - require.NotNil(t, s.OpenApiSpec.Components.SecuritySchemes) - require.Contains(t, s.OpenApiSpec.Components.SecuritySchemes, "bearerAuth") + require.NotNil(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes) + require.Contains(t, s.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes, "bearerAuth") }) }