diff --git a/go/mysql/collations/integration/main_test.go b/go/mysql/collations/integration/main_test.go index f0d8d4dfdfa..23c6f8d2716 100644 --- a/go/mysql/collations/integration/main_test.go +++ b/go/mysql/collations/integration/main_test.go @@ -47,7 +47,7 @@ func mysqlconn(t *testing.T) *mysql.Conn { if err != nil { t.Fatal(err) } - if !strings.HasPrefix(conn.ServerVersion, "8.0.") { + if !strings.HasPrefix(conn.ServerVersion, "8.") { conn.Close() t.Skipf("collation integration tests are only supported in MySQL 8.0+") } diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index cb6e71cc4fa..6b97104d1ed 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -526,7 +526,7 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll asm.emit(func(env *ExpressionEnv) int { end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { - if env.vm.stack[sp].(*evalInt64).i != 0 { + if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation) goto done } @@ -780,8 +780,8 @@ func (asm *assembler) Convert_bB(offset int) { var f float64 if arg != nil { f, _ = fastparse.ParseFloat64(arg.(*evalBytes).string()) + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0) } - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(f != 0.0) return 1 }, "CONV VARBINARY(SP-%d), BOOL", offset) } @@ -789,7 +789,9 @@ func (asm *assembler) Convert_bB(offset int) { func (asm *assembler) Convert_TB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalTemporal).isZero()) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalTemporal).isZero()) + } return 1 }, "CONV SQLTYPES(SP-%d), BOOL", offset) } @@ -837,7 +839,9 @@ func (asm *assembler) Convert_Tj(offset int) { func (asm *assembler) Convert_dB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && !arg.(*evalDecimal).dec.IsZero()) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(!arg.(*evalDecimal).dec.IsZero()) + } return 1 }, "CONV DECIMAL(SP-%d), BOOL", offset) } @@ -857,7 +861,9 @@ func (asm *assembler) Convert_dbit(offset int) { func (asm *assembler) Convert_fB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalFloat).f != 0.0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalFloat).f != 0.0) + } return 1 }, "CONV FLOAT64(SP-%d), BOOL", offset) } @@ -904,7 +910,9 @@ func (asm *assembler) Convert_Tf(offset int) { func (asm *assembler) Convert_iB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalInt64).i != 0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalInt64).i != 0) + } return 1 }, "CONV INT64(SP-%d), BOOL", offset) } @@ -984,7 +992,9 @@ func (asm *assembler) Convert_Nj(offset int) { func (asm *assembler) Convert_uB(offset int) { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-offset] - env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg != nil && arg.(*evalUint64).u != 0) + if arg != nil { + env.vm.stack[env.vm.sp-offset] = env.vm.arena.newEvalBool(arg.(*evalUint64).u != 0) + } return 1 }, "CONV UINT64(SP-%d), BOOL", offset) } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 9cc90d5a18f..4c21de2dde6 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -444,6 +444,38 @@ func TestCompilerSingle(t *testing.T) { expression: `INTERVAL(0, 0, 0, -1, NULL, NULL, 1)`, result: `INT64(5)`, }, + { + expression: `cast(null * 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null + 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null - 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null / 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `cast(null % 1 as CHAR)`, + result: `NULL`, + }, + { + expression: `1 AND NULL * 1`, + result: `NULL`, + }, + { + expression: `case 0 when NULL then 1 else 0 end`, + result: `INT64(0)`, + }, + { + expression: `case when null is null then 23 else null end`, + result: `INT64(23)`, + }, } for _, tc := range testCases { diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 50326c9eb3c..3622a270e9a 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -191,7 +191,7 @@ func (op *opArithAdd) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sumtype, Col: collationNumeric}, nil + return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithSub) eval(left, right eval) (eval, error) { @@ -274,7 +274,7 @@ func (op *opArithSub) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: subtype, Col: collationNumeric}, nil + return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithMul) eval(left, right eval) (eval, error) { @@ -334,7 +334,7 @@ func (op *opArithMul) compile(c *compiler, left, right Expr) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: multype, Col: collationNumeric}, nil + return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (op *opArithDiv) eval(left, right eval) (eval, error) { @@ -604,5 +604,9 @@ func (expr *NegateExpr) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip) - return ctype{Type: neg, Col: collationNumeric}, nil + return ctype{Type: neg, Flag: nullableFlags(arg.Flag), Col: collationNumeric}, nil +} + +func nullableFlags(flag typeFlag) typeFlag { + return flag & (flagNull | flagNullable) } diff --git a/go/vt/vtgate/evalengine/expr_bit.go b/go/vt/vtgate/evalengine/expr_bit.go index 9da632eed30..0c6cf56367d 100644 --- a/go/vt/vtgate/evalengine/expr_bit.go +++ b/go/vt/vtgate/evalengine/expr_bit.go @@ -298,7 +298,7 @@ func (expr *BitwiseExpr) compileBinary(c *compiler, asm_ins_bb, asm_ins_uu func( asm_ins_uu() c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { @@ -327,8 +327,8 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil } - _ = c.compileToBitwiseUint64(lt, 2) - _ = c.compileToUint64(rt, 1) + lt = c.compileToBitwiseUint64(lt, 2) + rt = c.compileToUint64(rt, 1) if i < 0 { c.asm.BitShiftLeft_uu() @@ -337,7 +337,7 @@ func (expr *BitwiseExpr) compileShift(c *compiler, i int) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil } func (expr *BitwiseExpr) compile(c *compiler) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 387438c4310..e2c1503b754 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -99,9 +99,9 @@ func (bv *BindVariable) typeof(env *ExpressionEnv, _ []*querypb.Field) (sqltypes case sqltypes.Null: return sqltypes.Null, flagNull | flagNullable case sqltypes.HexNum, sqltypes.HexVal: - return sqltypes.VarBinary, flagHex + return sqltypes.VarBinary, flagHex | flagNullable default: - return tt, 0 + return tt, flagNullable } } diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index cef7493e026..f5329e8973c 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -350,11 +350,13 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) { swapped := false var skip2 *jump + nullable := true switch expr.Op.(type) { case compareNullSafeEQ: skip2 = c.asm.jumpFrom() c.asm.Cmp_nullsafe(skip2) + nullable = false default: skip2 = c.compileNullCheck1r(rt) } @@ -392,6 +394,9 @@ func (expr *ComparisonExpr) compile(c *compiler) (ctype, error) { } cmptype := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean} + if nullable { + cmptype.Flag |= nullableFlags(lt.Flag | rt.Flag) + } switch expr.Op.(type) { case compareEQ: @@ -530,17 +535,17 @@ func (expr *InExpr) compile(c *compiler) (ctype, error) { } rhs := expr.Right.(TupleExpr) - + var rt ctype if table := expr.compileTable(lhs, rhs); table != nil { c.asm.In_table(expr.Negate, table) } else { - _, err := rhs.compile(c) + rt, err = rhs.compile(c) if err != nil { return ctype{}, err } c.asm.In_slow(expr.Negate) } - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | (nullableFlags(lhs.Flag) | (rt.Flag & flagNullable))}, nil } func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool { diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 189b68e4136..3b2cd4a6ae5 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -254,7 +254,7 @@ func (expr *NotExpr) compile(c *compiler) (ctype, error) { c.asm.Not_i() } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (l *LogicalExpr) eval(env *ExpressionEnv) (eval, error) { @@ -331,7 +331,7 @@ func (expr *LogicalExpr) compile(c *compiler) (ctype, error) { c.asm.LogicalRight(expr.opname) c.asm.jumpDestination(jump) - return ctype{Type: sqltypes.Int64, Flag: flagNullable | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: ((lt.Flag | rt.Flag) & flagNullable) | flagIsBoolean, Col: collationNumeric}, nil } func intervalCompare(n, val eval) (int, bool, error) { @@ -629,7 +629,11 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } } - ct := ctype{Type: ta.result(), Col: ca.result()} + var f typeFlag + if ta.nullable { + f |= flagNullable + } + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_base64.go b/go/vt/vtgate/evalengine/fn_base64.go index 0e27052641f..6beb7b0209c 100644 --- a/go/vt/vtgate/evalengine/fn_base64.go +++ b/go/vt/vtgate/evalengine/fn_base64.go @@ -119,7 +119,7 @@ func (call *builtinToBase64) compile(c *compiler) (ctype, error) { c.asm.Fn_TO_BASE64(t, col) c.asm.jumpDestination(skip) - return ctype{Type: t, Col: col}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil } func (call *builtinFromBase64) eval(env *ExpressionEnv) (eval, error) { @@ -172,5 +172,5 @@ func (call *builtinFromBase64) compile(c *compiler) (ctype, error) { c.asm.Fn_FROM_BASE64(t) c.asm.jumpDestination(skip) - return ctype{Type: t, Col: collationBinary}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: collationBinary}, nil } diff --git a/go/vt/vtgate/evalengine/fn_bit.go b/go/vt/vtgate/evalengine/fn_bit.go index 5a89ff41276..2aba7fa2549 100644 --- a/go/vt/vtgate/evalengine/fn_bit.go +++ b/go/vt/vtgate/evalengine/fn_bit.go @@ -68,11 +68,11 @@ func (expr *builtinBitCount) compile(c *compiler) (ctype, error) { if ct.Type == sqltypes.VarBinary && !ct.isHexOrBitLiteral() { c.asm.BitCount_b() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil } _ = c.compileToBitwiseUint64(ct, 1) c.asm.BitCount_u() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationBinary}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(ct.Flag), Col: collationBinary}, nil } diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index ef8d7f4e3c8..bfd69c7252d 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -326,7 +326,9 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, env := collations.Local() var ca collationAggregation + var f typeFlag for _, arg := range args { + f |= nullableFlags(arg.Flag) if err := ca.add(env, arg.Col); err != nil { return ctype{}, err } @@ -334,15 +336,17 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, tc := ca.result() c.asm.Fn_MULTICMP_c(len(args), call.cmp < 0, tc) - return ctype{Type: sqltypes.VarChar, Col: tc}, nil + return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil } func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, error) { + var f typeFlag for i, tt := range args { + f |= nullableFlags(tt.Flag) c.compileToDecimal(tt, len(args)-i) } c.asm.Fn_MULTICMP_d(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric}, nil } func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { @@ -354,6 +358,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { text int binary int args []ctype + nullable bool ) /* @@ -373,6 +378,7 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { args = append(args, tt) + nullable = nullable || tt.nullable() switch tt.Type { case sqltypes.Int64: signed++ @@ -386,19 +392,25 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { text++ case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + case sqltypes.Null: + nullable = true default: return ctype{}, c.unsupported(call) } } + var f typeFlag + if nullable { + f |= flagNullable + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: f, Col: collationNumeric}, nil } if unsigned == len(args) { c.asm.Fn_MULTICMP_u(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Uint64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Uint64, Flag: f, Col: collationNumeric}, nil } return call.compile_d(c, args) } @@ -407,14 +419,14 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { return call.compile_c(c, args) } c.asm.Fn_MULTICMP_b(len(args), call.cmp < 0) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + return ctype{Type: sqltypes.VarBinary, Flag: f, Col: collationBinary}, nil } else { if floats > 0 { for i, tt := range args { c.compileToFloat(tt, len(args)-i) } c.asm.Fn_MULTICMP_f(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Float64, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Float64, Flag: f, Col: collationNumeric}, nil } if decimals > 0 { return call.compile_d(c, args) @@ -447,9 +459,13 @@ type typeAggregation struct { geometry uint16 blob uint16 total uint16 + nullable bool } func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { + if f&flagNullable != 0 { + ta.nullable = true + } switch tt { case sqltypes.Float32, sqltypes.Float64: ta.double++ diff --git a/go/vt/vtgate/evalengine/fn_crypto.go b/go/vt/vtgate/evalengine/fn_crypto.go index 8b3eeac2f99..491fa7bcf68 100644 --- a/go/vt/vtgate/evalengine/fn_crypto.go +++ b/go/vt/vtgate/evalengine/fn_crypto.go @@ -74,7 +74,7 @@ func (call *builtinMD5) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_MD5(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinSHA1 struct { @@ -121,7 +121,7 @@ func (call *builtinSHA1) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_SHA1(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinSHA2 struct { @@ -205,7 +205,7 @@ func (call *builtinSHA2) compile(c *compiler) (ctype, error) { col := defaultCoercionCollation(c.cfg.Collation) c.asm.Fn_SHA2(col) c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.VarChar, Col: col, Flag: str.Flag | flagNullable}, nil + return ctype{Type: sqltypes.VarChar, Col: col, Flag: nullableFlags(str.Flag)}, nil } type builtinRandomBytes struct { @@ -265,5 +265,5 @@ func (call *builtinRandomBytes) compile(c *compiler) (ctype, error) { c.asm.Fn_RandomBytes() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary, Flag: arg.Flag | flagNullable}, nil + return ctype{Type: sqltypes.VarBinary, Col: collationBinary, Flag: nullableFlags(arg.Flag) | flagNullable}, nil } diff --git a/go/vt/vtgate/evalengine/fn_hex.go b/go/vt/vtgate/evalengine/fn_hex.go index 3aa395f74cc..a6e189bdf6f 100644 --- a/go/vt/vtgate/evalengine/fn_hex.go +++ b/go/vt/vtgate/evalengine/fn_hex.go @@ -88,7 +88,7 @@ func (call *builtinHex) compile(c *compiler) (ctype, error) { c.asm.jumpDestination(skip) - return ctype{Type: t, Col: col}, nil + return ctype{Type: t, Flag: nullableFlags(str.Flag), Col: col}, nil } type builtinUnhex struct { diff --git a/go/vt/vtgate/evalengine/fn_json.go b/go/vt/vtgate/evalengine/fn_json.go index 31f568a3df5..ff80964800f 100644 --- a/go/vt/vtgate/evalengine/fn_json.go +++ b/go/vt/vtgate/evalengine/fn_json.go @@ -431,7 +431,7 @@ func (call *builtinJSONContainsPath) compile(c *compiler) (ctype, error) { } c.asm.Fn_JSON_CONTAINS_PATH(match, paths) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | flagNullable}, nil } type jsonMatch int8 diff --git a/go/vt/vtgate/evalengine/fn_misc.go b/go/vt/vtgate/evalengine/fn_misc.go index 04770c387af..966017edbc7 100644 --- a/go/vt/vtgate/evalengine/fn_misc.go +++ b/go/vt/vtgate/evalengine/fn_misc.go @@ -319,7 +319,7 @@ func (call *builtinIsIPV4) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_IPV4() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV4Compat) eval(env *ExpressionEnv) (eval, error) { @@ -354,7 +354,7 @@ func (call *builtinIsIPV4Compat) compile(c *compiler) (ctype, error) { c.asm.SetBool(1, false) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV4Mapped) eval(env *ExpressionEnv) (eval, error) { @@ -389,7 +389,7 @@ func (call *builtinIsIPV4Mapped) compile(c *compiler) (ctype, error) { c.asm.SetBool(1, false) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinIsIPV6) eval(env *ExpressionEnv) (eval, error) { @@ -425,7 +425,7 @@ func (call *builtinIsIPV6) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_IPV6() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func errIncorrectUUID(in []byte, f string) error { @@ -502,7 +502,7 @@ func (call *builtinBinToUUID) compile(c *compiler) (ctype, error) { } col := defaultCoercionCollation(call.collate) - ct := ctype{Type: sqltypes.VarChar, Flag: arg.Flag, Col: col} + ct := ctype{Type: sqltypes.VarChar, Flag: nullableFlags(arg.Flag), Col: col} if len(call.Arguments) == 1 { c.asm.Fn_BIN_TO_UUID0(col) @@ -561,7 +561,7 @@ func (call *builtinIsUUID) compile(c *compiler) (ctype, error) { c.asm.Fn_IS_UUID() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Flag: arg.Flag | flagIsBoolean, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Int64, Flag: nullableFlags(arg.Flag) | flagIsBoolean, Col: collationNumeric}, nil } func (call *builtinUUID) eval(env *ExpressionEnv) (eval, error) { @@ -636,7 +636,7 @@ func (call *builtinUUIDToBin) compile(c *compiler) (ctype, error) { c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } - ct := ctype{Type: sqltypes.VarBinary, Flag: arg.Flag, Col: collationBinary} + ct := ctype{Type: sqltypes.VarBinary, Flag: nullableFlags(arg.Flag), Col: collationBinary} if len(call.Arguments) == 1 { c.asm.Fn_UUID_TO_BIN0() diff --git a/go/vt/vtgate/evalengine/fn_numeric.go b/go/vt/vtgate/evalengine/fn_numeric.go index 3802fbd5630..a69a6e35a86 100644 --- a/go/vt/vtgate/evalengine/fn_numeric.go +++ b/go/vt/vtgate/evalengine/fn_numeric.go @@ -185,7 +185,7 @@ func (expr *builtinAbs) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) - convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: arg.Flag} + convt := ctype{Type: arg.Type, Col: collationNumeric, Flag: nullableFlags(arg.Flag)} switch arg.Type { case sqltypes.Int64: c.asm.Fn_ABS_i() @@ -357,7 +357,7 @@ func (expr *builtinAtan2) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_ATAN2() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } func (call *builtinAtan2) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) { @@ -643,7 +643,7 @@ func (expr *builtinLog) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_LOG() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } type builtinLog10 struct { @@ -758,7 +758,7 @@ func (expr *builtinPow) compile(c *compiler) (ctype, error) { c.compileToFloat(arg2, 1) c.asm.Fn_POW() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: arg1.Flag | arg2.Flag | flagNullable}, nil + return ctype{Type: sqltypes.Float64, Col: collationNumeric, Flag: nullableFlags(arg1.Flag | arg2.Flag)}, nil } type builtinSign struct { @@ -843,7 +843,7 @@ func (expr *builtinSign) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(arg.Flag)}, nil } type builtinSqrt struct { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 7146ac03b68..bbc54a65480 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -270,7 +270,7 @@ func (call *builtinASCII) compile(c *compiler) (ctype, error) { c.asm.Fn_ASCII() c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: str.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(str.Flag)}, nil } func charOrd(b []byte, coll collations.ID) int64 { @@ -330,7 +330,7 @@ func (call *builtinOrd) compile(c *compiler) (ctype, error) { c.asm.Fn_ORD(col) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: str.Flag}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(str.Flag)}, nil } // maxRepeatLength is the maximum number of times a string can be repeated. @@ -721,7 +721,7 @@ func (call builtinPad) compile(c *compiler) (ctype, error) { c.asm.Fn_RPAD(col) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } func strcmpCollate(left, right []byte, col collations.ID) int64 { @@ -823,7 +823,7 @@ func (expr *builtinStrcmp) compile(c *compiler) (ctype, error) { c.asm.Strcmp(mcol) c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil + return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: nullableFlags(lt.Flag | rt.Flag)}, nil } func (call builtinTrim) eval(env *ExpressionEnv) (eval, error) { @@ -912,7 +912,7 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) { c.asm.Fn_TRIM1(col) } c.asm.jumpDestination(skip1) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: nullableFlags(str.Flag), Col: col}, nil } pat, err := call.Arguments[1].compile(c) @@ -943,7 +943,7 @@ func (call builtinTrim) compile(c *compiler) (ctype, error) { } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sqltypes.VarChar, Col: col}, nil + return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil } type builtinConcat struct { diff --git a/go/vt/vtgate/evalengine/testcases/inputs.go b/go/vt/vtgate/evalengine/testcases/inputs.go index badf7bc9dc8..4e3a1525172 100644 --- a/go/vt/vtgate/evalengine/testcases/inputs.go +++ b/go/vt/vtgate/evalengine/testcases/inputs.go @@ -92,6 +92,8 @@ var inputConversions = []string{ `'foobar'`, `_utf8 'foobar'`, `''`, `_binary 'foobar'`, `0x0`, `0x1`, `0xff`, `X'00'`, `X'01'`, `X'ff'`, "NULL", "true", "false", + "NULL * 1", "1 * NULL", "NULL * NULL", "NULL / 1", "1 / NULL", "NULL / NULL", + "NULL + 1", "1 + NULL", "NULL + NULL", "NULL - 1", "1 - NULL", "NULL - NULL", "0xFF666F6F626172FF", "0x666F6F626172FF", "0xFF666F6F626172", "9223372036854775807", "-9223372036854775808", "18446744073709551615", "18446744073709540000e0", diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index c7cd12a79e4..5dc54699d9f 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -738,7 +738,7 @@ func TestSelectLastInsertId(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "last_insert_id()", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "last_insert_id()", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewUint64(52), @@ -823,9 +823,9 @@ func TestSelectInitializedVitessAwareVariable(t *testing.T) { result, err := executorExec(executor, sql, nil) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, - {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@autocommit", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@enable_system_settings", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, + {Name: "@@query_timeout", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), @@ -861,7 +861,7 @@ func TestSelectUserDefinedVariable(t *testing.T) { require.NoError(t, err) wantResult = &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "@foo", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewVarChar("bar"), @@ -884,7 +884,7 @@ func TestFoundRows(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "found_rows()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "found_rows()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(1), @@ -914,7 +914,7 @@ func testRowCount(t *testing.T, executor *Executor, wantRowCount int64) { result, err := executorExec(executor, "select row_count()", map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "row_count()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "row_count()", Type: sqltypes.Int64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(wantRowCount), @@ -931,7 +931,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { result1 := []*sqltypes.Result{{ Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, InsertID: 0, Rows: [][]sqltypes.Value{{ @@ -945,7 +945,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG)}, + {Name: "id", Type: sqltypes.Int32, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt32(52), @@ -1008,7 +1008,7 @@ func TestLastInsertIDInSubQueryExpression(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "x", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "x", Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_UNSIGNED_FLAG)}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewUint64(12345), @@ -1036,7 +1036,7 @@ func TestSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.VarChar, Charset: uint32(collations.Default()), Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "database()", Type: sqltypes.VarChar, Charset: uint32(collations.Default())}, }, Rows: [][]sqltypes.Value{{ sqltypes.NewVarChar("TestExecutor@primary"),