Skip to content

Commit

Permalink
Enum value parsing: do not parse by whitespace (#15493)
Browse files Browse the repository at this point in the history
Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
shlomi-noach and dbussink authored Mar 17, 2024
1 parent 2ee5946 commit 51debbd
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 21 deletions.
52 changes: 52 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ var (

// ErrIncompatibleTypeCast indicates a casting problem
ErrIncompatibleTypeCast = errors.New("Cannot convert value to desired type")

ErrInvalidEncodedString = errors.New("invalid SQL encoded string")
)

const (
Expand Down Expand Up @@ -861,6 +863,56 @@ var encodeRef = map[byte]byte{
'\\': '\\',
}

// BufDecodeStringSQL decodes the string into a strings.Builder
func BufDecodeStringSQL(buf *strings.Builder, val string) error {
if len(val) < 2 || val[0] != '\'' || val[len(val)-1] != '\'' {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
in := hack.StringBytes(val[1 : len(val)-1])
idx := 0
for {
if idx >= len(in) {
return nil
}
ch := in[idx]
if ch == '\'' {
idx++
if idx >= len(in) {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
if in[idx] != '\'' {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
buf.WriteByte(ch)
idx++
continue
}
if ch == '\\' {
idx++
if idx >= len(in) {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
decoded := SQLDecodeMap[in[idx]]
if decoded == DontEscape {
return fmt.Errorf("%s: %w", val, ErrInvalidEncodedString)
}
buf.WriteByte(decoded)
idx++
continue
}

buf.WriteByte(ch)
idx++
}
}

// DecodeStringSQL encodes the string as a SQL string.
func DecodeStringSQL(val string) (string, error) {
var buf strings.Builder
err := BufDecodeStringSQL(&buf, val)
return buf.String(), err
}

func init() {
for i := range SQLEncodeMap {
SQLEncodeMap[i] = DontEscape
Expand Down
57 changes: 57 additions & 0 deletions go/sqltypes/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,60 @@ func TestHexAndBitToBytes(t *testing.T) {
})
}
}

func TestEncodeStringSQL(t *testing.T) {
testcases := []struct {
in string
out string
}{
{
in: "",
out: "''",
},
{
in: "\x00'\"\b\n\r\t\x1A\\",
out: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'",
},
}
for _, tcase := range testcases {
out := EncodeStringSQL(tcase.in)
assert.Equal(t, tcase.out, out)
}
}

func TestDecodeStringSQL(t *testing.T) {
testcases := []struct {
in string
out string
err string
}{
{
in: "",
err: ": invalid SQL encoded string",
}, {
in: "''",
err: "",
},
{
in: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'",
out: "\x00'\"\b\n\r\t\x1A\\",
},
{
in: "'light ''green\\r\\n, \\nfoo'",
out: "light 'green\r\n, \nfoo",
},
{
in: "'foo \\\\ % _bar'",
out: "foo \\ % _bar",
},
}
for _, tcase := range testcases {
out, err := DecodeStringSQL(tcase.in)
if tcase.err != "" {
assert.EqualError(t, err, tcase.err)
} else {
require.NoError(t, err)
assert.Equal(t, tcase.out, out)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
change e e enum('red', 'light green', 'blue', 'orange', 'yellow') collate 'utf8_bin' null default null
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
drop table if exists onlineddl_test;
create table onlineddl_test (
id int auto_increment,
i int not null,
e enum('red', 'light green', 'blue', 'orange') null default null collate 'utf8_bin',
primary key(id)
) auto_increment=1;

drop event if exists onlineddl_test;
delimiter ;;
create event onlineddl_test
on schedule every 1 second
starts current_timestamp
ends current_timestamp + interval 60 second
on completion not preserve
enable
do
begin
insert into onlineddl_test values (null, 11, 'red');
insert into onlineddl_test values (null, 13, 'light green');
insert into onlineddl_test values (null, 17, 'blue');
set @last_insert_id := last_insert_id();
update onlineddl_test set e='orange' where id = @last_insert_id;
insert into onlineddl_test values (null, 23, null);
set @last_insert_id := last_insert_id();
update onlineddl_test set i=i+1, e=null where id = @last_insert_id;
end ;;
69 changes: 54 additions & 15 deletions go/vt/schema/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"strconv"
"strings"

"vitess.io/vitess/go/textutil"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
)

Expand Down Expand Up @@ -113,22 +113,61 @@ func ParseSetValues(setColumnType string) string {
// returns the (unquoted) text values
// Expected input: `'x-small','small','medium','large','x-large'`
// Unexpected input: `enum('x-small','small','medium','large','x-large')`
func parseEnumOrSetTokens(enumOrSetValues string) (tokens []string) {
if submatch := enumValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 {
// input should not contain `enum(...)` column definition, just the comma delimited list
return tokens
}
if submatch := setValuesRegexp.FindStringSubmatch(enumOrSetValues); len(submatch) > 0 {
// input should not contain `enum(...)` column definition, just the comma delimited list
return tokens
}
tokens = textutil.SplitDelimitedList(enumOrSetValues)
for i := range tokens {
if strings.HasPrefix(tokens[i], `'`) && strings.HasSuffix(tokens[i], `'`) {
tokens[i] = strings.Trim(tokens[i], `'`)
func parseEnumOrSetTokens(enumOrSetValues string) []string {
// We need to track both the start of the current value and current
// position, since there might be quoted quotes inside the value
// which we need to handle.
start := 0
pos := 1
var tokens []string
for {
// If the input does not start with a quote, it's not a valid enum/set definition
if enumOrSetValues[start] != '\'' {
return nil
}
i := strings.IndexByte(enumOrSetValues[pos:], '\'')
// If there's no closing quote, we have invalid input
if i < 0 {
return nil
}
// We're at the end here of the last quoted value,
// so we add the last token and return them.
if i == len(enumOrSetValues[pos:])-1 {
tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start:])
if err != nil {
return nil
}
tokens = append(tokens, tok)
return tokens
}
// MySQL double quotes things as escape value, so if we see another
// single quote, we skip the character and remove it from the input.
if enumOrSetValues[pos+i+1] == '\'' {
pos = pos + i + 2
continue
}
// Next value needs to be a comma as a separator, otherwise
// the data is invalid so we return nil.
if enumOrSetValues[pos+i+1] != ',' {
return nil
}
// If we're at the end of the input here, it's invalid
// since we have a trailing comma which is not what MySQL
// returns.
if pos+i+1 == len(enumOrSetValues) {
return nil
}

tok, err := sqltypes.DecodeStringSQL(enumOrSetValues[start : pos+i+1])
if err != nil {
return nil
}

tokens = append(tokens, tok)
// We add 2 to the position to skip the closing quote & comma
start = pos + i + 2
pos = start + 1
}
return tokens
}

// ParseEnumOrSetTokensMap parses the comma delimited part of an enum column definition
Expand Down
25 changes: 25 additions & 0 deletions go/vt/schema/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ func TestParseEnumValues(t *testing.T) {
assert.Equal(t, input, enumValues)
}
}

{
inputs := []string{
``,
`abc`,
`func('x small','small','medium','large','x large')`,
`set('x small','small','medium','large','x large')`,
}
for _, input := range inputs {
enumValues := ParseEnumValues(input)
assert.Equal(t, input, enumValues)
}
}
}

func TestParseSetValues(t *testing.T) {
Expand Down Expand Up @@ -125,6 +138,18 @@ func TestParseEnumTokens(t *testing.T) {
expect := []string{"x-small", "small", "medium", "large", "x-large"}
assert.Equal(t, expect, enumTokens)
}
{
input := `'x small','small','medium','large','x large'`
enumTokens := parseEnumOrSetTokens(input)
expect := []string{"x small", "small", "medium", "large", "x large"}
assert.Equal(t, expect, enumTokens)
}
{
input := `'with '' quote','and \n newline'`
enumTokens := parseEnumOrSetTokens(input)
expect := []string{"with ' quote", "and \n newline"}
assert.Equal(t, expect, enumTokens)
}
{
input := `enum('x-small','small','medium','large','x-large')`
enumTokens := parseEnumOrSetTokens(input)
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vttablet/onlineddl/vrepl/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,35 +333,35 @@ func TestGetExpandedColumnNames(t *testing.T) {
"expand enum",
Column{
Type: EnumColumnType,
EnumValues: "'a', 'b'",
EnumValues: "'a','b'",
},
Column{
Type: EnumColumnType,
EnumValues: "'a', 'x'",
EnumValues: "'a','x'",
},
true,
},
{
"expand enum",
Column{
Type: EnumColumnType,
EnumValues: "'a', 'b'",
EnumValues: "'a','b'",
},
Column{
Type: EnumColumnType,
EnumValues: "'a', 'b', 'c'",
EnumValues: "'a','b','c'",
},
true,
},
{
"reduce enum",
Column{
Type: EnumColumnType,
EnumValues: "'a', 'b', 'c'",
EnumValues: "'a','b','c'",
},
Column{
Type: EnumColumnType,
EnumValues: "'a', 'b'",
EnumValues: "'a','b'",
},
false,
},
Expand Down

0 comments on commit 51debbd

Please sign in to comment.