Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: differentiate between array types and standard structs #134

Merged
merged 9 commits into from
Jul 16, 2024
118 changes: 68 additions & 50 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,58 +161,31 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
}

// Request body
bodyTag := tagFromType(*new(B))
if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag != "unknown-interface" && bodyTag != "string" {

bodySchema, ok := s.OpenApiSpec.Components.Schemas[bodyTag]
if !ok {
var err error
bodySchema, err = generator.NewSchemaRefForValue(new(B), s.OpenApiSpec.Components.Schemas)
if err != nil {
return operation, err
}
s.OpenApiSpec.Components.Schemas[bodyTag] = bodySchema
}

bodyTag := schemaTagFromType(s, *new(B))
if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag.name != "unknown-interface" && bodyTag.name != "string" {
content := openapi3.NewContentWithSchemaRef(&bodyTag.SchemaRef, []string{"application/json"})
requestBody := openapi3.NewRequestBody().
WithRequired(true).
WithDescription("Request body for " + reflect.TypeOf(*new(B)).String())
WithDescription("Request body for " + reflect.TypeOf(*new(B)).String()).
WithContent(content)

if bodySchema != nil {
content := openapi3.NewContentWithSchema(bodySchema.Value, []string{"application/json"})
content["application/json"].Schema.Ref = "#/components/schemas/" + bodyTag
requestBody.WithContent(content)
}

s.OpenApiSpec.Components.RequestBodies[bodyTag] = &openapi3.RequestBodyRef{
s.OpenApiSpec.Components.RequestBodies[bodyTag.name] = &openapi3.RequestBodyRef{
Value: requestBody,
}

// add request body to operation
operation.RequestBody = &openapi3.RequestBodyRef{
Ref: "#/components/requestBodies/" + bodyTag,
Ref: "#/components/requestBodies/" + bodyTag.name,
Value: requestBody,
}
}

tag := tagFromType(*new(T))
// Response body
responseSchema, ok := s.OpenApiSpec.Components.Schemas[tag]
if !ok {
var err error
responseSchema, err = generator.NewSchemaRefForValue(new(T), s.OpenApiSpec.Components.Schemas)
if err != nil {
return operation, err
}
s.OpenApiSpec.Components.Schemas[tag] = responseSchema
}
responseSchema := schemaTagFromType(s, *new(T))
content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"})
response := openapi3.NewResponse().
WithDescription("OK").
WithContent(content)

response := openapi3.NewResponse().WithDescription("OK")
if responseSchema != nil {
content := openapi3.NewContentWithSchema(responseSchema.Value, []string{"application/json"})
content["application/json"].Schema.Ref = "#/components/schemas/" + tag
response.WithContent(content)
}
operation.AddResponse(200, response)

// Path parameters
Expand All @@ -227,25 +200,70 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
return operation, nil
}

func tagFromType(v any) string {
// schemaTag is a struct that holds the name of the struct and the associated openapi3.SchemaRef
type schemaTag struct {
openapi3.SchemaRef
name string
}

func schemaTagFromType(s *Server, v any) schemaTag {
if v == nil {
return "unknown-interface"
// ensure we add unknown-interface to our schemas
s.getOrCreateSchema("unknown-interface", struct{}{})
Copy link
Collaborator Author

@dylanhitt dylanhitt Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need test to ensure an unknown-interface component schema is being added to the spec via test. I did not pick this up until my own diff

return schemaTag{
name: "unknown-interface",
SchemaRef: openapi3.SchemaRef{
Ref: "#/components/schemas/unknown-interface",
},
}
}

return dive(reflect.TypeOf(v), 4)
return dive(s, reflect.TypeOf(v), schemaTag{}, 5)
}

// dive returns the name of the type of the given reflect.Type.
// If the type is a pointer, slice, array, map, channel, function, or unsafe pointer,
// dive returns a schemaTag which includes the generated openapi3.SchemaRef and
// the name of the struct being passed in.
// If the type is a pointer, map, channel, function, or unsafe pointer,
// it will dive into the type and return the name of the type it points to.
func dive(t reflect.Type, maxDepth int) string {
// If the type is a slice or array type it will dive into the type as well as
// build and openapi3.Schema where Type is array and Ref is set to the proper
// components Schema
func dive(s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag {
if maxDepth == 0 {
return schemaTag{
name: "default",
SchemaRef: openapi3.SchemaRef{
Ref: "#/components/schemas/default",
},
}
}

switch t.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Array, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer:
if maxDepth == 0 {
return "default"
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer:
return dive(s, t.Elem(), tag, maxDepth-1)

case reflect.Slice, reflect.Array:
item := dive(s, t.Elem(), tag, maxDepth-1)
tag.name = item.name
tag.Value = &openapi3.Schema{
Type: "array",
Items: &item.SchemaRef,
}
return dive(t.Elem(), maxDepth-1)
return tag
Comment on lines +245 to +252
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, you really dived into the code! 😉


default:
return t.Name()
tag.name = t.Name()
tag.Ref = "#/components/schemas/" + tag.name
tag.Value = s.getOrCreateSchema(tag.name, reflect.New(t).Interface())
return tag
}
}

func (s *Server) getOrCreateSchema(key string, v any) *openapi3.Schema {
schemaRef, ok := s.OpenApiSpec.Components.Schemas[key]
if !ok {
schemaRef, _ = generator.NewSchemaRefForValue(v, s.OpenApiSpec.Components.Schemas)
s.OpenApiSpec.Components.Schemas[key] = schemaRef
}
return schemaRef.Value
}
138 changes: 114 additions & 24 deletions openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -22,33 +23,120 @@ type MyOutputStruct struct {
Quantity int `json:"quantity"`
}

func TestTagFromType(t *testing.T) {
require.Equal(t, "unknown-interface", tagFromType(*new(any)), "behind any interface")
require.Equal(t, "MyStruct", tagFromType(MyStruct{}))
type testCaseForTagType[V any] struct {
name string
description string
inputType V
s *Server

t.Run("behind pointers or pointers-like", func(t *testing.T) {
require.Equal(t, "MyStruct", tagFromType(&MyStruct{}))
require.Equal(t, "MyStruct", tagFromType([]MyStruct{}))
require.Equal(t, "MyStruct", tagFromType(&[]MyStruct{}))
type DeeplyNested *[]MyStruct
require.Equal(t, "MyStruct", tagFromType(new(DeeplyNested)), "behind 4 pointers")
})

t.Run("safety against recursion", func(t *testing.T) {
type DeeplyNested *[]MyStruct
type MoreDeeplyNested *[]DeeplyNested
require.Equal(t, "MyStruct", tagFromType(*new(MoreDeeplyNested)), "behind 5 pointers")
expectedTagValue string
}

require.Equal(t, "default", tagFromType(new(MoreDeeplyNested)), "behind 6 pointers")
require.Equal(t, "default", tagFromType([]*MoreDeeplyNested{}), "behind 7 pointers")
})
func Test_tagFromType(t *testing.T) {
s := NewServer()
type DeeplyNested *[]MyStruct
type MoreDeeplyNested *[]DeeplyNested

tcs := []testCaseForTagType[any]{
{
name: "unknown_interface",
description: "behind any interface",
inputType: *new(any),
expectedTagValue: "unknown-interface",
s: s,
},
{
name: "simple_struct",
description: "basic struct",
inputType: MyStruct{},
expectedTagValue: "MyStruct",
s: s,
},
{
name: "is_pointer",
description: "",
inputType: &MyStruct{},
expectedTagValue: "MyStruct",
s: s,
},
{
name: "is_array",
description: "",
inputType: []MyStruct{},
expectedTagValue: "MyStruct",
s: s,
},
{
name: "is_reference_to_array",
description: "",
inputType: &[]MyStruct{},
expectedTagValue: "MyStruct",
s: s,
},
{
name: "is_deeply_nested",
description: "behind 4 pointers",
inputType: new(DeeplyNested),
expectedTagValue: "MyStruct",
s: s,
},
{
name: "5_pointers",
description: "behind 5 pointers",
inputType: *new(MoreDeeplyNested),
expectedTagValue: "MyStruct",
s: s,
},
{
name: "6_pointers",
description: "behind 6 pointers",
inputType: new(MoreDeeplyNested),
expectedTagValue: "default",
s: s,
},
{
name: "7_pointers",
description: "behind 7 pointers",
inputType: []*MoreDeeplyNested{},
expectedTagValue: "default",
s: s,
},
{
name: "detecting_string",
description: "",
inputType: "string",
expectedTagValue: "string",
s: s,
},
{
name: "new_string",
description: "",
inputType: new(string),
expectedTagValue: "string",
s: s,
},
{
name: "string_array",
description: "",
inputType: []string{},
expectedTagValue: "string",
s: s,
},
{
name: "pointer_string_array",
description: "",
inputType: &[]string{},
expectedTagValue: "string",
s: s,
},
}

t.Run("detecting string", func(t *testing.T) {
require.Equal(t, "string", tagFromType("string"))
require.Equal(t, "string", tagFromType(new(string)))
require.Equal(t, "string", tagFromType([]string{}))
require.Equal(t, "string", tagFromType(&[]string{}))
})
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
tag := schemaTagFromType(tc.s, tc.inputType)
assert.Equal(t, tc.expectedTagValue, tag.name, tc.description)
})
}
}

func TestServer_generateOpenAPI(t *testing.T) {
Expand All @@ -67,6 +155,8 @@ func TestServer_generateOpenAPI(t *testing.T) {
require.NotNil(t, document.Paths.Find("/"))
require.Nil(t, document.Paths.Find("/unknown"))
require.NotNil(t, document.Paths.Find("/post"))
require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Type, "array")
require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Items.Ref, "#/components/schemas/MyStruct")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to do more to validate the actual schema reference here

require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200"))
require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"])
require.Nil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"].Schema.Value.Properties["unknown"])
Expand Down