diff --git a/go/mysql/collations/integration/main_test.go b/go/mysql/collations/integration/main_test.go index f0d8d4dfdfa..23c6f8d2716 100644 --- a/go/mysql/collations/integration/main_test.go +++ b/go/mysql/collations/integration/main_test.go @@ -47,7 +47,7 @@ func mysqlconn(t *testing.T) *mysql.Conn { if err != nil { t.Fatal(err) } - if !strings.HasPrefix(conn.ServerVersion, "8.0.") { + if !strings.HasPrefix(conn.ServerVersion, "8.") { conn.Close() t.Skipf("collation integration tests are only supported in MySQL 8.0+") } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 105ee3f4721..d62a164b55a 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -528,7 +528,7 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll asm.emit(func(env *ExpressionEnv) int { end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { - if env.vm.stack[sp].(*evalInt64).i != 0 { + if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation) goto done } @@ -782,8 +782,8 @@ func (asm *assembler) Convert_bB(offset int) { var f float64 if arg != nil { f, _ = fastparse.ParseFloat64(arg.(*evalBytes).string()) + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0) } - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0) return 1 }, "CONV VARBINARY(SP-%d), BOOL", offset) } @@ -791,7 +791,9 @@ func (asm *assembler) Convert_bB(offset int) { func (asm *assembler) Convert_TB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalTemporal).isZero()) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalTemporal).isZero()) + } return 1 }, "CONV SQLTYPES(SP-%d), BOOL", offset) } @@ -839,7 +841,9 @@ func (asm *assembler) Convert_Tj(offset int) { func (asm *assembler) Convert_dB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalDecimal).dec.IsZero()) + } return 1 }, "CONV DECIMAL(SP-%d), BOOL", offset) } @@ -859,7 +863,9 @@ func (asm *assembler) Convert_dbit(offset int) { func (asm *assembler) Convert_fB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalFloat).f != 0.0) + } return 1 }, "CONV FLOAT64(SP-%d), BOOL", offset) } @@ -906,7 +912,9 @@ func (asm *assembler) Convert_Tf(offset int) { func (asm *assembler) Convert_iB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalInt64).i != 0) + } return 1 }, "CONV INT64(SP-%d), BOOL", offset) } @@ -986,7 +994,9 @@ func (asm *assembler) Convert_Nj(offset int) { func (asm *assembler) Convert_uB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalUint64).u != 0) + } return 1 }, "CONV UINT64(SP-%d), BOOL", offset) } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index efcf0036acb..06dd679920e 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -455,6 +455,38 @@ func TestCompilerSingle(t *testing.T) { expression: `WEIGHT_STRING('foobar' as char(3))`, result: `VARBINARY("\x1c\xe5\x1d\xdd\x1d\xdd")`, }, + { + expression: `cast(null * 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null + 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null - 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null / 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null % 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `1 AND NULL * 1`, + result: `NULL`, + }, + { + expression: `case 0 when NULL then 1 else 0 end`, + result: `INT64(0)`, + }, + { + expression: `case when null is null then 23 else null end`, + result: `INT64(23)`, + }, } for _, tc := range testCases { diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 50326c9eb3c..3622a270e9a 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -191,7 +191,7 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sumtype, Col: collationNumeric}, nil + return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithSub) eval(left, right eval) (eval, error) { @@ -274,7 +274,7 @@ func (op *opArithSub) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: subtype, Col: collationNumeric}, nil + return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithMul) eval(left, right eval) (eval, error) { @@ -334,7 +334,7 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: multype, Col: collationNumeric}, nil + return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithDiv) eval(left, right eval) (eval, error) { @@ -604,5 +604,9 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip) - return ctype{Type: neg, Col: collationNumeric}, nil + return ctype{Type: neg, Flag: nullableFlags(arg.Flag), Col: collationNumeric}, nil +} + +func nullableFlags(flag typeFlag) typeFlag { + return flag & (flagNull | flagNullable) } diff --git a/go/vt/vtgate/evalengine/expr_bit.go b/go/vt/vtgate/evalengine/expr_bit.go index 9da632eed30..0c6cf56367d 100644 --- a/go/vt/vtgate/evalengine/expr_bit.go +++ b/go/vt/vtgate/evalengine/expr_bit.go @@ -298,7 +298,7 @@ func (expr *BitwiseExpr) compileBinary(c *compiler, asm_ins_bb, asm_ins_uu func( asm_ins_uu() c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { @@ -327,8 +327,8 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } - _ = c.compileToBitwiseUint64(lt, 2) - _ = c.compileToUint64(rt, 1) + lt = c.compileToBitwiseUint64(lt, 2) + rt = c.compileToUint64(rt, 1) if i < 0 { c.asm.BitShiftLeft_uu() @@ -337,7 +337,7 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (expr *BitwiseExpr) compile(c *compiler) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 9172f8abc3c..98d70ae0f42 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -91,9 +91,9 @@ func (bv *BindVariable) typeof(env *ExpressionEnv, _ []*querypb.Field) (sqltypes case sqltypes.Null: return sqltypes.Null, flagNull | flagNullable case sqltypes.HexNum, sqltypes.HexVal: - return sqltypes.VarBinary, flagHex + return sqltypes.VarBinary, flagHex | flagNullable default: - return tt, 0 + return tt, flagNullable } } diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index e7490370a1b..c98802818a8 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -373,11 +373,13 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) { swapped := false var skip2 *jump + nullable := true switch expr.Op.(type) { case compareNullSafeEQ: skip2 = c.asm.jumpFrom() c.asm.Cmp_nullsafe(skip2) + nullable = false default: skip2 = c.compileNullCheck1r(rt) } @@ -415,6 +417,9 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) { } cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean} + if nullable { + cmptype.Flag |= nullableFlags(lt.Flag | rt.Flag) + } switch expr.Op.(type) { case compareEQ: @@ -554,16 +559,17 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) { rhs := expr.Right.(TupleExpr) + var rt ctype if table := expr.compileTable(lhs, rhs); table != nil { c.asm.In_table(expr.Negate, table) } else { - _, err := rhs.compile(c) + rt, err = rhs.compile(c) if err != nil { return ctype{}, err } c.asm.In_slow(expr.Negate) } - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil } func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool { diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 189b68e4136..3b2cd4a6ae5 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -254,7 +254,7 @@ func (expr *NotExpr) compile(c *compiler) (ctype, error) { c.asm.Not_i() } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) { @@ -331,7 +331,7 @@ func (expr *LogicalExpr) compile(c *compiler) (ctype, error) { c.asm.LogicalRight(expr.opname) c.asm.jumpDestination(jump) - return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: ((lt.Flag | rt.Flag) & flagNullable) | flagIsBoolean, Col: collationNumeric}, nil } func intervalCompare(n, val eval) (int, bool, error) { @@ -629,7 +629,11 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } } - ct := ctype{Type: ta.result(), Col: ca.result()} + var f typeFlag + if ta.nullable { + f |= flagNullable + } + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_base64.go b/go/vt/vtgate/evalengine/fn_base64.go index 0e27052641f..6beb7b0209c 100644 --- a/go/vt/vtgate/evalengine/fn_base64.go +++ b/go/vt/vtgate/evalengine/fn_base64.go @@ -119,7 +119,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) { c.asm.Fn_TO_BASE64(t, col) c.asm.jumpDestination(skip) - return ctype{Type: t, Col: col}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil } func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) { @@ -172,5 +172,5 @@ func (call *builtinFromBase64) compile(c *compiler) (ctype, error) { c.asm.Fn_FROM_BASE64(t) c.asm.jumpDestination(skip) - return ctype{Type: t, Col: collationBinary}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: collationBinary}, nil } diff --git a/go/vt/vtgate/evalengine/fn_bit.go b/go/vt/vtgate/evalengine/fn_bit.go index 5a89ff41276..2aba7fa2549 100644 --- a/go/vt/vtgate/evalengine/fn_bit.go +++ b/go/vt/vtgate/evalengine/fn_bit.go @@ -68,11 +68,11 @@ func (expr *builtinBitCount) compile(c *compiler) (ctype, error) { if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { c.asm.BitCount_b() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil } _ = c.compileToBitwiseUint64(ct, 1) c.asm.BitCount_u() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil } diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index ee4f61cb596..ea61d4524a8 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -327,7 +327,9 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, env := collations.Local() var ca collationAggregation + var f typeFlag for _, arg := range args { + f |= nullableFlags(arg.Flag) if err := ca.add(env, arg.Col); err != nil { return ctype{}, err } @@ -335,15 +337,17 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, tc := ca.result() c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, tc) - return ctype{Type: sqltypes.VarChar, Col: tc}, nil + return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil } func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, error) { + var f typeFlag for i, tt := range args { + f |= nullableFlags(tt.Flag) c.compileToDecimal(tt, len(args)-i) } c.asm.Fn_MULTICMP_d(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric}, nil } func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { @@ -355,6 +359,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { text int binary int args []ctype + nullable bool ) /* @@ -374,6 +379,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { args = append(args, tt) + nullable = nullable || tt.nullable() switch tt.Type { case sqltypes.Int64: signed++ @@ -387,19 +393,25 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { text++ case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + case sqltypes.Null: + nullable = true default: return ctype{}, c.unsupported(call) } } + var f typeFlag + if nullable { + f |= flagNullable + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: f, Col: collationNumeric}, nil } if unsigned == len(args) { c.asm.Fn_MULTICMP_u(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: f, Col: collationNumeric}, nil } return call.compile_d(c, args) } @@ -408,14 +420,14 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { return call.compile_c(c, args) } c.asm.Fn_MULTICMP_b(len(args), call.cmp < 0) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + return ctype{Type: sqltypes.VarBinary, Flag: f, Col: collationBinary}, nil } else { if floats > 0 { for i, tt := range args { c.compileToFloat(tt, len(args)-i) } c.asm.Fn_MULTICMP_f(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Flag: f, Col: collationNumeric}, nil } if decimals > 0 { return call.compile_d(c, args) @@ -448,9 +460,14 @@ type typeAggregation struct { geometry uint16 blob uint16 total uint16 + + nullable bool } func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { + if f&flagNullable != 0 { + ta.nullable = true + } switch tt { case sqltypes.Float32, sqltypes.Float64: ta.double++ diff --git a/go/vt/vtgate/evalengine/fn_crypto.go b/go/vt/vtgate/evalengine/fn_crypto.go index 8b3eeac2f99..491fa7bcf68 100644 --- a/go/vt/vtgate/evalengine/fn_crypto.go +++ b/go/vt/vtgate/evalengine/fn_crypto.go @@ -74,7 +74,7 @@ func (call *builtinMD5) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_MD5(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinSHA1 struct { @@ -121,7 +121,7 @@ func (call *builtinSHA1) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_SHA1(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinSHA2 struct { @@ -205,7 +205,7 @@ func (call *builtinSHA2) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_SHA2(col) c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag | flagNullable}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinRandomBytes struct { @@ -265,5 +265,5 @@ func (call *builtinRandomBytes) compile(c *compiler) (ctype, error) { c.asm.Fn_RandomBytes() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary, Flag: arg.Flag | flagNullable}, nil + return ctype{Type: sqltypes.VarBinary, Col: collationBinary, Flag: nullableFlags(arg.Flag) | flagNullable}, nil } diff --git a/go/vt/vtgate/evalengine/fn_hex.go b/go/vt/vtgate/evalengine/fn_hex.go index 0045bfd6688..ca84921e049 100644 --- a/go/vt/vtgate/evalengine/fn_hex.go +++ b/go/vt/vtgate/evalengine/fn_hex.go @@ -88,7 +88,7 @@ func (call *builtinHex) compile(c *compiler) (ctype, error) { c.asm.jumpDestination(skip) - return ctype{Type: t, Col: col}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil } type builtinUnhex struct { diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index 7c7c6a67f8d..ed7f478f32b 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -431,7 +431,7 @@ func (call *builtinJSONContainsPath) compile(c *compiler) (ctype, error) { } c.asm.Fn_JSON_CONTAINS_PATH(match, paths) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | flagNullable}, nil } type jsonMatch int8 diff --git a/go/vt/vtgate/evalengine/fn_misc.go b/go/vt/vtgate/evalengine/fn_misc.go index 04770c387af..966017edbc7 100644 --- a/go/vt/vtgate/evalengine/fn_misc.go +++ b/go/vt/vtgate/evalengine/fn_misc.go @@ -319,7 +319,7 @@ func (call *builtinIsIPV4) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_IPV4() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV4Compat) eval(env *ExpressionEnv) (eval, error) { @@ -354,7 +354,7 @@ func (call *builtinIsIPV4Compat) compile(c *compiler) (ctype, error) { c.asm.SetBool(1, false) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV4Mapped) eval(env *ExpressionEnv) (eval, error) { @@ -389,7 +389,7 @@ func (call *builtinIsIPV4Mapped) compile(c *compiler) (ctype, error) { c.asm.SetBool(1, false) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV6) eval(env *ExpressionEnv) (eval, error) { @@ -425,7 +425,7 @@ func (call *builtinIsIPV6) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_IPV6() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func errIncorrectUUID(in []byte, f string) error { @@ -502,7 +502,7 @@ func (call *builtinBinToUUID) compile(c *compiler) (ctype, error) { } col := defaultCoercionCollation(call.collate) - ct := ctype{Type: sqltypes.VarChar, Flag: arg.Flag, Col: col} + ct := ctype{Type: sqltypes.VarChar, Flag: nullableFlags(arg.Flag), Col: col} if len(call.Arguments) == 1 { c.asm.Fn_BIN_TO_UUID0(col) @@ -561,7 +561,7 @@ func (call *builtinIsUUID) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_UUID() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinUUID) eval(env *ExpressionEnv) (eval, error) { @@ -636,7 +636,7 @@ func (call *builtinUUIDToBin) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } - ct := ctype{Type: sqltypes.VarBinary, Flag: arg.Flag, Col: collationBinary} + ct := ctype{Type: sqltypes.VarBinary, Flag: nullableFlags(arg.Flag), Col: collationBinary} if len(call.Arguments) == 1 { c.asm.Fn_UUID_TO_BIN0() diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 3802fbd5630..a69a6e35a86 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -185,7 +185,7 @@ func (expr *builtinAbs) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) - convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: arg.Flag} + convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: nullableFlags(arg.Flag)} switch arg.Type { case sqltypes.Int64: c.asm.Fn_ABS_i() @@ -357,7 +357,7 @@ func (expr *builtinAtan2) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_ATAN2() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } func (call *builtinAtan2) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) { @@ -643,7 +643,7 @@ func (expr *builtinLog) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_LOG() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } type builtinLog10 struct { @@ -758,7 +758,7 @@ func (expr *builtinPow) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_POW() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag | flagNullable}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } type builtinSign struct { @@ -843,7 +843,7 @@ func (expr *builtinSign) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(arg.Flag)}, nil } type builtinSqrt struct { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index b34618b00d2..ef3d9037dd2 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -271,7 +271,7 @@ func (call *builtinASCII) compile(c *compiler) (ctype, error) { c.asm.Fn_ASCII() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: str.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(str.Flag)}, nil } func charOrd(b []byte, coll collations.ID) int64 { @@ -331,7 +331,7 @@ func (call *builtinOrd) compile(c *compiler) (ctype, error) { c.asm.Fn_ORD(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: str.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(str.Flag)}, nil } // maxRepeatLength is the maximum number of times a string can be repeated. @@ -784,7 +784,7 @@ func (call builtinPad) compile(c *compiler) (ctype, error) { c.asm.Fn_RPAD(col) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } func strcmpCollate(left, right []byte, col collations.ID) int64 { @@ -886,7 +886,7 @@ func (expr *builtinStrcmp) compile(c *compiler) (ctype, error) { c.asm.Strcmp(mcol) c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(lt.Flag | rt.Flag)}, nil } func (call builtinTrim) eval(env *ExpressionEnv) (eval, error) { @@ -975,7 +975,7 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) { c.asm.Fn_TRIM1(col) } c.asm.jumpDestination(skip1) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: nullableFlags(str.Flag), Col: col}, nil } pat, err := call.Arguments[1].compile(c) @@ -1006,7 +1006,7 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } type builtinConcat struct { diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index 245318529c3..6172f0cd260 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -92,6 +92,8 @@ var inputConversions = []string{ `'foobar'`, `_utf8 'foobar'`, `''`, `_binary 'foobar'`, `0x0`, `0x1`, `0xff`, `X'00'`, `X'01'`, `X'ff'`, "NULL", "true", "false", + "NULL * 1", "1 * NULL", "NULL * NULL", "NULL / 1", "1 / NULL", "NULL / NULL", + "NULL + 1", "1 + NULL", "NULL + NULL", "NULL - 1", "1 - NULL", "NULL - NULL", "0xFF666F6F626172FF", "0x666F6F626172FF", "0xFF666F6F626172", "9223372036854775807", "-9223372036854775808", "18446744073709551615", "18446744073709540000e0", diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 829cd844fd1..10440fc4af5 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -753,7 +753,7 @@ func TestSelectLastInsertId(t *testing.T) { result, err := executorExec(ctx, executor, session, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "last_insert_id()", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "last_insert_id()", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewUint64(52), @@ -785,20 +785,20 @@ func TestSelectSystemVariables(t *testing.T) { result, err := executorExec(ctx, executor, session, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@client_found_rows", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@skip_query_plan_cache", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@sql_select_limit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@transaction_mode", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@workload", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@read_after_write_gtid", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@read_after_write_timeout", Type: sqltypes.Float64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@session_track_gtids", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@ddl_strategy", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@migration_context", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@socket", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, - {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@client_found_rows", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@skip_query_plan_cache", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@sql_select_limit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@transaction_mode", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@workload", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@read_after_write_gtid", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@read_after_write_timeout", Type: sqltypes.Float64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@session_track_gtids", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@ddl_strategy", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@migration_context", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@socket", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, + {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ // the following are the uninitialised session values @@ -841,9 +841,9 @@ func TestSelectInitializedVitessAwareVariable(t *testing.T) { result, err := executorExec(ctx, executor, session, sql, nil) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), @@ -882,7 +882,7 @@ func TestSelectUserDefinedVariable(t *testing.T) { require.NoError(t, err) wantResult = &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "@foo", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewVarChar("bar"), @@ -908,7 +908,7 @@ func TestFoundRows(t *testing.T) { result, err := executorExec(ctx, executor, session, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "found_rows()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "found_rows()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), @@ -941,7 +941,7 @@ func testRowCount(t *testing.T, ctx context.Context, executor *Executor, session result, err := executorExec(ctx, executor, session, "select row_count()", map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "row_count()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "row_count()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(wantRowCount), @@ -962,7 +962,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { result1 := []*sqltypes.Result{{ Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, InsertID: 0, Rows: [][]sqltypes.Value{{ @@ -976,7 +976,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt32(52), @@ -1044,7 +1044,7 @@ func TestLastInsertIDInSubQueryExpression(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "x", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "x", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewUint64(12345), @@ -1075,7 +1075,7 @@ func TestSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "database()", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewVarChar("TestExecutor@primary"),