From f736d2a06d0df1a7f00db783178952a2a77ef92a Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 22 Dec 2025 16:44:47 +0100 Subject: [PATCH] Add (Reader).ReadBytesLimit Simplifies generated code. Fixes https://github.com/tinylib/msgp/issues/425 Modify a few tests to allow wrapped errors. --- _generated/limits_test.go | 9 ++++--- gen/decode.go | 54 +++++-------------------------------- msgp/read.go | 57 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/_generated/limits_test.go b/_generated/limits_test.go index fcefe579..9ee25872 100644 --- a/_generated/limits_test.go +++ b/_generated/limits_test.go @@ -2,6 +2,7 @@ package _generated import ( "bytes" + "errors" "fmt" "testing" @@ -1045,7 +1046,7 @@ func TestAllowNilSecurityLimits(t *testing.T) { reader := msgp.NewReader(bytes.NewReader(buf)) err := data.DecodeMsg(reader) - if err != msgp.ErrLimitExceeded { + if !errors.Is(err, msgp.ErrLimitExceeded) { t.Errorf("Expected ErrLimitExceeded for allownil bytes exceeding limit, got %v", err) } }) @@ -1160,7 +1161,7 @@ func TestAllowNilSecurityLimits(t *testing.T) { reader := msgp.NewReader(bytes.NewReader(buf)) err := data.DecodeMsg(reader) - if err != msgp.ErrLimitExceeded { + if !errors.Is(err, msgp.ErrLimitExceeded) { t.Errorf("Expected ErrLimitExceeded (header-first check), got %v", err) } }) @@ -1361,7 +1362,7 @@ func TestAllowNilZeroCopy(t *testing.T) { reader := msgp.NewReader(bytes.NewReader(buf)) err := data.DecodeMsg(reader) - if err != msgp.ErrLimitExceeded { + if !errors.Is(err, msgp.ErrLimitExceeded) { t.Errorf("Expected ErrLimitExceeded for zerocopy allownil field exceeding limit, got %v", err) } }) @@ -1379,7 +1380,7 @@ func TestAllowNilZeroCopy(t *testing.T) { reader := msgp.NewReader(bytes.NewReader(buf)) err := data.DecodeMsg(reader) - if err != msgp.ErrLimitExceeded { + if !errors.Is(err, msgp.ErrLimitExceeded) { t.Errorf("Expected ErrLimitExceeded (header-first check) for zerocopy allownil, got %v", err) } }) diff --git a/gen/decode.go b/gen/decode.go index 2b6a6f38..cf3f05f7 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -168,26 +168,11 @@ func (d *decodeGen) readBytesWithLimit(vname string, fieldLimit uint32) string { // Choose reading strategy based on whether limits exist if limit > 0 && limit != math.MaxUint32 { - // Limits exist - use header-first security approach - sz := randIdent() - d.p.printf("\nvar %s uint32", sz) - d.p.printf("\n%s, err = dc.ReadBytesHeader()", sz) - d.p.wrapErrCheck(d.ctx.ArgsStr()) - - // Check size against limit before allocating - d.p.printf("\nif %s > %s {", sz, limitName) - d.p.printf("\nerr = msgp.ErrLimitExceeded") - d.p.printf("\nreturn") + d.p.printf("\n%s, err = dc.ReadBytesLimit(%s, %s)", vname, vname, limitName) + // field cannot be nil + d.p.printf("\nif err == nil && %s == nil {", vname) + d.p.printf("\n%s = []byte{}", vname) d.p.printf("\n}") - - // Allocate and read the data - // regular field - ensure always allocated, even for size 0 - d.p.printf("\nif %s == nil || uint32(cap(%s)) < %s {", vname, vname, sz) - d.p.printf("\n%s = make([]byte, %s)", vname, sz) - d.p.printf("\n} else {") - d.p.printf("\n%s = %s[:%s]", vname, vname, sz) - d.p.printf("\n}") - d.p.printf("\n_, err = dc.ReadFull(%s)", vname) return "" } else { // No limits - use original direct reading approach for efficiency @@ -348,35 +333,8 @@ func (d *decodeGen) readBytesConvertWithLimit(tmp string, allowNil bool, receive // Check if limits exist to decide on reading strategy if d.ctx.currentFieldArrayLimit != math.MaxUint32 || d.ctx.arrayLimit != math.MaxUint32 { - // Limits exist - use header-first approach for security - sz := randIdent() - d.p.printf("\nvar %s uint32", sz) - d.p.printf("\n%s, err = dc.ReadBytesHeader()", sz) - d.p.wrapErrCheck(d.ctx.ArgsStr()) - - // Check array limits for bytes (use currentFieldArrayLimit or arrayLimit) - if d.ctx.currentFieldArrayLimit != math.MaxUint32 { - d.p.printf("\nif %s > %d {", sz, d.ctx.currentFieldArrayLimit) - d.p.printf("\nerr = msgp.ErrLimitExceeded") - d.p.printf("\nreturn") - d.p.printf("\n}") - } else if d.ctx.arrayLimit != math.MaxUint32 { - d.p.printf("\nif %s > %slimitArrays {", sz, d.ctx.limitPrefix) - d.p.printf("\nerr = msgp.ErrLimitExceeded") - d.p.printf("\nreturn") - d.p.printf("\n}") - } - - // Allocate and read with type conversion - if tmp != receiverVar { - d.p.printf("\n%s = %s", tmp, receiverVar) - } - d.p.printf("\nif %s == nil || uint32(cap(%s)) < %s {", tmp, tmp, sz) - d.p.printf("\n%s = make([]byte, %s)", tmp, sz) - d.p.printf("\n} else {") - d.p.printf("\n%s = %s[:%s]", tmp, tmp, sz) - d.p.printf("\n}") - d.p.printf("\n_, err = dc.ReadFull(%s)", tmp) + // No limits - use original efficient approach with receiver cast as destination + d.p.printf("\n%s, err = dc.ReadBytesLimit(%s, %d)", tmp, receiverVar, min(d.ctx.currentFieldArrayLimit, d.ctx.arrayLimit)) } else { // No limits - use original efficient approach with receiver cast as destination d.p.printf("\n%s, err = dc.ReadBytes(%s)", tmp, receiverVar) diff --git a/msgp/read.go b/msgp/read.go index 264933ad..580a8c25 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -1009,6 +1009,63 @@ func (m *Reader) ReadBytes(scratch []byte) (b []byte, err error) { return } +// ReadBytesLimit reads a MessagePack 'bin' object +// from the reader and returns its value. It may +// use 'scratch' for storage if it is non-nil. +// If n >= 0 this will be the maximum bytes read. +// If n < 0, the cap(scratch) will be the limit. +// If SetMaxElements has been used on the Reader, +// that will only be checked if the scratch is too small. +func (m *Reader) ReadBytesLimit(scratch []byte, n int64) (b []byte, err error) { + var p []byte + var lead byte + p, err = m.R.Peek(2) + if err != nil { + return + } + lead = p[0] + var read int64 + switch lead { + case mbin8: + read = int64(p[1]) + m.R.Skip(2) + case mbin16: + p, err = m.R.Next(3) + if err != nil { + return + } + read = int64(big.Uint16(p[1:])) + case mbin32: + p, err = m.R.Next(5) + if err != nil { + return + } + read = int64(big.Uint32(p[1:])) + default: + err = badPrefix(BinType, lead) + return + } + if n < 0 { + n = int64(cap(scratch)) + } + fmt.Println(n, read) + if read > n { + err = ErrLimitExceeded + return + } + if int64(cap(scratch)) < read { + b = make([]byte, read) + if read > int64(m.GetMaxElements()) { + err = ErrLimitExceeded + return + } + } else { + b = scratch[0:read] + } + _, err = m.R.ReadFull(b) + return +} + // ReadBytesHeader reads the size header // of a MessagePack 'bin' object. The user // is responsible for dealing with the next