Skip to content

Commit

Permalink
fix: differentiate between array typres and standard structs
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhitt committed Jun 26, 2024
1 parent 11ee6f7 commit 1f35838
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 33 deletions.
94 changes: 61 additions & 33 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,15 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
// Request body
bodyTag := tagFromType(*new(B))
if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag != "unknown-interface" && bodyTag != "string" {

bodySchema, ok := s.OpenApiSpec.Components.Schemas[bodyTag]
if !ok {
var err error
bodySchema, err = generator.NewSchemaRefForValue(new(B), s.OpenApiSpec.Components.Schemas)
if err != nil {
return operation, err
}
s.OpenApiSpec.Components.Schemas[bodyTag] = bodySchema
bodySchema, err := schemaRefFromType[B](s, *new(B))
if err != nil {
return operation, err
}

content := openapi3.NewContentWithSchemaRef(bodySchema, []string{"application/json"})
requestBody := openapi3.NewRequestBody().
WithRequired(true).
WithDescription("Request body for " + reflect.TypeOf(*new(B)).String())

if bodySchema != nil {
content := openapi3.NewContentWithSchema(bodySchema.Value, []string{"application/json"})
content["application/json"].Schema.Ref = "#/components/schemas/" + bodyTag
requestBody.WithContent(content)
}
WithDescription("Request body for " + reflect.TypeOf(*new(B)).String()).
WithContent(content)

s.OpenApiSpec.Components.RequestBodies[bodyTag] = &openapi3.RequestBodyRef{
Value: requestBody,
Expand All @@ -195,24 +184,16 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
}
}

tag := tagFromType(*new(T))
// Response body
responseSchema, ok := s.OpenApiSpec.Components.Schemas[tag]
if !ok {
var err error
responseSchema, err = generator.NewSchemaRefForValue(new(T), s.OpenApiSpec.Components.Schemas)
if err != nil {
return operation, err
}
s.OpenApiSpec.Components.Schemas[tag] = responseSchema
responseSchema, err := schemaRefFromType[T](s, *new(T))
if err != nil {
return operation, err
}

response := openapi3.NewResponse().WithDescription("OK")
if responseSchema != nil {
content := openapi3.NewContentWithSchema(responseSchema.Value, []string{"application/json"})
content["application/json"].Schema.Ref = "#/components/schemas/" + tag
response.WithContent(content)
}
content := openapi3.NewContentWithSchemaRef(responseSchema, []string{"application/json"})
response := openapi3.NewResponse().
WithDescription("OK").
WithContent(content)

operation.AddResponse(200, response)

// Path parameters
Expand All @@ -227,6 +208,53 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
return operation, nil
}

func schemaRefFromType[V any](s *Server, v any) (*openapi3.SchemaRef, error) {
if v == nil {
return &openapi3.SchemaRef{
Ref: "#/components/schemas/unknown-interface",
}, nil
}

schemaRef, err := schemaDive[V](s, reflect.TypeOf(v), openapi3.SchemaRef{}, 4)
return &schemaRef, err
}

func schemaDive[V any](s *Server, t reflect.Type, schemaRef openapi3.SchemaRef, maxDepth int) (openapi3.SchemaRef, error) {
if maxDepth == 0 {
return openapi3.SchemaRef{
Ref: "#/components/schemas/default",
}, nil
}

switch t.Kind() {
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer:
return schemaDive[V](s, t.Elem(), schemaRef, maxDepth-1)
case reflect.Slice, reflect.Array:
item, err := schemaDive[V](s, t.Elem(), schemaRef, maxDepth-1)
if err != nil {
return schemaRef, err
}
schemaRef.Value = &openapi3.Schema{
Type: "array",
Items: &item,
}
return schemaRef, nil
default:
componentRef, ok := s.OpenApiSpec.Components.Schemas[t.Name()]
if !ok {
var err error
componentRef, err = generator.NewSchemaRefForValue(new(V), s.OpenApiSpec.Components.Schemas)
if err != nil {
return openapi3.SchemaRef{}, err
}
s.OpenApiSpec.Components.Schemas[t.Name()] = componentRef
}
schemaRef.Value = componentRef.Value
schemaRef.Ref = "#/components/schemas/" + t.Name()
return schemaRef, nil
}
}

func tagFromType(v any) string {
if v == nil {
return "unknown-interface"
Expand Down
2 changes: 2 additions & 0 deletions openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func TestServer_generateOpenAPI(t *testing.T) {
require.NotNil(t, document.Paths.Find("/"))
require.Nil(t, document.Paths.Find("/unknown"))
require.NotNil(t, document.Paths.Find("/post"))
require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Type, "array")
require.Equal(t, document.Paths.Find("/post").Post.Responses.Value("200").Value.Content["application/json"].Schema.Value.Items.Ref, "#/components/schemas/MyStruct")
require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200"))
require.NotNil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"])
require.Nil(t, document.Paths.Find("/post/{id}").Get.Responses.Value("200").Value.Content["application/json"].Schema.Value.Properties["unknown"])
Expand Down

0 comments on commit 1f35838

Please sign in to comment.