From c69309d1ffb7888cefec358a649a41059087ed4b Mon Sep 17 00:00:00 2001 From: John Howard Date: Thu, 18 Jan 2024 09:14:33 -0800 Subject: [PATCH] Fix marshalling of empty oneOf messages Fixes https://github.com/planetscale/vtprotobuf/issues/61 --- conformance/internal/conformance/fuzz_test.go | 58 +++++++++++++++++++ .../internal/conformance/oneof_test.go | 16 +++++ .../test_messages_proto2_vtproto.pb.go | 14 +++++ .../test_messages_proto3_vtproto.pb.go | 14 +++++ .../testdata/fuzz/FuzzProto/2bbe66c82943d372 | 2 + .../testdata/fuzz/FuzzProto/df6d2595afa8403b | 2 + .../testdata/fuzz/FuzzProto/f27b832b1ad734c3 | 2 + .../testdata/fuzz/FuzzProto/fde3fe3ca1ceac51 | 2 + features/marshal/marshalto.go | 11 ++-- features/size/size.go | 9 ++- testproto/pool/pool_with_oneof_vtproto.pb.go | 30 ++++++++++ testproto/unsafe/unsafe_vtproto.pb.go | 40 +++++++++++++ types/known/structpb/struct_vtproto.pb.go | 20 +++++++ 13 files changed, 211 insertions(+), 9 deletions(-) create mode 100644 conformance/internal/conformance/fuzz_test.go create mode 100644 conformance/internal/conformance/oneof_test.go create mode 100644 conformance/internal/conformance/testdata/fuzz/FuzzProto/2bbe66c82943d372 create mode 100644 conformance/internal/conformance/testdata/fuzz/FuzzProto/df6d2595afa8403b create mode 100644 conformance/internal/conformance/testdata/fuzz/FuzzProto/f27b832b1ad734c3 create mode 100644 conformance/internal/conformance/testdata/fuzz/FuzzProto/fde3fe3ca1ceac51 diff --git a/conformance/internal/conformance/fuzz_test.go b/conformance/internal/conformance/fuzz_test.go new file mode 100644 index 0000000..5aa0c52 --- /dev/null +++ b/conformance/internal/conformance/fuzz_test.go @@ -0,0 +1,58 @@ +package conformance + +import ( + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + "strings" + "testing" +) + +func roundTripUpstream(b []byte) ([]byte, error) { + msg := &TestAllTypesProto3{} + if err := proto.Unmarshal(b, msg); err != nil { + return nil, err + } + res, err := proto.Marshal(msg) + if err != nil { + return nil, err + } + return res, nil +} + +func roundTripVtprotobuf(b []byte) ([]byte, error) { + msg := &TestAllTypesProto3{} + if err := msg.UnmarshalVT(b); err != nil { + return nil, err + } + res, err := msg.MarshalVT() + if err != nil { + return nil, err + } + return res, nil +} + +func FuzzProto(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + u, uerr := roundTripUpstream(b) + v, verr := roundTripVtprotobuf(b) + if verr != nil && strings.Contains(verr.Error(), "wrong wireType") { + t.Skip() + } + if uerr != nil && strings.Contains(uerr.Error(), "cannot parse invalid wire-format data") { + t.Skip() + } + if (uerr != nil) != (verr != nil) { + t.Fatalf("upstream err: %v (%v), vtprotobuf err: %v (%v)", uerr, u, verr, v) + } + vt := &TestAllTypesProto3{} + _ = vt.UnmarshalVT(b) + t.Logf("vtprotobuf: %v, %v", protojson.Format(vt), prototext.Format(vt)) + us := &TestAllTypesProto3{} + _ = proto.Unmarshal(b, us) + + t.Logf("upstream: %v, %v", protojson.Format(us), prototext.Format(us)) + require.Equal(t, u, v) + }) +} diff --git a/conformance/internal/conformance/oneof_test.go b/conformance/internal/conformance/oneof_test.go new file mode 100644 index 0000000..af30965 --- /dev/null +++ b/conformance/internal/conformance/oneof_test.go @@ -0,0 +1,16 @@ +package conformance + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestEmptyOneoff(t *testing.T) { + // Regression test for https://github.com/planetscale/vtprotobuf/issues/61 + msg := &TestAllTypesProto3{OneofField: &TestAllTypesProto3_OneofNestedMessage{}} + upstream, _ := proto.Marshal(msg) + vt, _ := msg.MarshalVTStrict() + require.Equal(t, upstream, vt) +} diff --git a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go index 88f5f11..4b0a042 100644 --- a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go @@ -3918,6 +3918,12 @@ func (m *TestAllTypesProto2_OneofNestedMessage) MarshalToSizedBufferVT(dAtA []by dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -6038,6 +6044,12 @@ func (m *TestAllTypesProto2_OneofNestedMessage) MarshalToSizedBufferVTStrict(dAt dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -7094,6 +7106,8 @@ func (m *TestAllTypesProto2_OneofNestedMessage) SizeVT() (n int) { if m.OneofNestedMessage != nil { l = m.OneofNestedMessage.SizeVT() n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go index 5618fb5..60da09c 100644 --- a/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto3_vtproto.pb.go @@ -4113,6 +4113,12 @@ func (m *TestAllTypesProto3_OneofNestedMessage) MarshalToSizedBufferVT(dAtA []by dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -6317,6 +6323,12 @@ func (m *TestAllTypesProto3_OneofNestedMessage) MarshalToSizedBufferVTStrict(dAt dAtA[i] = 0x7 i-- dAtA[i] = 0x82 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x7 + i-- + dAtA[i] = 0x82 } return len(dAtA) - i, nil } @@ -7286,6 +7298,8 @@ func (m *TestAllTypesProto3_OneofNestedMessage) SizeVT() (n int) { if m.OneofNestedMessage != nil { l = m.OneofNestedMessage.SizeVT() n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/conformance/internal/conformance/testdata/fuzz/FuzzProto/2bbe66c82943d372 b/conformance/internal/conformance/testdata/fuzz/FuzzProto/2bbe66c82943d372 new file mode 100644 index 0000000..6e7e978 --- /dev/null +++ b/conformance/internal/conformance/testdata/fuzz/FuzzProto/2bbe66c82943d372 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\xe30$") diff --git a/conformance/internal/conformance/testdata/fuzz/FuzzProto/df6d2595afa8403b b/conformance/internal/conformance/testdata/fuzz/FuzzProto/df6d2595afa8403b new file mode 100644 index 0000000..54b66aa --- /dev/null +++ b/conformance/internal/conformance/testdata/fuzz/FuzzProto/df6d2595afa8403b @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("8\xb30") diff --git a/conformance/internal/conformance/testdata/fuzz/FuzzProto/f27b832b1ad734c3 b/conformance/internal/conformance/testdata/fuzz/FuzzProto/f27b832b1ad734c3 new file mode 100644 index 0000000..41846d9 --- /dev/null +++ b/conformance/internal/conformance/testdata/fuzz/FuzzProto/f27b832b1ad734c3 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x80\xff\x000") diff --git a/conformance/internal/conformance/testdata/fuzz/FuzzProto/fde3fe3ca1ceac51 b/conformance/internal/conformance/testdata/fuzz/FuzzProto/fde3fe3ca1ceac51 new file mode 100644 index 0000000..e61825e --- /dev/null +++ b/conformance/internal/conformance/testdata/fuzz/FuzzProto/fde3fe3ca1ceac51 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("X\xb30") diff --git a/features/marshal/marshalto.go b/features/marshal/marshalto.go index 9f323ce..88bcb21 100644 --- a/features/marshal/marshalto.go +++ b/features/marshal/marshalto.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/planetscale/vtprotobuf/generator" - "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" @@ -520,7 +519,12 @@ func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) { default: panic("not implemented") } - if repeated || nullable { + if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() { + p.P("} else {") + p.P("i = protohelpers.EncodeVarint(dAtA, i, 0)") + p.encodeKey(fieldNumber, wireType) + p.P("}") + } else if repeated || nullable { p.P(`}`) } } @@ -676,7 +680,7 @@ func (p *marshal) message(message *protogen.Message) { p.P(`}`) p.P() - //Generate MarshalToVT methods for oneof fields + // Generate MarshalToVT methods for oneof fields for _, field := range message.Fields { if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() { continue @@ -709,7 +713,6 @@ func (p *marshal) marshalBackwardSize(varInt bool) { if varInt { p.encodeVarint(`size`) } - } func (p *marshal) marshalBackward(varName string, varInt bool, message *protogen.Message) { diff --git a/features/size/size.go b/features/size/size.go index 61bf3da..da6a7db 100644 --- a/features/size/size.go +++ b/features/size/size.go @@ -8,11 +8,10 @@ package size import ( "strconv" + "github.com/planetscale/vtprotobuf/generator" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" - - "github.com/planetscale/vtprotobuf/generator" ) func init() { @@ -266,7 +265,9 @@ func (p *size) field(oneof bool, field *protogen.Field, sizeName string) { default: panic("not implemented") } - if repeated || nullable { + if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() { + p.P("} else { n += 3 }") + } else if repeated || nullable { p.P(`}`) } } @@ -310,8 +311,6 @@ func (p *size) message(message *protogen.Message) { } p.P(`}`) } else { - //if _, ok := oneofs[fieldname]; !ok { - //oneofs[fieldname] = struct{}{} p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{ SizeVT() int }); ok {`) p.P(`n+=vtmsg.`, sizeName, `()`) p.P(`}`) diff --git a/testproto/pool/pool_with_oneof_vtproto.pb.go b/testproto/pool/pool_with_oneof_vtproto.pb.go index 90a6a0d..9423720 100644 --- a/testproto/pool/pool_with_oneof_vtproto.pb.go +++ b/testproto/pool/pool_with_oneof_vtproto.pb.go @@ -571,6 +571,10 @@ func (m *OneofTest_Test1_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -590,6 +594,10 @@ func (m *OneofTest_Test2_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -609,6 +617,10 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -864,6 +876,10 @@ func (m *OneofTest_Test1_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -883,6 +899,10 @@ func (m *OneofTest_Test2_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -902,6 +922,10 @@ func (m *OneofTest_Test3_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1114,6 +1138,8 @@ func (m *OneofTest_Test1_) SizeVT() (n int) { if m.Test1 != nil { l = m.Test1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1126,6 +1152,8 @@ func (m *OneofTest_Test2_) SizeVT() (n int) { if m.Test2 != nil { l = m.Test2.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1138,6 +1166,8 @@ func (m *OneofTest_Test3_) SizeVT() (n int) { if m.Test3 != nil { l = m.Test3.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/testproto/unsafe/unsafe_vtproto.pb.go b/testproto/unsafe/unsafe_vtproto.pb.go index 5126407..f9c05bd 100644 --- a/testproto/unsafe/unsafe_vtproto.pb.go +++ b/testproto/unsafe/unsafe_vtproto.pb.go @@ -743,6 +743,10 @@ func (m *UnsafeTest_Sub1_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -762,6 +766,10 @@ func (m *UnsafeTest_Sub2_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -781,6 +789,10 @@ func (m *UnsafeTest_Sub3_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -800,6 +812,10 @@ func (m *UnsafeTest_Sub4_) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -1105,6 +1121,10 @@ func (m *UnsafeTest_Sub1_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0xa + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0xa } return len(dAtA) - i, nil } @@ -1124,6 +1144,10 @@ func (m *UnsafeTest_Sub2_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 } return len(dAtA) - i, nil } @@ -1143,6 +1167,10 @@ func (m *UnsafeTest_Sub3_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x1a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x1a } return len(dAtA) - i, nil } @@ -1162,6 +1190,10 @@ func (m *UnsafeTest_Sub4_) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x22 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x22 } return len(dAtA) - i, nil } @@ -1279,6 +1311,8 @@ func (m *UnsafeTest_Sub1_) SizeVT() (n int) { if m.Sub1 != nil { l = m.Sub1.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1291,6 +1325,8 @@ func (m *UnsafeTest_Sub2_) SizeVT() (n int) { if m.Sub2 != nil { l = m.Sub2.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1303,6 +1339,8 @@ func (m *UnsafeTest_Sub3_) SizeVT() (n int) { if m.Sub3 != nil { l = m.Sub3.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -1315,6 +1353,8 @@ func (m *UnsafeTest_Sub4_) SizeVT() (n int) { if m.Sub4 != nil { l = m.Sub4.SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } diff --git a/types/known/structpb/struct_vtproto.pb.go b/types/known/structpb/struct_vtproto.pb.go index aa35f71..9c1195a 100644 --- a/types/known/structpb/struct_vtproto.pb.go +++ b/types/known/structpb/struct_vtproto.pb.go @@ -569,6 +569,10 @@ func (m *Value_StructValue) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -588,6 +592,10 @@ func (m *Value_ListValue) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x32 } return len(dAtA) - i, nil } @@ -832,6 +840,10 @@ func (m *Value_StructValue) MarshalToSizedBufferVTStrict(dAtA []byte) (int, erro i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x2a + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x2a } return len(dAtA) - i, nil } @@ -851,6 +863,10 @@ func (m *Value_ListValue) MarshalToSizedBufferVTStrict(dAtA []byte) (int, error) i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) i-- dAtA[i] = 0x32 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x32 } return len(dAtA) - i, nil } @@ -986,6 +1002,8 @@ func (m *Value_StructValue) SizeVT() (n int) { if m.StructValue != nil { l = (*Struct)(m.StructValue).SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n } @@ -998,6 +1016,8 @@ func (m *Value_ListValue) SizeVT() (n int) { if m.ListValue != nil { l = (*ListValue)(m.ListValue).SizeVT() n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 3 } return n }