Skip to content

Commit 5c6c571

Browse files
committed
fix: relect.New(t) instead of using generics with dive so we only reflect the struct and not it's array/ptrs
1 parent e2c7d3f commit 5c6c571

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

openapi.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
161161
}
162162

163163
// Request body
164-
bodyTag := schemaTagFromType[B](s, *new(B))
164+
bodyTag := schemaTagFromType(s, *new(B))
165165
if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && bodyTag.name != "unknown-interface" && bodyTag.name != "string" {
166166
content := openapi3.NewContentWithSchemaRef(&bodyTag.SchemaRef, []string{"application/json"})
167167
requestBody := openapi3.NewRequestBody().
@@ -180,7 +180,7 @@ func RegisterOpenAPIOperation[T, B any](s *Server, method, path string) (*openap
180180
}
181181
}
182182

183-
responseSchema := schemaTagFromType[T](s, *new(T))
183+
responseSchema := schemaTagFromType(s, *new(T))
184184
content := openapi3.NewContentWithSchemaRef(&responseSchema.SchemaRef, []string{"application/json"})
185185
response := openapi3.NewResponse().
186186
WithDescription("OK").
@@ -205,7 +205,7 @@ type schemaTag struct {
205205
name string
206206
}
207207

208-
func schemaTagFromType[V any](s *Server, v any) schemaTag {
208+
func schemaTagFromType(s *Server, v any) schemaTag {
209209
if v == nil {
210210
// ensure we add unknown-interface to our schemas
211211
s.getOrCreateSchema("unknown-interface", struct{}{})
@@ -217,7 +217,7 @@ func schemaTagFromType[V any](s *Server, v any) schemaTag {
217217
}
218218
}
219219

220-
return dive[V](s, reflect.TypeOf(v), schemaTag{}, 5)
220+
return dive(s, reflect.TypeOf(v), schemaTag{}, 5)
221221
}
222222

223223
// dive returns a schemaTag which includes the generated openapi3.SchemaRef and
@@ -227,7 +227,7 @@ func schemaTagFromType[V any](s *Server, v any) schemaTag {
227227
// If the type is a slice or array type it will dive into the type as well as
228228
// build and openapi3.Schema where Type is array and Ref is set to the proper
229229
// components Schema
230-
func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag {
230+
func dive(s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaTag {
231231
if maxDepth == 0 {
232232
return schemaTag{
233233
name: "default",
@@ -239,10 +239,10 @@ func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaT
239239

240240
switch t.Kind() {
241241
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Func, reflect.UnsafePointer:
242-
return dive[V](s, t.Elem(), tag, maxDepth-1)
242+
return dive(s, t.Elem(), tag, maxDepth-1)
243243

244244
case reflect.Slice, reflect.Array:
245-
item := dive[V](s, t.Elem(), tag, maxDepth-1)
245+
item := dive(s, t.Elem(), tag, maxDepth-1)
246246
tag.name = item.name
247247
tag.Value = &openapi3.Schema{
248248
Type: "array",
@@ -253,7 +253,7 @@ func dive[V any](s *Server, t reflect.Type, tag schemaTag, maxDepth int) schemaT
253253
default:
254254
tag.name = t.Name()
255255
tag.Ref = "#/components/schemas/" + tag.name
256-
tag.Value = s.getOrCreateSchema(tag.name, new(V))
256+
tag.Value = s.getOrCreateSchema(tag.name, reflect.New(t).Interface())
257257
return tag
258258
}
259259
}

openapi_test.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ type testCaseForTagType[V any] struct {
3232
expectedTagValue string
3333
}
3434

35-
func runTestCase[V any](tc testCaseForTagType[V]) func(t *testing.T) {
36-
return func(t *testing.T) {
37-
tag := schemaTagFromType[V](tc.s, tc.inputType)
38-
assert.Equal(t, tc.expectedTagValue, tag.name, tc.description)
39-
}
40-
}
41-
4235
func Test_tagFromType(t *testing.T) {
4336
s := NewServer()
4437
type DeeplyNested *[]MyStruct
@@ -139,7 +132,10 @@ func Test_tagFromType(t *testing.T) {
139132
}
140133

141134
for _, tc := range tcs {
142-
t.Run(tc.name, runTestCase(tc))
135+
t.Run(tc.name, func(t *testing.T) {
136+
tag := schemaTagFromType(tc.s, tc.inputType)
137+
assert.Equal(t, tc.expectedTagValue, tag.name, tc.description)
138+
})
143139
}
144140
}
145141

0 commit comments

Comments
 (0)