From 0f275e93b7ff0edd976022fe3db4345826ba5dde Mon Sep 17 00:00:00 2001 From: Dat Nguyen Date: Sun, 15 Dec 2024 14:09:37 +0700 Subject: [PATCH 1/2] perf: enhance default tag for pointer type --- errors.go | 10 +++++ modifiers/multi.go | 20 ++++++++- modifiers/multi_test.go | 96 +++++++++++++++++++++++++++++++++++++++-- mold.go | 2 +- util.go | 45 +++++++++++++++++++ 5 files changed, 167 insertions(+), 6 deletions(-) diff --git a/errors.go b/errors.go index b51bb5b..279db37 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,16 @@ var ( ErrInvalidKeysTag = errors.New("'" + keysTag + "' tag must be immediately preceeded by the '" + diveTag + "' tag") ) +// ErrUnsupportedType describes an unsupported field type +type ErrUnsupportedType struct { + typ reflect.Type +} + +// Error returns the UnsupportedType error text +func (e *ErrUnsupportedType) Error() string { + return fmt.Sprintf("mold: unsupported field type: %s", e.typ.Kind()) +} + // ErrUndefinedTag defines a tag that does not exist type ErrUndefinedTag struct { tag string diff --git a/modifiers/multi.go b/modifiers/multi.go index 9320f00..4e55088 100644 --- a/modifiers/multi.go +++ b/modifiers/multi.go @@ -17,7 +17,7 @@ var ( // defaultValue allows setting of a default value IF no value is already present. func defaultValue(ctx context.Context, fl mold.FieldLevel) error { - if !fl.Field().IsZero() { + if mold.HasValue(fl.Field()) { return nil } return setValue(ctx, fl) @@ -125,7 +125,25 @@ func setValue(ctx context.Context, fl mold.FieldLevel) error { fl.Field().Set(reflect.MakeChan(fl.Field().Type(), buffer)) case reflect.Ptr: + // Handle pointer fields by: + // 1. Creating a new pointer point to empty value with reflect.New() fl.Field().Set(reflect.New(fl.Field().Type().Elem())) + + // 2. Attempting to set its underlying value + // Try to convert the parameter string to the appropriate primitive type + // that the pointer references (e.g., *string, *int, *bool) + value, err := mold.GetPrimitiveValue(fl.Field().Type().Elem(), fl.Param()) + if err != nil { + // If ErrUnsupportedType: leave as zero value + if _, isUnsupportedType := err.(*mold.ErrUnsupportedType); isUnsupportedType { + break + } + // For all other errors except ErrUnsupportedType: propagate the error + return err + } + // If no error: set the underlying value + fl.Field().Elem().Set(value) + } return nil } diff --git a/modifiers/multi_test.go b/modifiers/multi_test.go index 201a46c..455cf40 100644 --- a/modifiers/multi_test.go +++ b/modifiers/multi_test.go @@ -174,8 +174,8 @@ func TestDefaultSetSpecialTypes(t *testing.T) { field: (*[]string)(nil), tags: "default", vf: func(field interface{}) { - m := field.([]string) - Equal(t, len(m), 0) + m := field.(*[]string) + Equal(t, len(*m), 0) }, }, { @@ -183,8 +183,44 @@ func TestDefaultSetSpecialTypes(t *testing.T) { field: (*[]string)(nil), tags: "set", vf: func(field interface{}) { - m := field.([]string) - Equal(t, len(m), 0) + m := field.(*[]string) + Equal(t, len(*m), 0) + }, + }, + { + name: "default pointer to int", + field: (*int)(nil), + tags: "default=5", + vf: func(field interface{}) { + m := field.(*int) + Equal(t, m, 5) + }, + }, + { + name: "set pointer to int", + field: (*int)(nil), + tags: "set=5", + vf: func(field interface{}) { + m := field.(*int) + Equal(t, *m, 5) + }, + }, + { + name: "default pointer to string", + field: (*string)(nil), + tags: "default=test", + vf: func(field interface{}) { + m := field.(*string) + Equal(t, *m, "test") + }, + }, + { + name: "set pointer to string", + field: (*string)(nil), + tags: "set", + vf: func(field interface{}) { + m := field.(*string) + Equal(t, *m, "") }, }, } @@ -379,6 +415,42 @@ func TestDefault(t *testing.T) { tags: "default=1s", expected: time.Duration(1_000_000_000), }, + { + name: "default nil pointer to int", + field: (*int)(nil), + tags: "default=3", + expected: 3, + }, + { + name: "default not nil pointer to int", + field: newPointer(1), + tags: "default=3", + expected: 1, + }, + { + name: "default nil pointer to string", + field: (*string)(nil), + tags: "default=test", + expected: "test", + }, + { + name: "default not nil pointer to string", + field: newPointer("existing_value"), + tags: "default=test", + expected: "existing_value", + }, + { + name: "default nil pointer to bool", + field: (*bool)(nil), + tags: "default=true", + expected: true, + }, + { + name: "default not nil pointer to bool", + field: newPointer(true), + tags: "default=true", + expected: true, + }, { name: "bad default time.Duration", field: time.Duration(0), @@ -409,6 +481,18 @@ func TestDefault(t *testing.T) { tags: "default=blue", expectError: true, }, + { + name: "bad default pointer to int", + field: (*int)(nil), + tags: "default=abc", + expectError: true, + }, + { + name: "bad default pointer to bool", + field: (*bool)(nil), + tags: "default=abc", + expectError: true, + }, } for _, tc := range tests { @@ -500,3 +584,7 @@ func TestEmpty(t *testing.T) { }) } } + +func newPointer[T any](value T) *T { + return &value +} diff --git a/mold.go b/mold.go index dee4ee6..b1d05e9 100644 --- a/mold.go +++ b/mold.go @@ -264,7 +264,7 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cT }); err != nil { return } - orig.Set(reflect.Indirect(newVal)) + orig.Set(newVal) current, kind = t.extractType(orig) } else { if err = ct.fn(ctx, fieldLevel{ diff --git a/util.go b/util.go index 1f47769..dfeb3e3 100644 --- a/util.go +++ b/util.go @@ -2,6 +2,7 @@ package mold import ( "reflect" + "strconv" ) // extractType gets the actual underlying type of field value. @@ -36,3 +37,47 @@ func HasValue(field reflect.Value) bool { return field.IsValid() && field.Interface() != reflect.Zero(field.Type()).Interface() } } + +func GetPrimitiveValue(typ reflect.Type, value string) (reflect.Value, error) { + switch typ.Kind() { + + case reflect.String: + return reflect.ValueOf(value), nil + + case reflect.Int: + value, err := strconv.Atoi(value) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(value), nil + + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(int64(value)), nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + value, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(uint64(value)), nil + + case reflect.Float32, reflect.Float64: + value, err := strconv.ParseFloat(value, 64) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(value), nil + + case reflect.Bool: + value, err := strconv.ParseBool(value) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(value), nil + } + return reflect.Value{}, &ErrUnsupportedType{typ: typ} +} From b1cd38c9094da49154371f995a2a1236d4a94667 Mon Sep 17 00:00:00 2001 From: Dat Nguyen Date: Sun, 15 Dec 2024 21:29:21 +0700 Subject: [PATCH 2/2] chore: add test for getPrimitiveValue util --- errors.go | 15 +++- modifiers/multi.go | 2 +- util.go | 14 ++-- util_test.go | 199 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 220 insertions(+), 10 deletions(-) create mode 100644 util_test.go diff --git a/errors.go b/errors.go index 279db37..42871f0 100644 --- a/errors.go +++ b/errors.go @@ -20,12 +20,23 @@ var ( // ErrUnsupportedType describes an unsupported field type type ErrUnsupportedType struct { - typ reflect.Type + typ reflect.Kind } // Error returns the UnsupportedType error text func (e *ErrUnsupportedType) Error() string { - return fmt.Sprintf("mold: unsupported field type: %s", e.typ.Kind()) + return fmt.Sprintf("mold: unsupported field type: %s", e.typ) +} + +// ErrFailedToParseValue describes an error while parsing a value +type ErrFailedToParseValue struct { + typ reflect.Kind + err error +} + +// Error returns the FailedToParseValue error text +func (e *ErrFailedToParseValue) Error() string { + return fmt.Sprintf("mold: failed to parse value for type %s: %s", e.typ, e.err.Error()) } // ErrUndefinedTag defines a tag that does not exist diff --git a/modifiers/multi.go b/modifiers/multi.go index 4e55088..6752ad4 100644 --- a/modifiers/multi.go +++ b/modifiers/multi.go @@ -132,7 +132,7 @@ func setValue(ctx context.Context, fl mold.FieldLevel) error { // 2. Attempting to set its underlying value // Try to convert the parameter string to the appropriate primitive type // that the pointer references (e.g., *string, *int, *bool) - value, err := mold.GetPrimitiveValue(fl.Field().Type().Elem(), fl.Param()) + value, err := mold.GetPrimitiveValue(fl.Field().Type().Elem().Kind(), fl.Param()) if err != nil { // If ErrUnsupportedType: leave as zero value if _, isUnsupportedType := err.(*mold.ErrUnsupportedType); isUnsupportedType { diff --git a/util.go b/util.go index dfeb3e3..8a02dfa 100644 --- a/util.go +++ b/util.go @@ -38,8 +38,8 @@ func HasValue(field reflect.Value) bool { } } -func GetPrimitiveValue(typ reflect.Type, value string) (reflect.Value, error) { - switch typ.Kind() { +func GetPrimitiveValue(typ reflect.Kind, value string) (reflect.Value, error) { + switch typ { case reflect.String: return reflect.ValueOf(value), nil @@ -47,35 +47,35 @@ func GetPrimitiveValue(typ reflect.Type, value string) (reflect.Value, error) { case reflect.Int: value, err := strconv.Atoi(value) if err != nil { - return reflect.Value{}, err + return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err} } return reflect.ValueOf(value), nil case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: value, err := strconv.ParseInt(value, 10, 64) if err != nil { - return reflect.Value{}, err + return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err} } return reflect.ValueOf(int64(value)), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: value, err := strconv.ParseUint(value, 10, 64) if err != nil { - return reflect.Value{}, err + return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err} } return reflect.ValueOf(uint64(value)), nil case reflect.Float32, reflect.Float64: value, err := strconv.ParseFloat(value, 64) if err != nil { - return reflect.Value{}, err + return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err} } return reflect.ValueOf(value), nil case reflect.Bool: value, err := strconv.ParseBool(value) if err != nil { - return reflect.Value{}, err + return reflect.Value{}, &ErrFailedToParseValue{typ: typ, err: err} } return reflect.ValueOf(value), nil } diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..3235f0a --- /dev/null +++ b/util_test.go @@ -0,0 +1,199 @@ +package mold + +import ( + "math" + "reflect" + "testing" + + . "github.com/go-playground/assert/v2" +) + +func TestGetPrimitiveValue(t *testing.T) { + tests := []struct { + name string + typ reflect.Kind + value string + expected reflect.Value + expectError string + }{ + { + name: "string", + typ: reflect.String, + value: "test", + expected: reflect.ValueOf("test"), + }, + { + name: "int", + typ: reflect.Int, + value: "123", + expected: reflect.ValueOf(123), + }, + { + name: "int8", + typ: reflect.Int8, + value: "123", + expected: reflect.ValueOf(int8(123)), + }, + { + name: "int16", + typ: reflect.Int16, + value: "123", + expected: reflect.ValueOf(int16(123)), + }, + { + name: "bool", + typ: reflect.Bool, + value: "true", + expected: reflect.ValueOf(true), + }, + { + name: "error while parsing int", + typ: reflect.Int, + value: "abc", + expectError: "mold: failed to parse value for type int: strconv.Atoi: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing int8", + typ: reflect.Int8, + value: "abc", + expectError: "mold: failed to parse value for type int8: strconv.ParseInt: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing int16", + typ: reflect.Int16, + value: "abc", + expectError: "mold: failed to parse value for type int16: strconv.ParseInt: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing int32", + typ: reflect.Int32, + value: "abc", + expectError: "mold: failed to parse value for type int32: strconv.ParseInt: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing int64", + typ: reflect.Int64, + value: "abc", + expectError: "mold: failed to parse value for type int64: strconv.ParseInt: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing uint", + typ: reflect.Uint, + value: "abc", + expectError: "mold: failed to parse value for type uint: strconv.ParseUint: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing uint8", + typ: reflect.Uint8, + value: "abc", + expectError: "mold: failed to parse value for type uint8: strconv.ParseUint: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing uint16", + typ: reflect.Uint16, + value: "12.34", + expectError: "mold: failed to parse value for type uint16: strconv.ParseUint: parsing \"12.34\": invalid syntax", + }, + { + name: "error while parsing uint32", + typ: reflect.Uint32, + value: "12.34", + expectError: "mold: failed to parse value for type uint32: strconv.ParseUint: parsing \"12.34\": invalid syntax", + }, + { + name: "error while parsing uint64", + typ: reflect.Uint64, + value: "12.34", + expectError: "mold: failed to parse value for type uint64: strconv.ParseUint: parsing \"12.34\": invalid syntax", + }, + { + name: "error while parsing float32", + typ: reflect.Float32, + value: "abc", + expectError: "mold: failed to parse value for type float32: strconv.ParseFloat: parsing \"abc\": invalid syntax", + }, + { + name: "error while parsing bool", + typ: reflect.Bool, + value: "invalid bool", + expectError: "mold: failed to parse value for type bool: strconv.ParseBool: parsing \"invalid bool\": invalid syntax", + }, + { + name: "unsupported type", + typ: reflect.Struct, + value: "abc", + expectError: "mold: unsupported field type: struct", + }, + { + name: "uint", + typ: reflect.Uint, + value: "123", + expected: reflect.ValueOf(uint(123)), + }, + { + name: "uint8", + typ: reflect.Uint8, + value: "123", + expected: reflect.ValueOf(uint8(123)), + }, + { + name: "uint16", + typ: reflect.Uint16, + value: "123", + expected: reflect.ValueOf(uint16(123)), + }, + { + name: "uint32", + typ: reflect.Uint32, + value: "123", + expected: reflect.ValueOf(uint32(123)), + }, + { + name: "uint64", + typ: reflect.Uint64, + value: "123", + expected: reflect.ValueOf(uint64(123)), + }, + { + name: "float32", + typ: reflect.Float32, + value: "123.45", + expected: reflect.ValueOf(float32(123.45)), + }, + { + name: "float64", + typ: reflect.Float64, + value: "123.45", + expected: reflect.ValueOf(float64(123.45)), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual, err := GetPrimitiveValue(tc.typ, tc.value) + if tc.expectError != "" { + NotEqual(t, nil, err) + Equal(t, tc.expectError, err.Error()) + } else { + Equal(t, nil, err) + switch tc.typ { + case reflect.String: + Equal(t, tc.expected.String(), actual.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + Equal(t, tc.expected.Int(), actual.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + Equal(t, tc.expected.Uint(), actual.Uint()) + case reflect.Float32, reflect.Float64: + // could not assert equal float because of precision issues + // so we just check in 4 decimal places + decimalMask := math.Pow(10, 4) + Equal(t, math.Round(tc.expected.Float()*decimalMask), math.Round(actual.Float()*decimalMask)) + case reflect.Bool: + Equal(t, tc.expected.Bool(), actual.Bool()) + default: + Equal(t, tc.expected.Interface(), actual.Interface()) + } + } + }) + } +}