From 2b06ddbf28e2855adaf707a0b6b868fd706a843c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 9 Dec 2025 15:59:15 +0100 Subject: [PATCH 1/5] perf: improve checksum calculation Optimize the checksum calculation that is used for read/write transactions. New implementation: ``` goos: darwin goarch: arm64 pkg: github.com/googleapis/go-sql-spanner cpu: Apple M3 BenchmarkChecksumRowIteratorRandom BenchmarkChecksumRowIteratorRandom/num-rows-1 BenchmarkChecksumRowIteratorRandom/num-rows-1-8 134152 8184 ns/op 13528 B/op 123 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-10 BenchmarkChecksumRowIteratorRandom/num-rows-10-8 16498 72459 ns/op 120690 B/op 1335 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-100 BenchmarkChecksumRowIteratorRandom/num-rows-100-8 1725 692326 ns/op 1153975 B/op 12781 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-1000 BenchmarkChecksumRowIteratorRandom/num-rows-1000-8 157 7566917 ns/op 11375170 B/op 130466 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-10000 BenchmarkChecksumRowIteratorRandom/num-rows-10000-8 14 83402393 ns/op 112045544 B/op 1294480 allocs/op ``` Original implementation: ``` goos: darwin goarch: arm64 pkg: github.com/googleapis/go-sql-spanner cpu: Apple M3 BenchmarkChecksumRowIteratorRandom BenchmarkChecksumRowIteratorRandom/num-rows-1 BenchmarkChecksumRowIteratorRandom/num-rows-1-8 33792 34931 ns/op 908 B/op 54 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-10 BenchmarkChecksumRowIteratorRandom/num-rows-10-8 4807 249917 ns/op 8768 B/op 531 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-100 BenchmarkChecksumRowIteratorRandom/num-rows-100-8 494 2408323 ns/op 87580 B/op 5301 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-1000 BenchmarkChecksumRowIteratorRandom/num-rows-1000-8 44 25243051 ns/op 870935 B/op 53003 allocs/op BenchmarkChecksumRowIteratorRandom/num-rows-10000 BenchmarkChecksumRowIteratorRandom/num-rows-10000-8 4 259823406 ns/op 8691096 B/op 530018 allocs/op ``` --- checksum_row_iterator.go | 120 ++++++++++++------- checksum_row_iterator_test.go | 202 ++++++++++++++++++++++++++------ testutil/mocked_inmem_server.go | 131 ++++++++++++++++++++- transaction.go | 7 +- 4 files changed, 375 insertions(+), 85 deletions(-) diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index c6980a27..6637f1ea 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -18,8 +18,11 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/gob" + "encoding/binary" + "hash" + "math" "reflect" + "sort" "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -32,16 +35,6 @@ import ( var errNextAfterSTop = status.Errorf(codes.FailedPrecondition, "Next called after Stop") -// init registers the protobuf types with gob so they can be encoded. -func init() { - gob.Register(structpb.Value_BoolValue{}) - gob.Register(structpb.Value_NumberValue{}) - gob.Register(structpb.Value_StringValue{}) - gob.Register(structpb.Value_NullValue{}) - gob.Register(structpb.Value_ListValue{}) - gob.Register(structpb.Value_StructValue{}) -} - var _ rowIterator = &checksumRowIterator{} // checksumRowIterator implements rowIterator and keeps track of a running @@ -66,12 +59,10 @@ type checksumRowIterator struct { // the retry has finished. stopped bool - // checksum contains the current checksum for the results that have been + // hash contains the current hash for the results that have been // seen. It is calculated as a SHA256 checksum over all rows that so far // have been returned. - checksum *[32]byte - buffer *bytes.Buffer - enc *gob.Encoder + hash hash.Hash // errIndex and err indicate any error and the index in the result set // where the error occurred. @@ -110,7 +101,7 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { // checksum of the columns that are included in this result. This is // also used to detect the possible difference between two empty // result sets with a different set of columns. - it.checksum, err = createMetadataChecksum(it.enc, it.buffer, it.metadata) + it.hash, err = createMetadataChecksum(it.metadata) if err != nil { return err } @@ -119,45 +110,82 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { return it.err } // Update the current checksum. - it.checksum, err = updateChecksum(it.enc, it.buffer, it.checksum, row) - return err + return updateChecksum(it.hash, row) }) return row, err } // updateChecksum calculates the following checksum based on a current checksum // and a new row. -func updateChecksum(enc *gob.Encoder, buffer *bytes.Buffer, currentChecksum *[32]byte, row *spanner.Row) (*[32]byte, error) { - buffer.Reset() - buffer.Write(currentChecksum[:]) +func updateChecksum(hash hash.Hash, row *spanner.Row) error { for i := 0; i < row.Size(); i++ { var v spanner.GenericColumnValue err := row.Column(i, &v) if err != nil { - return nil, err + return err } - err = enc.Encode(v) - if err != nil { - return nil, err + hashValue(v.Value, hash) + } + return nil +} + +var int32Buf [4]byte +var float64Buf [8]byte + +func hashValue(value *structpb.Value, digest hash.Hash) { + switch value.GetKind().(type) { + case *structpb.Value_StringValue: + digest.Write(intToByte(int32Buf, len(value.GetStringValue()))) + digest.Write([]byte(value.GetStringValue())) + case *structpb.Value_NullValue: + digest.Write([]byte{0}) + case *structpb.Value_NumberValue: + digest.Write(float64ToByte(float64Buf, value.GetNumberValue())) + case *structpb.Value_BoolValue: + if value.GetBoolValue() { + digest.Write([]byte{1}) + } else { + digest.Write([]byte{0}) + } + case *structpb.Value_StructValue: + fields := make([]string, 0, len(value.GetStructValue().Fields)) + for field, _ := range value.GetStructValue().Fields { + fields = append(fields, field) + } + sort.Strings(fields) + for _, field := range fields { + digest.Write(intToByte(int32Buf, len(field))) + digest.Write([]byte(field)) + hashValue(value.GetStructValue().Fields[field], digest) + } + case *structpb.Value_ListValue: + for _, v := range value.GetListValue().GetValues() { + hashValue(v, digest) } } - res := sha256.Sum256(buffer.Bytes()) - return &res, nil +} + +func intToByte(buf [4]byte, v int) []byte { + binary.BigEndian.PutUint32(buf[:], uint32(v)) + return buf[:] +} + +func float64ToByte(buf [8]byte, f float64) []byte { + binary.BigEndian.PutUint64(buf[:], math.Float64bits(f)) + return buf[:] } // createMetadataChecksum calculates the checksum of the metadata of a result. // Only the column names and types are included in the checksum. Any transaction // metadata is not included. -func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sppb.ResultSetMetadata) (*[32]byte, error) { - buffer.Reset() +func createMetadataChecksum(metadata *sppb.ResultSetMetadata) (hash.Hash, error) { + digest := sha256.New() for _, field := range metadata.RowType.Fields { - err := enc.Encode(field) - if err != nil { - return nil, err - } + digest.Write(intToByte(int32Buf, len(field.Name))) + digest.Write([]byte(field.Name)) + digest.Write(intToByte(int32Buf, int(field.Type.Code.Number()))) } - res := sha256.Sum256(buffer.Bytes()) - return &res, nil + return digest, nil } // retry implements retriableStatement.retry for queries. It will execute the @@ -167,8 +195,6 @@ func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sp // initial iterator was also returned by the new iterator, and that the errors // were returned by the same row index. func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - buffer := &bytes.Buffer{} - enc := gob.NewEncoder(buffer) retryIt := tx.QueryWithOptions(ctx, it.stmt, it.options) // If the original iterator had been stopped, we should also always stop the // new iterator. @@ -193,12 +219,13 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS // Iterate over the new result set as many times as we iterated over the initial // result set. The checksums of the two should be equal. Also, the new result set // should return any error on the same index as the original. - var newChecksum *[32]byte + // var newChecksum *[32]byte var checksumErr error + newHash := sha256.New() for n := int64(0); n < it.nc; n++ { row, err := retryIt.Next() if n == 0 && (err == nil || err == iterator.Done) { - newChecksum, checksumErr = createMetadataChecksum(enc, buffer, retryIt.Metadata) + newHash, checksumErr = createMetadataChecksum(retryIt.Metadata) if checksumErr != nil { return failRetry(checksumErr) } @@ -211,14 +238,14 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS } if errorsEqualForRetry(err, it.err) && n == it.errIndex { // Check that the checksums are also equal. - if !checksumsEqual(newChecksum, it.checksum) { + if !checksumsEqual(newHash, it.hash) { return failRetry(ErrAbortedDueToConcurrentModification) } return replaceIt(nil) } return failRetry(ErrAbortedDueToConcurrentModification) } - newChecksum, err = updateChecksum(enc, buffer, newChecksum, row) + err = updateChecksum(newHash, row) if err != nil { return failRetry(err) } @@ -230,7 +257,7 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS if it.err != nil { return failRetry(ErrAbortedDueToConcurrentModification) } - if !checksumsEqual(newChecksum, it.checksum) { + if !checksumsEqual(newHash, it.hash) { return failRetry(ErrAbortedDueToConcurrentModification) } // Everything seems to be equal, replace the underlying iterator and return @@ -238,8 +265,13 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS return replaceIt(nil) } -func checksumsEqual(c1, c2 *[32]byte) bool { - return (reflect.ValueOf(c1).IsNil() && reflect.ValueOf(c2).IsNil()) || *c1 == *c2 +func checksumsEqual(h1, h2 hash.Hash) bool { + if reflect.ValueOf(h1).IsNil() && reflect.ValueOf(h2).IsNil() { + return true + } + c1 := h1.Sum(nil) + c2 := h2.Sum(nil) + return bytes.Equal(c1, c2) } func (it *checksumRowIterator) Stop() { diff --git a/checksum_row_iterator_test.go b/checksum_row_iterator_test.go index cf31f36d..cc853919 100644 --- a/checksum_row_iterator_test.go +++ b/checksum_row_iterator_test.go @@ -16,23 +16,20 @@ package spannerdriver import ( "bytes" - "encoding/gob" + "crypto/sha256" + "fmt" "math/big" "testing" "time" "cloud.google.com/go/civil" "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/protobuf/types/known/structpb" ) func TestUpdateChecksum(t *testing.T) { - buffer1 := &bytes.Buffer{} - enc1 := gob.NewEncoder(buffer1) - buffer2 := &bytes.Buffer{} - enc2 := gob.NewEncoder(buffer2) - buffer3 := &bytes.Buffer{} - enc3 := gob.NewEncoder(buffer3) - row1, err := spanner.NewRow( []string{ "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", @@ -110,59 +107,69 @@ func TestUpdateChecksum(t *testing.T) { if err != nil { t.Fatalf("could not create row 3: %v", err) } - initial1 := new([32]byte) - checksum1, err := updateChecksum(enc1, buffer1, initial1, row1) + + hash1 := sha256.New() + err = updateChecksum(hash1, row1) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } - initial2 := new([32]byte) - checksum2, err := updateChecksum(enc2, buffer2, initial2, row2) + checksum1 := hash1.Sum(nil) + + hash2 := sha256.New() + err = updateChecksum(hash2, row2) if err != nil { t.Fatalf("could not calculate checksum 2: %v", err) } - initial3 := new([32]byte) - checksum3, err := updateChecksum(enc3, buffer3, initial3, row3) + checksum2 := hash2.Sum(nil) + + hash3 := sha256.New() + err = updateChecksum(hash3, row3) if err != nil { t.Fatalf("could not calculate checksum 3: %v", err) } // row1 and row2 are different, so the checksums should be different. - if *checksum1 == *checksum2 { + checksum3 := hash3.Sum(nil) + + if bytes.Equal(checksum1, checksum2) { t.Fatalf("checksum1 should not be equal to checksum2") } // row1 and row3 are equal, and should return the same checksum. - if *checksum1 != *checksum3 { + if !bytes.Equal(checksum1, checksum3) { t.Fatalf("checksum1 should be equal to checksum3") } // Updating checksums 1 and 3 with the data from row 2 should also produce // the same checksum. - checksum1_2, err := updateChecksum(enc1, buffer1, checksum1, row2) + err = updateChecksum(hash1, row2) if err != nil { t.Fatalf("could not calculate checksum 1_2: %v", err) } - checksum3_2, err := updateChecksum(enc3, buffer3, checksum3, row2) + checksum1_2 := hash1.Sum(nil) + + err = updateChecksum(hash3, row2) if err != nil { - t.Fatalf("could not calculate checksum 1_2: %v", err) + t.Fatalf("could not calculate checksum 3_2: %v", err) } - if *checksum1_2 != *checksum3_2 { + checksum3_2 := hash3.Sum(nil) + + if !bytes.Equal(checksum1_2, checksum3_2) { t.Fatalf("checksum1_2 should be equal to checksum3_2") } // The combination of row 3 and 2 will produce a different checksum than the // combination 2 and 3, because they are in a different order. - checksum2_3, err := updateChecksum(enc2, buffer2, checksum2, row3) + err = updateChecksum(hash2, row3) if err != nil { t.Fatalf("could not calculate checksum 2_3: %v", err) } - if *checksum2_3 == *checksum3_2 { + checksum2_3 := hash2.Sum(nil) + + if bytes.Equal(checksum2_3, checksum3_2) { t.Fatalf("checksum2_3 should not be equal to checksum3_2") } } func TestUpdateChecksumForNullValues(t *testing.T) { - buffer := &bytes.Buffer{} - enc := gob.NewEncoder(buffer) - row, err := spanner.NewRow( []string{ "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", @@ -180,26 +187,155 @@ func TestUpdateChecksumForNullValues(t *testing.T) { if err != nil { t.Fatalf("could not create row: %v", err) } - initial := new([32]byte) + hash1 := sha256.New() + initial := hash1.Sum(nil) // Create the initial checksum. - checksum, err := updateChecksum(enc, buffer, initial, row) + err = updateChecksum(hash1, row) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } + checksum1 := hash1.Sum(nil) // The calculated checksum should not be equal to the initial value, even though it only // contains null values. - if *checksum == *initial { + if bytes.Equal(initial, checksum1) { t.Fatalf("checksum value should not be equal to the initial value") } // Calculating the same checksum again should yield the same result. - buffer2 := &bytes.Buffer{} - enc2 := gob.NewEncoder(buffer2) - initial2 := new([32]byte) - checksum2, err := updateChecksum(enc2, buffer2, initial2, row) + hash2 := sha256.New() + err = updateChecksum(hash2, row) if err != nil { t.Fatalf("failed to update checksum: %v", err) } - if *checksum != *checksum2 { + checksum2 := hash2.Sum(nil) + if !bytes.Equal(checksum1, checksum2) { t.Fatalf("recalculated checksum does not match the initial calculation") } } + +func BenchmarkChecksumRowIterator(b *testing.B) { + row1, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(1), 3.14, numeric("6.626"), "test", []byte("testbytes"), civil.Date{Year: 2021, Month: 8, Day: 5}, + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "value", "other-key": ["value1", "value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1", "test2"}, [][]byte{[]byte("testbytes1"), []byte("testbytes1")}, + []civil.Date{{Year: 2021, Month: 8, Day: 5}, {Year: 2021, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2021, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "value1", "other-key1": ["value1", "value2"]}`), + nullJson(true, `"key2": "value2", "other-key2": ["value1", "value2"]}`), + }, + }, + ) + row2, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(2), 6.626, numeric("3.14"), "test2", []byte("testbytes2"), civil.Date{Year: 2020, Month: 8, Day: 5}, + time.Date(2020, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "other-value", "other-key": ["other-value1", "other-value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1_", "test2_"}, [][]byte{[]byte("testbytes1_"), []byte("testbytes1_")}, + []civil.Date{{Year: 2020, Month: 8, Day: 5}, {Year: 2020, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2020, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2020, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "other-value1", "other-key1": ["other-value1", "other-value2"]}`), + nullJson(true, `"key2": "other-value2", "other-key2": ["other-value1", "other-value2"]}`), + }, + }, + ) + row3, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(1), 3.14, numeric("6.626"), "test", []byte("testbytes"), civil.Date{Year: 2021, Month: 8, Day: 5}, + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "value", "other-key": ["value1", "value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1", "test2"}, [][]byte{[]byte("testbytes1"), []byte("testbytes1")}, + []civil.Date{{Year: 2021, Month: 8, Day: 5}, {Year: 2021, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2021, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "value1", "other-key1": ["value1", "value2"]}`), + nullJson(true, `"key2": "value2", "other-key2": ["value1", "value2"]}`), + }, + }, + ) + + for b.Loop() { + hash := sha256.New() + if err := updateChecksum(hash, row1); err != nil { + b.Fatal(err) + } + if err := updateChecksum(hash, row2); err != nil { + b.Fatal(err) + } + if err := updateChecksum(hash, row3); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkChecksumRowIteratorRandom(b *testing.B) { + for _, numRows := range []int{1, 10, 100, 1000, 10000} { + resultSet := testutil.CreateRandomResultSet(numRows) + columnNames := make([]string, len(resultSet.Metadata.RowType.Fields)) + for i := range columnNames { + columnNames[i] = resultSet.Metadata.RowType.Fields[i].Name + } + var err error + rows := make([]*spanner.Row, numRows) + for row, values := range resultSet.Rows { + columnValues := convertListValuesToColumnValues(resultSet.Metadata, values) + c := make([]interface{}, len(columnValues)) + for i := range columnValues { + c[i] = columnValues[i] + } + rows[row], err = spanner.NewRow(columnNames, c) + if err != nil { + b.Fatal(err) + } + } + + b.Run(fmt.Sprintf("num-rows-%d", numRows), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + hash, err := createMetadataChecksum(resultSet.Metadata) + if err != nil { + b.Fatal(err) + } + for _, row := range rows { + if err := updateChecksum(hash, row); err != nil { + b.Fatal(err) + } + } + } + }) + } +} + +func convertListValuesToColumnValues(metadata *spannerpb.ResultSetMetadata, values *structpb.ListValue) []spanner.GenericColumnValue { + res := make([]spanner.GenericColumnValue, len(values.Values)) + for i := range values.Values { + res[i] = spanner.GenericColumnValue{Value: values.Values[i], Type: metadata.RowType.Fields[i].Type} + } + return res +} diff --git a/testutil/mocked_inmem_server.go b/testutil/mocked_inmem_server.go index 1bb70452..97ba4997 100644 --- a/testutil/mocked_inmem_server.go +++ b/testutil/mocked_inmem_server.go @@ -15,17 +15,21 @@ package testutil import ( + crypto "crypto/rand" "encoding/base64" "fmt" "math" + "math/rand" "net" "strconv" "testing" + "time" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" "cloud.google.com/go/spanner/apiv1/spannerpb" pb "cloud.google.com/go/spanner/testdata/protos" + "github.com/google/uuid" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -236,7 +240,7 @@ func createSingersRow(idx int64) *structpb.ListValue { } } -func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb.ResultSet { +func CreateResultSetMetadataWithAllTypes() *spannerpb.ResultSetMetadata { index := 0 fields := make([]*spannerpb.StructType_Field, 26) fields[index] = &spannerpb.StructType_Field{ @@ -410,9 +414,14 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb rowType := &spannerpb.StructType{ Fields: fields, } - metadata := &spannerpb.ResultSetMetadata{ + return &spannerpb.ResultSetMetadata{ RowType: rowType, } +} + +func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb.ResultSet { + metadata := CreateResultSetMetadataWithAllTypes() + fields := metadata.RowType.Fields rows := make([]*structpb.ListValue, 1) rowValue := make([]*structpb.Value, len(fields)) if nullValues { @@ -436,7 +445,7 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb Genre: &singer2ProtoEnum, } - index = 0 + index := 0 rowValue[index] = &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: true}} index++ rowValue[index] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "test"}} @@ -576,6 +585,122 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb } } +func CreateRandomResultSet(numRows int) *spannerpb.ResultSet { + metadata := CreateResultSetMetadataWithAllTypes() + fields := metadata.RowType.Fields + rows := make([]*structpb.ListValue, numRows) + + for i := 0; i < numRows; i++ { + rowValue := make([]*structpb.Value, len(fields)) + for col := range fields { + rowValue[col] = randomValue(fields[col].Type) + } + rows[i] = &structpb.ListValue{Values: rowValue} + } + return &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } +} + +var nullValue *structpb.Value + +func init() { + nullValue = &structpb.Value{Kind: &structpb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}} +} + +func randomValue(t *spannerpb.Type) *structpb.Value { + if rand.Intn(10) == 5 { + return nullValue + } + switch t.Code { + case spannerpb.TypeCode_BOOL: + return randomBoolValue() + case spannerpb.TypeCode_BYTES: + return randomBytesValue() + case spannerpb.TypeCode_DATE: + return randomDateValue() + case spannerpb.TypeCode_FLOAT32: + return randomFloat32Value() + case spannerpb.TypeCode_FLOAT64: + return randomFloat64Value() + case spannerpb.TypeCode_INT64: + return randomInt64Value() + case spannerpb.TypeCode_JSON: + return randomJsonValue() + case spannerpb.TypeCode_NUMERIC: + return randomNumericValue() + case spannerpb.TypeCode_STRING: + return randomStringValue() + case spannerpb.TypeCode_TIMESTAMP: + return randomTimestampValue() + case spannerpb.TypeCode_UUID: + return randomUuidValue() + case spannerpb.TypeCode_ARRAY: + numElements := rand.Intn(10) + value := &structpb.Value{Kind: &structpb.Value_ListValue{ListValue: &structpb.ListValue{Values: make([]*structpb.Value, numElements)}}} + for i := range numElements { + value.GetListValue().Values[i] = randomValue(t.ArrayElementType) + } + return value + } + return nullValue +} + +func randomBoolValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: rand.Intn(2) == 1}} +} + +func randomString() string { + b := make([]byte, rand.Intn(1024)) + _, _ = crypto.Read(b) + return base64.StdEncoding.EncodeToString(b) +} + +func randomBytesValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: randomString()}} +} + +func randomDateValue() *structpb.Value { + year := rand.Intn(2100) + 1 + month := rand.Intn(12) + 1 + day := rand.Intn(28) + 1 + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%04d-%02d-%02d", year, month, day)}} +} + +func randomFloat32Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(rand.Float32())}} +} + +func randomFloat64Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(rand.Float32())}} +} + +func randomJsonValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf(`{"key": "%s"}`, randomString())}} +} + +func randomInt64Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%d", rand.Int63())}} +} + +func randomNumericValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%d.%d", rand.Intn(10000000), rand.Intn(1000))}} +} + +func randomStringValue() *structpb.Value { + return randomBytesValue() +} + +func randomTimestampValue() *structpb.Value { + t := time.UnixMilli(time.Now().UnixMilli() + int64(rand.Intn(1000000))) + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: t.Format(time.RFC3339)}} +} + +func randomUuidValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: uuid.New().String()}} +} + func nullValueOrAlt(nullValue bool, alt *structpb.Value) *structpb.Value { if nullValue { return &structpb.Value{Kind: &structpb.Value_NullValue{}} diff --git a/transaction.go b/transaction.go index 041c84a3..0affbc5a 100644 --- a/transaction.go +++ b/transaction.go @@ -15,10 +15,9 @@ package spannerdriver import ( - "bytes" "context" + "crypto/sha256" "database/sql/driver" - "encoding/gob" "fmt" "log/slog" "math/rand" @@ -625,7 +624,6 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen // If retries are enabled, we need to use a row iterator that will keep // track of a running checksum of all the results that we see. - buffer := &bytes.Buffer{} it := &checksumRowIterator{ RowIterator: tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), ctx: ctx, @@ -633,8 +631,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen stmt: stmt, stmtType: stmtType, options: execOptions.QueryOptions, - buffer: buffer, - enc: gob.NewEncoder(buffer), + hash: sha256.New(), } tx.statements = append(tx.statements, it) return it, nil From 8d531dd5fb72b44faa83df6d7374cb892f57037a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 9 Dec 2025 16:32:23 +0100 Subject: [PATCH 2/5] chore: improve nil check Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- checksum_row_iterator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 6637f1ea..886aa93b 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -266,8 +266,8 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS } func checksumsEqual(h1, h2 hash.Hash) bool { - if reflect.ValueOf(h1).IsNil() && reflect.ValueOf(h2).IsNil() { - return true + if h1 == nil || h2 == nil { + return h1 == h2 } c1 := h1.Sum(nil) c2 := h2.Sum(nil) From 0751ee7610bd19a533206fb2abec4bd04233f1e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 9 Dec 2025 16:33:07 +0100 Subject: [PATCH 3/5] chore: remove redundant identifier Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- checksum_row_iterator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 886aa93b..a0b74727 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -149,7 +149,7 @@ func hashValue(value *structpb.Value, digest hash.Hash) { } case *structpb.Value_StructValue: fields := make([]string, 0, len(value.GetStructValue().Fields)) - for field, _ := range value.GetStructValue().Fields { +for field := range value.GetStructValue().Fields { fields = append(fields, field) } sort.Strings(fields) From 2983ed9a71121355f635c2baa6bc8c8ef32d7b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 9 Dec 2025 17:32:47 +0100 Subject: [PATCH 4/5] chore: address review comments --- testutil/mocked_inmem_server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testutil/mocked_inmem_server.go b/testutil/mocked_inmem_server.go index 97ba4997..c76376f1 100644 --- a/testutil/mocked_inmem_server.go +++ b/testutil/mocked_inmem_server.go @@ -673,7 +673,7 @@ func randomFloat32Value() *structpb.Value { } func randomFloat64Value() *structpb.Value { - return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(rand.Float32())}} + return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: rand.Float64()}} } func randomJsonValue() *structpb.Value { From cf60b9b3ef786ca01e7ef2b7644807c93177c593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 10 Dec 2025 18:57:51 +0100 Subject: [PATCH 5/5] fix: move buffers to iterator to prevent concurrent use The int and float buffers were defined as global variables, which meant that they could have been used concurrently if multiple checksumIterators were used at the same time. --- checksum_row_iterator.go | 37 ++++++++++---------- checksum_row_iterator_test.go | 30 ++++++++++------- driver_with_mockserver_test.go | 22 ++++++------ testutil/mocked_inmem_server.go | 22 ++++++------ transaction.go | 3 ++ transaction_test.go | 60 +++++++++++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 53 deletions(-) diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 040966c2..32cedd14 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -61,7 +61,9 @@ type checksumRowIterator struct { // hash contains the current hash for the results that have been // seen. It is calculated as a SHA256 checksum over all rows that so far // have been returned. - hash hash.Hash + hash hash.Hash + int32Buf [4]byte + float64Buf [8]byte // errIndex and err indicate any error and the index in the result set // where the error occurred. @@ -100,7 +102,7 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { // checksum of the columns that are included in this result. This is // also used to detect the possible difference between two empty // result sets with a different set of columns. - it.hash, err = createMetadataChecksum(it.metadata) + it.hash, err = it.createMetadataChecksum(it.metadata) if err != nil { return err } @@ -109,37 +111,34 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { return it.err } // Update the current checksum. - return updateChecksum(it.hash, row) + return it.updateChecksum(it.hash, row) }) return row, err } // updateChecksum calculates the following checksum based on a current checksum // and a new row. -func updateChecksum(hash hash.Hash, row *spanner.Row) error { +func (it *checksumRowIterator) updateChecksum(hash hash.Hash, row *spanner.Row) error { for i := 0; i < row.Size(); i++ { var v spanner.GenericColumnValue err := row.Column(i, &v) if err != nil { return err } - hashValue(v.Value, hash) + it.hashValue(v.Value, hash) } return nil } -var int32Buf [4]byte -var float64Buf [8]byte - -func hashValue(value *structpb.Value, digest hash.Hash) { +func (it *checksumRowIterator) hashValue(value *structpb.Value, digest hash.Hash) { switch value.GetKind().(type) { case *structpb.Value_StringValue: - digest.Write(intToByte(int32Buf, len(value.GetStringValue()))) + digest.Write(intToByte(it.int32Buf, len(value.GetStringValue()))) digest.Write([]byte(value.GetStringValue())) case *structpb.Value_NullValue: digest.Write([]byte{0}) case *structpb.Value_NumberValue: - digest.Write(float64ToByte(float64Buf, value.GetNumberValue())) + digest.Write(float64ToByte(it.float64Buf, value.GetNumberValue())) case *structpb.Value_BoolValue: if value.GetBoolValue() { digest.Write([]byte{1}) @@ -153,13 +152,13 @@ func hashValue(value *structpb.Value, digest hash.Hash) { } sort.Strings(fields) for _, field := range fields { - digest.Write(intToByte(int32Buf, len(field))) + digest.Write(intToByte(it.int32Buf, len(field))) digest.Write([]byte(field)) - hashValue(value.GetStructValue().Fields[field], digest) + it.hashValue(value.GetStructValue().Fields[field], digest) } case *structpb.Value_ListValue: for _, v := range value.GetListValue().GetValues() { - hashValue(v, digest) + it.hashValue(v, digest) } } } @@ -177,12 +176,12 @@ func float64ToByte(buf [8]byte, f float64) []byte { // createMetadataChecksum calculates the checksum of the metadata of a result. // Only the column names and types are included in the checksum. Any transaction // metadata is not included. -func createMetadataChecksum(metadata *sppb.ResultSetMetadata) (hash.Hash, error) { +func (it *checksumRowIterator) createMetadataChecksum(metadata *sppb.ResultSetMetadata) (hash.Hash, error) { digest := sha256.New() for _, field := range metadata.RowType.Fields { - digest.Write(intToByte(int32Buf, len(field.Name))) + digest.Write(intToByte(it.int32Buf, len(field.Name))) digest.Write([]byte(field.Name)) - digest.Write(intToByte(int32Buf, int(field.Type.Code.Number()))) + digest.Write(intToByte(it.int32Buf, int(field.Type.Code.Number()))) } return digest, nil } @@ -224,7 +223,7 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS for n := int64(0); n < it.nc; n++ { row, err := retryIt.Next() if n == 0 && (err == nil || err == iterator.Done) { - newHash, checksumErr = createMetadataChecksum(retryIt.Metadata) + newHash, checksumErr = it.createMetadataChecksum(retryIt.Metadata) if checksumErr != nil { return failRetry(checksumErr) } @@ -244,7 +243,7 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS } return failRetry(ErrAbortedDueToConcurrentModification) } - err = updateChecksum(newHash, row) + err = it.updateChecksum(newHash, row) if err != nil { return failRetry(err) } diff --git a/checksum_row_iterator_test.go b/checksum_row_iterator_test.go index cc853919..7548a0ed 100644 --- a/checksum_row_iterator_test.go +++ b/checksum_row_iterator_test.go @@ -108,22 +108,23 @@ func TestUpdateChecksum(t *testing.T) { t.Fatalf("could not create row 3: %v", err) } + it := &checksumRowIterator{} hash1 := sha256.New() - err = updateChecksum(hash1, row1) + err = it.updateChecksum(hash1, row1) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } checksum1 := hash1.Sum(nil) hash2 := sha256.New() - err = updateChecksum(hash2, row2) + err = it.updateChecksum(hash2, row2) if err != nil { t.Fatalf("could not calculate checksum 2: %v", err) } checksum2 := hash2.Sum(nil) hash3 := sha256.New() - err = updateChecksum(hash3, row3) + err = it.updateChecksum(hash3, row3) if err != nil { t.Fatalf("could not calculate checksum 3: %v", err) } @@ -140,13 +141,13 @@ func TestUpdateChecksum(t *testing.T) { // Updating checksums 1 and 3 with the data from row 2 should also produce // the same checksum. - err = updateChecksum(hash1, row2) + err = it.updateChecksum(hash1, row2) if err != nil { t.Fatalf("could not calculate checksum 1_2: %v", err) } checksum1_2 := hash1.Sum(nil) - err = updateChecksum(hash3, row2) + err = it.updateChecksum(hash3, row2) if err != nil { t.Fatalf("could not calculate checksum 3_2: %v", err) } @@ -158,7 +159,7 @@ func TestUpdateChecksum(t *testing.T) { // The combination of row 3 and 2 will produce a different checksum than the // combination 2 and 3, because they are in a different order. - err = updateChecksum(hash2, row3) + err = it.updateChecksum(hash2, row3) if err != nil { t.Fatalf("could not calculate checksum 2_3: %v", err) } @@ -187,10 +188,11 @@ func TestUpdateChecksumForNullValues(t *testing.T) { if err != nil { t.Fatalf("could not create row: %v", err) } + it := &checksumRowIterator{} hash1 := sha256.New() initial := hash1.Sum(nil) // Create the initial checksum. - err = updateChecksum(hash1, row) + err = it.updateChecksum(hash1, row) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } @@ -202,7 +204,7 @@ func TestUpdateChecksumForNullValues(t *testing.T) { } // Calculating the same checksum again should yield the same result. hash2 := sha256.New() - err = updateChecksum(hash2, row) + err = it.updateChecksum(hash2, row) if err != nil { t.Fatalf("failed to update checksum: %v", err) } @@ -281,14 +283,15 @@ func BenchmarkChecksumRowIterator(b *testing.B) { ) for b.Loop() { + it := &checksumRowIterator{} hash := sha256.New() - if err := updateChecksum(hash, row1); err != nil { + if err := it.updateChecksum(hash, row1); err != nil { b.Fatal(err) } - if err := updateChecksum(hash, row2); err != nil { + if err := it.updateChecksum(hash, row2); err != nil { b.Fatal(err) } - if err := updateChecksum(hash, row3); err != nil { + if err := it.updateChecksum(hash, row3); err != nil { b.Fatal(err) } } @@ -318,12 +321,13 @@ func BenchmarkChecksumRowIteratorRandom(b *testing.B) { b.Run(fmt.Sprintf("num-rows-%d", numRows), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - hash, err := createMetadataChecksum(resultSet.Metadata) + it := &checksumRowIterator{} + hash, err := it.createMetadataChecksum(resultSet.Metadata) if err != nil { b.Fatal(err) } for _, row := range rows { - if err := updateChecksum(hash, row); err != nil { + if err := it.updateChecksum(hash, row); err != nil { b.Fatal(err) } } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index fc0a9fbb..f2120609 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -5636,19 +5636,19 @@ func nullUuid(valid bool, v string) spanner.NullUUID { return spanner.NullUUID{Valid: true, UUID: uuid.MustParse(v)} } -func setupTestDBConnection(t *testing.T) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnection(t testing.TB) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParams(t, "") } -func setupTestDBConnectionWithDialect(t *testing.T, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithDialect(t testing.TB, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParamsAndDialect(t, "", dialect) } -func setupTestDBConnectionWithParams(t *testing.T, params string) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithParams(t testing.TB, params string) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParamsAndDialect(t, params, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) } -func setupTestDBConnectionWithParamsAndDialect(t *testing.T, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithParamsAndDialect(t testing.TB, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServerWithDialect(t, dialect) db, err := sql.Open( "spanner", @@ -5663,7 +5663,7 @@ func setupTestDBConnectionWithParamsAndDialect(t *testing.T, params string, dial } } -func setupTestDBConnectionWithConfigurator(t *testing.T, params string, configurator func(config *spanner.ClientConfig, opts *[]option.ClientOption)) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithConfigurator(t testing.TB, params string, configurator func(config *spanner.ClientConfig, opts *[]option.ClientOption)) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServer(t) dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;%s", server.Address, params) config, err := ExtractConnectorConfig(dsn) @@ -5684,7 +5684,7 @@ func setupTestDBConnectionWithConfigurator(t *testing.T, params string, configur } } -func setupTestDBConnectionWithConnectorConfig(t *testing.T, config ConnectorConfig) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithConnectorConfig(t testing.TB, config ConnectorConfig) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServer(t) config.Host = server.Address if config.Params == nil { @@ -5703,23 +5703,23 @@ func setupTestDBConnectionWithConnectorConfig(t *testing.T, config ConnectorConf } } -func setupMockedTestServer(t *testing.T) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServer(t testing.TB) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfig(t, spanner.ClientConfig{}) } -func setupMockedTestServerWithDialect(t *testing.T, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithDialect(t testing.TB, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptionsAndDialect(t, spanner.ClientConfig{}, []option.ClientOption{}, dialect) } -func setupMockedTestServerWithConfig(t *testing.T, config spanner.ClientConfig) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfig(t testing.TB, config spanner.ClientConfig) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{}) } -func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config spanner.ClientConfig, clientOptions []option.ClientOption) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfigAndClientOptions(t testing.TB, config spanner.ClientConfig, clientOptions []option.ClientOption) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptionsAndDialect(t, config, clientOptions, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) } -func setupMockedTestServerWithConfigAndClientOptionsAndDialect(t *testing.T, config spanner.ClientConfig, clientOptions []option.ClientOption, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfigAndClientOptionsAndDialect(t testing.TB, config spanner.ClientConfig, clientOptions []option.ClientOption, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { server, opts, serverTeardown := testutil.NewMockedSpannerInMemTestServer(t) server.SetupSelectDialectResult(dialect) diff --git a/testutil/mocked_inmem_server.go b/testutil/mocked_inmem_server.go index c76376f1..eda7ef6b 100644 --- a/testutil/mocked_inmem_server.go +++ b/testutil/mocked_inmem_server.go @@ -90,14 +90,14 @@ type MockedSpannerInMemTestServer struct { // NewMockedSpannerInMemTestServer creates a MockedSpannerInMemTestServer at // localhost with a random port and returns client options that can be used // to connect to it. -func NewMockedSpannerInMemTestServer(t *testing.T) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { +func NewMockedSpannerInMemTestServer(t testing.TB) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { return NewMockedSpannerInMemTestServerWithAddr(t, "localhost:0") } // NewMockedSpannerInMemTestServerWithAddr creates a MockedSpannerInMemTestServer // at a given listening address and returns client options that can be used // to connect to it. -func NewMockedSpannerInMemTestServerWithAddr(t *testing.T, addr string) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { +func NewMockedSpannerInMemTestServerWithAddr(t testing.TB, addr string) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { mockedServer = &MockedSpannerInMemTestServer{} opts = mockedServer.setupMockedServerWithAddr(t, addr) return mockedServer, opts, func() { @@ -108,7 +108,7 @@ func NewMockedSpannerInMemTestServerWithAddr(t *testing.T, addr string) (mockedS } } -func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, addr string) []option.ClientOption { +func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t testing.TB, addr string) []option.ClientOption { s.TestSpanner = NewInMemSpannerServer() s.TestInstanceAdmin = NewInMemInstanceAdminServer() s.TestDatabaseAdmin = NewInMemDatabaseAdminServer() @@ -128,7 +128,9 @@ func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, a if err != nil { t.Fatal(err) } - go s.server.Serve(lis) + go func() { + _ = s.server.Serve(lis) + }() s.Address = lis.Addr().String() opts := []option.ClientOption{ @@ -141,23 +143,23 @@ func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, a func (s *MockedSpannerInMemTestServer) SetupSelectDialectResult(dialect databasepb.DatabaseDialect) { result := &StatementResult{Type: StatementResultResultSet, ResultSet: CreateSelectDialectResultSet(dialect)} - s.TestSpanner.PutStatementResult(selectDialect, result) + _ = s.TestSpanner.PutStatementResult(selectDialect, result) } func (s *MockedSpannerInMemTestServer) setupSelect1Result() { result := &StatementResult{Type: StatementResultResultSet, ResultSet: CreateSelect1ResultSet()} - s.TestSpanner.PutStatementResult("SELECT 1", result) + _ = s.TestSpanner.PutStatementResult("SELECT 1", result) } func (s *MockedSpannerInMemTestServer) setupFooResults() { resultSet := CreateSingleColumnInt64ResultSet(selectFooFromBarResults, "FOO") result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} - s.TestSpanner.PutStatementResult(SelectFooFromBar, result) - s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ + _ = s.TestSpanner.PutStatementResult(SelectFooFromBar, result) + _ = s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ Type: StatementResultUpdateCount, UpdateCount: UpdateBarSetFooRowCount, }) - s.TestSpanner.PutStatementResult(UpdateSingersSetLastName, &StatementResult{ + _ = s.TestSpanner.PutStatementResult(UpdateSingersSetLastName, &StatementResult{ Type: StatementResultUpdateCount, UpdateCount: UpdateSingersSetLastNameRowCount, }) @@ -175,7 +177,7 @@ func (s *MockedSpannerInMemTestServer) setupSingersResults() { Rows: rows, } result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} - s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) + _ = s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) } // CreateSingleRowSingersResult creates a result set containing a single row of diff --git a/transaction.go b/transaction.go index 0affbc5a..c256d20b 100644 --- a/transaction.go +++ b/transaction.go @@ -498,6 +498,9 @@ func (tx *readWriteTransaction) runWithRetry(ctx context.Context, f func(ctx con if err == nil { err = f(ctx) } + if err == nil { + return + } if err == ErrAbortedDueToConcurrentModification { tx.logger.Log(ctx, LevelNotice, "transaction retry failed due to a concurrent modification") return diff --git a/transaction_test.go b/transaction_test.go index ddd6d05a..7634ae21 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -357,3 +357,63 @@ func TestTransactionTimeoutSecondStatement(t *testing.T) { t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w) } } + +func BenchmarkReadWriteTransaction(b *testing.B) { + db, server, teardown := setupTestDBConnection(b) + defer teardown() + ctx := context.Background() + query := "select * from random_table" + + for _, numRows := range []int{1, 10, 100, 1000, 10000} { + resultSet := testutil.CreateRandomResultSet(numRows) + _ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + }) + + b.Run(fmt.Sprintf("num-rows-%d", numRows), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + b.Fatal(err) + } + //if _, err := tx.ExecContext(ctx, "set local retry_aborts_internally = false"); err != nil { + // b.Fatal(err) + //} + if _, err := tx.ExecContext(ctx, "set local transaction_tag = 'my_tag'"); err != nil { + b.Fatal(err) + } + rows, err := tx.QueryContext(ctx, query) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + // Just iterate through the results + } + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + for range 10 { + if _, err = tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + b.Fatal(err) + } + } + if _, err := tx.ExecContext(ctx, "start batch dml"); err != nil { + b.Fatal(err) + } + for range 10 { + if _, err = tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + b.Fatal(err) + } + } + if _, err := tx.ExecContext(ctx, "run batch"); err != nil { + b.Fatal(err) + } + if err := tx.Commit(); err != nil { + b.Fatal(err) + } + } + }) + } +}