Skip to content

Commit 9d34a76

Browse files
authored
fix: Nested generic fields not fully working, if generic type is from… (#1305)
* fix: Nested generic fields not fully working, if generic type is from another package - change full name generation and support SelectorExpr - prepend package only, if no name does not contain package fixes #1304 * test: New tests added increase code coverage for generics
1 parent 732c087 commit 9d34a76

File tree

5 files changed

+204
-20
lines changed

5 files changed

+204
-20
lines changed

generics.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
167167
// splitStructName splits a generic struct name in his parts
168168
func splitStructName(fullGenericForm string) (string, []string) {
169169
// split only at the first '[' and remove the last ']'
170+
if fullGenericForm[len(fullGenericForm)-1] != ']' {
171+
return "", nil
172+
}
173+
170174
genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2)
171175
if len(genericParams) == 1 {
172176
return "", nil
@@ -224,12 +228,11 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin
224228
func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
225229
switch fieldType := field.(type) {
226230
case *ast.IndexListExpr:
227-
spec := &TypeSpecDef{
228-
File: file,
229-
TypeSpec: getGenericTypeSpec(fieldType.X),
230-
PkgPath: file.Name.Name,
231+
fullName, err := getGenericTypeName(file, fieldType.X)
232+
if err != nil {
233+
return "", err
231234
}
232-
fullName := spec.FullName() + "["
235+
fullName += "["
233236

234237
for _, index := range fieldType.Indices {
235238
var fieldName string
@@ -252,11 +255,6 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
252255

253256
return strings.TrimRight(fullName, ",") + "]", nil
254257
case *ast.IndexExpr:
255-
if file.Name == nil {
256-
return "", errors.New("file name is nil")
257-
}
258-
packageName, _ := getFieldType(file, file.Name)
259-
260258
x, err := getFieldType(file, fieldType.X)
261259
if err != nil {
262260
return "", err
@@ -267,18 +265,38 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
267265
return "", err
268266
}
269267

268+
packageName := ""
269+
if !strings.Contains(x, ".") {
270+
if file.Name == nil {
271+
return "", errors.New("file name is nil")
272+
}
273+
packageName, _ = getFieldType(file, file.Name)
274+
}
275+
270276
return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
271277
}
272278

273279
return "", fmt.Errorf("unknown field type %#v", field)
274280
}
275281

276-
func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec {
282+
func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
277283
switch indexType := field.(type) {
278284
case *ast.Ident:
279-
return indexType.Obj.Decl.(*ast.TypeSpec)
285+
spec := &TypeSpecDef{
286+
File: file,
287+
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
288+
PkgPath: file.Name.Name,
289+
}
290+
return spec.FullName(), nil
280291
case *ast.ArrayType:
281-
return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec)
292+
spec := &TypeSpecDef{
293+
File: file,
294+
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
295+
PkgPath: file.Name.Name,
296+
}
297+
return spec.FullName(), nil
298+
case *ast.SelectorExpr:
299+
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
282300
}
283-
return nil
301+
return "", fmt.Errorf("unknown type %#v", field)
284302
}

generics_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,87 @@ func TestParseGenericsNames(t *testing.T) {
9393
assert.Equal(t, string(expected), string(b))
9494
}
9595

96+
func TestParametrizeStruct(t *testing.T) {
97+
pd := PackagesDefinitions{
98+
packages: make(map[string]*PackageDefinitions),
99+
}
100+
// valid
101+
typeSpec := pd.parametrizeStruct(&TypeSpecDef{
102+
TypeSpec: &ast.TypeSpec{
103+
Name: &ast.Ident{Name: "Field"},
104+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
105+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
106+
}}, "test.Field[string, []string]", false)
107+
assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name())
108+
109+
// definition contains one type params, but two type params are provided
110+
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
111+
TypeSpec: &ast.TypeSpec{
112+
Name: &ast.Ident{Name: "Field"},
113+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}},
114+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
115+
}}, "test.Field[string, string]", false)
116+
assert.Nil(t, typeSpec)
117+
118+
// definition contains two type params, but only one is used
119+
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
120+
TypeSpec: &ast.TypeSpec{
121+
Name: &ast.Ident{Name: "Field"},
122+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
123+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
124+
}}, "test.Field[string]", false)
125+
assert.Nil(t, typeSpec)
126+
127+
// name is not a valid type name
128+
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
129+
TypeSpec: &ast.TypeSpec{
130+
Name: &ast.Ident{Name: "Field"},
131+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
132+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
133+
}}, "test.Field[string", false)
134+
assert.Nil(t, typeSpec)
135+
136+
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
137+
TypeSpec: &ast.TypeSpec{
138+
Name: &ast.Ident{Name: "Field"},
139+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
140+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
141+
}}, "test.Field[string, [string]", false)
142+
assert.Nil(t, typeSpec)
143+
144+
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
145+
TypeSpec: &ast.TypeSpec{
146+
Name: &ast.Ident{Name: "Field"},
147+
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
148+
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
149+
}}, "test.Field[string, ]string]", false)
150+
assert.Nil(t, typeSpec)
151+
}
152+
153+
func TestSplitStructNames(t *testing.T) {
154+
t.Parallel()
155+
156+
field, params := splitStructName("test.Field")
157+
assert.Empty(t, field)
158+
assert.Nil(t, params)
159+
160+
field, params = splitStructName("test.Field]")
161+
assert.Empty(t, field)
162+
assert.Nil(t, params)
163+
164+
field, params = splitStructName("test.Field[string")
165+
assert.Empty(t, field)
166+
assert.Nil(t, params)
167+
168+
field, params = splitStructName("test.Field[string]")
169+
assert.Equal(t, "test.Field", field)
170+
assert.Equal(t, []string{"string"}, params)
171+
172+
field, params = splitStructName("test.Field[string, []string]")
173+
assert.Equal(t, "test.Field", field)
174+
assert.Equal(t, []string{"string", "[]string"}, params)
175+
}
176+
96177
func TestGetGenericFieldType(t *testing.T) {
97178
field, err := getFieldType(
98179
&ast.File{Name: &ast.Ident{Name: "test"}},
@@ -124,6 +205,34 @@ func TestGetGenericFieldType(t *testing.T) {
124205
assert.NoError(t, err)
125206
assert.Equal(t, "test.Field[string,int]", field)
126207

208+
field, err = getFieldType(
209+
&ast.File{Name: &ast.Ident{Name: "test"}},
210+
&ast.IndexListExpr{
211+
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
212+
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}},
213+
},
214+
)
215+
assert.NoError(t, err)
216+
assert.Equal(t, "test.Field[string,[]int]", field)
217+
218+
field, err = getFieldType(
219+
&ast.File{Name: &ast.Ident{Name: "test"}},
220+
&ast.IndexListExpr{
221+
X: &ast.BadExpr{},
222+
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
223+
},
224+
)
225+
assert.Error(t, err)
226+
227+
field, err = getFieldType(
228+
&ast.File{Name: &ast.Ident{Name: "test"}},
229+
&ast.IndexListExpr{
230+
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
231+
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}},
232+
},
233+
)
234+
assert.Error(t, err)
235+
127236
field, err = getFieldType(
128237
&ast.File{Name: &ast.Ident{Name: "test"}},
129238
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
@@ -148,4 +257,40 @@ func TestGetGenericFieldType(t *testing.T) {
148257
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}},
149258
)
150259
assert.Error(t, err)
260+
261+
field, err = getFieldType(
262+
&ast.File{Name: &ast.Ident{Name: "test"}},
263+
&ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}},
264+
)
265+
assert.NoError(t, err)
266+
assert.Equal(t, "field.Name[string]", field)
267+
}
268+
269+
func TestGetGenericTypeName(t *testing.T) {
270+
field, err := getGenericTypeName(
271+
&ast.File{Name: &ast.Ident{Name: "test"}},
272+
&ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
273+
)
274+
assert.NoError(t, err)
275+
assert.Equal(t, "test.Field", field)
276+
277+
field, err = getGenericTypeName(
278+
&ast.File{Name: &ast.Ident{Name: "test"}},
279+
&ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}},
280+
)
281+
assert.NoError(t, err)
282+
assert.Equal(t, "test.Field", field)
283+
284+
field, err = getGenericTypeName(
285+
&ast.File{Name: &ast.Ident{Name: "test"}},
286+
&ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}},
287+
)
288+
assert.NoError(t, err)
289+
assert.Equal(t, "field.Name", field)
290+
291+
_, err = getGenericTypeName(
292+
&ast.File{Name: &ast.Ident{Name: "test"}},
293+
&ast.BadExpr{},
294+
)
295+
assert.Error(t, err)
151296
}

testdata/generics_property/api/api.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package api
22

33
import (
4+
"github.com/swaggo/swag/testdata/generics_property/web"
45
"net/http"
56
)
67

8+
type NestedResponse struct {
9+
web.GenericResponse[[]string, *uint8]
10+
}
11+
712
// @Summary List Posts
813
// @Description Get All of the Posts
914
// @Accept json
@@ -12,6 +17,7 @@ import (
1217
// @Success 200 {object} web.PostResponse "ok"
1318
// @Success 201 {object} web.PostResponses "ok"
1419
// @Success 202 {object} web.StringResponse "ok"
20+
// @Success 203 {object} NestedResponse "ok"
1521
// @Router /posts [get]
1622
func GetPosts(w http.ResponseWriter, r *http.Request) {
1723
}

testdata/generics_property/expected.json

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
"type": "integer",
4040
"name": "rows",
4141
"in": "query"
42-
},
43-
{
44-
"type": "string",
45-
"name": "search",
46-
"in": "query"
4742
}
4843
],
4944
"responses": {
@@ -64,12 +59,32 @@
6459
"schema": {
6560
"$ref": "#/definitions/web.StringResponse"
6661
}
62+
},
63+
"203": {
64+
"description": "ok",
65+
"schema": {
66+
"$ref": "#/definitions/api.NestedResponse"
67+
}
6768
}
6869
}
6970
}
7071
}
7172
},
7273
"definitions": {
74+
"api.NestedResponse": {
75+
"type": "object",
76+
"properties": {
77+
"items": {
78+
"type": "array",
79+
"items": {
80+
"type": "string"
81+
}
82+
},
83+
"items2": {
84+
"type": "integer"
85+
}
86+
}
87+
},
7388
"types.Field-string": {
7489
"type": "object",
7590
"properties": {

testdata/generics_property/web/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (String) Where(ps ...PostSelector) String {
2828

2929
type PostPager struct {
3030
Pager[String, PostSelector]
31-
Search string `json:"search" form:"search"`
31+
Search types.Field[string] `json:"search" form:"search"`
3232
}
3333

3434
type PostResponse struct {

0 commit comments

Comments
 (0)