Skip to content

Commit 74278f1

Browse files
vitess-bot[bot]vmg
andauthored
[release-19.0] bugfix: wrong field type returned for SUM (#15192) (#15206)
Signed-off-by: Andres Taylor <andres@planetscale.com> Signed-off-by: Vicent Marti <vmg@strn.cat> Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com> Co-authored-by: Vicent Marti <vmg@strn.cat>
1 parent 4015988 commit 74278f1

File tree

7 files changed

+117
-73
lines changed

7 files changed

+117
-73
lines changed

go/test/endtoend/utils/mysql.go

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ import (
2727

2828
"github.com/stretchr/testify/assert"
2929

30+
"vitess.io/vitess/go/mysql"
3031
"vitess.io/vitess/go/mysql/collations"
31-
3232
"vitess.io/vitess/go/sqltypes"
33-
"vitess.io/vitess/go/vt/dbconfigs"
34-
"vitess.io/vitess/go/vt/sqlparser"
35-
36-
"vitess.io/vitess/go/mysql"
3733
"vitess.io/vitess/go/test/endtoend/cluster"
34+
"vitess.io/vitess/go/vt/dbconfigs"
3835
"vitess.io/vitess/go/vt/mysqlctl"
36+
querypb "vitess.io/vitess/go/vt/proto/query"
37+
"vitess.io/vitess/go/vt/sqlparser"
3938
)
4039

4140
const mysqlShutdownTimeout = 1 * time.Minute
@@ -160,7 +159,9 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error {
160159
return nil
161160
}
162161

163-
func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) error {
162+
func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error {
163+
t.Helper()
164+
164165
if vtQr == nil && mysqlQr == nil {
165166
return nil
166167
}
@@ -173,28 +174,29 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
173174
return errors.New("MySQL result is 'nil' while Vitess' is not.\n")
174175
}
175176

176-
var errStr string
177-
if compareColumns {
178-
vtColCount := len(vtQr.Fields)
179-
myColCount := len(mysqlQr.Fields)
180-
if vtColCount > 0 && myColCount > 0 {
181-
if vtColCount != myColCount {
182-
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
183-
errStr += fmt.Sprintf("column count does not match: %d vs %d\n", vtColCount, myColCount)
184-
}
185-
186-
var vtCols []string
187-
var myCols []string
188-
for i, vtField := range vtQr.Fields {
189-
vtCols = append(vtCols, vtField.Name)
190-
myCols = append(myCols, mysqlQr.Fields[i].Name)
191-
}
192-
if !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
193-
errStr += "column names do not match - the expected values are what mysql produced\n"
194-
errStr += fmt.Sprintf("Not equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
195-
}
177+
vtColCount := len(vtQr.Fields)
178+
myColCount := len(mysqlQr.Fields)
179+
180+
if vtColCount != myColCount {
181+
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
182+
}
183+
184+
if vtColCount > 0 {
185+
var vtCols []string
186+
var myCols []string
187+
for i, vtField := range vtQr.Fields {
188+
myField := mysqlQr.Fields[i]
189+
checkFields(t, myField.Name, vtField, myField)
190+
191+
vtCols = append(vtCols, vtField.Name)
192+
myCols = append(myCols, myField.Name)
193+
}
194+
195+
if compareColumnNames && !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
196+
t.Errorf("column names do not match - the expected values are what mysql produced\nNot equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
196197
}
197198
}
199+
198200
stmt, err := sqlparser.NewTestParser().Parse(query)
199201
if err != nil {
200202
t.Error(err)
@@ -209,7 +211,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
209211
return nil
210212
}
211213

212-
errStr += "Query (" + query + ") results mismatched.\nVitess Results:\n"
214+
errStr := "Query (" + query + ") results mismatched.\nVitess Results:\n"
213215
for _, row := range vtQr.Rows {
214216
errStr += fmt.Sprintf("%s\n", row)
215217
}
@@ -229,6 +231,20 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
229231
return errors.New(errStr)
230232
}
231233

234+
func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) {
235+
t.Helper()
236+
if vtField.Type != myField.Type {
237+
t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String())
238+
}
239+
240+
// starting in Vitess 20, decimal types are properly sized in their field information
241+
if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal {
242+
if vtField.Decimals != myField.Decimals {
243+
t.Errorf("for column %s field decimals count do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Decimals, vtField.Decimals)
244+
}
245+
}
246+
}
247+
232248
func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) {
233249
if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil {
234250
return

go/vt/vtgate/engine/aggregations.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ func (ap *AggregateParams) String() string {
9191

9292
func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
9393
if ap.OrigOpcode != AggregateUnassigned {
94-
return ap.OrigOpcode.Type(inputType)
94+
return ap.OrigOpcode.SQLType(inputType)
9595
}
96-
return ap.Opcode.Type(inputType)
96+
return ap.Opcode.SQLType(inputType)
9797
}
9898

9999
type aggregator interface {

go/vt/vtgate/engine/opcode/constants.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ package opcode
1919
import (
2020
"fmt"
2121

22+
"vitess.io/vitess/go/mysql/collations"
2223
"vitess.io/vitess/go/sqltypes"
2324
querypb "vitess.io/vitess/go/vt/proto/query"
25+
"vitess.io/vitess/go/vt/vtgate/evalengine"
2426
)
2527

2628
// PulloutOpcode is a number representing the opcode
@@ -138,7 +140,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) {
138140
}
139141

140142
// Type returns the opcode return sql type, and a bool telling is we are sure about this type or not
141-
func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
143+
func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
142144
switch code {
143145
case AggregateUnassigned:
144146
return sqltypes.Null
@@ -169,6 +171,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
169171
}
170172
}
171173

174+
func (code AggregateOpcode) Nullable() bool {
175+
switch code {
176+
case AggregateCount, AggregateCountStar:
177+
return false
178+
default:
179+
return true
180+
}
181+
}
182+
183+
func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type {
184+
sqltype := code.SQLType(t.Type())
185+
collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset())
186+
nullable := code.Nullable()
187+
size := t.Size()
188+
189+
scale := t.Scale()
190+
if code == AggregateAvg {
191+
scale += 4
192+
}
193+
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
194+
}
195+
172196
func (code AggregateOpcode) NeedsComparableValues() bool {
173197
switch code {
174198
case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax:

go/vt/vtgate/engine/opcode/constants_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
func TestCheckAllAggrOpCodes(t *testing.T) {
3131
// This test is just checking that we never reach the panic when using Type() on valid opcodes
3232
for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ {
33-
i.Type(sqltypes.Null)
33+
i.SQLType(sqltypes.Null)
3434
}
3535
}
3636

@@ -56,7 +56,7 @@ func TestType(t *testing.T) {
5656

5757
for _, tc := range tt {
5858
t.Run(tc.opcode.String()+"_"+tc.typ.String(), func(t *testing.T) {
59-
out := tc.opcode.Type(tc.typ)
59+
out := tc.opcode.SQLType(tc.typ)
6060
assert.Equal(t, tc.out, out)
6161
})
6262
}
@@ -70,7 +70,7 @@ func TestType_Panic(t *testing.T) {
7070
assert.Contains(t, errMsg, "ERROR", "Expected panic message containing 'ERROR'")
7171
}
7272
}()
73-
AggregateOpcode(999).Type(sqltypes.VarChar)
73+
AggregateOpcode(999).SQLType(sqltypes.VarChar)
7474
}
7575

7676
func TestNeedsListArg(t *testing.T) {

go/vt/vtgate/engine/projection.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,12 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query
158158
fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG)
159159
}
160160
fields = append(fields, &querypb.Field{
161-
Name: col,
162-
Type: typ.Type(),
163-
Charset: uint32(typ.Collation()),
164-
Flags: fl,
161+
Name: col,
162+
Type: typ.Type(),
163+
Charset: uint32(typ.Collation()),
164+
ColumnLength: uint32(typ.Size()),
165+
Decimals: uint32(typ.Scale()),
166+
Flags: fl,
165167
})
166168
}
167169
return fields, nil

0 commit comments

Comments
 (0)