Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions _generated/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package _generated

import (
"bytes"
"errors"
"fmt"
"testing"

Expand Down Expand Up @@ -1045,7 +1046,7 @@ func TestAllowNilSecurityLimits(t *testing.T) {

reader := msgp.NewReader(bytes.NewReader(buf))
err := data.DecodeMsg(reader)
if err != msgp.ErrLimitExceeded {
if !errors.Is(err, msgp.ErrLimitExceeded) {
t.Errorf("Expected ErrLimitExceeded for allownil bytes exceeding limit, got %v", err)
}
})
Expand Down Expand Up @@ -1160,7 +1161,7 @@ func TestAllowNilSecurityLimits(t *testing.T) {

reader := msgp.NewReader(bytes.NewReader(buf))
err := data.DecodeMsg(reader)
if err != msgp.ErrLimitExceeded {
if !errors.Is(err, msgp.ErrLimitExceeded) {
t.Errorf("Expected ErrLimitExceeded (header-first check), got %v", err)
}
})
Expand Down Expand Up @@ -1361,7 +1362,7 @@ func TestAllowNilZeroCopy(t *testing.T) {

reader := msgp.NewReader(bytes.NewReader(buf))
err := data.DecodeMsg(reader)
if err != msgp.ErrLimitExceeded {
if !errors.Is(err, msgp.ErrLimitExceeded) {
t.Errorf("Expected ErrLimitExceeded for zerocopy allownil field exceeding limit, got %v", err)
}
})
Expand All @@ -1379,7 +1380,7 @@ func TestAllowNilZeroCopy(t *testing.T) {

reader := msgp.NewReader(bytes.NewReader(buf))
err := data.DecodeMsg(reader)
if err != msgp.ErrLimitExceeded {
if !errors.Is(err, msgp.ErrLimitExceeded) {
t.Errorf("Expected ErrLimitExceeded (header-first check) for zerocopy allownil, got %v", err)
}
})
Expand Down
54 changes: 6 additions & 48 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,11 @@ func (d *decodeGen) readBytesWithLimit(vname string, fieldLimit uint32) string {

// Choose reading strategy based on whether limits exist
if limit > 0 && limit != math.MaxUint32 {
// Limits exist - use header-first security approach
sz := randIdent()
d.p.printf("\nvar %s uint32", sz)
d.p.printf("\n%s, err = dc.ReadBytesHeader()", sz)
d.p.wrapErrCheck(d.ctx.ArgsStr())

// Check size against limit before allocating
d.p.printf("\nif %s > %s {", sz, limitName)
d.p.printf("\nerr = msgp.ErrLimitExceeded")
d.p.printf("\nreturn")
d.p.printf("\n%s, err = dc.ReadBytesLimit(%s, %s)", vname, vname, limitName)
// field cannot be nil
d.p.printf("\nif err == nil && %s == nil {", vname)
d.p.printf("\n%s = []byte{}", vname)
d.p.printf("\n}")

// Allocate and read the data
// regular field - ensure always allocated, even for size 0
d.p.printf("\nif %s == nil || uint32(cap(%s)) < %s {", vname, vname, sz)
d.p.printf("\n%s = make([]byte, %s)", vname, sz)
d.p.printf("\n} else {")
d.p.printf("\n%s = %s[:%s]", vname, vname, sz)
d.p.printf("\n}")
d.p.printf("\n_, err = dc.ReadFull(%s)", vname)
return ""
} else {
// No limits - use original direct reading approach for efficiency
Expand Down Expand Up @@ -348,35 +333,8 @@ func (d *decodeGen) readBytesConvertWithLimit(tmp string, allowNil bool, receive

// Check if limits exist to decide on reading strategy
if d.ctx.currentFieldArrayLimit != math.MaxUint32 || d.ctx.arrayLimit != math.MaxUint32 {
// Limits exist - use header-first approach for security
sz := randIdent()
d.p.printf("\nvar %s uint32", sz)
d.p.printf("\n%s, err = dc.ReadBytesHeader()", sz)
d.p.wrapErrCheck(d.ctx.ArgsStr())

// Check array limits for bytes (use currentFieldArrayLimit or arrayLimit)
if d.ctx.currentFieldArrayLimit != math.MaxUint32 {
d.p.printf("\nif %s > %d {", sz, d.ctx.currentFieldArrayLimit)
d.p.printf("\nerr = msgp.ErrLimitExceeded")
d.p.printf("\nreturn")
d.p.printf("\n}")
} else if d.ctx.arrayLimit != math.MaxUint32 {
d.p.printf("\nif %s > %slimitArrays {", sz, d.ctx.limitPrefix)
d.p.printf("\nerr = msgp.ErrLimitExceeded")
d.p.printf("\nreturn")
d.p.printf("\n}")
}

// Allocate and read with type conversion
if tmp != receiverVar {
d.p.printf("\n%s = %s", tmp, receiverVar)
}
d.p.printf("\nif %s == nil || uint32(cap(%s)) < %s {", tmp, tmp, sz)
d.p.printf("\n%s = make([]byte, %s)", tmp, sz)
d.p.printf("\n} else {")
d.p.printf("\n%s = %s[:%s]", tmp, tmp, sz)
d.p.printf("\n}")
d.p.printf("\n_, err = dc.ReadFull(%s)", tmp)
// No limits - use original efficient approach with receiver cast as destination
d.p.printf("\n%s, err = dc.ReadBytesLimit(%s, %d)", tmp, receiverVar, min(d.ctx.currentFieldArrayLimit, d.ctx.arrayLimit))
} else {
// No limits - use original efficient approach with receiver cast as destination
d.p.printf("\n%s, err = dc.ReadBytes(%s)", tmp, receiverVar)
Expand Down
57 changes: 57 additions & 0 deletions msgp/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,63 @@ func (m *Reader) ReadBytes(scratch []byte) (b []byte, err error) {
return
}

// ReadBytesLimit reads a MessagePack 'bin' object
// from the reader and returns its value. It may
// use 'scratch' for storage if it is non-nil.
// If n >= 0 this will be the maximum bytes read.
// If n < 0, the cap(scratch) will be the limit.
// If SetMaxElements has been used on the Reader,
// that will only be checked if the scratch is too small.
func (m *Reader) ReadBytesLimit(scratch []byte, n int64) (b []byte, err error) {
var p []byte
var lead byte
p, err = m.R.Peek(2)
if err != nil {
return
}
lead = p[0]
var read int64
switch lead {
case mbin8:
read = int64(p[1])
m.R.Skip(2)
case mbin16:
p, err = m.R.Next(3)
if err != nil {
return
}
read = int64(big.Uint16(p[1:]))
case mbin32:
p, err = m.R.Next(5)
if err != nil {
return
}
read = int64(big.Uint32(p[1:]))
default:
err = badPrefix(BinType, lead)
return
}
if n < 0 {
n = int64(cap(scratch))
}
fmt.Println(n, read)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@klauspost you left a Println in here

if read > n {
err = ErrLimitExceeded
return
}
if int64(cap(scratch)) < read {
b = make([]byte, read)
if read > int64(m.GetMaxElements()) {
err = ErrLimitExceeded
return
}
} else {
b = scratch[0:read]
}
_, err = m.R.ReadFull(b)
return
}

// ReadBytesHeader reads the size header
// of a MessagePack 'bin' object. The user
// is responsible for dealing with the next
Expand Down
Loading