Skip to content

Commit bdab2bf

Browse files
Fix nullability checks in evalengine (#14556)
Signed-off-by: Manan Gupta <manan@planetscale.com> Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
1 parent 49a1154 commit bdab2bf

21 files changed

+166
-85
lines changed

go/mysql/collations/integration/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func mysqlconn(t *testing.T) *mysql.Conn {
4747
if err != nil {
4848
t.Fatal(err)
4949
}
50-
if !strings.HasPrefix(conn.ServerVersion, "8.0.") {
50+
if !strings.HasPrefix(conn.ServerVersion, "8.") {
5151
conn.Close()
5252
t.Skipf("collation integration tests are only supported in MySQL 8.0+")
5353
}

go/vt/vtgate/evalengine/api_type_aggregation.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type typeAggregation struct {
4242
geometry uint16
4343
blob uint16
4444
total uint16
45+
46+
nullable bool
4547
}
4648

4749
func AggregateTypes(types []sqltypes.Type) sqltypes.Type {
@@ -63,6 +65,7 @@ func (ta *typeAggregation) addEval(e eval) {
6365
switch e := e.(type) {
6466
case nil:
6567
t = sqltypes.Null
68+
ta.nullable = true
6669
case *evalBytes:
6770
t = sqltypes.Type(e.tt)
6871
f = e.flag
@@ -73,6 +76,9 @@ func (ta *typeAggregation) addEval(e eval) {
7376
}
7477

7578
func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
79+
if f&flagNullable != 0 {
80+
ta.nullable = true
81+
}
7682
switch tt {
7783
case sqltypes.Float32, sqltypes.Float64:
7884
ta.double++

go/vt/vtgate/evalengine/compiler_asm.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
528528
asm.emit(func(env *ExpressionEnv) int {
529529
end := env.vm.sp - elseOffset
530530
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
531-
if env.vm.stack[sp].(*evalInt64).i != 0 {
531+
if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 {
532532
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now)
533533
goto done
534534
}
@@ -782,16 +782,18 @@ func (asm *assembler) Convert_bB(offset int) {
782782
var f float64
783783
if arg != nil {
784784
f, _ = fastparse.ParseFloat64(arg.(*evalBytes).string())
785+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
785786
}
786-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0)
787787
return 1
788788
}, "CONV VARBINARY(SP-%d), BOOL", offset)
789789
}
790790

791791
func (asm *assembler) Convert_TB(offset int) {
792792
asm.emit(func(env *ExpressionEnv) int {
793793
arg := env.vm.stack[env.vm.sp-offset]
794-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalTemporal).isZero())
794+
if arg != nil {
795+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalTemporal).isZero())
796+
}
795797
return 1
796798
}, "CONV SQLTYPES(SP-%d), BOOL", offset)
797799
}
@@ -839,7 +841,9 @@ func (asm *assembler) Convert_Tj(offset int) {
839841
func (asm *assembler) Convert_dB(offset int) {
840842
asm.emit(func(env *ExpressionEnv) int {
841843
arg := env.vm.stack[env.vm.sp-offset]
842-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero())
844+
if arg != nil {
845+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalDecimal).dec.IsZero())
846+
}
843847
return 1
844848
}, "CONV DECIMAL(SP-%d), BOOL", offset)
845849
}
@@ -859,7 +863,9 @@ func (asm *assembler) Convert_dbit(offset int) {
859863
func (asm *assembler) Convert_fB(offset int) {
860864
asm.emit(func(env *ExpressionEnv) int {
861865
arg := env.vm.stack[env.vm.sp-offset]
862-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0)
866+
if arg != nil {
867+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalFloat).f != 0.0)
868+
}
863869
return 1
864870
}, "CONV FLOAT64(SP-%d), BOOL", offset)
865871
}
@@ -917,7 +923,9 @@ func (asm *assembler) Convert_Tf(offset int) {
917923
func (asm *assembler) Convert_iB(offset int) {
918924
asm.emit(func(env *ExpressionEnv) int {
919925
arg := env.vm.stack[env.vm.sp-offset]
920-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0)
926+
if arg != nil {
927+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalInt64).i != 0)
928+
}
921929
return 1
922930
}, "CONV INT64(SP-%d), BOOL", offset)
923931
}
@@ -997,7 +1005,9 @@ func (asm *assembler) Convert_Nj(offset int) {
9971005
func (asm *assembler) Convert_uB(offset int) {
9981006
asm.emit(func(env *ExpressionEnv) int {
9991007
arg := env.vm.stack[env.vm.sp-offset]
1000-
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0)
1008+
if arg != nil {
1009+
env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalUint64).u != 0)
1010+
}
10011011
return 1
10021012
}, "CONV UINT64(SP-%d), BOOL", offset)
10031013
}

go/vt/vtgate/evalengine/compiler_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,38 @@ func TestCompilerSingle(t *testing.T) {
532532
expression: `UNIX_TIMESTAMP('20000101103458.111111') + 1`,
533533
result: `DECIMAL(946719299.111111)`,
534534
},
535+
{
536+
expression: `cast(null * 1 as CHAR)`,
537+
result: `NULL`,
538+
},
539+
{
540+
expression: `cast(null + 1 as CHAR)`,
541+
result: `NULL`,
542+
},
543+
{
544+
expression: `cast(null - 1 as CHAR)`,
545+
result: `NULL`,
546+
},
547+
{
548+
expression: `cast(null / 1 as CHAR)`,
549+
result: `NULL`,
550+
},
551+
{
552+
expression: `cast(null % 1 as CHAR)`,
553+
result: `NULL`,
554+
},
555+
{
556+
expression: `1 AND NULL * 1`,
557+
result: `NULL`,
558+
},
559+
{
560+
expression: `case 0 when NULL then 1 else 0 end`,
561+
result: `INT64(0)`,
562+
},
563+
{
564+
expression: `case when null is null then 23 else null end`,
565+
result: `INT64(23)`,
566+
},
535567
}
536568

537569
tz, _ := time.LoadLocation("Europe/Madrid")

go/vt/vtgate/evalengine/expr_arithmetic.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) {
127127
}
128128

129129
c.asm.jumpDestination(skip1, skip2)
130-
return ctype{Type: sumtype, Col: collationNumeric}, nil
130+
return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
131131
}
132132

133133
func (op *opArithSub) eval(left, right eval) (eval, error) {
@@ -210,7 +210,7 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) {
210210
}
211211

212212
c.asm.jumpDestination(skip1, skip2)
213-
return ctype{Type: subtype, Col: collationNumeric}, nil
213+
return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
214214
}
215215

216216
func (op *opArithMul) eval(left, right eval) (eval, error) {
@@ -270,7 +270,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) {
270270
}
271271

272272
c.asm.jumpDestination(skip1, skip2)
273-
return ctype{Type: multype, Col: collationNumeric}, nil
273+
return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
274274
}
275275

276276
func (op *opArithDiv) eval(left, right eval) (eval, error) {
@@ -525,9 +525,13 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) {
525525
c.asm.jumpDestination(skip)
526526
return ctype{
527527
Type: neg,
528-
Flag: arg.Flag & (flagNull | flagNullable),
528+
Flag: nullableFlags(arg.Flag),
529529
Size: arg.Size,
530530
Scale: arg.Scale,
531531
Col: collationNumeric,
532532
}, nil
533533
}
534+
535+
func nullableFlags(flag typeFlag) typeFlag {
536+
return flag & (flagNull | flagNullable)
537+
}

go/vt/vtgate/evalengine/expr_bit.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func (expr *BitwiseExpr) compileBinary(c *compiler, asm_ins_bb, asm_ins_uu func(
270270

271271
asm_ins_uu()
272272
c.asm.jumpDestination(skip1, skip2)
273-
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
273+
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
274274
}
275275

276276
func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
@@ -299,8 +299,8 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
299299
return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil
300300
}
301301

302-
_ = c.compileToBitwiseUint64(lt, 2)
303-
_ = c.compileToUint64(rt, 1)
302+
lt = c.compileToBitwiseUint64(lt, 2)
303+
rt = c.compileToUint64(rt, 1)
304304

305305
if i < 0 {
306306
c.asm.BitShiftLeft_uu()
@@ -309,7 +309,7 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) {
309309
}
310310

311311
c.asm.jumpDestination(skip1, skip2)
312-
return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil
312+
return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil
313313
}
314314

315315
func (expr *BitwiseExpr) compile(c *compiler) (ctype, error) {

go/vt/vtgate/evalengine/expr_bvar.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) {
106106
case sqltypes.Null:
107107
return ctype{Type: sqltypes.Null, Flag: flagNull | flagNullable, Col: collationNull}, nil
108108
case sqltypes.HexNum, sqltypes.HexVal:
109-
return ctype{Type: sqltypes.VarBinary, Flag: flagHex, Col: collationNumeric}, nil
109+
return ctype{Type: sqltypes.VarBinary, Flag: flagHex | flagNullable, Col: collationNumeric}, nil
110110
case sqltypes.BitNum:
111-
return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil
111+
return ctype{Type: sqltypes.VarBinary, Flag: flagBit | flagNullable, Col: collationNumeric}, nil
112112
default:
113-
return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
113+
return ctype{Type: tt, Flag: flagNullable, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
114114
}
115115
}
116116

go/vt/vtgate/evalengine/expr_compare.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,13 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {
365365

366366
swapped := false
367367
var skip2 *jump
368+
nullable := true
368369

369370
switch expr.Op.(type) {
370371
case compareNullSafeEQ:
371372
skip2 = c.asm.jumpFrom()
372373
c.asm.Cmp_nullsafe(skip2)
374+
nullable = false
373375
default:
374376
skip2 = c.compileNullCheck1r(rt)
375377
}
@@ -407,6 +409,9 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) {
407409
}
408410

409411
cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}
412+
if nullable {
413+
cmptype.Flag |= nullableFlags(lt.Flag | rt.Flag)
414+
}
410415

411416
switch expr.Op.(type) {
412417
case compareEQ:
@@ -540,16 +545,18 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) {
540545

541546
switch rhs := expr.Right.(type) {
542547
case TupleExpr:
548+
var rt ctype
543549
if table := expr.compileTable(lhs, rhs); table != nil {
544550
c.asm.In_table(expr.Negate, table)
545551
} else {
546-
_, err := rhs.compile(c)
552+
rt, err = rhs.compile(c)
547553
if err != nil {
548554
return ctype{}, err
549555
}
550556
c.asm.In_slow(expr.Negate)
551557
}
552-
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil
558+
559+
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil
553560
case *BindVariable:
554561
return ctype{}, c.unsupported(expr)
555562
default:

go/vt/vtgate/evalengine/expr_logical.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ func (expr *NotExpr) compile(c *compiler) (ctype, error) {
379379
c.asm.Not_i()
380380
}
381381
c.asm.jumpDestination(skip)
382-
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
382+
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil
383383
}
384384

385385
func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) {
@@ -450,7 +450,7 @@ func (expr *LogicalExpr) compile(c *compiler) (ctype, error) {
450450

451451
expr.op.compileRight(c)
452452
c.asm.jumpDestination(jump)
453-
return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil
453+
return ctype{Type: sqltypes.Int64, Flag: ((lt.Flag | rt.Flag) & flagNullable) | flagIsBoolean, Col: collationNumeric}, nil
454454
}
455455

456456
func intervalCompare(n, val eval) (int, bool, error) {
@@ -711,7 +711,11 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
711711
}
712712
}
713713

714-
ct := ctype{Type: ta.result(), Col: ca.result()}
714+
var f typeFlag
715+
if ta.nullable {
716+
f |= flagNullable
717+
}
718+
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
715719
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col)
716720
return ct, nil
717721
}

go/vt/vtgate/evalengine/expr_tuple.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ func (tuple TupleExpr) FormatFast(buf *sqlparser.TrackedBuffer) {
6666
}
6767

6868
func (tuple TupleExpr) typeof(*ExpressionEnv) (ctype, error) {
69-
return ctype{Type: sqltypes.Tuple}, nil
69+
return ctype{Type: sqltypes.Tuple, Col: collationBinary}, nil
7070
}

go/vt/vtgate/evalengine/fn_base64.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) {
110110
c.asm.Fn_TO_BASE64(t, col)
111111
c.asm.jumpDestination(skip)
112112

113-
return ctype{Type: t, Col: col}, nil
113+
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil
114114
}
115115

116116
func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) {
@@ -155,5 +155,5 @@ func (call *builtinFromBase64) compile(c *compiler) (ctype, error) {
155155
c.asm.Fn_FROM_BASE64(t)
156156
c.asm.jumpDestination(skip)
157157

158-
return ctype{Type: t, Col: collationBinary}, nil
158+
return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: collationBinary}, nil
159159
}

go/vt/vtgate/evalengine/fn_bit.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ func (expr *builtinBitCount) compile(c *compiler) (ctype, error) {
6161
if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() {
6262
c.asm.BitCount_b()
6363
c.asm.jumpDestination(skip)
64-
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
64+
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
6565
}
6666

6767
_ = c.compileToBitwiseUint64(ct, 1)
6868
c.asm.BitCount_u()
6969
c.asm.jumpDestination(skip)
70-
return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil
70+
return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil
7171
}

0 commit comments

Comments
 (0)