Skip to content

Commit

Permalink
Decorrelates openapi registration from server
Browse files Browse the repository at this point in the history
  • Loading branch information
EwenQuim committed Dec 12, 2024
1 parent 19a76d1 commit d4f1294
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 75 deletions.
2 changes: 1 addition & 1 deletion examples/petstore/lib/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestPetstoreOpenAPIGeneration(t *testing.T) {
)

server.OutputOpenAPISpec()
err := server.OpenApiSpec.Validate(context.Background())
err := server.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
require.NoError(t, err)

generatedSpec, err := os.ReadFile("testdata/doc/openapi.json")
Expand Down
2 changes: 1 addition & 1 deletion examples/petstore/lib/testdata/doc/openapi.golden.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"openapi": "3.1.0",
"components": {
"schemas": {
"HTTPError": {
Expand Down Expand Up @@ -149,6 +148,7 @@
"title": "OpenAPI",
"version": "0.0.1"
},
"openapi": "3.1.0",
"paths": {
"/pets/": {
"get": {
Expand Down
2 changes: 1 addition & 1 deletion mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func Register[T, B any](s *Server, route Route[T, B], controller http.Handler, o
route.Path = s.basePath + route.Path

var err error
route.Operation, err = RegisterOpenAPIOperation(s, route)
route.Operation, err = RegisterOpenAPIOperation(s.OpenAPIzer, route)
if err != nil {
slog.Warn("error documenting openapi operation", "error", err)
}
Expand Down
16 changes: 8 additions & 8 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(s, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })

require.Equal(t, s.DisableOpenapi, true)
require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenApiSpec.Paths.Find("/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test") == nil)
})

t.Run("hide group", func(t *testing.T) {
Expand All @@ -500,8 +500,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(g, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })

require.Equal(t, g.DisableOpenapi, true)
require.True(t, s.OpenApiSpec.Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
})

t.Run("hide group but not other group", func(t *testing.T) {
Expand All @@ -514,8 +514,8 @@ func TestHideOpenapiRoutes(t *testing.T) {

require.Equal(t, true, g.DisableOpenapi)
require.Equal(t, false, g2.DisableOpenapi)
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
require.True(t, s.OpenApiSpec.Paths.Find("/group2/test") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group2/test") != nil)
})

t.Run("hide group but show sub group", func(t *testing.T) {
Expand All @@ -527,8 +527,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(g2, "/test", func(ctx *ContextNoBody) (string, error) { return "test", nil })

require.Equal(t, true, g.DisableOpenapi)
require.True(t, s.OpenApiSpec.Paths.Find("/group/test") == nil)
require.True(t, s.OpenApiSpec.Paths.Find("/group/sub/test") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/sub/test") != nil)
})
}

Expand Down
83 changes: 60 additions & 23 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,45 @@ import (
"strings"

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3gen"
)

type OpenAPIzer interface {
OpenAPIDescription() *openapi3.T
Generator() *openapi3gen.Generator
GlobalOpenAPIResponses() *[]openAPIError
}

func NewSpec() *Spec {
desc := NewOpenApiSpec()
return &Spec{
description: &desc,
generator: openapi3gen.NewGenerator(),
globalOpenAPIResponses: &[]openAPIError{},
}
}

// Holds the OpenAPI OpenAPIDescription (OAD) and OpenAPI capabilities.
type Spec struct {
description *openapi3.T
generator *openapi3gen.Generator
globalOpenAPIResponses *[]openAPIError
}

func (d *Spec) OpenAPIDescription() *openapi3.T {
return d.description
}

func (d *Spec) Generator() *openapi3gen.Generator {
return d.generator
}

func (d *Spec) GlobalOpenAPIResponses() *[]openAPIError {
return d.globalOpenAPIResponses
}

var _ OpenAPIzer = &Spec{}

func NewOpenApiSpec() openapi3.T {
info := &openapi3.Info{
Title: "OpenAPI",
Expand Down Expand Up @@ -53,11 +90,11 @@ func (s *Server) Show() *Server {
}

func declareAllTagsFromOperations(s *Server) {
for _, pathItem := range s.OpenApiSpec.Paths.Map() {
for _, pathItem := range s.OpenAPIzer.OpenAPIDescription().Paths.Map() {
for _, op := range pathItem.Operations() {
for _, tag := range op.Tags {
if s.OpenApiSpec.Tags.Get(tag) == nil {
s.OpenApiSpec.Tags = append(s.OpenApiSpec.Tags, &openapi3.Tag{
if s.OpenAPIzer.OpenAPIDescription().Tags.Get(tag) == nil {
s.OpenAPIzer.OpenAPIDescription().Tags = append(s.OpenAPIzer.OpenAPIDescription().Tags, &openapi3.Tag{
Name: tag,
})
}
Expand All @@ -73,7 +110,7 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
declareAllTagsFromOperations(s)

// Validate
err := s.OpenApiSpec.Validate(context.Background())
err := s.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
if err != nil {
slog.Error("Error validating spec", "error", err)
}
Expand All @@ -95,14 +132,14 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
}
}

return s.OpenApiSpec
return *s.OpenAPIzer.OpenAPIDescription()
}

func (s *Server) marshalSpec() ([]byte, error) {
if s.OpenAPIConfig.PrettyFormatJson {
return json.MarshalIndent(s.OpenApiSpec, "", " ")
return json.MarshalIndent(s.OpenAPIzer.OpenAPIDescription(), "", " ")
}
return json.Marshal(s.OpenApiSpec)
return json.Marshal(s.OpenAPIzer.OpenAPIDescription())
}

func (s *Server) saveOpenAPIToFile(jsonSpecLocalPath string, jsonSpec []byte) error {
Expand Down Expand Up @@ -164,7 +201,7 @@ func validateSwaggerUrl(swaggerUrl string) bool {
}

// RegisterOpenAPIOperation registers an OpenAPI operation.
func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3.Operation, error) {
func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*openapi3.Operation, error) {
if route.Operation == nil {
route.Operation = openapi3.NewOperation()
}
Expand All @@ -184,7 +221,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
}

// Response - globals
for _, openAPIGlobalResponse := range s.globalOpenAPIResponses {
for _, openAPIGlobalResponse := range *s.GlobalOpenAPIResponses() {
addResponseIfNotSet(s, route.Operation, openAPIGlobalResponse.Code, openAPIGlobalResponse.Description, openAPIGlobalResponse.ErrorType)
}

Expand Down Expand Up @@ -228,7 +265,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, route Route[T, B]) (*openapi3
}
}

s.OpenApiSpec.AddOperation(route.Path, route.Method, route.Operation)
s.OpenAPIDescription().AddOperation(route.Path, route.Method, route.Operation)

return route.Operation, nil
}
Expand All @@ -247,10 +284,10 @@ type SchemaTag struct {
Name string
}

func SchemaTagFromType(s *Server, v any) SchemaTag {
func SchemaTagFromType(s OpenAPIzer, v any) SchemaTag {
if v == nil {
// ensure we add unknown-interface to our schemas
schema := s.getOrCreateSchema("unknown-interface", struct{}{})
schema := getOrCreateSchema(s, "unknown-interface", struct{}{})
return SchemaTag{
Name: "unknown-interface",
SchemaRef: openapi3.SchemaRef{
Expand All @@ -270,7 +307,7 @@ func SchemaTagFromType(s *Server, v any) SchemaTag {
// If the type is a slice or array type it will dive into the type as well as
// build and openapi3.Schema where Type is array and Ref is set to the proper
// components Schema
func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
func dive(s OpenAPIzer, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
if maxDepth == 0 {
return SchemaTag{
Name: "default",
Expand All @@ -297,26 +334,26 @@ func dive(s *Server, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
return dive(s, t.Field(0).Type, tag, maxDepth-1)
}
tag.Ref = "#/components/schemas/" + tag.Name
tag.Value = s.getOrCreateSchema(tag.Name, reflect.New(t).Interface())
tag.Value = getOrCreateSchema(s, tag.Name, reflect.New(t).Interface())

return tag
}
}

// getOrCreateSchema is used to get a schema from the OpenAPI spec.
// If the schema does not exist, it will create a new schema and add it to the OpenAPI spec.
func (s *Server) getOrCreateSchema(key string, v any) *openapi3.Schema {
schemaRef, ok := s.OpenApiSpec.Components.Schemas[key]
func getOrCreateSchema(s OpenAPIzer, key string, v any) *openapi3.Schema {
schemaRef, ok := s.OpenAPIDescription().Components.Schemas[key]
if !ok {
schemaRef = s.createSchema(key, v)
schemaRef = createSchema(s, key, v)
}
return schemaRef.Value
}

// createSchema is used to create a new schema and add it to the OpenAPI spec.
// Relies on the openapi3gen package to generate the schema, and adds custom struct tags.
func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
schemaRef, err := s.openAPIGenerator.NewSchemaRefForValue(v, s.OpenApiSpec.Components.Schemas)
func createSchema(s OpenAPIzer, key string, v any) *openapi3.SchemaRef {
schemaRef, err := s.Generator().NewSchemaRefForValue(v, s.OpenAPIDescription().Components.Schemas)
if err != nil {
slog.Error("Error generating schema", "key", key, "error", err)
}
Expand All @@ -327,9 +364,9 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
schemaRef.Value.Description = descriptionable.Description()
}

s.parseStructTags(reflect.TypeOf(v), schemaRef)
parseStructTags(reflect.TypeOf(v), schemaRef)

s.OpenApiSpec.Components.Schemas[key] = schemaRef
s.OpenAPIDescription().Components.Schemas[key] = schemaRef

return schemaRef
}
Expand All @@ -346,7 +383,7 @@ func (s *Server) createSchema(key string, v any) *openapi3.SchemaRef {
// - min=1 => minLength=1 (for strings)
// - max=100 => max=100 (for integers)
// - max=100 => maxLength=100 (for strings)
func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) {
func parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef) {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
Expand All @@ -360,7 +397,7 @@ func (s *Server) parseStructTags(t reflect.Type, schemaRef *openapi3.SchemaRef)

if field.Anonymous {
fieldType := field.Type
s.parseStructTags(fieldType, schemaRef)
parseStructTags(fieldType, schemaRef)
continue
}

Expand Down
2 changes: 1 addition & 1 deletion openapi_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r Route[ResponseBody, RequestBody]) Param(paramType ParamType, name, descr
}

// Registers a response for the route, only if error for this code is not already set.
func addResponseIfNotSet(s *Server, operation *openapi3.Operation, code int, description string, errorType ...any) {
func addResponseIfNotSet(s OpenAPIzer, operation *openapi3.Operation, code int, description string, errorType ...any) {
var responseSchema SchemaTag

if len(errorType) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ func TestDeclareCustom200Response(t *testing.T) {
w.Write([]byte("PNG image"))
}, optionReturnsPNG)

openAPIResponse := s.OpenApiSpec.Paths.Find("/image").Get.Responses.Value("200")
openAPIResponse := s.OpenAPIzer.OpenAPIDescription().Paths.Find("/image").Get.Responses.Value("200")
require.Nil(t, openAPIResponse.Value.Content.Get("application/json"))
require.NotNil(t, openAPIResponse.Value.Content.Get("image/png"))
require.Equal(t, "Generated image", *openAPIResponse.Value.Description)
Expand Down
8 changes: 4 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ func OptionAddError(code int, description string, errorType ...any) func(*BaseRo
}

if len(errorType) > 0 {
responseSchema = SchemaTagFromType(r.mainRouter, errorType[0])
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, errorType[0])
} else {
responseSchema = SchemaTagFromType(r.mainRouter, HTTPError{})
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, HTTPError{})
}
content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"})

Expand Down Expand Up @@ -374,14 +374,14 @@ func OptionDefaultStatusCode(defaultStatusCode int) func(*BaseRoute) {
// })
func OptionSecurity(securityRequirements ...openapi3.SecurityRequirement) func(*BaseRoute) {
return func(r *BaseRoute) {
if r.mainRouter.OpenApiSpec.Components == nil {
if r.mainRouter.OpenAPIzer.OpenAPIDescription().Components == nil {
panic("zero security schemes have been registered with the server")
}

// Validate the security scheme exists in components
for _, req := range securityRequirements {
for schemeName := range req {
if _, exists := r.mainRouter.OpenApiSpec.Components.SecuritySchemes[schemeName]; !exists {
if _, exists := r.mainRouter.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes[schemeName]; !exists {
panic(fmt.Sprintf("security scheme '%s' not defined in components", schemeName))
}
}
Expand Down
16 changes: 8 additions & 8 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ func TestPath(t *testing.T) {

fuego.Get(s, "/test/{id}", helloWorld)

require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
})

t.Run("Declare explicitly an existing path parameter for the route", func(t *testing.T) {
Expand All @@ -310,9 +310,9 @@ func TestPath(t *testing.T) {
fuego.OptionPath("id", "some id", param.Example("123", "123"), param.Nullable()),
)

require.Equal(t, "id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "some id", s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, true, s.OpenApiSpec.Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "some id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, true, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
})

t.Run("Declare explicitly a non-existing path parameter for the route panics", func(t *testing.T) {
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestRequestContentType(t *testing.T) {
require.NotNil(t, content.Get("application/json"))
require.Nil(t, content.Get("application/xml"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})

Expand All @@ -369,7 +369,7 @@ func TestRequestContentType(t *testing.T) {
require.Nil(t, content.Get("application/xml"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})

Expand All @@ -385,7 +385,7 @@ func TestRequestContentType(t *testing.T) {
require.Nil(t, content.Get("application/xml"))
require.NotNil(t, content.Get("my/content-type"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
_, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})
}
Expand Down
Loading

0 comments on commit d4f1294

Please sign in to comment.