From 4d97507b2f113581babe14d8e6347f3414e09fc7 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 18 Dec 2024 12:36:30 +0100 Subject: [PATCH] feat: handle last_insert_id with arguments in the evalengine Signed-off-by: Andres Taylor --- .../endtoend/vtgate/queries/misc/misc_test.go | 3 +- go/vt/vtgate/engine/fake_vcursor_test.go | 3 + go/vt/vtgate/engine/primitive.go | 2 + go/vt/vtgate/evalengine/compiler_asm.go | 21 +++++ go/vt/vtgate/evalengine/compiler_test.go | 93 +++++++++++++++++++ go/vt/vtgate/evalengine/expr_env.go | 2 + go/vt/vtgate/evalengine/fn_misc.go | 40 ++++++++ .../evalengine/integration/comparison_test.go | 4 + go/vt/vtgate/evalengine/translate_builtin.go | 5 + go/vt/vtgate/executorcontext/vcursor_impl.go | 6 ++ 10 files changed, 178 insertions(+), 1 deletion(-) diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 62b859ab1d3..e8701fe88bd 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -163,6 +163,7 @@ func TestSetAndGetLastInsertID(t *testing.T) { "update t1 set id2 = last_insert_id(%d) where id1 = 2", "update t1 set id2 = 88 where id1 = last_insert_id(%d)", "delete from t1 where id1 = last_insert_id(%d)", + "select id2, last_insert_id(count(*)) from t1 where %d group by id2", } for _, workload := range []string{"olap", "oltp"} { @@ -175,7 +176,7 @@ func TestSetAndGetLastInsertID(t *testing.T) { require.NoError(t, err) } - // Insert a row for UPDATE tests + // Insert a few rows for UPDATE tests mcmp.Exec("insert into t1 (id1, id2) values (1, 10)") for _, query := range queries { diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index f27ca380876..72422193ee8 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -893,6 +893,9 @@ func (t *loggingVCursor) RecordMirrorStats(sourceExecTime, targetExecTime time.D } } +func (t *loggingVCursor) SetLastInsertID(id uint64) {} +func (t *noopVCursor) SetLastInsertID(id uint64) {} + func (t *noopVCursor) VExplainLogging() {} func (t *noopVCursor) DisableLogging() {} func (t *noopVCursor) GetVExplainLogs() []ExecuteEntry { diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index e6fa102581e..7734dd81a6b 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -147,6 +147,8 @@ type ( // RecordMirrorStats is used to record stats about a mirror query. RecordMirrorStats(time.Duration, time.Duration, error) + + SetLastInsertID(uint64) } // SessionActions gives primitives ability to interact with the session state diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index dfb1a30bffc..9099f8711c0 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -5138,3 +5138,24 @@ func (asm *assembler) Introduce(offset int, t sqltypes.Type, col collations.Type return 1 }, "INTRODUCE (SP-1)") } + +func (asm *assembler) Fn_LAST_INSERT_ID() { + asm.emit(func(env *ExpressionEnv) int { + arg := env.vm.stack[env.vm.sp-1].(*evalUint64) + env.VCursor().SetLastInsertID(arg.u) + return 1 + }, "FN LAST_INSERT_ID UINT64(SP-1)") +} + +func (asm *assembler) Fn_LAST_INSERT_ID_NULL() { + asm.emit(func(env *ExpressionEnv) int { + env.VCursor().SetLastInsertID(0) + return 1 + }, "FN LAST_INSERT_ID NULL") +} + +func (asm *assembler) addJump(end *jump) { + asm.emit(func(env *ExpressionEnv) int { + return end.offset() + }, "JUMP") +} diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index cb9b99e7776..ea3a7a603ce 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -883,6 +883,99 @@ func TestBindVarLiteral(t *testing.T) { } } +type testVcursor struct { + lastInsertID *uint64 + env *vtenv.Environment +} + +func (t *testVcursor) TimeZone() *time.Location { + return time.UTC +} + +func (t *testVcursor) GetKeyspace() string { + return "apa" +} + +func (t *testVcursor) SQLMode() string { + return "oltp" +} + +func (t *testVcursor) Environment() *vtenv.Environment { + return t.env +} + +func (t *testVcursor) SetLastInsertID(id uint64) { + t.lastInsertID = &id +} + +var _ evalengine.VCursor = (*testVcursor)(nil) + +func TestLastInsertID(t *testing.T) { + var testCases = []struct { + expression string + result uint64 + missing bool + }{ + { + expression: `last_insert_id(1)`, + result: 1, + }, { + expression: `12`, + missing: true, + }, { + expression: `last_insert_id(666)`, + result: 666, + }, { + expression: `last_insert_id(null)`, + result: 0, + }, + } + + venv := vtenv.NewTestEnv() + for _, tc := range testCases { + t.Run(tc.expression, func(t *testing.T) { + expr, err := venv.Parser().ParseExpr(tc.expression) + require.NoError(t, err) + + cfg := &evalengine.Config{ + Collation: collations.CollationUtf8mb4ID, + NoConstantFolding: true, + NoCompilation: false, + Environment: venv, + } + t.Run("eval", func(t *testing.T) { + cfg.NoCompilation = true + runTest(t, expr, cfg, tc) + }) + t.Run("compiled", func(t *testing.T) { + cfg.NoCompilation = false + runTest(t, expr, cfg, tc) + }) + }) + } +} + +func runTest(t *testing.T, expr sqlparser.Expr, cfg *evalengine.Config, tc struct { + expression string + result uint64 + missing bool +}) { + converted, err := evalengine.Translate(expr, cfg) + require.NoError(t, err) + + vc := &testVcursor{env: vtenv.NewTestEnv()} + env := evalengine.NewExpressionEnv(context.Background(), nil, vc) + + _, err = env.Evaluate(converted) + require.NoError(t, err) + if tc.missing { + require.Nil(t, vc.lastInsertID) + } else { + require.NotNil(t, vc.lastInsertID) + require.Equal(t, tc.result, *vc.lastInsertID) + } +} + func TestCompilerNonConstant(t *testing.T) { var testCases = []struct { expression string diff --git a/go/vt/vtgate/evalengine/expr_env.go b/go/vt/vtgate/evalengine/expr_env.go index 38a65f9b4e0..4a7f9849ab0 100644 --- a/go/vt/vtgate/evalengine/expr_env.go +++ b/go/vt/vtgate/evalengine/expr_env.go @@ -35,6 +35,7 @@ type VCursor interface { GetKeyspace() string SQLMode() string Environment() *vtenv.Environment + SetLastInsertID(id uint64) } type ( @@ -140,6 +141,7 @@ func (e *emptyVCursor) GetKeyspace() string { func (e *emptyVCursor) SQLMode() string { return config.DefaultSQLMode } +func (e *emptyVCursor) SetLastInsertID(_ uint64) {} func NewEmptyVCursor(env *vtenv.Environment, tz *time.Location) VCursor { return &emptyVCursor{env: env, tz: tz} diff --git a/go/vt/vtgate/evalengine/fn_misc.go b/go/vt/vtgate/evalengine/fn_misc.go index 8813b62f823..cb17f8d6560 100644 --- a/go/vt/vtgate/evalengine/fn_misc.go +++ b/go/vt/vtgate/evalengine/fn_misc.go @@ -81,6 +81,10 @@ type ( builtinUUIDToBin struct { CallExpr } + + builtinLastInsertID struct { + CallExpr + } ) var _ IR = (*builtinInetAton)(nil) @@ -95,6 +99,7 @@ var _ IR = (*builtinBinToUUID)(nil) var _ IR = (*builtinIsUUID)(nil) var _ IR = (*builtinUUID)(nil) var _ IR = (*builtinUUIDToBin)(nil) +var _ IR = (*builtinLastInsertID)(nil) func (call *builtinInetAton) eval(env *ExpressionEnv) (eval, error) { arg, err := call.arg1(env) @@ -155,6 +160,7 @@ func (call *builtinInetNtoa) compile(c *compiler) (ctype, error) { c.compileToUint64(arg, 1) col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_INET_NTOA(col) + c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil @@ -194,6 +200,40 @@ func (call *builtinInet6Aton) compile(c *compiler) (ctype, error) { return ctype{Type: sqltypes.VarBinary, Flag: flagNullable, Col: collationBinary}, nil } +func (call *builtinLastInsertID) eval(env *ExpressionEnv) (eval, error) { + arg, err := call.arg1(env) + if err != nil { + return nil, err + } + if arg == nil { + env.VCursor().SetLastInsertID(0) + return nil, err + } + insertID := uint64(evalToInt64(arg).i) + env.VCursor().SetLastInsertID(insertID) + return newEvalUint64(insertID), nil +} + +func (call *builtinLastInsertID) compile(c *compiler) (ctype, error) { + arg, err := call.Arguments[0].compile(c) + if err != nil { + return ctype{}, err + } + + setZero := c.compileNullCheck1(arg) + c.compileToUint64(arg, 1) + c.asm.Fn_LAST_INSERT_ID() + end := c.asm.jumpFrom() + c.asm.addJump(end) + + c.asm.jumpDestination(setZero) + c.asm.Fn_LAST_INSERT_ID_NULL() + + c.asm.jumpDestination(end) + + return ctype{Type: sqltypes.Uint64, Flag: flagNullable, Col: collationNumeric}, nil +} + func printIPv6AsIPv4(addr netip.Addr) (netip.Addr, bool) { b := addr.AsSlice() if len(b) != 16 { diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index ea327601975..d559cb8ab1d 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -209,6 +209,10 @@ type vcursor struct { env *vtenv.Environment } +func (vc *vcursor) SetLastInsertID(id uint64) {} + +var _ evalengine.VCursor = (*vcursor)(nil) + func (vc *vcursor) GetKeyspace() string { return "vttest" } diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 476ee32483b..1f8bd7798aa 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -662,6 +662,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinReplace{CallExpr: call, collate: ast.cfg.Collation}, nil + case "last_insert_id": + if len(args) != 1 { + return nil, argError(method) + } + return &builtinLastInsertID{CallExpr: call}, nil default: return nil, translateExprNotSupported(fn) } diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index 40317f5103a..19480f06648 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -1583,3 +1583,9 @@ func (vc *VCursorImpl) GetContextWithTimeOut(ctx context.Context) (context.Conte func (vc *VCursorImpl) IgnoreMaxMemoryRows() bool { return vc.ignoreMaxMemoryRows } + +func (vc *VCursorImpl) SetLastInsertID(id uint64) { + vc.SafeSession.mu.Lock() + defer vc.SafeSession.mu.Unlock() + vc.SafeSession.LastInsertId = id +}