Skip to content

Commit

Permalink
chore: add test for getPrimitiveValue util
Browse files Browse the repository at this point in the history
  • Loading branch information
Dat Nguyen committed Dec 15, 2024
1 parent 0f275e9 commit b1cd38c
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 10 deletions.
15 changes: 13 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@ var (

// ErrUnsupportedType describes an unsupported field type
type ErrUnsupportedType struct {
typ reflect.Type
typ reflect.Kind
}

// Error returns the UnsupportedType error text
func (e *ErrUnsupportedType) Error() string {
return fmt.Sprintf("mold: unsupported field type: %s", e.typ.Kind())
return fmt.Sprintf("mold: unsupported field type: %s", e.typ)
}

// ErrFailedToParseValue describes an error while parsing a value
type ErrFailedToParseValue struct {
typ reflect.Kind
err error
}

// Error returns the FailedToParseValue error text
func (e *ErrFailedToParseValue) Error() string {
return fmt.Sprintf("mold: failed to parse value for type %s: %s", e.typ, e.err.Error())
}

// ErrUndefinedTag defines a tag that does not exist
Expand Down
2 changes: 1 addition & 1 deletion modifiers/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func setValue(ctx context.Context, fl mold.FieldLevel) error {
// 2. Attempting to set its underlying value
// Try to convert the parameter string to the appropriate primitive type
// that the pointer references (e.g., *string, *int, *bool)
value, err := mold.GetPrimitiveValue(fl.Field().Type().Elem(), fl.Param())
value, err := mold.GetPrimitiveValue(fl.Field().Type().Elem().Kind(), fl.Param())
if err != nil {
// If ErrUnsupportedType: leave as zero value
if _, isUnsupportedType := err.(*mold.ErrUnsupportedType); isUnsupportedType {
Expand Down
14 changes: 7 additions & 7 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,44 @@ func HasValue(field reflect.Value) bool {
}
}

func GetPrimitiveValue(typ reflect.Type, value string) (reflect.Value, error) {
switch typ.Kind() {
func GetPrimitiveValue(typ reflect.Kind, value string) (reflect.Value, error) {
switch typ {

case reflect.String:
return reflect.ValueOf(value), nil

case reflect.Int:
value, err := strconv.Atoi(value)
if err != nil {
return reflect.Value{}, err
return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err}
}
return reflect.ValueOf(value), nil

case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
value, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return reflect.Value{}, err
return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err}
}
return reflect.ValueOf(int64(value)), nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
value, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return reflect.Value{}, err
return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err}
}
return reflect.ValueOf(uint64(value)), nil

case reflect.Float32, reflect.Float64:
value, err := strconv.ParseFloat(value, 64)
if err != nil {
return reflect.Value{}, err
return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err}
}
return reflect.ValueOf(value), nil

case reflect.Bool:
value, err := strconv.ParseBool(value)
if err != nil {
return reflect.Value{}, err
return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err}
}
return reflect.ValueOf(value), nil
}
Expand Down
199 changes: 199 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package mold

import (
"math"
"reflect"
"testing"

. "github.com/go-playground/assert/v2"
)

func TestGetPrimitiveValue(t *testing.T) {
tests := []struct {
name string
typ reflect.Kind
value string
expected reflect.Value
expectError string
}{
{
name: "string",
typ: reflect.String,
value: "test",
expected: reflect.ValueOf("test"),
},
{
name: "int",
typ: reflect.Int,
value: "123",
expected: reflect.ValueOf(123),
},
{
name: "int8",
typ: reflect.Int8,
value: "123",
expected: reflect.ValueOf(int8(123)),
},
{
name: "int16",
typ: reflect.Int16,
value: "123",
expected: reflect.ValueOf(int16(123)),
},
{
name: "bool",
typ: reflect.Bool,
value: "true",
expected: reflect.ValueOf(true),
},
{
name: "error while parsing int",
typ: reflect.Int,
value: "abc",
expectError: "mold: failed to parse value for type int: strconv.Atoi: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing int8",
typ: reflect.Int8,
value: "abc",
expectError: "mold: failed to parse value for type int8: strconv.ParseInt: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing int16",
typ: reflect.Int16,
value: "abc",
expectError: "mold: failed to parse value for type int16: strconv.ParseInt: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing int32",
typ: reflect.Int32,
value: "abc",
expectError: "mold: failed to parse value for type int32: strconv.ParseInt: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing int64",
typ: reflect.Int64,
value: "abc",
expectError: "mold: failed to parse value for type int64: strconv.ParseInt: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing uint",
typ: reflect.Uint,
value: "abc",
expectError: "mold: failed to parse value for type uint: strconv.ParseUint: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing uint8",
typ: reflect.Uint8,
value: "abc",
expectError: "mold: failed to parse value for type uint8: strconv.ParseUint: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing uint16",
typ: reflect.Uint16,
value: "12.34",
expectError: "mold: failed to parse value for type uint16: strconv.ParseUint: parsing \"12.34\": invalid syntax",
},
{
name: "error while parsing uint32",
typ: reflect.Uint32,
value: "12.34",
expectError: "mold: failed to parse value for type uint32: strconv.ParseUint: parsing \"12.34\": invalid syntax",
},
{
name: "error while parsing uint64",
typ: reflect.Uint64,
value: "12.34",
expectError: "mold: failed to parse value for type uint64: strconv.ParseUint: parsing \"12.34\": invalid syntax",
},
{
name: "error while parsing float32",
typ: reflect.Float32,
value: "abc",
expectError: "mold: failed to parse value for type float32: strconv.ParseFloat: parsing \"abc\": invalid syntax",
},
{
name: "error while parsing bool",
typ: reflect.Bool,
value: "invalid bool",
expectError: "mold: failed to parse value for type bool: strconv.ParseBool: parsing \"invalid bool\": invalid syntax",
},
{
name: "unsupported type",
typ: reflect.Struct,
value: "abc",
expectError: "mold: unsupported field type: struct",
},
{
name: "uint",
typ: reflect.Uint,
value: "123",
expected: reflect.ValueOf(uint(123)),
},
{
name: "uint8",
typ: reflect.Uint8,
value: "123",
expected: reflect.ValueOf(uint8(123)),
},
{
name: "uint16",
typ: reflect.Uint16,
value: "123",
expected: reflect.ValueOf(uint16(123)),
},
{
name: "uint32",
typ: reflect.Uint32,
value: "123",
expected: reflect.ValueOf(uint32(123)),
},
{
name: "uint64",
typ: reflect.Uint64,
value: "123",
expected: reflect.ValueOf(uint64(123)),
},
{
name: "float32",
typ: reflect.Float32,
value: "123.45",
expected: reflect.ValueOf(float32(123.45)),
},
{
name: "float64",
typ: reflect.Float64,
value: "123.45",
expected: reflect.ValueOf(float64(123.45)),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
actual, err := GetPrimitiveValue(tc.typ, tc.value)
if tc.expectError != "" {
NotEqual(t, nil, err)
Equal(t, tc.expectError, err.Error())
} else {
Equal(t, nil, err)
switch tc.typ {
case reflect.String:
Equal(t, tc.expected.String(), actual.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
Equal(t, tc.expected.Int(), actual.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Equal(t, tc.expected.Uint(), actual.Uint())
case reflect.Float32, reflect.Float64:
// could not assert equal float because of precision issues
// so we just check in 4 decimal places
decimalMask := math.Pow(10, 4)
Equal(t, math.Round(tc.expected.Float()*decimalMask), math.Round(actual.Float()*decimalMask))
case reflect.Bool:
Equal(t, tc.expected.Bool(), actual.Bool())
default:
Equal(t, tc.expected.Interface(), actual.Interface())
}
}
})
}
}

0 comments on commit b1cd38c

Please sign in to comment.