Skip to content

Commit

Permalink
evalengine: Proper support for bit literals (#14374)
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <vmg@strn.cat>
  • Loading branch information
vmg authored Oct 26, 2023
1 parent 2f56827 commit bc8df35
Show file tree
Hide file tree
Showing 32 changed files with 318 additions and 116 deletions.
6 changes: 5 additions & 1 deletion go/mysql/collations/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ func Default() ID {
}

func DefaultCollationForType(t sqltypes.Type) ID {
return CollationForType(t, Default())
}

func CollationForType(t sqltypes.Type, fallback ID) ID {
switch {
case sqltypes.IsText(t):
return Default()
return fallback
case t == sqltypes.TypeJSON:
return CollationUtf8mb4ID
default:
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func TestBitVals(t *testing.T) {

mcmp.AssertMatches(`select b'1001', 0x9, B'010011011010'`, `[[VARBINARY("\t") VARBINARY("\t") VARBINARY("\x04\xda")]]`)
mcmp.AssertMatches(`select b'1001', 0x9, B'010011011010' from t1`, `[[VARBINARY("\t") VARBINARY("\t") VARBINARY("\x04\xda")]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010'`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[UINT64(10) UINT64(11) UINT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010' from t1`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[UINT64(10) UINT64(11) UINT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010'`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[INT64(10) UINT64(11) INT64(1245)]]`)
mcmp.AssertMatchesNoCompare(`select 1 + b'1001', 2 + 0x9, 3 + B'010011011010' from t1`, `[[INT64(10) UINT64(11) INT64(1245)]]`, `[[INT64(10) UINT64(11) INT64(1245)]]`)
}

func TestHexVals(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestNormalizeAllFields(t *testing.T) {
defer conn.Close()

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
qr := utils.Exec(t, conn, selectQuery)
Expand Down
4 changes: 1 addition & 3 deletions go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,10 @@ func (node *Literal) Format(buf *TrackedBuffer) {
switch node.Type {
case StrVal:
sqltypes.MakeTrusted(sqltypes.VarBinary, node.Bytes()).EncodeSQL(buf)
case IntVal, FloatVal, DecimalVal, HexNum:
case IntVal, FloatVal, DecimalVal, HexNum, BitNum:
buf.astPrintf(node, "%#s", node.Val)
case HexVal:
buf.astPrintf(node, "X'%#s'", node.Val)
case BitVal:
buf.astPrintf(node, "B'%#s'", node.Val)
case DateVal:
buf.astPrintf(node, "date'%#s'", node.Val)
case TimeVal:
Expand Down
6 changes: 1 addition & 5 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ const (
FloatVal
HexNum
HexVal
BitVal
BitNum
DateVal
TimeVal
TimestampVal
Expand Down Expand Up @@ -515,9 +515,9 @@ func NewHexLiteral(in string) *Literal {
return &Literal{Type: HexVal, Val: in}
}

// NewBitLiteral builds a new BitVal containing a bit literal.
// NewBitLiteral builds a new BitNum containing a bit literal.
func NewBitLiteral(in string) *Literal {
return &Literal{Type: BitVal, Val: in}
return &Literal{Type: BitNum, Val: in}
}

// NewDateLiteral builds a new Date.
Expand Down Expand Up @@ -583,8 +583,8 @@ func (node *Literal) SQLType() sqltypes.Type {
return sqltypes.HexNum
case HexVal:
return sqltypes.HexVal
case BitVal:
return sqltypes.HexNum
case BitNum:
return sqltypes.BitNum
case DateVal:
return sqltypes.Date
case TimeVal:
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func LiteralToValue(lit *Literal) (sqltypes.Value, error) {
return parseHexLiteral(b[1:])
case HexVal:
return parseHexLiteral(lit.Bytes())
case BitVal:
case BitNum:
return parseBitLiteral(lit.Bytes())
case DateVal:
d, ok := datetime.ParseDate(lit.Val)
Expand Down
20 changes: 5 additions & 15 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ package sqlparser

import (
"bytes"
"math/big"

"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/hex"
"vitess.io/vitess/go/sqltypes"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -365,19 +363,11 @@ func SQLToBindvar(node SQLNode) *querypb.BindVariable {
buf = append(buf, bytes.ToUpper(node.Bytes())...)
buf = append(buf, '\'')
v, err = sqltypes.NewValue(sqltypes.HexVal, buf)
case BitVal:
// Convert bit value to hex number in parameterized query format
var i big.Int
_, ok := i.SetString(string(node.Bytes()), 2)
if !ok {
return nil
}

buf := i.Bytes()
out := make([]byte, 0, (len(buf)*2)+2)
out = append(out, '0', 'x')
out = append(out, hex.EncodeBytes(buf)...)
v, err = sqltypes.NewValue(sqltypes.HexNum, out)
case BitNum:
out := make([]byte, 0, len(node.Bytes())+2)
out = append(out, '0', 'b')
out = append(out, node.Bytes()[2:]...)
v, err = sqltypes.NewValue(sqltypes.BitNum, out)
case DateVal:
v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes())
case TimeVal:
Expand Down
24 changes: 12 additions & 12 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,23 @@ func TestNormalize(t *testing.T) {
}, {
// Bin values work fine
in: "select * from t where foo = b'11'",
outstmt: "select * from t where foo = :foo /* HEXNUM */",
outstmt: "select * from t where foo = :foo /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.HexNumBindVariable([]byte("0x03")),
"foo": sqltypes.BitNumBindVariable([]byte("0b11")),
},
}, {
// Large bin values work fine
in: "select * from t where foo = b'11101010100101010010101010101010101010101000100100100100100101001101010101010101000001'",
outstmt: "select * from t where foo = :foo /* HEXNUM */",
outstmt: "select * from t where foo = :foo /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.HexNumBindVariable([]byte("0x3AA54AAAAAA24925355541")),
"foo": sqltypes.BitNumBindVariable([]byte("0b11101010100101010010101010101010101010101000100100100100100101001101010101010101000001")),
},
}, {
// Bin value does not convert for DMLs
in: "update a set v1 = b'11'",
outstmt: "update a set v1 = :v1 /* HEXNUM */",
outstmt: "update a set v1 = :v1 /* BITNUM */",
outbv: map[string]*querypb.BindVariable{
"v1": sqltypes.HexNumBindVariable([]byte("0x03")),
"v1": sqltypes.BitNumBindVariable([]byte("0b11")),
},
}, {
// ORDER BY column_position
Expand Down Expand Up @@ -308,14 +308,14 @@ func TestNormalize(t *testing.T) {
"bv3": sqltypes.Int64BindVariable(3),
},
}, {
// BitVal should also be normalized
// BitNum should also be normalized
in: `select b'1', 0b01, b'1010', 0b1111111`,
outstmt: `select :bv1 /* HEXNUM */, :bv2 /* HEXNUM */, :bv3 /* HEXNUM */, :bv4 /* HEXNUM */ from dual`,
outstmt: `select :bv1 /* BITNUM */, :bv2 /* BITNUM */, :bv3 /* BITNUM */, :bv4 /* BITNUM */ from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.HexNumBindVariable([]byte("0x01")),
"bv2": sqltypes.HexNumBindVariable([]byte("0x01")),
"bv3": sqltypes.HexNumBindVariable([]byte("0x0A")),
"bv4": sqltypes.HexNumBindVariable([]byte("0x7F")),
"bv1": sqltypes.BitNumBindVariable([]byte("0b1")),
"bv2": sqltypes.BitNumBindVariable([]byte("0b01")),
"bv3": sqltypes.BitNumBindVariable([]byte("0b1010")),
"bv4": sqltypes.BitNumBindVariable([]byte("0b1111111")),
},
}, {
// DateVal should also be normalized
Expand Down
9 changes: 5 additions & 4 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ var (
output: "select `name`, numbers from (select * from users) as x(`name`, numbers)",
}, {
input: "select 0b010, 0b0111, b'0111', b'011'",
output: "select B'010', B'0111', B'0111', B'011' from dual",
output: "select 0b010, 0b0111, 0b0111, 0b011 from dual",
}, {
input: "select 0x010, 0x0111, x'0111'",
output: "select 0x010, 0x0111, X'0111' from dual",
Expand Down Expand Up @@ -1120,9 +1120,10 @@ var (
input: "select /* hex caps */ X'F0a1' from t",
}, {
input: "select /* bit literal */ b'0101' from t",
output: "select /* bit literal */ B'0101' from t",
output: "select /* bit literal */ 0b0101 from t",
}, {
input: "select /* bit literal caps */ B'010011011010' from t",
input: "select /* bit literal caps */ B'010011011010' from t",
output: "select /* bit literal caps */ 0b010011011010 from t",
}, {
input: "select /* 0x */ 0xf0 from t",
}, {
Expand Down Expand Up @@ -5004,7 +5005,7 @@ func TestCreateTable(t *testing.T) {
` + "`" + `s3` + "`" + ` varchar default null,
s4 timestamp default current_timestamp(),
s41 timestamp default now(),
s5 bit(1) default B'0'
s5 bit(1) default 0b0
)`,
}, {
// test non_reserved word in column name
Expand Down
8 changes: 4 additions & 4 deletions go/vt/sqlparser/sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions go/vt/sqlparser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -1689,27 +1689,27 @@ text_literal
}
| BITNUM
{
$$ = NewBitLiteral($1[2:])
$$ = NewBitLiteral($1)
}
| BIT_LITERAL
{
$$ = NewBitLiteral($1)
$$ = NewBitLiteral("0b" + $1)
}
| VALUE_ARG
{
$$ = parseBindVariable(yylex, $1[1:])
}
| underscore_charsets BIT_LITERAL %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2)}
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral("0b" + $2)}
}
| underscore_charsets HEXNUM %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewHexNumLiteral($2)}
}
| underscore_charsets BITNUM %prec UNARY
{
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2[2:])}
$$ = &IntroducerExpr{CharacterSet: $1, Expr: NewBitLiteral($2)}
}
| underscore_charsets HEX %prec UNARY
{
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/testdata/select_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14654,7 +14654,7 @@ INPUT
select hex(_utf8mb4 B'001111111111');
END
OUTPUT
select hex(_utf8mb4 B'001111111111') from dual
select hex(_utf8mb4 0b001111111111) from dual
END
INPUT
select NULLIF(1,NULL), NULLIF(1.0, NULL), NULLIF("test", NULL);
Expand Down Expand Up @@ -18968,7 +18968,7 @@ INPUT
select hex(_utf8 B'001111111111');
END
OUTPUT
select hex(_utf8mb3 B'001111111111') from dual
select hex(_utf8mb3 0b001111111111) from dual
END
INPUT
select right('hello', -18446744073709551615);
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ func inferColTypeFromExpr(node sqlparser.Expr, tableColumnMap map[sqlparser.Iden
fallthrough
case sqlparser.HexVal:
fallthrough
case sqlparser.BitVal:
case sqlparser.BitNum:
colTypes = append(colTypes, querypb.Type_INT32)
case sqlparser.StrVal:
colTypes = append(colTypes, querypb.Type_VARCHAR)
Expand Down
13 changes: 9 additions & 4 deletions go/vt/vtgate/evalengine/api_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"vitess.io/vitess/go/mysql/fastparse"
"vitess.io/vitess/go/mysql/hex"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// NullExpr is just what you are lead to believe
Expand Down Expand Up @@ -156,11 +158,14 @@ func parseHexNumber(val []byte) ([]byte, error) {
return parseHexLiteral(val[1:])
}

func parseBitLiteral(val []byte) ([]byte, error) {
func parseBitNum(val []byte) ([]byte, error) {
if val[0] != '0' || val[1] != 'b' {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (missing 0b prefix)", val)
}
var i big.Int
_, ok := i.SetString(string(val), 2)
_, ok := i.SetString(hack.String(val)[2:], 2)
if !ok {
panic("malformed bit literal from parser")
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "malformed Bit literal: %q (not base 2)", val)
}
return i.Bytes(), nil
}
Expand All @@ -186,7 +191,7 @@ func NewLiteralBinaryFromHexNum(val []byte) (*Literal, error) {
}

func NewLiteralBinaryFromBit(val []byte) (*Literal, error) {
raw, err := parseBitLiteral(val)
raw, err := parseBitNum(val)
if err != nil {
return nil, err
}
Expand Down
15 changes: 15 additions & 0 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ func AggregateTypes(types []sqltypes.Type) sqltypes.Type {
return typeAgg.result()
}

func (ta *typeAggregation) addEval(e eval) {
var t sqltypes.Type
var f typeFlag
switch e := e.(type) {
case nil:
t = sqltypes.Null
case *evalBytes:
t = sqltypes.Type(e.tt)
f = e.flag
default:
t = e.SQLType()
}
ta.add(t, f)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
switch tt {
case sqltypes.Float32, sqltypes.Float64:
Expand Down
Loading

0 comments on commit bc8df35

Please sign in to comment.