diff --git a/openapi.go b/openapi.go index 93222c03..eb37e4fb 100644 --- a/openapi.go +++ b/openapi.go @@ -161,58 +161,31 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap } // Request body - bodyTag := tagFromType(*new(B)) - if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag != "unknown-interface" && bodyTag != "string" { - - bodySchema, ok := s.OpenApiSpec.Components.Schemas[bodyTag] - if !ok { - var err error - bodySchema, err = generator.NewSchemaRefForValue(new(B), s.OpenApiSpec.Components.Schemas) - if err != nil { - return operation, err - } - s.OpenApiSpec.Components.Schemas[bodyTag] = bodySchema - } - + bodyTag := schemaTagFromType(s, *new(B)) + if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag.name != "unknown-interface" && bodyTag.name != "string" { + content := openapi3.NewContentWithSchemaRef(&bodyTag.SchemaRef, []string{"application/json"}) requestBody := openapi3.NewRequestBody(). WithRequired(true). - WithDescription("Request body for " + reflect.TypeOf(*new(B)).String()) + WithDescription("Request body for " + reflect.TypeOf(*new(B)).String()). + WithContent(content) - if bodySchema != nil { - content := openapi3.NewContentWithSchema(bodySchema.Value, []string{"application/json"}) - content["application/json"].Schema.Ref = "#/components/schemas/" + bodyTag - requestBody.WithContent(content) - } - - s.OpenApiSpec.Components.RequestBodies[bodyTag] = &openapi3.RequestBodyRef{ + s.OpenApiSpec.Components.RequestBodies[bodyTag.name] = &openapi3.RequestBodyRef{ Value: requestBody, } // add request body to operation operation.RequestBody = &openapi3.RequestBodyRef{ - Ref: "#/components/requestBodies/" + bodyTag, + Ref: "#/components/requestBodies/" + bodyTag.name, Value: requestBody, } } - tag := tagFromType(*new(T)) - // Response body - responseSchema, ok := s.OpenApiSpec.Components.Schemas[tag] - if !ok { - var err error - responseSchema, err = generator.NewSchemaRefForValue(new(T), s.OpenApiSpec.Components.Schemas) - if err != nil { - return operation, err - } - s.OpenApiSpec.Components.Schemas[tag] = responseSchema - } + responseSchema := schemaTagFromType(s, *new(T)) + content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"}) + response := openapi3.NewResponse(). + WithDescription("OK"). + WithContent(content) - response := openapi3.NewResponse().WithDescription("OK") - if responseSchema != nil { - content := openapi3.NewContentWithSchema(responseSchema.Value, []string{"application/json"}) - content["application/json"].Schema.Ref = "#/components/schemas/" + tag - response.WithContent(content) - } operation.AddResponse(200, response) // Path parameters @@ -227,25 +200,70 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap return operation, nil } -func tagFromType(v any) string { +// schemaTag is a struct that holds the name of the struct and the associated openapi3.SchemaRef +type schemaTag struct { + openapi3.SchemaRef + name string +} + +func schemaTagFromType(s *Server, v any) schemaTag { if v == nil { - return "unknown-interface" + // ensure we add unknown-interface to our schemas + s.getOrCreateSchema("unknown-interface", struct{}{}) + return schemaTag{ + name: "unknown-interface", + SchemaRef: openapi3.SchemaRef{ + Ref: "#/components/schemas/unknown-interface", + }, + } } - return dive(reflect.TypeOf(v), 4) + return dive(s, reflect.TypeOf(v), schemaTag{}, 5) } -// dive returns the name of the type of the given reflect.Type. -// If the type is a pointer, slice, array, map, channel, function, or unsafe pointer, +// dive returns a schemaTag which includes the generated openapi3.SchemaRef and +// the name of the struct being passed in. +// If the type is a pointer, map, channel, function, or unsafe pointer, // it will dive into the type and return the name of the type it points to. -func dive(t reflect.Type, maxDepth int) string { +// If the type is a slice or array type it will dive into the type as well as +// build and openapi3.Schema where Type is array and Ref is set to the proper +// components Schema +func dive(s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag { + if maxDepth == 0 { + return schemaTag{ + name: "default", + SchemaRef: openapi3.SchemaRef{ + Ref: "#/components/schemas/default", + }, + } + } + switch t.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Array, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer: - if maxDepth == 0 { - return "default" + case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer: + return dive(s, t.Elem(), tag, maxDepth-1) + + case reflect.Slice, reflect.Array: + item := dive(s, t.Elem(), tag, maxDepth-1) + tag.name = item.name + tag.Value = &openapi3.Schema{ + Type: "array", + Items: &item.SchemaRef, } - return dive(t.Elem(), maxDepth-1) + return tag + default: - return t.Name() + tag.name = t.Name() + tag.Ref = "#/components/schemas/" + tag.name + tag.Value = s.getOrCreateSchema(tag.name, reflect.New(t).Interface()) + return tag + } +} + +func (s *Server) getOrCreateSchema(key string, v any) *openapi3.Schema { + schemaRef, ok := s.OpenApiSpec.Components.Schemas[key] + if !ok { + schemaRef, _ = generator.NewSchemaRefForValue(v, s.OpenApiSpec.Components.Schemas) + s.OpenApiSpec.Components.Schemas[key] = schemaRef } + return schemaRef.Value } diff --git a/openapi_test.go b/openapi_test.go index a4f10259..442dc6e5 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -8,6 +8,7 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -22,33 +23,120 @@ type MyOutputStruct struct { Quantity int `json:"quantity"` } -func TestTagFromType(t *testing.T) { - require.Equal(t, "unknown-interface", tagFromType(*new(any)), "behind any interface") - require.Equal(t, "MyStruct", tagFromType(MyStruct{})) +type testCaseForTagType[V any] struct { + name string + description string + inputType V + s *Server - t.Run("behind pointers or pointers-like", func(t *testing.T) { - require.Equal(t, "MyStruct", tagFromType(&MyStruct{})) - require.Equal(t, "MyStruct", tagFromType([]MyStruct{})) - require.Equal(t, "MyStruct", tagFromType(&[]MyStruct{})) - type DeeplyNested *[]MyStruct - require.Equal(t, "MyStruct", tagFromType(new(DeeplyNested)), "behind 4 pointers") - }) - - t.Run("safety against recursion", func(t *testing.T) { - type DeeplyNested *[]MyStruct - type MoreDeeplyNested *[]DeeplyNested - require.Equal(t, "MyStruct", tagFromType(*new(MoreDeeplyNested)), "behind 5 pointers") + expectedTagValue string +} - require.Equal(t, "default", tagFromType(new(MoreDeeplyNested)), "behind 6 pointers") - require.Equal(t, "default", tagFromType([]*MoreDeeplyNested{}), "behind 7 pointers") - }) +func Test_tagFromType(t *testing.T) { + s := NewServer() + type DeeplyNested *[]MyStruct + type MoreDeeplyNested *[]DeeplyNested + + tcs := []testCaseForTagType[any]{ + { + name: "unknown_interface", + description: "behind any interface", + inputType: *new(any), + expectedTagValue: "unknown-interface", + s: s, + }, + { + name: "simple_struct", + description: "basic struct", + inputType: MyStruct{}, + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "is_pointer", + description: "", + inputType: &MyStruct{}, + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "is_array", + description: "", + inputType: []MyStruct{}, + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "is_reference_to_array", + description: "", + inputType: &[]MyStruct{}, + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "is_deeply_nested", + description: "behind 4 pointers", + inputType: new(DeeplyNested), + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "5_pointers", + description: "behind 5 pointers", + inputType: *new(MoreDeeplyNested), + expectedTagValue: "MyStruct", + s: s, + }, + { + name: "6_pointers", + description: "behind 6 pointers", + inputType: new(MoreDeeplyNested), + expectedTagValue: "default", + s: s, + }, + { + name: "7_pointers", + description: "behind 7 pointers", + inputType: []*MoreDeeplyNested{}, + expectedTagValue: "default", + s: s, + }, + { + name: "detecting_string", + description: "", + inputType: "string", + expectedTagValue: "string", + s: s, + }, + { + name: "new_string", + description: "", + inputType: new(string), + expectedTagValue: "string", + s: s, + }, + { + name: "string_array", + description: "", + inputType: []string{}, + expectedTagValue: "string", + s: s, + }, + { + name: "pointer_string_array", + description: "", + inputType: &[]string{}, + expectedTagValue: "string", + s: s, + }, + } - t.Run("detecting string", func(t *testing.T) { - require.Equal(t, "string", tagFromType("string")) - require.Equal(t, "string", tagFromType(new(string))) - require.Equal(t, "string", tagFromType([]string{})) - require.Equal(t, "string", tagFromType(&[]string{})) - }) + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + tag := schemaTagFromType(tc.s, tc.inputType) + assert.Equal(t, tc.expectedTagValue, tag.name, tc.description) + }) + } } func TestServer_generateOpenAPI(t *testing.T) { @@ -67,6 +155,8 @@ func TestServer_generateOpenAPI(t *testing.T) { require.NotNil(t, document.Paths.Find("/")) require.Nil(t, document.Paths.Find("/unknown")) require.NotNil(t, document.Paths.Find("/post")) + require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Type, "array") + require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Items.Ref, "#/components/schemas/MyStruct") require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200")) require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"]) require.Nil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"].Schema.Value.Properties["unknown"])