diff --git a/json/codec.go b/json/codec.go index 908c3f6..37ec9cf 100644 --- a/json/codec.go +++ b/json/codec.go @@ -4,6 +4,7 @@ import ( "encoding" "encoding/json" "fmt" + "math/big" "reflect" "sort" "strconv" @@ -838,6 +839,7 @@ func constructInlineValueEncodeFunc(encode encodeFunc) encodeFunc { // compiles down to zero instructions. // USE CAREFULLY! // This was copied from the runtime; see issues 23382 and 7921. +// //go:nosplit func noescape(p unsafe.Pointer) unsafe.Pointer { x := uintptr(p) @@ -1078,6 +1080,7 @@ var ( float32Type = reflect.TypeOf(float32(0)) float64Type = reflect.TypeOf(float64(0)) + bigIntType = reflect.TypeOf(new(big.Int)) numberType = reflect.TypeOf(json.Number("")) stringType = reflect.TypeOf("") stringsType = reflect.TypeOf([]string(nil)) @@ -1104,6 +1107,8 @@ var ( jsonUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + + bigIntDecoder = constructJSONUnmarshalerDecodeFunc(bigIntType, false) ) // ============================================================================= diff --git a/json/decode.go b/json/decode.go index 4f7f7a0..9792af0 100644 --- a/json/decode.go +++ b/json/decode.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "math/big" "reflect" "strconv" "time" @@ -16,6 +17,10 @@ import ( "github.com/segmentio/encoding/iso8601" ) +func (d decoder) anyFlagsSet(flags ParseFlags) bool { + return d.flags&flags != 0 +} + func (d decoder) decodeNull(b []byte, p unsafe.Pointer) ([]byte, error) { if hasNullPrefix(b) { return b[4:], nil @@ -1306,15 +1311,7 @@ func (d decoder) decodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) { v, val = nil, k == True case Num: - if (d.flags & UseNumber) != 0 { - n := Number("") - v, err = d.decodeNumber(v, unsafe.Pointer(&n)) - val = n - } else { - f := 0.0 - v, err = d.decodeFloat64(v, unsafe.Pointer(&f)) - val = f - } + v, err = d.decodeDynamicNumber(v, unsafe.Pointer(&val)) default: return b, syntaxError(v, "expected token but found '%c'", v[0]) @@ -1332,6 +1329,68 @@ func (d decoder) decodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) { return b, nil } +func (d decoder) decodeDynamicNumber(b []byte, p unsafe.Pointer) ([]byte, error) { + kind := Float + var err error + + // Only pre-parse for numeric kind if a conditional decode + // has been requested. + if d.anyFlagsSet(UseBigInt | UseInt64 | UseUint64) { + _, _, kind, err = d.parseNumber(b) + if err != nil { + return b, err + } + } + + var rem []byte + anyPtr := (*any)(p) + + // Mutually exclusive integer handling cases. + switch { + // If requested, attempt decode of positive integers as uint64. + case kind == Uint && d.anyFlagsSet(UseUint64): + rem, err = decodeInto[uint64](anyPtr, b, d, decoder.decodeUint64) + if err == nil { + return rem, err + } + + // If uint64 decode was not requested but int64 decode was requested, + // then attempt decode of positive integers as int64. + case kind == Uint && d.anyFlagsSet(UseInt64): + fallthrough + + // If int64 decode was requested, + // attempt decode of negative integers as int64. + case kind == Int && d.anyFlagsSet(UseInt64): + rem, err = decodeInto[int64](anyPtr, b, d, decoder.decodeInt64) + if err == nil { + return rem, err + } + } + + // Fallback numeric handling cases: + // these cannot be combined into the above switch, + // since these cases also handle overflow + // from the above cases, if decode was already attempted. + switch { + // If *big.Int decode was requested, handle that case for any integer. + case kind == Uint && d.anyFlagsSet(UseBigInt): + fallthrough + case kind == Int && d.anyFlagsSet(UseBigInt): + rem, err = decodeInto[*big.Int](anyPtr, b, d, bigIntDecoder) + + // If json.Number decode was requested, handle that for any number. + case d.anyFlagsSet(UseNumber): + rem, err = decodeInto[Number](anyPtr, b, d, decoder.decodeNumber) + + // Fall back to float64 decode when no special decoding has been requested. + default: + rem, err = decodeInto[float64](anyPtr, b, d, decoder.decodeFloat64) + } + + return rem, err +} + func (d decoder) decodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) { if hasNullPrefix(b) { *(*interface{})(p) = nil diff --git a/json/json.go b/json/json.go index 47f3ba1..d5f6f9d 100644 --- a/json/json.go +++ b/json/json.go @@ -128,6 +128,19 @@ const ( // mode. DontMatchCaseInsensitiveStructFields + // Decode integers into *big.Int. + // Takes precedence over UseNumber for integers. + UseBigInt + + // Decode in-range integers to int64. + // Takes precedence over UseNumber and UseBigInt for in-range integers. + UseInt64 + + // Decode in-range positive integers to uint64. + // Takes precedence over UseNumber, UseBigInt, and UseInt64 + // for positive, in-range integers. + UseUint64 + // ZeroCopy is a parsing flag that combines all the copy optimizations // available in the package. // diff --git a/json/json_test.go b/json/json_test.go index 9af95b3..fb77868 100644 --- a/json/json_test.go +++ b/json/json_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "math" + "math/big" "os" "path/filepath" "reflect" @@ -85,6 +86,13 @@ type tree struct { Right *tree } +var ( + // bigPos128 and bigNeg128 are 1<<128 and -1<<128 + // certainly neither is representable using a uint64/int64. + bigPos128 = new(big.Int).Lsh(big.NewInt(1), 128) + bigNeg128 = new(big.Int).Neg(bigPos128) +) + var testValues = [...]interface{}{ // constants nil, @@ -126,6 +134,9 @@ var testValues = [...]interface{}{ float64(math.SmallestNonzeroFloat64), float64(math.MaxFloat64), + bigPos128, + bigNeg128, + // number Number("0"), Number("1234567890"), @@ -484,6 +495,134 @@ func TestCodecDuration(t *testing.T) { } } +var numericParseTests = [...]struct { + name string + input string + flags ParseFlags + want any +}{ + { + name: "zero_flags_default", + input: `0`, + flags: 0, + want: float64(0), + }, + { + name: "zero_flags_int_uint_bigint_number", + input: `0`, + flags: UseInt64 | UseUint64 | UseBigInt | UseNumber, + want: uint64(0), + }, + { + name: "zero_flags_int_bigint_number", + input: `0`, + flags: UseInt64 | UseBigInt | UseNumber, + want: int64(0), + }, + { + name: "zero_flags_bigint_number", + input: `0`, + flags: UseBigInt | UseNumber, + want: big.NewInt(0), + }, + { + name: "zero_flags_number", + input: `0`, + flags: UseNumber, + want: json.Number(`0`), + }, + { + name: "max_uint64_flags_default", + input: fmt.Sprint(uint64(math.MaxUint64)), + flags: 0, + want: float64(math.MaxUint64), + }, + { + name: "max_uint64_flags_int_uint_bigint_number", + input: fmt.Sprint(uint64(math.MaxUint64)), + flags: UseInt64 | UseUint64 | UseBigInt | UseNumber, + want: uint64(math.MaxUint64), + }, + { + name: "min_int64_flags_uint_int_bigint_number", + input: fmt.Sprint(int64(math.MinInt64)), + flags: UseInt64 | UseBigInt | UseNumber, + want: int64(math.MinInt64), + }, + { + name: "max_uint64_flags_int_bigint_number", + input: fmt.Sprint(uint64(math.MaxUint64)), + flags: UseInt64 | UseBigInt | UseNumber, + want: new(big.Int).SetUint64(math.MaxUint64), + }, + { + name: "overflow_uint64_flags_uint_int_bigint_number", + input: bigPos128.String(), + flags: UseUint64 | UseInt64 | UseBigInt | UseNumber, + want: bigPos128, + }, + { + name: "underflow_uint64_flags_uint_int_bigint_number", + input: bigNeg128.String(), + flags: UseUint64 | UseInt64 | UseBigInt | UseNumber, + want: bigNeg128, + }, + { + name: "overflow_uint64_flags_uint_int_number", + input: bigPos128.String(), + flags: UseUint64 | UseInt64 | UseNumber, + want: json.Number(bigPos128.String()), + }, + { + name: "underflow_uint64_flags_uint_int_number", + input: bigNeg128.String(), + flags: UseUint64 | UseInt64 | UseNumber, + want: json.Number(bigNeg128.String()), + }, + { + name: "overflow_uint64_flags_uint_int", + input: bigPos128.String(), + flags: UseUint64 | UseInt64, + want: float64(1 << 128), + }, + { + name: "underflow_uint64_flags_uint_int", + input: bigNeg128.String(), + flags: UseUint64 | UseInt64, + want: float64(-1 << 128), + }, +} + +func TestParse_numeric(t *testing.T) { + t.Parallel() + + for _, test := range numericParseTests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var got any + + rem, err := Parse([]byte(test.input), &got, test.flags) + if err != nil { + format := "Parse(%#q, ..., %#b) = %q [error], want nil" + t.Errorf(format, test.input, test.flags, err) + } + + if len(rem) != 0 { + format := "Parse(%#q, ..., %#b) = %#q, want zero length" + t.Errorf(format, test.input, test.flags, rem) + } + + if !reflect.DeepEqual(got, test.want) { + format := "Parse(%#q, %#b) -> %T(%#[3]v), want %T(%#[4]v)" + t.Errorf(format, test.input, test.flags, got, test.want) + } + }) + } +} + func newValue(model interface{}) reflect.Value { if model == nil { return reflect.New(reflect.TypeOf(&model).Elem())