Skip to content

Commit

Permalink
Custom primitive types (#198)
Browse files Browse the repository at this point in the history
* dealing with custom primitive types

* deal with array of custom types

* code doc

* testing custom types
  • Loading branch information
Abdul Alkhatib authored and pei0804 committed Sep 1, 2018
1 parent c62a15d commit d18ec72
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 27 deletions.
19 changes: 15 additions & 4 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type Parser struct {
// TypeDefinitions is a map that stores [package name][type name][*ast.TypeSpec]
TypeDefinitions map[string]map[string]*ast.TypeSpec

// CustomPrimitiveTypes is a map that stores custom primitive types to actual golang types [type name][string]
CustomPrimitiveTypes map[string]string

//registerTypes is a map that stores [refTypeName][*ast.TypeSpec]
registerTypes map[string]*ast.TypeSpec

Expand All @@ -64,9 +67,10 @@ func New() *Parser {
Definitions: make(map[string]spec.Schema),
},
},
files: make(map[string]*ast.File),
TypeDefinitions: make(map[string]map[string]*ast.TypeSpec),
registerTypes: make(map[string]*ast.TypeSpec),
files: make(map[string]*ast.File),
TypeDefinitions: make(map[string]map[string]*ast.TypeSpec),
CustomPrimitiveTypes: make(map[string]string),
registerTypes: make(map[string]*ast.TypeSpec),
}
return parser
}
Expand Down Expand Up @@ -333,7 +337,14 @@ func (parser *Parser) ParseType(astFile *ast.File) {
if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE {
for _, astSpec := range generalDeclaration.Specs {
if typeSpec, ok := astSpec.(*ast.TypeSpec); ok {
parser.TypeDefinitions[astFile.Name.String()][typeSpec.Name.String()] = typeSpec
typeName := fmt.Sprintf("%v", typeSpec.Type)
// check if its a custom primitive type
if IsGolangPrimitiveType(typeName) {
parser.CustomPrimitiveTypes[typeSpec.Name.String()] = typeName
} else {
parser.TypeDefinitions[astFile.Name.String()][typeSpec.Name.String()] = typeSpec
}

}
}
}
Expand Down
9 changes: 9 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,15 @@ func TestParseSimpleApi_ForSnakecase(t *testing.T) {
}
}
},
"custom_string": {
"type": "string"
},
"custom_string_arr": {
"type": "array",
"items": {
"type": "string"
}
},
"data": {
"type": "object"
},
Expand Down
14 changes: 12 additions & 2 deletions property.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ func getPropertyName(field *ast.Field, parser *Parser) propertyName {
if astTypeSelectorExpr, ok := field.Type.(*ast.SelectorExpr); ok {
return parseFieldSelectorExpr(astTypeSelectorExpr, parser, newProperty)
}

// check if it is a custom type
typeName := fmt.Sprintf("%v", field.Type)
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[typeName]; isCustomType {
return propertyName{SchemaType: actualPrimitiveType, ArrayType: actualPrimitiveType}
}

if astTypeIdent, ok := field.Type.(*ast.Ident); ok {
name := astTypeIdent.Name
schemeType := TransToValidSchemeType(name)
Expand Down Expand Up @@ -101,8 +108,11 @@ func getPropertyName(field *ast.Field, parser *Parser) propertyName {
return propertyName{SchemaType: "array", ArrayType: name}
}
}
str := fmt.Sprintf("%s", astTypeArray.Elt)
return propertyName{SchemaType: "array", ArrayType: str}
itemTypeName := fmt.Sprintf("%s", astTypeArray.Elt)
if actualPrimitiveType, isCustomType := parser.CustomPrimitiveTypes[itemTypeName]; isCustomType {
itemTypeName = actualPrimitiveType
}
return propertyName{SchemaType: "array", ArrayType: itemTypeName}
}
if _, ok := field.Type.(*ast.MapType); ok { // if map
//TODO: support map
Expand Down
18 changes: 9 additions & 9 deletions property_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestGetPropertyNameSelectorExpr(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameIdentObjectId(t *testing.T) {
Expand All @@ -50,7 +50,7 @@ func TestGetPropertyNameIdentObjectId(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameIdentUUID(t *testing.T) {
Expand All @@ -73,7 +73,7 @@ func TestGetPropertyNameIdentUUID(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameIdentDecimal(t *testing.T) {
Expand All @@ -96,7 +96,7 @@ func TestGetPropertyNameIdentDecimal(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameIdentTime(t *testing.T) {
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestGetPropertyNameStarExprIdent(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameArrayStarExpr(t *testing.T) {
Expand All @@ -160,7 +160,7 @@ func TestGetPropertyNameArrayStarExpr(t *testing.T) {
"string",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameMap(t *testing.T) {
Expand All @@ -179,7 +179,7 @@ func TestGetPropertyNameMap(t *testing.T) {
"object",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameStruct(t *testing.T) {
Expand All @@ -191,7 +191,7 @@ func TestGetPropertyNameStruct(t *testing.T) {
"object",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}

func TestGetPropertyNameInterface(t *testing.T) {
Expand All @@ -203,5 +203,5 @@ func TestGetPropertyNameInterface(t *testing.T) {
"object",
"",
}
assert.Equal(t, expected, getPropertyName(input, nil))
assert.Equal(t, expected, getPropertyName(input, New()))
}
25 changes: 25 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,28 @@ func TransToValidSchemeType(typeName string) string {
return typeName // to support user defined types
}
}

// IsGolangPrimitiveType determine whether the type name is a golang primitive type
func IsGolangPrimitiveType(typeName string) bool {
switch typeName {
case "uint",
"int",
"uint8",
"int8",
"uint16",
"int16",
"byte",
"uint32",
"int32",
"rune",
"uint64",
"int64",
"float32",
"float64",
"bool",
"string":
return true
default:
return false
}
}
29 changes: 17 additions & 12 deletions testdata/simple2/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package web
import (
"time"

uuid "github.com/satori/go.uuid"
"github.com/shopspring/decimal"
)

Expand All @@ -18,20 +19,24 @@ type Pet struct {
PhotoUrls []string `example:"http://test/image/1.jpg,http://test/image/2.jpg"`
}
}
Name string `example:"poti"`
PhotoUrls []string `example:"http://test/image/1.jpg,http://test/image/2.jpg"`
Tags []Tag
Pets *[]Pet2
Pets2 []*Pet2
Status string
Price float32 `example:"3.25" validate:"required,gte=0,lte=130"`
IsAlive bool `example:"true"`
Data interface{}
Hidden string `json:"-"`
UUID uuid.UUID
Decimal decimal.Decimal
Name string `example:"poti"`
PhotoUrls []string `example:"http://test/image/1.jpg,http://test/image/2.jpg"`
Tags []Tag
Pets *[]Pet2
Pets2 []*Pet2
Status string
Price float32 `example:"3.25" validate:"required,gte=0,lte=130"`
IsAlive bool `example:"true"`
Data interface{}
Hidden string `json:"-"`
UUID uuid.UUID
Decimal decimal.Decimal
customString CustomString
customStringArr []CustomString
}

type CustomString string

type Tag struct {
ID int `format:"int64"`
Name string
Expand Down

0 comments on commit d18ec72

Please sign in to comment.