From 91f4c5b71feec9fae8af1d883f62438ebfa2565d Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Mon, 5 Feb 2024 01:56:59 +0530 Subject: [PATCH] evalEngine: Implement INSTR Signed-off-by: Noble Mittal --- go/vt/vtgate/evalengine/cached_size.go | 12 +++ go/vt/vtgate/evalengine/compiler_asm.go | 12 +++ go/vt/vtgate/evalengine/fn_string.go | 77 ++++++++++++++++++++ go/vt/vtgate/evalengine/testcases/cases.go | 36 +++++++++ go/vt/vtgate/evalengine/translate_builtin.go | 5 ++ 5 files changed, 142 insertions(+) diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 8c22ff4ecd9..a79957282d3 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -931,6 +931,18 @@ func (cached *builtinInetNtoa) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinInstr) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinIsIPV4) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index de856bb4333..0eed6bdf1a6 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -1418,6 +1418,18 @@ func (asm *assembler) Mod_dd() { }, "MOD DECIMAL(SP-2), DECIMAL(SP-1)") } +func (asm *assembler) Fn_INSTR() { + asm.adjustStack(-1) + + asm.emit(func(env *ExpressionEnv) int { + str := env.vm.stack[env.vm.sp-2].(*evalBytes) + substr := env.vm.stack[env.vm.sp-1].(*evalBytes) + + env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalInt64(instrIndex(str, substr)) + return 1 + }, "FN INSTR VARCHAR(SP-2) VARCHAR(SP-1)") +} + func (asm *assembler) Fn_ASCII() { asm.emit(func(env *ExpressionEnv) int { arg := env.vm.stack[env.vm.sp-1].(*evalBytes) diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index f69b8db1e72..2620c878da9 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -43,6 +43,11 @@ type ( CallExpr } + builtinInstr struct { + CallExpr + collate collations.ID + } + builtinASCII struct { CallExpr } @@ -99,6 +104,7 @@ type ( var _ IR = (*builtinChangeCase)(nil) var _ IR = (*builtinCharLength)(nil) var _ IR = (*builtinLength)(nil) +var _ IR = (*builtinInstr)(nil) var _ IR = (*builtinASCII)(nil) var _ IR = (*builtinOrd)(nil) var _ IR = (*builtinBitLength)(nil) @@ -199,6 +205,77 @@ func (call *builtinLength) compile(c *compiler) (ctype, error) { return c.compileFn_length(call.Arguments[0], c.asm.Fn_LENGTH) } +func instrIndex(str *evalBytes, substr *evalBytes) int64 { + // Case sensitive if one of the strings is binary string + if !str.isBinary() && !substr.isBinary() { + str.bytes = bytes.ToLower(str.bytes) + substr.bytes = bytes.ToLower(substr.bytes) + } + + pos := bytes.Index(str.bytes, substr.bytes) + 1 + return int64(pos) +} + +func (call *builtinInstr) eval(env *ExpressionEnv) (eval, error) { + arg1, arg2, err := call.arg2(env) + if err != nil { + return nil, err + } + if arg1 == nil || arg2 == nil { + return nil, nil + } + + str, ok := arg1.(*evalBytes) + if !ok { + str, err = evalToVarchar(arg1, call.collate, true) + if err != nil { + return nil, err + } + } + + substr, ok := arg2.(*evalBytes) + if !ok { + substr, err = evalToVarchar(arg2, call.collate, true) + if err != nil { + return nil, err + } + } + + return newEvalInt64(instrIndex(str, substr)), nil +} + +func (call *builtinInstr) compile(c *compiler) (ctype, error) { + arg1, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + skip1 := c.compileNullCheck1(arg1) + + switch { + case arg1.isTextual(): + default: + c.asm.Convert_xce(1, sqltypes.VarChar, call.collate) + } + + arg2, err := call.Arguments[1].compile(c) + if err != nil { + return ctype{}, err + } + + skip2 := c.compileNullCheck1(arg2) + + switch { + case arg2.isTextual(): + default: + c.asm.Convert_xce(1, sqltypes.VarChar, call.collate) + } + + c.asm.Fn_INSTR() + c.asm.jumpDestination(skip1, skip2) + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + func (call *builtinBitLength) eval(env *ExpressionEnv) (eval, error) { arg, err := call.arg1(env) if err != nil { diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index f9036c1afca..6b72052483c 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -67,6 +67,7 @@ var Cases = []TestCase{ {Run: FnUpper}, {Run: FnCharLength}, {Run: FnLength}, + {Run: FnInstr}, {Run: FnBitLength}, {Run: FnAscii}, {Run: FnOrd}, @@ -1339,6 +1340,41 @@ func FnLength(yield Query) { } } +func FnInstr(yield Query) { + for _, str := range inputStrings { + for _, substr := range inputStrings { + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) + } + } + + cases := []struct { + str string + substr string + }{ + {"'ACABAB'", "'AB'"}, + {"'ABABAB'", "'AB'"}, + {"'ABABAB'", "'ab'"}, + {"'ABABAB'", "'ba'"}, + {"'CBDASD'", "'ab'"}, + {"'ABABAB'", "''"}, + {"'ABABAB'", ""}, + {"1233", "23"}, + {"0x616162", "0x6162"}, + {"0x616162", "0x4141"}, + {"'AAB'", "123"}, + {"123", "'ABC'"}, + {"_binary'FOOBAR'", "'AR'"}, + {"_binary'FOOBAR'", "BINARY 'AR'"}, + {"BINARY 'FOOBAR'", "'ar'"}, + {"'foobarbar'", "'bar'"}, + {"'xbar'", "'foobar'"}, + } + + for _, tc := range cases { + yield(fmt.Sprintf("INSTR(%s, %s)", tc.str, tc.substr), nil) + } +} + func FnBitLength(yield Query) { for _, str := range inputStrings { yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil) diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index b44c17f6f9b..2ed800c4a0a 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -285,6 +285,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinLength{CallExpr: call}, nil + case "instr": + if len(args) != 2 { + return nil, argError(method) + } + return &builtinInstr{CallExpr: call, collate: ast.cfg.Collation}, nil case "bit_length": if len(args) != 1 { return nil, argError(method)