Skip to content

Commit

Permalink
feat: return PROTO columns as bytes and integers
Browse files Browse the repository at this point in the history
PROTO columns are returned as []byte. ENUM columnns are returned
as INT64. This enables the use of these columns through the standard
database/sql Scan method.

Fixes #333
  • Loading branch information
olavloite committed Jan 15, 2025
1 parent 8bddc89 commit 44a4ea4
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 19 deletions.
89 changes: 75 additions & 14 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
pb "cloud.google.com/go/spanner/testdata/protos"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/go-sql-spanner/testutil"
Expand Down Expand Up @@ -550,6 +551,8 @@ func TestQueryWithAllTypes(t *testing.T) {
var d civil.Date
var ts time.Time
var j spanner.NullJSON
var p []byte
var e int64
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
Expand All @@ -560,7 +563,9 @@ func TestQueryWithAllTypes(t *testing.T) {
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray)
var pArray [][]byte
var eArray []spanner.NullInt64
err = rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &p, &e, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray, &pArray, &eArray)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -591,10 +596,25 @@ func TestQueryWithAllTypes(t *testing.T) {
if g, w := ts, time.Date(2021, 7, 21, 21, 7, 59, 339911800, time.UTC); g != w {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", g, w)
}
if !runsOnEmulator() {
if g, w := j, nullJson(true, `{"key":"value","other-key":["value1","value2"]}`); !cmp.Equal(g, w) {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", g, w)
}
if g, w := j, nullJson(true, `{"key":"value","other-key":["value1","value2"]}`); !cmp.Equal(g, w) {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", g, w)
}
wantSingerEnumValue := pb.Genre_ROCK
wantSingerProtoMsg := pb.SingerInfo{
SingerId: proto.Int64(1),
BirthDate: proto.String("January"),
Nationality: proto.String("Country1"),
Genre: &wantSingerEnumValue,
}
gotSingerProto := pb.SingerInfo{}
if err := proto.Unmarshal(p, &gotSingerProto); err != nil {
t.Fatalf("failed to unmarshal proto: %v", err)
}
if g, w := &gotSingerProto, &wantSingerProtoMsg; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(pb.SingerInfo{})) {
t.Errorf("row value mismatch for proto\nGot: %v\nWant: %v", g, w)
}
if g, w := pb.Genre(e), wantSingerEnumValue; g != w {
t.Errorf("row value mismatch for enum\nGot: %v\nWant: %v", g, w)
}
if g, w := bArray, []spanner.NullBool{{Valid: true, Bool: true}, {}, {Valid: true, Bool: false}}; !cmp.Equal(g, w) {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", g, w)
Expand Down Expand Up @@ -623,14 +643,39 @@ func TestQueryWithAllTypes(t *testing.T) {
if g, w := tsArray, []spanner.NullTime{{Valid: true, Time: ts1}, {}, {Valid: true, Time: ts2}}; !cmp.Equal(g, w) {
t.Errorf("row value mismatch for timestamp array\nGot: %v\nWant: %v", g, w)
}
if !runsOnEmulator() {
if g, w := jArray, []spanner.NullJSON{
nullJson(true, `{"key1": "value1", "other-key1": ["value1", "value2"]}`),
nullJson(false, ""),
nullJson(true, `{"key2": "value2", "other-key2": ["value1", "value2"]}`),
}; !cmp.Equal(g, w) {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", g, w)
}
if g, w := jArray, []spanner.NullJSON{
nullJson(true, `{"key1": "value1", "other-key1": ["value1", "value2"]}`),
nullJson(false, ""),
nullJson(true, `{"key2": "value2", "other-key2": ["value1", "value2"]}`),
}; !cmp.Equal(g, w) {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", g, w)
}
if g, w := len(pArray), 3; g != w {
t.Errorf("row value length mismatch for proto array\nGot: %v\nWant: %v", g, w)
}
wantSinger2ProtoEnum := pb.Genre_FOLK
wantSinger2ProtoMsg := pb.SingerInfo{
SingerId: proto.Int64(2),
BirthDate: proto.String("February"),
Nationality: proto.String("Country2"),
Genre: &wantSinger2ProtoEnum,
}
gotSingerProto1 := pb.SingerInfo{}
if err := proto.Unmarshal(pArray[0], &gotSingerProto1); err != nil {
t.Fatalf("failed to unmarshal proto: %v", err)
}
gotSingerProto2 := pb.SingerInfo{}
if err := proto.Unmarshal(pArray[2], &gotSingerProto2); err != nil {
t.Fatalf("failed to unmarshal proto: %v", err)
}
if g, w := &gotSingerProto1, &wantSingerProtoMsg; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(pb.SingerInfo{})) {
t.Errorf("row value mismatch for proto\nGot: %v\nWant: %v", g, w)
}
if g, w := pArray[1], []byte(nil); !cmp.Equal(g, w) {
t.Errorf("row value mismatch for proto\nGot: %v\nWant: %v", g, w)
}
if g, w := &gotSingerProto2, &wantSinger2ProtoMsg; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(pb.SingerInfo{})) {
t.Errorf("row value mismatch for proto\nGot: %v\nWant: %v", g, w)
}
}
if rows.Err() != nil {
Expand Down Expand Up @@ -963,6 +1008,8 @@ func TestQueryWithNullParameters(t *testing.T) {
var d spanner.NullDate // There's no equivalent sql type.
var ts sql.NullTime
var j spanner.NullJSON // There's no equivalent sql type.
var p []byte // Proto columns are returned as bytes.
var e sql.NullInt64 // Enum columns are returned as int64.
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
Expand All @@ -973,7 +1020,9 @@ func TestQueryWithNullParameters(t *testing.T) {
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray)
var pArray [][]byte
var eArray []spanner.NullInt64
err = rows.Scan(&b, &s, &bt, &i, &f32, &f, &r, &d, &ts, &j, &p, &e, &bArray, &sArray, &btArray, &iArray, &f32Array, &fArray, &rArray, &dArray, &tsArray, &jArray, &pArray, &eArray)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1007,6 +1056,12 @@ func TestQueryWithNullParameters(t *testing.T) {
if j.Valid {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", j, spanner.NullJSON{})
}
if p != nil {
t.Errorf("row value mismatch for proto\nGot: %v\nWant: %v", p, nil)
}
if e.Valid {
t.Errorf("row value mismatch for enum\nGot: %v\nWant: %v", e, spanner.NullInt64{})
}
if bArray != nil {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", bArray, nil)
}
Expand Down Expand Up @@ -1037,6 +1092,12 @@ func TestQueryWithNullParameters(t *testing.T) {
if jArray != nil {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", jArray, nil)
}
if pArray != nil {
t.Errorf("row value mismatch for proto array\nGot: %v\nWant: %v", pArray, nil)
}
if eArray != nil {
t.Errorf("row value mismatch for enum array\nGot: %v\nWant: %v", eArray, nil)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
Expand Down
8 changes: 4 additions & 4 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (r *rows) Next(dest []driver.Value) error {
return err
}
switch col.Type.Code {
case sppb.TypeCode_INT64:
case sppb.TypeCode_INT64, sppb.TypeCode_ENUM:
var v spanner.NullInt64
if err := col.Decode(&v); err != nil {
return err
Expand Down Expand Up @@ -167,7 +167,7 @@ func (r *rows) Next(dest []driver.Value) error {
// for JSON in the Go sql package. That means that instead of returning
// nil we should return a NullJSON with valid=false.
dest[i] = v
case sppb.TypeCode_BYTES:
case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO:
// The column value is a base64 encoded string.
var v []byte
if err := col.Decode(&v); err != nil {
Expand Down Expand Up @@ -206,7 +206,7 @@ func (r *rows) Next(dest []driver.Value) error {
}
case sppb.TypeCode_ARRAY:
switch col.Type.ArrayElementType.Code {
case sppb.TypeCode_INT64:
case sppb.TypeCode_INT64, sppb.TypeCode_ENUM:
var v []spanner.NullInt64
if err := col.Decode(&v); err != nil {
return err
Expand Down Expand Up @@ -242,7 +242,7 @@ func (r *rows) Next(dest []driver.Value) error {
return err
}
dest[i] = v
case sppb.TypeCode_BYTES:
case sppb.TypeCode_BYTES, sppb.TypeCode_PROTO:
var v [][]byte
if err := col.Decode(&v); err != nil {
return err
Expand Down
76 changes: 75 additions & 1 deletion testutil/mocked_inmem_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ import (
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/admin/instance/apiv1/instancepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
pb "cloud.google.com/go/spanner/testdata/protos"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -223,7 +226,7 @@ func createSingersRow(idx int64) *structpb.ListValue {

func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
index := 0
fields := make([]*spannerpb.StructType_Field, 20)
fields := make([]*spannerpb.StructType_Field, 24)
fields[index] = &spannerpb.StructType_Field{
Name: "ColBool",
Type: &spannerpb.Type{Code: spannerpb.TypeCode_BOOL},
Expand Down Expand Up @@ -274,6 +277,16 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
Type: &spannerpb.Type{Code: spannerpb.TypeCode_JSON},
}
index++
fields[index] = &spannerpb.StructType_Field{
Name: "ColProto",
Type: &spannerpb.Type{Code: spannerpb.TypeCode_PROTO},
}
index++
fields[index] = &spannerpb.StructType_Field{
Name: "ColProtoEnum",
Type: &spannerpb.Type{Code: spannerpb.TypeCode_ENUM},
}
index++
fields[index] = &spannerpb.StructType_Field{
Name: "ColBoolArray",
Type: &spannerpb.Type{
Expand Down Expand Up @@ -353,6 +366,22 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
ArrayElementType: &spannerpb.Type{Code: spannerpb.TypeCode_JSON},
},
}
index++
fields[index] = &spannerpb.StructType_Field{
Name: "ColProtoArray",
Type: &spannerpb.Type{
Code: spannerpb.TypeCode_ARRAY,
ArrayElementType: &spannerpb.Type{Code: spannerpb.TypeCode_PROTO},
},
}
index++
fields[index] = &spannerpb.StructType_Field{
Name: "ColProtoEnumArray",
Type: &spannerpb.Type{
Code: spannerpb.TypeCode_ARRAY,
ArrayElementType: &spannerpb.Type{Code: spannerpb.TypeCode_ENUM},
},
}
rowType := &spannerpb.StructType{
Fields: fields,
}
Expand All @@ -366,6 +395,22 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
rowValue[i] = &structpb.Value{Kind: &structpb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}
}
} else {
singerEnumValue := pb.Genre_ROCK
singerProtoMsg := pb.SingerInfo{
SingerId: proto.Int64(1),
BirthDate: proto.String("January"),
Nationality: proto.String("Country1"),
Genre: &singerEnumValue,
}

singer2ProtoEnum := pb.Genre_FOLK
singer2ProtoMsg := pb.SingerInfo{
SingerId: proto.Int64(2),
BirthDate: proto.String("February"),
Nationality: proto.String("Country2"),
Genre: &singer2ProtoEnum,
}

index = 0
rowValue[index] = &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: true}}
index++
Expand All @@ -387,6 +432,10 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
index++
rowValue[index] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: `{"key": "value", "other-key": ["value1", "value2"]}`}}
index++
rowValue[index] = protoMessageProto(&singerProtoMsg)
index++
rowValue[index] = protoEnumProto(&singerEnumValue)
index++
rowValue[index] = &structpb.Value{Kind: &structpb.Value_ListValue{
ListValue: &structpb.ListValue{Values: []*structpb.Value{
{Kind: &structpb.Value_BoolValue{BoolValue: true}},
Expand Down Expand Up @@ -466,6 +515,22 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
{Kind: &structpb.Value_StringValue{StringValue: `{"key2": "value2", "other-key2": ["value1", "value2"]}`}},
}},
}}
index++
rowValue[index] = &structpb.Value{Kind: &structpb.Value_ListValue{
ListValue: &structpb.ListValue{Values: []*structpb.Value{
protoMessageProto(&singerProtoMsg),
{Kind: &structpb.Value_NullValue{}},
protoMessageProto(&singer2ProtoMsg),
}},
}}
index++
rowValue[index] = &structpb.Value{Kind: &structpb.Value_ListValue{
ListValue: &structpb.ListValue{Values: []*structpb.Value{
protoEnumProto(&singerEnumValue),
{Kind: &structpb.Value_NullValue{}},
protoEnumProto(&singer2ProtoEnum),
}},
}}
}
rows[0] = &structpb.ListValue{
Values: rowValue,
Expand All @@ -476,6 +541,15 @@ func CreateResultSetWithAllTypes(nullValues bool) *spannerpb.ResultSet {
}
}

func protoMessageProto(m proto.Message) *structpb.Value {
var b, _ = proto.Marshal(m)
return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}}
}

func protoEnumProto(e protoreflect.Enum) *structpb.Value {
return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(int64(e.Number()), 10)}}
}

func CreateSelect1ResultSet() *spannerpb.ResultSet {
return CreateSingleColumnResultSet([]int64{1}, "")
}
Expand Down

0 comments on commit 44a4ea4

Please sign in to comment.