Skip to content

Commit 0079cc4

Browse files
authored
backport upstream vitessio/12007 (#75)
1 parent 3cfa60f commit 0079cc4

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

go/vt/vitessdriver/rows.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ limitations under the License.
1717
package vitessdriver
1818

1919
import (
20+
"database/sql"
2021
"database/sql/driver"
2122
"io"
23+
"reflect"
24+
"time"
2225

2326
"vitess.io/vitess/go/sqltypes"
27+
"vitess.io/vitess/go/vt/proto/query"
2428
)
2529

2630
// rows creates a database/sql/driver compliant Row iterator
@@ -58,3 +62,60 @@ func (ri *rows) Next(dest []driver.Value) error {
5862
ri.index++
5963
return nil
6064
}
65+
66+
var (
67+
typeInt8 = reflect.TypeOf(int8(0))
68+
typeUint8 = reflect.TypeOf(uint8(0))
69+
typeInt16 = reflect.TypeOf(int16(0))
70+
typeUint16 = reflect.TypeOf(uint16(0))
71+
typeInt32 = reflect.TypeOf(int32(0))
72+
typeUint32 = reflect.TypeOf(uint32(0))
73+
typeInt64 = reflect.TypeOf(int64(0))
74+
typeUint64 = reflect.TypeOf(uint64(0))
75+
typeFloat32 = reflect.TypeOf(float32(0))
76+
typeFloat64 = reflect.TypeOf(float64(0))
77+
typeRawBytes = reflect.TypeOf(sql.RawBytes{})
78+
typeTime = reflect.TypeOf(time.Time{})
79+
typeUnknown = reflect.TypeOf(new(interface{}))
80+
)
81+
82+
// Implements the RowsColumnTypeScanType interface
83+
func (ri *rows) ColumnTypeScanType(index int) reflect.Type {
84+
field := ri.qr.Fields[index]
85+
switch field.GetType() {
86+
case query.Type_INT8:
87+
return typeInt8
88+
case query.Type_UINT8:
89+
return typeUint8
90+
case query.Type_INT16, query.Type_YEAR:
91+
return typeInt16
92+
case query.Type_UINT16:
93+
return typeUint16
94+
case query.Type_INT24:
95+
return typeInt32
96+
case query.Type_UINT24: // no 24 bit type, using 32 instead
97+
return typeUint32
98+
case query.Type_INT32:
99+
return typeInt32
100+
case query.Type_UINT32:
101+
return typeUint32
102+
case query.Type_INT64:
103+
return typeInt64
104+
case query.Type_UINT64:
105+
return typeUint64
106+
case query.Type_FLOAT32:
107+
return typeFloat32
108+
case query.Type_FLOAT64:
109+
return typeFloat64
110+
case query.Type_TIMESTAMP, query.Type_DECIMAL, query.Type_VARCHAR, query.Type_TEXT,
111+
query.Type_BLOB, query.Type_VARBINARY, query.Type_CHAR, query.Type_BINARY, query.Type_BIT,
112+
query.Type_ENUM, query.Type_SET, query.Type_TUPLE, query.Type_GEOMETRY, query.Type_JSON,
113+
query.Type_HEXNUM, query.Type_HEXVAL:
114+
115+
return typeRawBytes
116+
case query.Type_DATE, query.Type_TIME, query.Type_DATETIME:
117+
return typeTime
118+
default:
119+
return typeUnknown
120+
}
121+
}

go/vt/vitessdriver/rows_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ package vitessdriver
1818

1919
import (
2020
"database/sql/driver"
21+
"fmt"
2122
"io"
2223
"reflect"
2324
"testing"
2425

26+
"github.com/stretchr/testify/assert"
2527
"github.com/stretchr/testify/require"
2628

2729
"vitess.io/vitess/go/sqltypes"
@@ -135,3 +137,92 @@ func TestRows(t *testing.T) {
135137

136138
_ = ri.Close()
137139
}
140+
141+
// Test that the ColumnTypeScanType function returns the correct reflection type for each
142+
// sql type. The sql type in turn comes from a table column's type.
143+
func TestColumnTypeScanType(t *testing.T) {
144+
var r = sqltypes.Result{
145+
Fields: []*querypb.Field{
146+
{
147+
Name: "field1",
148+
Type: sqltypes.Int8,
149+
},
150+
{
151+
Name: "field2",
152+
Type: sqltypes.Uint8,
153+
},
154+
{
155+
Name: "field3",
156+
Type: sqltypes.Int16,
157+
},
158+
{
159+
Name: "field4",
160+
Type: sqltypes.Uint16,
161+
},
162+
{
163+
Name: "field5",
164+
Type: sqltypes.Int24,
165+
},
166+
{
167+
Name: "field6",
168+
Type: sqltypes.Uint24,
169+
},
170+
{
171+
Name: "field7",
172+
Type: sqltypes.Int32,
173+
},
174+
{
175+
Name: "field8",
176+
Type: sqltypes.Uint32,
177+
},
178+
{
179+
Name: "field9",
180+
Type: sqltypes.Int64,
181+
},
182+
{
183+
Name: "field10",
184+
Type: sqltypes.Uint64,
185+
},
186+
{
187+
Name: "field11",
188+
Type: sqltypes.Float32,
189+
},
190+
{
191+
Name: "field12",
192+
Type: sqltypes.Float64,
193+
},
194+
{
195+
Name: "field13",
196+
Type: sqltypes.VarBinary,
197+
},
198+
{
199+
Name: "field14",
200+
Type: sqltypes.Datetime,
201+
},
202+
},
203+
}
204+
205+
ri := newRows(&r, &converter{}).(driver.RowsColumnTypeScanType)
206+
defer ri.Close()
207+
208+
wantTypes := []reflect.Type{
209+
typeInt8,
210+
typeUint8,
211+
typeInt16,
212+
typeUint16,
213+
typeInt32,
214+
typeUint32,
215+
typeInt32,
216+
typeUint32,
217+
typeInt64,
218+
typeUint64,
219+
typeFloat32,
220+
typeFloat64,
221+
typeRawBytes,
222+
typeTime,
223+
}
224+
225+
for i := 0; i < len(wantTypes); i++ {
226+
assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i]))
227+
}
228+
}

0 commit comments

Comments
 (0)