diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index c6980a27..32cedd14 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -18,8 +18,10 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/gob" - "reflect" + "encoding/binary" + "hash" + "math" + "sort" "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -32,16 +34,6 @@ import ( var errNextAfterSTop = status.Errorf(codes.FailedPrecondition, "Next called after Stop") -// init registers the protobuf types with gob so they can be encoded. -func init() { - gob.Register(structpb.Value_BoolValue{}) - gob.Register(structpb.Value_NumberValue{}) - gob.Register(structpb.Value_StringValue{}) - gob.Register(structpb.Value_NullValue{}) - gob.Register(structpb.Value_ListValue{}) - gob.Register(structpb.Value_StructValue{}) -} - var _ rowIterator = &checksumRowIterator{} // checksumRowIterator implements rowIterator and keeps track of a running @@ -66,12 +58,12 @@ type checksumRowIterator struct { // the retry has finished. stopped bool - // checksum contains the current checksum for the results that have been + // hash contains the current hash for the results that have been // seen. It is calculated as a SHA256 checksum over all rows that so far // have been returned. - checksum *[32]byte - buffer *bytes.Buffer - enc *gob.Encoder + hash hash.Hash + int32Buf [4]byte + float64Buf [8]byte // errIndex and err indicate any error and the index in the result set // where the error occurred. @@ -110,7 +102,7 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { // checksum of the columns that are included in this result. This is // also used to detect the possible difference between two empty // result sets with a different set of columns. - it.checksum, err = createMetadataChecksum(it.enc, it.buffer, it.metadata) + it.hash, err = it.createMetadataChecksum(it.metadata) if err != nil { return err } @@ -119,45 +111,79 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) { return it.err } // Update the current checksum. - it.checksum, err = updateChecksum(it.enc, it.buffer, it.checksum, row) - return err + return it.updateChecksum(it.hash, row) }) return row, err } // updateChecksum calculates the following checksum based on a current checksum // and a new row. -func updateChecksum(enc *gob.Encoder, buffer *bytes.Buffer, currentChecksum *[32]byte, row *spanner.Row) (*[32]byte, error) { - buffer.Reset() - buffer.Write(currentChecksum[:]) +func (it *checksumRowIterator) updateChecksum(hash hash.Hash, row *spanner.Row) error { for i := 0; i < row.Size(); i++ { var v spanner.GenericColumnValue err := row.Column(i, &v) if err != nil { - return nil, err + return err } - err = enc.Encode(v) - if err != nil { - return nil, err + it.hashValue(v.Value, hash) + } + return nil +} + +func (it *checksumRowIterator) hashValue(value *structpb.Value, digest hash.Hash) { + switch value.GetKind().(type) { + case *structpb.Value_StringValue: + digest.Write(intToByte(it.int32Buf, len(value.GetStringValue()))) + digest.Write([]byte(value.GetStringValue())) + case *structpb.Value_NullValue: + digest.Write([]byte{0}) + case *structpb.Value_NumberValue: + digest.Write(float64ToByte(it.float64Buf, value.GetNumberValue())) + case *structpb.Value_BoolValue: + if value.GetBoolValue() { + digest.Write([]byte{1}) + } else { + digest.Write([]byte{0}) + } + case *structpb.Value_StructValue: + fields := make([]string, 0, len(value.GetStructValue().Fields)) + for field := range value.GetStructValue().Fields { + fields = append(fields, field) + } + sort.Strings(fields) + for _, field := range fields { + digest.Write(intToByte(it.int32Buf, len(field))) + digest.Write([]byte(field)) + it.hashValue(value.GetStructValue().Fields[field], digest) + } + case *structpb.Value_ListValue: + for _, v := range value.GetListValue().GetValues() { + it.hashValue(v, digest) } } - res := sha256.Sum256(buffer.Bytes()) - return &res, nil +} + +func intToByte(buf [4]byte, v int) []byte { + binary.BigEndian.PutUint32(buf[:], uint32(v)) + return buf[:] +} + +func float64ToByte(buf [8]byte, f float64) []byte { + binary.BigEndian.PutUint64(buf[:], math.Float64bits(f)) + return buf[:] } // createMetadataChecksum calculates the checksum of the metadata of a result. // Only the column names and types are included in the checksum. Any transaction // metadata is not included. -func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sppb.ResultSetMetadata) (*[32]byte, error) { - buffer.Reset() +func (it *checksumRowIterator) createMetadataChecksum(metadata *sppb.ResultSetMetadata) (hash.Hash, error) { + digest := sha256.New() for _, field := range metadata.RowType.Fields { - err := enc.Encode(field) - if err != nil { - return nil, err - } + digest.Write(intToByte(it.int32Buf, len(field.Name))) + digest.Write([]byte(field.Name)) + digest.Write(intToByte(it.int32Buf, int(field.Type.Code.Number()))) } - res := sha256.Sum256(buffer.Bytes()) - return &res, nil + return digest, nil } // retry implements retriableStatement.retry for queries. It will execute the @@ -167,8 +193,6 @@ func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sp // initial iterator was also returned by the new iterator, and that the errors // were returned by the same row index. func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - buffer := &bytes.Buffer{} - enc := gob.NewEncoder(buffer) retryIt := tx.QueryWithOptions(ctx, it.stmt, it.options) // If the original iterator had been stopped, we should also always stop the // new iterator. @@ -193,12 +217,13 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS // Iterate over the new result set as many times as we iterated over the initial // result set. The checksums of the two should be equal. Also, the new result set // should return any error on the same index as the original. - var newChecksum *[32]byte + // var newChecksum *[32]byte var checksumErr error + newHash := sha256.New() for n := int64(0); n < it.nc; n++ { row, err := retryIt.Next() if n == 0 && (err == nil || err == iterator.Done) { - newChecksum, checksumErr = createMetadataChecksum(enc, buffer, retryIt.Metadata) + newHash, checksumErr = it.createMetadataChecksum(retryIt.Metadata) if checksumErr != nil { return failRetry(checksumErr) } @@ -211,14 +236,14 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS } if errorsEqualForRetry(err, it.err) && n == it.errIndex { // Check that the checksums are also equal. - if !checksumsEqual(newChecksum, it.checksum) { + if !checksumsEqual(newHash, it.hash) { return failRetry(ErrAbortedDueToConcurrentModification) } return replaceIt(nil) } return failRetry(ErrAbortedDueToConcurrentModification) } - newChecksum, err = updateChecksum(enc, buffer, newChecksum, row) + err = it.updateChecksum(newHash, row) if err != nil { return failRetry(err) } @@ -230,7 +255,7 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS if it.err != nil { return failRetry(ErrAbortedDueToConcurrentModification) } - if !checksumsEqual(newChecksum, it.checksum) { + if !checksumsEqual(newHash, it.hash) { return failRetry(ErrAbortedDueToConcurrentModification) } // Everything seems to be equal, replace the underlying iterator and return @@ -238,8 +263,13 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS return replaceIt(nil) } -func checksumsEqual(c1, c2 *[32]byte) bool { - return (reflect.ValueOf(c1).IsNil() && reflect.ValueOf(c2).IsNil()) || *c1 == *c2 +func checksumsEqual(h1, h2 hash.Hash) bool { + if h1 == nil || h2 == nil { + return h1 == h2 + } + c1 := h1.Sum(nil) + c2 := h2.Sum(nil) + return bytes.Equal(c1, c2) } func (it *checksumRowIterator) Stop() { diff --git a/checksum_row_iterator_test.go b/checksum_row_iterator_test.go index cf31f36d..7548a0ed 100644 --- a/checksum_row_iterator_test.go +++ b/checksum_row_iterator_test.go @@ -16,23 +16,20 @@ package spannerdriver import ( "bytes" - "encoding/gob" + "crypto/sha256" + "fmt" "math/big" "testing" "time" "cloud.google.com/go/civil" "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/protobuf/types/known/structpb" ) func TestUpdateChecksum(t *testing.T) { - buffer1 := &bytes.Buffer{} - enc1 := gob.NewEncoder(buffer1) - buffer2 := &bytes.Buffer{} - enc2 := gob.NewEncoder(buffer2) - buffer3 := &bytes.Buffer{} - enc3 := gob.NewEncoder(buffer3) - row1, err := spanner.NewRow( []string{ "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", @@ -110,59 +107,70 @@ func TestUpdateChecksum(t *testing.T) { if err != nil { t.Fatalf("could not create row 3: %v", err) } - initial1 := new([32]byte) - checksum1, err := updateChecksum(enc1, buffer1, initial1, row1) + + it := &checksumRowIterator{} + hash1 := sha256.New() + err = it.updateChecksum(hash1, row1) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } - initial2 := new([32]byte) - checksum2, err := updateChecksum(enc2, buffer2, initial2, row2) + checksum1 := hash1.Sum(nil) + + hash2 := sha256.New() + err = it.updateChecksum(hash2, row2) if err != nil { t.Fatalf("could not calculate checksum 2: %v", err) } - initial3 := new([32]byte) - checksum3, err := updateChecksum(enc3, buffer3, initial3, row3) + checksum2 := hash2.Sum(nil) + + hash3 := sha256.New() + err = it.updateChecksum(hash3, row3) if err != nil { t.Fatalf("could not calculate checksum 3: %v", err) } // row1 and row2 are different, so the checksums should be different. - if *checksum1 == *checksum2 { + checksum3 := hash3.Sum(nil) + + if bytes.Equal(checksum1, checksum2) { t.Fatalf("checksum1 should not be equal to checksum2") } // row1 and row3 are equal, and should return the same checksum. - if *checksum1 != *checksum3 { + if !bytes.Equal(checksum1, checksum3) { t.Fatalf("checksum1 should be equal to checksum3") } // Updating checksums 1 and 3 with the data from row 2 should also produce // the same checksum. - checksum1_2, err := updateChecksum(enc1, buffer1, checksum1, row2) + err = it.updateChecksum(hash1, row2) if err != nil { t.Fatalf("could not calculate checksum 1_2: %v", err) } - checksum3_2, err := updateChecksum(enc3, buffer3, checksum3, row2) + checksum1_2 := hash1.Sum(nil) + + err = it.updateChecksum(hash3, row2) if err != nil { - t.Fatalf("could not calculate checksum 1_2: %v", err) + t.Fatalf("could not calculate checksum 3_2: %v", err) } - if *checksum1_2 != *checksum3_2 { + checksum3_2 := hash3.Sum(nil) + + if !bytes.Equal(checksum1_2, checksum3_2) { t.Fatalf("checksum1_2 should be equal to checksum3_2") } // The combination of row 3 and 2 will produce a different checksum than the // combination 2 and 3, because they are in a different order. - checksum2_3, err := updateChecksum(enc2, buffer2, checksum2, row3) + err = it.updateChecksum(hash2, row3) if err != nil { t.Fatalf("could not calculate checksum 2_3: %v", err) } - if *checksum2_3 == *checksum3_2 { + checksum2_3 := hash2.Sum(nil) + + if bytes.Equal(checksum2_3, checksum3_2) { t.Fatalf("checksum2_3 should not be equal to checksum3_2") } } func TestUpdateChecksumForNullValues(t *testing.T) { - buffer := &bytes.Buffer{} - enc := gob.NewEncoder(buffer) - row, err := spanner.NewRow( []string{ "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", @@ -180,26 +188,158 @@ func TestUpdateChecksumForNullValues(t *testing.T) { if err != nil { t.Fatalf("could not create row: %v", err) } - initial := new([32]byte) + it := &checksumRowIterator{} + hash1 := sha256.New() + initial := hash1.Sum(nil) // Create the initial checksum. - checksum, err := updateChecksum(enc, buffer, initial, row) + err = it.updateChecksum(hash1, row) if err != nil { t.Fatalf("could not calculate checksum 1: %v", err) } + checksum1 := hash1.Sum(nil) // The calculated checksum should not be equal to the initial value, even though it only // contains null values. - if *checksum == *initial { + if bytes.Equal(initial, checksum1) { t.Fatalf("checksum value should not be equal to the initial value") } // Calculating the same checksum again should yield the same result. - buffer2 := &bytes.Buffer{} - enc2 := gob.NewEncoder(buffer2) - initial2 := new([32]byte) - checksum2, err := updateChecksum(enc2, buffer2, initial2, row) + hash2 := sha256.New() + err = it.updateChecksum(hash2, row) if err != nil { t.Fatalf("failed to update checksum: %v", err) } - if *checksum != *checksum2 { + checksum2 := hash2.Sum(nil) + if !bytes.Equal(checksum1, checksum2) { t.Fatalf("recalculated checksum does not match the initial calculation") } } + +func BenchmarkChecksumRowIterator(b *testing.B) { + row1, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(1), 3.14, numeric("6.626"), "test", []byte("testbytes"), civil.Date{Year: 2021, Month: 8, Day: 5}, + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "value", "other-key": ["value1", "value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1", "test2"}, [][]byte{[]byte("testbytes1"), []byte("testbytes1")}, + []civil.Date{{Year: 2021, Month: 8, Day: 5}, {Year: 2021, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2021, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "value1", "other-key1": ["value1", "value2"]}`), + nullJson(true, `"key2": "value2", "other-key2": ["value1", "value2"]}`), + }, + }, + ) + row2, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(2), 6.626, numeric("3.14"), "test2", []byte("testbytes2"), civil.Date{Year: 2020, Month: 8, Day: 5}, + time.Date(2020, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "other-value", "other-key": ["other-value1", "other-value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1_", "test2_"}, [][]byte{[]byte("testbytes1_"), []byte("testbytes1_")}, + []civil.Date{{Year: 2020, Month: 8, Day: 5}, {Year: 2020, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2020, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2020, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "other-value1", "other-key1": ["other-value1", "other-value2"]}`), + nullJson(true, `"key2": "other-value2", "other-key2": ["other-value1", "other-value2"]}`), + }, + }, + ) + row3, _ := spanner.NewRow( + []string{ + "ColBool", "ColInt64", "ColFloat64", "ColNumeric", "ColString", "ColBytes", "ColDate", "ColTimestamp", "ColJson", + "ArrBool", "ArrInt64", "ArrFloat64", "ArrNumeric", "ArrString", "ArrBytes", "ArrDate", "ArrTimestamp", "ArrJson", + }, + []interface{}{ + true, int64(1), 3.14, numeric("6.626"), "test", []byte("testbytes"), civil.Date{Year: 2021, Month: 8, Day: 5}, + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + nullJson(true, `"key": "value", "other-key": ["value1", "value2"]}`), + []bool{true, false}, []int64{1, 2}, []float64{3.14, 6.626}, []big.Rat{numeric("3.14"), numeric("6.626")}, + []string{"test1", "test2"}, [][]byte{[]byte("testbytes1"), []byte("testbytes1")}, + []civil.Date{{Year: 2021, Month: 8, Day: 5}, {Year: 2021, Month: 8, Day: 6}}, + []time.Time{ + time.Date(2021, 8, 5, 13, 19, 23, 123456789, time.UTC), + time.Date(2021, 8, 6, 13, 19, 23, 123456789, time.UTC), + }, + []spanner.NullJSON{ + nullJson(true, `"key1": "value1", "other-key1": ["value1", "value2"]}`), + nullJson(true, `"key2": "value2", "other-key2": ["value1", "value2"]}`), + }, + }, + ) + + for b.Loop() { + it := &checksumRowIterator{} + hash := sha256.New() + if err := it.updateChecksum(hash, row1); err != nil { + b.Fatal(err) + } + if err := it.updateChecksum(hash, row2); err != nil { + b.Fatal(err) + } + if err := it.updateChecksum(hash, row3); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkChecksumRowIteratorRandom(b *testing.B) { + for _, numRows := range []int{1, 10, 100, 1000, 10000} { + resultSet := testutil.CreateRandomResultSet(numRows) + columnNames := make([]string, len(resultSet.Metadata.RowType.Fields)) + for i := range columnNames { + columnNames[i] = resultSet.Metadata.RowType.Fields[i].Name + } + var err error + rows := make([]*spanner.Row, numRows) + for row, values := range resultSet.Rows { + columnValues := convertListValuesToColumnValues(resultSet.Metadata, values) + c := make([]interface{}, len(columnValues)) + for i := range columnValues { + c[i] = columnValues[i] + } + rows[row], err = spanner.NewRow(columnNames, c) + if err != nil { + b.Fatal(err) + } + } + + b.Run(fmt.Sprintf("num-rows-%d", numRows), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + it := &checksumRowIterator{} + hash, err := it.createMetadataChecksum(resultSet.Metadata) + if err != nil { + b.Fatal(err) + } + for _, row := range rows { + if err := it.updateChecksum(hash, row); err != nil { + b.Fatal(err) + } + } + } + }) + } +} + +func convertListValuesToColumnValues(metadata *spannerpb.ResultSetMetadata, values *structpb.ListValue) []spanner.GenericColumnValue { + res := make([]spanner.GenericColumnValue, len(values.Values)) + for i := range values.Values { + res[i] = spanner.GenericColumnValue{Value: values.Values[i], Type: metadata.RowType.Fields[i].Type} + } + return res +} diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index fc0a9fbb..f2120609 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -5636,19 +5636,19 @@ func nullUuid(valid bool, v string) spanner.NullUUID { return spanner.NullUUID{Valid: true, UUID: uuid.MustParse(v)} } -func setupTestDBConnection(t *testing.T) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnection(t testing.TB) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParams(t, "") } -func setupTestDBConnectionWithDialect(t *testing.T, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithDialect(t testing.TB, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParamsAndDialect(t, "", dialect) } -func setupTestDBConnectionWithParams(t *testing.T, params string) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithParams(t testing.TB, params string) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { return setupTestDBConnectionWithParamsAndDialect(t, params, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) } -func setupTestDBConnectionWithParamsAndDialect(t *testing.T, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithParamsAndDialect(t testing.TB, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServerWithDialect(t, dialect) db, err := sql.Open( "spanner", @@ -5663,7 +5663,7 @@ func setupTestDBConnectionWithParamsAndDialect(t *testing.T, params string, dial } } -func setupTestDBConnectionWithConfigurator(t *testing.T, params string, configurator func(config *spanner.ClientConfig, opts *[]option.ClientOption)) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithConfigurator(t testing.TB, params string, configurator func(config *spanner.ClientConfig, opts *[]option.ClientOption)) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServer(t) dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;%s", server.Address, params) config, err := ExtractConnectorConfig(dsn) @@ -5684,7 +5684,7 @@ func setupTestDBConnectionWithConfigurator(t *testing.T, params string, configur } } -func setupTestDBConnectionWithConnectorConfig(t *testing.T, config ConnectorConfig) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { +func setupTestDBConnectionWithConnectorConfig(t testing.TB, config ConnectorConfig) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) { server, _, serverTeardown := setupMockedTestServer(t) config.Host = server.Address if config.Params == nil { @@ -5703,23 +5703,23 @@ func setupTestDBConnectionWithConnectorConfig(t *testing.T, config ConnectorConf } } -func setupMockedTestServer(t *testing.T) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServer(t testing.TB) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfig(t, spanner.ClientConfig{}) } -func setupMockedTestServerWithDialect(t *testing.T, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithDialect(t testing.TB, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptionsAndDialect(t, spanner.ClientConfig{}, []option.ClientOption{}, dialect) } -func setupMockedTestServerWithConfig(t *testing.T, config spanner.ClientConfig) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfig(t testing.TB, config spanner.ClientConfig) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{}) } -func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config spanner.ClientConfig, clientOptions []option.ClientOption) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfigAndClientOptions(t testing.TB, config spanner.ClientConfig, clientOptions []option.ClientOption) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { return setupMockedTestServerWithConfigAndClientOptionsAndDialect(t, config, clientOptions, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) } -func setupMockedTestServerWithConfigAndClientOptionsAndDialect(t *testing.T, config spanner.ClientConfig, clientOptions []option.ClientOption, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { +func setupMockedTestServerWithConfigAndClientOptionsAndDialect(t testing.TB, config spanner.ClientConfig, clientOptions []option.ClientOption, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, client *spanner.Client, teardown func()) { server, opts, serverTeardown := testutil.NewMockedSpannerInMemTestServer(t) server.SetupSelectDialectResult(dialect) diff --git a/testutil/mocked_inmem_server.go b/testutil/mocked_inmem_server.go index 1bb70452..eda7ef6b 100644 --- a/testutil/mocked_inmem_server.go +++ b/testutil/mocked_inmem_server.go @@ -15,17 +15,21 @@ package testutil import ( + crypto "crypto/rand" "encoding/base64" "fmt" "math" + "math/rand" "net" "strconv" "testing" + "time" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" "cloud.google.com/go/spanner/apiv1/spannerpb" pb "cloud.google.com/go/spanner/testdata/protos" + "github.com/google/uuid" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -86,14 +90,14 @@ type MockedSpannerInMemTestServer struct { // NewMockedSpannerInMemTestServer creates a MockedSpannerInMemTestServer at // localhost with a random port and returns client options that can be used // to connect to it. -func NewMockedSpannerInMemTestServer(t *testing.T) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { +func NewMockedSpannerInMemTestServer(t testing.TB) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { return NewMockedSpannerInMemTestServerWithAddr(t, "localhost:0") } // NewMockedSpannerInMemTestServerWithAddr creates a MockedSpannerInMemTestServer // at a given listening address and returns client options that can be used // to connect to it. -func NewMockedSpannerInMemTestServerWithAddr(t *testing.T, addr string) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { +func NewMockedSpannerInMemTestServerWithAddr(t testing.TB, addr string) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { mockedServer = &MockedSpannerInMemTestServer{} opts = mockedServer.setupMockedServerWithAddr(t, addr) return mockedServer, opts, func() { @@ -104,7 +108,7 @@ func NewMockedSpannerInMemTestServerWithAddr(t *testing.T, addr string) (mockedS } } -func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, addr string) []option.ClientOption { +func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t testing.TB, addr string) []option.ClientOption { s.TestSpanner = NewInMemSpannerServer() s.TestInstanceAdmin = NewInMemInstanceAdminServer() s.TestDatabaseAdmin = NewInMemDatabaseAdminServer() @@ -124,7 +128,9 @@ func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, a if err != nil { t.Fatal(err) } - go s.server.Serve(lis) + go func() { + _ = s.server.Serve(lis) + }() s.Address = lis.Addr().String() opts := []option.ClientOption{ @@ -137,23 +143,23 @@ func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, a func (s *MockedSpannerInMemTestServer) SetupSelectDialectResult(dialect databasepb.DatabaseDialect) { result := &StatementResult{Type: StatementResultResultSet, ResultSet: CreateSelectDialectResultSet(dialect)} - s.TestSpanner.PutStatementResult(selectDialect, result) + _ = s.TestSpanner.PutStatementResult(selectDialect, result) } func (s *MockedSpannerInMemTestServer) setupSelect1Result() { result := &StatementResult{Type: StatementResultResultSet, ResultSet: CreateSelect1ResultSet()} - s.TestSpanner.PutStatementResult("SELECT 1", result) + _ = s.TestSpanner.PutStatementResult("SELECT 1", result) } func (s *MockedSpannerInMemTestServer) setupFooResults() { resultSet := CreateSingleColumnInt64ResultSet(selectFooFromBarResults, "FOO") result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} - s.TestSpanner.PutStatementResult(SelectFooFromBar, result) - s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ + _ = s.TestSpanner.PutStatementResult(SelectFooFromBar, result) + _ = s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ Type: StatementResultUpdateCount, UpdateCount: UpdateBarSetFooRowCount, }) - s.TestSpanner.PutStatementResult(UpdateSingersSetLastName, &StatementResult{ + _ = s.TestSpanner.PutStatementResult(UpdateSingersSetLastName, &StatementResult{ Type: StatementResultUpdateCount, UpdateCount: UpdateSingersSetLastNameRowCount, }) @@ -171,7 +177,7 @@ func (s *MockedSpannerInMemTestServer) setupSingersResults() { Rows: rows, } result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} - s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) + _ = s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) } // CreateSingleRowSingersResult creates a result set containing a single row of @@ -236,7 +242,7 @@ func createSingersRow(idx int64) *structpb.ListValue { } } -func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb.ResultSet { +func CreateResultSetMetadataWithAllTypes() *spannerpb.ResultSetMetadata { index := 0 fields := make([]*spannerpb.StructType_Field, 26) fields[index] = &spannerpb.StructType_Field{ @@ -410,9 +416,14 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb rowType := &spannerpb.StructType{ Fields: fields, } - metadata := &spannerpb.ResultSetMetadata{ + return &spannerpb.ResultSetMetadata{ RowType: rowType, } +} + +func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb.ResultSet { + metadata := CreateResultSetMetadataWithAllTypes() + fields := metadata.RowType.Fields rows := make([]*structpb.ListValue, 1) rowValue := make([]*structpb.Value, len(fields)) if nullValues { @@ -436,7 +447,7 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb Genre: &singer2ProtoEnum, } - index = 0 + index := 0 rowValue[index] = &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: true}} index++ rowValue[index] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "test"}} @@ -576,6 +587,122 @@ func CreateResultSetWithAllTypes(nullValues, nullValuesInArrays bool) *spannerpb } } +func CreateRandomResultSet(numRows int) *spannerpb.ResultSet { + metadata := CreateResultSetMetadataWithAllTypes() + fields := metadata.RowType.Fields + rows := make([]*structpb.ListValue, numRows) + + for i := 0; i < numRows; i++ { + rowValue := make([]*structpb.Value, len(fields)) + for col := range fields { + rowValue[col] = randomValue(fields[col].Type) + } + rows[i] = &structpb.ListValue{Values: rowValue} + } + return &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } +} + +var nullValue *structpb.Value + +func init() { + nullValue = &structpb.Value{Kind: &structpb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}} +} + +func randomValue(t *spannerpb.Type) *structpb.Value { + if rand.Intn(10) == 5 { + return nullValue + } + switch t.Code { + case spannerpb.TypeCode_BOOL: + return randomBoolValue() + case spannerpb.TypeCode_BYTES: + return randomBytesValue() + case spannerpb.TypeCode_DATE: + return randomDateValue() + case spannerpb.TypeCode_FLOAT32: + return randomFloat32Value() + case spannerpb.TypeCode_FLOAT64: + return randomFloat64Value() + case spannerpb.TypeCode_INT64: + return randomInt64Value() + case spannerpb.TypeCode_JSON: + return randomJsonValue() + case spannerpb.TypeCode_NUMERIC: + return randomNumericValue() + case spannerpb.TypeCode_STRING: + return randomStringValue() + case spannerpb.TypeCode_TIMESTAMP: + return randomTimestampValue() + case spannerpb.TypeCode_UUID: + return randomUuidValue() + case spannerpb.TypeCode_ARRAY: + numElements := rand.Intn(10) + value := &structpb.Value{Kind: &structpb.Value_ListValue{ListValue: &structpb.ListValue{Values: make([]*structpb.Value, numElements)}}} + for i := range numElements { + value.GetListValue().Values[i] = randomValue(t.ArrayElementType) + } + return value + } + return nullValue +} + +func randomBoolValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: rand.Intn(2) == 1}} +} + +func randomString() string { + b := make([]byte, rand.Intn(1024)) + _, _ = crypto.Read(b) + return base64.StdEncoding.EncodeToString(b) +} + +func randomBytesValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: randomString()}} +} + +func randomDateValue() *structpb.Value { + year := rand.Intn(2100) + 1 + month := rand.Intn(12) + 1 + day := rand.Intn(28) + 1 + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%04d-%02d-%02d", year, month, day)}} +} + +func randomFloat32Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(rand.Float32())}} +} + +func randomFloat64Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: rand.Float64()}} +} + +func randomJsonValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf(`{"key": "%s"}`, randomString())}} +} + +func randomInt64Value() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%d", rand.Int63())}} +} + +func randomNumericValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("%d.%d", rand.Intn(10000000), rand.Intn(1000))}} +} + +func randomStringValue() *structpb.Value { + return randomBytesValue() +} + +func randomTimestampValue() *structpb.Value { + t := time.UnixMilli(time.Now().UnixMilli() + int64(rand.Intn(1000000))) + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: t.Format(time.RFC3339)}} +} + +func randomUuidValue() *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: uuid.New().String()}} +} + func nullValueOrAlt(nullValue bool, alt *structpb.Value) *structpb.Value { if nullValue { return &structpb.Value{Kind: &structpb.Value_NullValue{}} diff --git a/transaction.go b/transaction.go index 041c84a3..c256d20b 100644 --- a/transaction.go +++ b/transaction.go @@ -15,10 +15,9 @@ package spannerdriver import ( - "bytes" "context" + "crypto/sha256" "database/sql/driver" - "encoding/gob" "fmt" "log/slog" "math/rand" @@ -499,6 +498,9 @@ func (tx *readWriteTransaction) runWithRetry(ctx context.Context, f func(ctx con if err == nil { err = f(ctx) } + if err == nil { + return + } if err == ErrAbortedDueToConcurrentModification { tx.logger.Log(ctx, LevelNotice, "transaction retry failed due to a concurrent modification") return @@ -625,7 +627,6 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen // If retries are enabled, we need to use a row iterator that will keep // track of a running checksum of all the results that we see. - buffer := &bytes.Buffer{} it := &checksumRowIterator{ RowIterator: tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), ctx: ctx, @@ -633,8 +634,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen stmt: stmt, stmtType: stmtType, options: execOptions.QueryOptions, - buffer: buffer, - enc: gob.NewEncoder(buffer), + hash: sha256.New(), } tx.statements = append(tx.statements, it) return it, nil diff --git a/transaction_test.go b/transaction_test.go index ddd6d05a..7634ae21 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -357,3 +357,63 @@ func TestTransactionTimeoutSecondStatement(t *testing.T) { t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w) } } + +func BenchmarkReadWriteTransaction(b *testing.B) { + db, server, teardown := setupTestDBConnection(b) + defer teardown() + ctx := context.Background() + query := "select * from random_table" + + for _, numRows := range []int{1, 10, 100, 1000, 10000} { + resultSet := testutil.CreateRandomResultSet(numRows) + _ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + }) + + b.Run(fmt.Sprintf("num-rows-%d", numRows), func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + b.Fatal(err) + } + //if _, err := tx.ExecContext(ctx, "set local retry_aborts_internally = false"); err != nil { + // b.Fatal(err) + //} + if _, err := tx.ExecContext(ctx, "set local transaction_tag = 'my_tag'"); err != nil { + b.Fatal(err) + } + rows, err := tx.QueryContext(ctx, query) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + // Just iterate through the results + } + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + for range 10 { + if _, err = tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + b.Fatal(err) + } + } + if _, err := tx.ExecContext(ctx, "start batch dml"); err != nil { + b.Fatal(err) + } + for range 10 { + if _, err = tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + b.Fatal(err) + } + } + if _, err := tx.ExecContext(ctx, "run batch"); err != nil { + b.Fatal(err) + } + if err := tx.Commit(); err != nil { + b.Fatal(err) + } + } + }) + } +}