Skip to content

Commit

Permalink
Removed interface and use a direct struct
Browse files Browse the repository at this point in the history
  • Loading branch information
EwenQuim committed Dec 12, 2024
1 parent 0c474a1 commit c833e87
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 95 deletions.
2 changes: 1 addition & 1 deletion examples/full-app-gourmet/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (rs Resources) Setup(
// Create server with some options
app := fuego.NewServer(options...)

app.OpenApiSpec.Info.Title = "Gourmet API"
app.OpenAPI.Description().Info.Title = "Gourmet API"

rs.API.Security = app.Security

Expand Down
2 changes: 1 addition & 1 deletion examples/petstore/lib/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestPetstoreOpenAPIGeneration(t *testing.T) {
)

server.OutputOpenAPISpec()
err := server.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
err := server.OpenAPI.Description().Validate(context.Background())
require.NoError(t, err)

generatedSpec, err := os.ReadFile("testdata/doc/openapi.json")
Expand Down
2 changes: 1 addition & 1 deletion mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func Register[T, B any](s *Server, route Route[T, B], controller http.Handler, o
route.Path = s.basePath + route.Path

var err error
route.Operation, err = RegisterOpenAPIOperation(s.OpenAPIzer, route)
route.Operation, err = RegisterOpenAPIOperation(s.OpenAPI, route)
if err != nil {
slog.Warn("error documenting openapi operation", "error", err)
}
Expand Down
16 changes: 8 additions & 8 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(s, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })

require.Equal(t, s.DisableOpenapi, true)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test") == nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/test") == nil)
})

t.Run("hide group", func(t *testing.T) {
Expand All @@ -500,8 +500,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(g, "/test", func(ctx *ContextNoBody) (string, error) { return "", nil })

require.Equal(t, g.DisableOpenapi, true)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/not-hidden") != nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/group/test") == nil)
})

t.Run("hide group but not other group", func(t *testing.T) {
Expand All @@ -514,8 +514,8 @@ func TestHideOpenapiRoutes(t *testing.T) {

require.Equal(t, true, g.DisableOpenapi)
require.Equal(t, false, g2.DisableOpenapi)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group2/test") != nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/group2/test") != nil)
})

t.Run("hide group but show sub group", func(t *testing.T) {
Expand All @@ -527,8 +527,8 @@ func TestHideOpenapiRoutes(t *testing.T) {
Get(g2, "/test", func(ctx *ContextNoBody) (string, error) { return "test", nil })

require.Equal(t, true, g.DisableOpenapi)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/group/sub/test") != nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/group/test") == nil)
require.True(t, s.OpenAPI.Description().Paths.Find("/group/sub/test") != nil)
})
}

Expand Down
60 changes: 24 additions & 36 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,30 @@ import (
"github.com/getkin/kin-openapi/openapi3gen"
)

type OpenAPIzer interface {
OpenAPIDescription() *openapi3.T
Generator() *openapi3gen.Generator
GlobalOpenAPIResponses() *[]openAPIError
}

func NewSpec() *OpenAPI {
func NewOpenAPI() *OpenAPI {
desc := NewOpenApiSpec()
return &OpenAPI{
description: &desc,
generator: openapi3gen.NewGenerator(),
globalOpenAPIResponses: &[]openAPIError{},
globalOpenAPIResponses: []openAPIError{},
}
}

// Holds the OpenAPI OpenAPIDescription (OAD) and OpenAPI capabilities.
type OpenAPI struct {
description *openapi3.T
generator *openapi3gen.Generator
globalOpenAPIResponses *[]openAPIError
globalOpenAPIResponses []openAPIError
}

func (d *OpenAPI) OpenAPIDescription() *openapi3.T {
func (d *OpenAPI) Description() *openapi3.T {
return d.description
}

func (d *OpenAPI) Generator() *openapi3gen.Generator {
return d.generator
}

func (d *OpenAPI) GlobalOpenAPIResponses() *[]openAPIError {
return d.globalOpenAPIResponses
}

var _ OpenAPIzer = &OpenAPI{}

func NewOpenApiSpec() openapi3.T {
info := &openapi3.Info{
Title: "OpenAPI",
Expand Down Expand Up @@ -90,11 +78,11 @@ func (s *Server) Show() *Server {
}

func declareAllTagsFromOperations(s *Server) {
for _, pathItem := range s.OpenAPIzer.OpenAPIDescription().Paths.Map() {
for _, pathItem := range s.OpenAPI.Description().Paths.Map() {
for _, op := range pathItem.Operations() {
for _, tag := range op.Tags {
if s.OpenAPIzer.OpenAPIDescription().Tags.Get(tag) == nil {
s.OpenAPIzer.OpenAPIDescription().Tags = append(s.OpenAPIzer.OpenAPIDescription().Tags, &openapi3.Tag{
if s.OpenAPI.Description().Tags.Get(tag) == nil {
s.OpenAPI.Description().Tags = append(s.OpenAPI.Description().Tags, &openapi3.Tag{
Name: tag,
})
}
Expand All @@ -110,7 +98,7 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
declareAllTagsFromOperations(s)

// Validate
err := s.OpenAPIzer.OpenAPIDescription().Validate(context.Background())
err := s.OpenAPI.Description().Validate(context.Background())
if err != nil {
slog.Error("Error validating spec", "error", err)
}
Expand All @@ -132,14 +120,14 @@ func (s *Server) OutputOpenAPISpec() openapi3.T {
}
}

return *s.OpenAPIzer.OpenAPIDescription()
return *s.OpenAPI.Description()
}

func (s *Server) marshalSpec() ([]byte, error) {
if s.OpenAPIConfig.PrettyFormatJson {
return json.MarshalIndent(s.OpenAPIzer.OpenAPIDescription(), "", " ")
return json.MarshalIndent(s.OpenAPI.Description(), "", " ")
}
return json.Marshal(s.OpenAPIzer.OpenAPIDescription())
return json.Marshal(s.OpenAPI.Description())
}

func (s *Server) saveOpenAPIToFile(jsonSpecLocalPath string, jsonSpec []byte) error {
Expand Down Expand Up @@ -201,7 +189,7 @@ func validateSwaggerUrl(swaggerUrl string) bool {
}

// RegisterOpenAPIOperation registers an OpenAPI operation.
func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*openapi3.Operation, error) {
func RegisterOpenAPIOperation[T, B any](s *OpenAPI, route Route[T, B]) (*openapi3.Operation, error) {
if route.Operation == nil {
route.Operation = openapi3.NewOperation()
}
Expand All @@ -221,7 +209,7 @@ func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*opena
}

// Response - globals
for _, openAPIGlobalResponse := range *s.GlobalOpenAPIResponses() {
for _, openAPIGlobalResponse := range s.globalOpenAPIResponses {
addResponseIfNotSet(s, route.Operation, openAPIGlobalResponse.Code, openAPIGlobalResponse.Description, openAPIGlobalResponse.ErrorType)
}

Expand Down Expand Up @@ -265,7 +253,7 @@ func RegisterOpenAPIOperation[T, B any](s OpenAPIzer, route Route[T, B]) (*opena
}
}

s.OpenAPIDescription().AddOperation(route.Path, route.Method, route.Operation)
s.Description().AddOperation(route.Path, route.Method, route.Operation)

return route.Operation, nil
}
Expand All @@ -284,10 +272,10 @@ type SchemaTag struct {
Name string
}

func SchemaTagFromType(s OpenAPIzer, v any) SchemaTag {
func SchemaTagFromType(s *OpenAPI, v any) SchemaTag {
if v == nil {
// ensure we add unknown-interface to our schemas
schema := getOrCreateSchema(s, "unknown-interface", struct{}{})
schema := s.getOrCreateSchema("unknown-interface", struct{}{})
return SchemaTag{
Name: "unknown-interface",
SchemaRef: openapi3.SchemaRef{
Expand All @@ -307,7 +295,7 @@ func SchemaTagFromType(s OpenAPIzer, v any) SchemaTag {
// If the type is a slice or array type it will dive into the type as well as
// build and openapi3.Schema where Type is array and Ref is set to the proper
// components Schema
func dive(s OpenAPIzer, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
func dive(s *OpenAPI, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
if maxDepth == 0 {
return SchemaTag{
Name: "default",
Expand All @@ -334,26 +322,26 @@ func dive(s OpenAPIzer, t reflect.Type, tag SchemaTag, maxDepth int) SchemaTag {
return dive(s, t.Field(0).Type, tag, maxDepth-1)
}
tag.Ref = "#/components/schemas/" + tag.Name
tag.Value = getOrCreateSchema(s, tag.Name, reflect.New(t).Interface())
tag.Value = s.getOrCreateSchema(tag.Name, reflect.New(t).Interface())

return tag
}
}

// getOrCreateSchema is used to get a schema from the OpenAPI spec.
// If the schema does not exist, it will create a new schema and add it to the OpenAPI spec.
func getOrCreateSchema(s OpenAPIzer, key string, v any) *openapi3.Schema {
schemaRef, ok := s.OpenAPIDescription().Components.Schemas[key]
func (s *OpenAPI) getOrCreateSchema(key string, v any) *openapi3.Schema {
schemaRef, ok := s.Description().Components.Schemas[key]
if !ok {
schemaRef = createSchema(s, key, v)
schemaRef = s.createSchema(key, v)
}
return schemaRef.Value
}

// createSchema is used to create a new schema and add it to the OpenAPI spec.
// Relies on the openapi3gen package to generate the schema, and adds custom struct tags.
func createSchema(s OpenAPIzer, key string, v any) *openapi3.SchemaRef {
schemaRef, err := s.Generator().NewSchemaRefForValue(v, s.OpenAPIDescription().Components.Schemas)
func (s *OpenAPI) createSchema(key string, v any) *openapi3.SchemaRef {
schemaRef, err := s.Generator().NewSchemaRefForValue(v, s.Description().Components.Schemas)
if err != nil {
slog.Error("Error generating schema", "key", key, "error", err)
}
Expand All @@ -366,7 +354,7 @@ func createSchema(s OpenAPIzer, key string, v any) *openapi3.SchemaRef {

parseStructTags(reflect.TypeOf(v), schemaRef)

s.OpenAPIDescription().Components.Schemas[key] = schemaRef
s.Description().Components.Schemas[key] = schemaRef

return schemaRef
}
Expand Down
2 changes: 1 addition & 1 deletion openapi_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r Route[ResponseBody, RequestBody]) Param(paramType ParamType, name, descr
}

// Registers a response for the route, only if error for this code is not already set.
func addResponseIfNotSet(s OpenAPIzer, operation *openapi3.Operation, code int, description string, errorType ...any) {
func addResponseIfNotSet(s *OpenAPI, operation *openapi3.Operation, code int, description string, errorType ...any) {
var responseSchema SchemaTag

if len(errorType) > 0 {
Expand Down
4 changes: 2 additions & 2 deletions openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func Test_tagFromType(t *testing.T) {

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
tag := SchemaTagFromType(s, tc.inputType)
tag := SchemaTagFromType(s.OpenAPI, tc.inputType)
require.Equal(t, tc.expectedTagValue, tag.Name, tc.description)
if tc.expectedTagValueType != nil {
require.NotNil(t, tag.Value)
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestDeclareCustom200Response(t *testing.T) {
w.Write([]byte("PNG image"))
}, optionReturnsPNG)

openAPIResponse := s.OpenAPIzer.OpenAPIDescription().Paths.Find("/image").Get.Responses.Value("200")
openAPIResponse := s.OpenAPI.Description().Paths.Find("/image").Get.Responses.Value("200")
require.Nil(t, openAPIResponse.Value.Content.Get("application/json"))
require.NotNil(t, openAPIResponse.Value.Content.Get("image/png"))
require.Equal(t, "Generated image", *openAPIResponse.Value.Description)
Expand Down
10 changes: 5 additions & 5 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func buildParam(name string, options ...func(*OpenAPIParam)) (OpenAPIParam, *ope
option(&param)
}

// Applies OpenAPIParam to openapi3.Parameter
// Applies *OpenAPIParam to openapi3.Parameter
// Why not use openapi3.NewHeaderParameter(name) directly?
// Because we might change the openapi3 library in the future,
// and we want to keep the flexibility to change the implementation without changing the API.
Expand Down Expand Up @@ -302,9 +302,9 @@ func OptionAddError(code int, description string, errorType ...any) func(*BaseRo
}

if len(errorType) > 0 {
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, errorType[0])
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPI, errorType[0])
} else {
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPIzer, HTTPError{})
responseSchema = SchemaTagFromType(r.mainRouter.OpenAPI, HTTPError{})
}
content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"})

Expand Down Expand Up @@ -374,14 +374,14 @@ func OptionDefaultStatusCode(defaultStatusCode int) func(*BaseRoute) {
// })
func OptionSecurity(securityRequirements ...openapi3.SecurityRequirement) func(*BaseRoute) {
return func(r *BaseRoute) {
if r.mainRouter.OpenAPIzer.OpenAPIDescription().Components == nil {
if r.mainRouter.OpenAPI.Description().Components == nil {
panic("zero security schemes have been registered with the server")
}

// Validate the security scheme exists in components
for _, req := range securityRequirements {
for schemeName := range req {
if _, exists := r.mainRouter.OpenAPIzer.OpenAPIDescription().Components.SecuritySchemes[schemeName]; !exists {
if _, exists := r.mainRouter.OpenAPI.Description().Components.SecuritySchemes[schemeName]; !exists {
panic(fmt.Sprintf("security scheme '%s' not defined in components", schemeName))
}
}
Expand Down
16 changes: 8 additions & 8 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ func TestPath(t *testing.T) {

fuego.Get(s, "/test/{id}", helloWorld)

require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, "id", s.OpenAPI.Description().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "", s.OpenAPI.Description().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
})

t.Run("Declare explicitly an existing path parameter for the route", func(t *testing.T) {
Expand All @@ -310,9 +310,9 @@ func TestPath(t *testing.T) {
fuego.OptionPath("id", "some id", param.Example("123", "123"), param.Nullable()),
)

require.Equal(t, "id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "some id", s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, true, s.OpenAPIzer.OpenAPIDescription().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
require.Equal(t, "id", s.OpenAPI.Description().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Name)
require.Equal(t, "some id", s.OpenAPI.Description().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Description)
require.Equal(t, true, s.OpenAPI.Description().Paths.Find("/test/{id}").Get.Parameters.GetByInAndName("path", "id").Required, "path parameter is forced to be required")
})

t.Run("Declare explicitly a non-existing path parameter for the route panics", func(t *testing.T) {
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestRequestContentType(t *testing.T) {
require.NotNil(t, content.Get("application/json"))
require.Nil(t, content.Get("application/xml"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPI.Description().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})

Expand All @@ -369,7 +369,7 @@ func TestRequestContentType(t *testing.T) {
require.Nil(t, content.Get("application/xml"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref)
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPI.Description().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})

Expand All @@ -385,7 +385,7 @@ func TestRequestContentType(t *testing.T) {
require.Nil(t, content.Get("application/xml"))
require.NotNil(t, content.Get("my/content-type"))
require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref)
_, ok := s.OpenAPIzer.OpenAPIDescription().Components.RequestBodies["ReqBody"]
_, ok := s.OpenAPI.Description().Components.RequestBodies["ReqBody"]
require.False(t, ok)
})
}
Expand Down
Loading

0 comments on commit c833e87

Please sign in to comment.