Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions checksum_row_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/gob"
"reflect"
"encoding/binary"
"hash"
"math"
"sort"

"cloud.google.com/go/spanner"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
Expand All @@ -32,16 +34,6 @@ import (

var errNextAfterSTop = status.Errorf(codes.FailedPrecondition, "Next called after Stop")

// init registers the protobuf types with gob so they can be encoded.
func init() {
gob.Register(structpb.Value_BoolValue{})
gob.Register(structpb.Value_NumberValue{})
gob.Register(structpb.Value_StringValue{})
gob.Register(structpb.Value_NullValue{})
gob.Register(structpb.Value_ListValue{})
gob.Register(structpb.Value_StructValue{})
}

var _ rowIterator = &checksumRowIterator{}

// checksumRowIterator implements rowIterator and keeps track of a running
Expand All @@ -66,12 +58,12 @@ type checksumRowIterator struct {
// the retry has finished.
stopped bool

// checksum contains the current checksum for the results that have been
// hash contains the current hash for the results that have been
// seen. It is calculated as a SHA256 checksum over all rows that so far
// have been returned.
checksum *[32]byte
buffer *bytes.Buffer
enc *gob.Encoder
hash hash.Hash
int32Buf [4]byte
float64Buf [8]byte

// errIndex and err indicate any error and the index in the result set
// where the error occurred.
Expand Down Expand Up @@ -110,7 +102,7 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) {
// checksum of the columns that are included in this result. This is
// also used to detect the possible difference between two empty
// result sets with a different set of columns.
it.checksum, err = createMetadataChecksum(it.enc, it.buffer, it.metadata)
it.hash, err = it.createMetadataChecksum(it.metadata)
if err != nil {
return err
}
Expand All @@ -119,45 +111,79 @@ func (it *checksumRowIterator) Next() (row *spanner.Row, err error) {
return it.err
}
// Update the current checksum.
it.checksum, err = updateChecksum(it.enc, it.buffer, it.checksum, row)
return err
return it.updateChecksum(it.hash, row)
})
return row, err
}

// updateChecksum calculates the following checksum based on a current checksum
// and a new row.
func updateChecksum(enc *gob.Encoder, buffer *bytes.Buffer, currentChecksum *[32]byte, row *spanner.Row) (*[32]byte, error) {
buffer.Reset()
buffer.Write(currentChecksum[:])
func (it *checksumRowIterator) updateChecksum(hash hash.Hash, row *spanner.Row) error {
for i := 0; i < row.Size(); i++ {
var v spanner.GenericColumnValue
err := row.Column(i, &v)
if err != nil {
return nil, err
return err
}
err = enc.Encode(v)
if err != nil {
return nil, err
it.hashValue(v.Value, hash)
}
return nil
}

func (it *checksumRowIterator) hashValue(value *structpb.Value, digest hash.Hash) {
switch value.GetKind().(type) {
case *structpb.Value_StringValue:
digest.Write(intToByte(it.int32Buf, len(value.GetStringValue())))
digest.Write([]byte(value.GetStringValue()))
case *structpb.Value_NullValue:
digest.Write([]byte{0})
case *structpb.Value_NumberValue:
digest.Write(float64ToByte(it.float64Buf, value.GetNumberValue()))
case *structpb.Value_BoolValue:
if value.GetBoolValue() {
digest.Write([]byte{1})
} else {
digest.Write([]byte{0})
}
case *structpb.Value_StructValue:
fields := make([]string, 0, len(value.GetStructValue().Fields))
for field := range value.GetStructValue().Fields {
fields = append(fields, field)
}
sort.Strings(fields)
for _, field := range fields {
digest.Write(intToByte(it.int32Buf, len(field)))
digest.Write([]byte(field))
it.hashValue(value.GetStructValue().Fields[field], digest)
}
case *structpb.Value_ListValue:
for _, v := range value.GetListValue().GetValues() {
it.hashValue(v, digest)
}
}
res := sha256.Sum256(buffer.Bytes())
return &res, nil
}

func intToByte(buf [4]byte, v int) []byte {
binary.BigEndian.PutUint32(buf[:], uint32(v))
return buf[:]
}

func float64ToByte(buf [8]byte, f float64) []byte {
binary.BigEndian.PutUint64(buf[:], math.Float64bits(f))
return buf[:]
}

// createMetadataChecksum calculates the checksum of the metadata of a result.
// Only the column names and types are included in the checksum. Any transaction
// metadata is not included.
func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sppb.ResultSetMetadata) (*[32]byte, error) {
buffer.Reset()
func (it *checksumRowIterator) createMetadataChecksum(metadata *sppb.ResultSetMetadata) (hash.Hash, error) {
digest := sha256.New()
for _, field := range metadata.RowType.Fields {
err := enc.Encode(field)
if err != nil {
return nil, err
}
digest.Write(intToByte(it.int32Buf, len(field.Name)))
digest.Write([]byte(field.Name))
digest.Write(intToByte(it.int32Buf, int(field.Type.Code.Number())))
}
res := sha256.Sum256(buffer.Bytes())
return &res, nil
return digest, nil
}

// retry implements retriableStatement.retry for queries. It will execute the
Expand All @@ -167,8 +193,6 @@ func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sp
// initial iterator was also returned by the new iterator, and that the errors
// were returned by the same row index.
func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error {
buffer := &bytes.Buffer{}
enc := gob.NewEncoder(buffer)
retryIt := tx.QueryWithOptions(ctx, it.stmt, it.options)
// If the original iterator had been stopped, we should also always stop the
// new iterator.
Expand All @@ -193,12 +217,13 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS
// Iterate over the new result set as many times as we iterated over the initial
// result set. The checksums of the two should be equal. Also, the new result set
// should return any error on the same index as the original.
var newChecksum *[32]byte
// var newChecksum *[32]byte
var checksumErr error
newHash := sha256.New()
for n := int64(0); n < it.nc; n++ {
row, err := retryIt.Next()
if n == 0 && (err == nil || err == iterator.Done) {
newChecksum, checksumErr = createMetadataChecksum(enc, buffer, retryIt.Metadata)
newHash, checksumErr = it.createMetadataChecksum(retryIt.Metadata)
if checksumErr != nil {
return failRetry(checksumErr)
}
Expand All @@ -211,14 +236,14 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS
}
if errorsEqualForRetry(err, it.err) && n == it.errIndex {
// Check that the checksums are also equal.
if !checksumsEqual(newChecksum, it.checksum) {
if !checksumsEqual(newHash, it.hash) {
return failRetry(ErrAbortedDueToConcurrentModification)
}
return replaceIt(nil)
}
return failRetry(ErrAbortedDueToConcurrentModification)
}
newChecksum, err = updateChecksum(enc, buffer, newChecksum, row)
err = it.updateChecksum(newHash, row)
if err != nil {
return failRetry(err)
}
Expand All @@ -230,16 +255,21 @@ func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteS
if it.err != nil {
return failRetry(ErrAbortedDueToConcurrentModification)
}
if !checksumsEqual(newChecksum, it.checksum) {
if !checksumsEqual(newHash, it.hash) {
return failRetry(ErrAbortedDueToConcurrentModification)
}
// Everything seems to be equal, replace the underlying iterator and return
// a nil error.
return replaceIt(nil)
}

func checksumsEqual(c1, c2 *[32]byte) bool {
return (reflect.ValueOf(c1).IsNil() && reflect.ValueOf(c2).IsNil()) || *c1 == *c2
func checksumsEqual(h1, h2 hash.Hash) bool {
if h1 == nil || h2 == nil {
return h1 == h2
}
c1 := h1.Sum(nil)
c2 := h2.Sum(nil)
return bytes.Equal(c1, c2)
}

func (it *checksumRowIterator) Stop() {
Expand Down
Loading
Loading