diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 678f4499f45..f01497d7bdf 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -272,8 +272,8 @@ func (mcmp *MySQLCompare) ExecAndIgnore(query string) (*sqltypes.Result, error) return mcmp.VtConn.ExecuteFetch(query, 1000, true) } -func (mcmp *MySQLCompare) Run(query string, f func(mcmp *MySQLCompare)) { - mcmp.AsT().Run(query, func(t *testing.T) { +func (mcmp *MySQLCompare) Run(name string, f func(mcmp *MySQLCompare)) { + mcmp.AsT().Run(name, func(t *testing.T) { inner := &MySQLCompare{ t: t, MySQLConn: mcmp.MySQLConn, diff --git a/go/test/endtoend/vtgate/queries/tpch/main_test.go b/go/test/endtoend/vtgate/queries/tpch/main_test.go new file mode 100644 index 00000000000..103adb336ab --- /dev/null +++ b/go/test/endtoend/vtgate/queries/tpch/main_test.go @@ -0,0 +1,89 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package union + +import ( + _ "embed" + "flag" + "fmt" + "os" + "testing" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/test/endtoend/utils" +) + +var ( + clusterInstance *cluster.LocalProcessCluster + vtParams mysql.ConnParams + mysqlParams mysql.ConnParams + keyspaceName = "ks" + cell = "zone-1" + + //go:embed schema.sql + schemaSQL string + + //go:embed vschema.json + vschema string +) + +func TestMain(m *testing.M) { + defer cluster.PanicHandler(nil) + flag.Parse() + + exitCode := func() int { + clusterInstance = cluster.NewCluster(cell, "localhost") + defer clusterInstance.Teardown() + + // Start topo server + err := clusterInstance.StartTopo() + if err != nil { + return 1 + } + + // Start keyspace + keyspace := &cluster.Keyspace{ + Name: keyspaceName, + SchemaSQL: schemaSQL, + VSchema: vschema, + } + err = clusterInstance.StartKeyspace(*keyspace, []string{"-80", "80-"}, 0, false) + if err != nil { + return 1 + } + + // Start vtgate + err = clusterInstance.StartVtgate() + if err != nil { + return 1 + } + + vtParams = clusterInstance.GetVTParams(keyspaceName) + + // create mysql instance and connection parameters + conn, closer, err := utils.NewMySQL(clusterInstance, keyspaceName, schemaSQL) + if err != nil { + fmt.Println(err) + return 1 + } + defer closer() + mysqlParams = conn + return m.Run() + }() + os.Exit(exitCode) +} diff --git a/go/test/endtoend/vtgate/queries/tpch/schema.sql b/go/test/endtoend/vtgate/queries/tpch/schema.sql new file mode 100644 index 00000000000..44af337938f --- /dev/null +++ b/go/test/endtoend/vtgate/queries/tpch/schema.sql @@ -0,0 +1,291 @@ +CREATE TABLE IF NOT EXISTS nation +( + N_NATIONKEY + INTEGER + NOT + NULL, + N_NAME + CHAR +( + 25 +) NOT NULL, + N_REGIONKEY INTEGER NOT NULL, + N_COMMENT VARCHAR +( + 152 +), + PRIMARY KEY +( + N_NATIONKEY +)); + +CREATE TABLE IF NOT EXISTS region +( + R_REGIONKEY + INTEGER + NOT + NULL, + R_NAME + CHAR +( + 25 +) NOT NULL, + R_COMMENT VARCHAR +( + 152 +), + PRIMARY KEY +( + R_REGIONKEY +)); + +CREATE TABLE IF NOT EXISTS part +( + P_PARTKEY + INTEGER + NOT + NULL, + P_NAME + VARCHAR +( + 55 +) NOT NULL, + P_MFGR CHAR +( + 25 +) NOT NULL, + P_BRAND CHAR +( + 10 +) NOT NULL, + P_TYPE VARCHAR +( + 25 +) NOT NULL, + P_SIZE INTEGER NOT NULL, + P_CONTAINER CHAR +( + 10 +) NOT NULL, + P_RETAILPRICE DECIMAL +( + 15, + 2 +) NOT NULL, + P_COMMENT VARCHAR +( + 23 +) NOT NULL, + PRIMARY KEY +( + P_PARTKEY +)); + +CREATE TABLE IF NOT EXISTS supplier +( + S_SUPPKEY + INTEGER + NOT + NULL, + S_NAME + CHAR +( + 25 +) NOT NULL, + S_ADDRESS VARCHAR +( + 40 +) NOT NULL, + S_NATIONKEY INTEGER NOT NULL, + S_PHONE CHAR +( + 15 +) NOT NULL, + S_ACCTBAL DECIMAL +( + 15, + 2 +) NOT NULL, + S_COMMENT VARCHAR +( + 101 +) NOT NULL, + PRIMARY KEY +( + S_SUPPKEY +)); + +CREATE TABLE IF NOT EXISTS partsupp +( + PS_PARTKEY + INTEGER + NOT + NULL, + PS_SUPPKEY + INTEGER + NOT + NULL, + PS_AVAILQTY + INTEGER + NOT + NULL, + PS_SUPPLYCOST + DECIMAL +( + 15, + 2 +) NOT NULL, + PS_COMMENT VARCHAR +( + 199 +) NOT NULL, + PRIMARY KEY +( + PS_PARTKEY, + PS_SUPPKEY +)); + +CREATE TABLE IF NOT EXISTS customer +( + C_CUSTKEY + INTEGER + NOT + NULL, + C_NAME + VARCHAR +( + 25 +) NOT NULL, + C_ADDRESS VARCHAR +( + 40 +) NOT NULL, + C_NATIONKEY INTEGER NOT NULL, + C_PHONE CHAR +( + 15 +) NOT NULL, + C_ACCTBAL DECIMAL +( + 15, + 2 +) NOT NULL, + C_MKTSEGMENT CHAR +( + 10 +) NOT NULL, + C_COMMENT VARCHAR +( + 117 +) NOT NULL, + PRIMARY KEY +( + C_CUSTKEY +)); + +CREATE TABLE IF NOT EXISTS orders +( + O_ORDERKEY + INTEGER + NOT + NULL, + O_CUSTKEY + INTEGER + NOT + NULL, + O_ORDERSTATUS + CHAR +( + 1 +) NOT NULL, + O_TOTALPRICE DECIMAL +( + 15, + 2 +) NOT NULL, + O_ORDERDATE DATE NOT NULL, + O_ORDERPRIORITY CHAR +( + 15 +) NOT NULL, + O_CLERK CHAR +( + 15 +) NOT NULL, + O_SHIPPRIORITY INTEGER NOT NULL, + O_COMMENT VARCHAR +( + 79 +) NOT NULL, + PRIMARY KEY +( + O_ORDERKEY +)); + +CREATE TABLE IF NOT EXISTS lineitem +( + L_ORDERKEY + INTEGER + NOT + NULL, + L_PARTKEY + INTEGER + NOT + NULL, + L_SUPPKEY + INTEGER + NOT + NULL, + L_LINENUMBER + INTEGER + NOT + NULL, + L_QUANTITY + DECIMAL +( + 15, + 2 +) NOT NULL, + L_EXTENDEDPRICE DECIMAL +( + 15, + 2 +) NOT NULL, + L_DISCOUNT DECIMAL +( + 15, + 2 +) NOT NULL, + L_TAX DECIMAL +( + 15, + 2 +) NOT NULL, + L_RETURNFLAG CHAR +( + 1 +) NOT NULL, + L_LINESTATUS CHAR +( + 1 +) NOT NULL, + L_SHIPDATE DATE NOT NULL, + L_COMMITDATE DATE NOT NULL, + L_RECEIPTDATE DATE NOT NULL, + L_SHIPINSTRUCT CHAR +( + 25 +) NOT NULL, + L_SHIPMODE CHAR +( + 10 +) NOT NULL, + L_COMMENT VARCHAR +( + 44 +) NOT NULL, + PRIMARY KEY +( + L_ORDERKEY, + L_LINENUMBER +)); diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go new file mode 100644 index 00000000000..b1dd4ef1e98 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -0,0 +1,141 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package union + +import ( + "testing" + + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/test/endtoend/utils" + + "github.com/stretchr/testify/require" +) + +func start(t *testing.T) (utils.MySQLCompare, func()) { + mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams) + require.NoError(t, err) + + deleteAll := func() { + _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") + + tables := []string{"nation", "region", "part", "supplier", "partsupp", "customer", "orders", "lineitem"} + for _, table := range tables { + _, _ = mcmp.ExecAndIgnore("delete from " + table) + } + } + + deleteAll() + + return mcmp, func() { + deleteAll() + mcmp.Close() + cluster.PanicHandler(t) + } +} + +func TestTPCHQueries(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + err := utils.WaitForColumn(t, clusterInstance.VtgateProcess, keyspaceName, "region", `R_COMMENT`) + require.NoError(t, err) + + insertQueries := []string{ + `INSERT INTO region (R_REGIONKEY, R_NAME, R_COMMENT) VALUES + (1, 'ASIA', 'Eastern Asia'), + (2, 'MIDDLE EAST', 'Rich cultural heritage');`, + `INSERT INTO nation (N_NATIONKEY, N_NAME, N_REGIONKEY, N_COMMENT) VALUES + (1, 'China', 1, 'Large population'), + (2, 'India', 1, 'Large variety of cultures'), + (3, 'Nation A', 2, 'Historic sites'), + (4, 'Nation B', 2, 'Beautiful landscapes');`, + `INSERT INTO supplier (S_SUPPKEY, S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT) VALUES + (1, 'Supplier A', '123 Square', 1, '86-123-4567', 5000.00, 'High quality steel'), + (2, 'Supplier B', '456 Ganges St', 2, '91-789-4561', 5500.00, 'Efficient production'), + (3, 'Supplier 1', 'Supplier Address 1', 3, '91-789-4562', 3000.00, 'Supplier Comment 1'), + (4, 'Supplier 2', 'Supplier Address 2', 2, '91-789-4563', 4000.00, 'Supplier Comment 2');`, + `INSERT INTO part (P_PARTKEY, P_NAME, P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT) VALUES + (100, 'Part 100', 'MFGR A', 'Brand X', 'BOLT STEEL', 30, 'SM BOX', 45.00, 'High strength'), + (101, 'Part 101', 'MFGR B', 'Brand Y', 'NUT STEEL', 30, 'LG BOX', 30.00, 'Rust resistant');`, + `INSERT INTO partsupp (PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT) VALUES + (100, 1, 500, 10.00, 'Deliveries on time'), + (101, 2, 300, 9.00, 'Back orders possible'), + (100, 2, 600, 8.50, 'Bulk discounts available');`, + `INSERT INTO customer (C_CUSTKEY, C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT) VALUES + (1, 'Customer A', '1234 Drive Lane', 1, '123-456-7890', 1000.00, 'AUTOMOBILE', 'Frequent orders'), + (2, 'Customer B', '5678 Park Ave', 2, '234-567-8901', 2000.00, 'AUTOMOBILE', 'Large orders'), + (3, 'Customer 1', 'Address 1', 1, 'Phone 1', 1000.00, 'Segment 1', 'Comment 1'), + (4, 'Customer 2', 'Address 2', 2, 'Phone 2', 2000.00, 'Segment 2', 'Comment 2');`, + `INSERT INTO orders (O_ORDERKEY, O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK, O_SHIPPRIORITY, O_COMMENT) VALUES + (100, 1, 'O', 15000.00, '1995-03-10', '1-URGENT', 'Clerk#0001', 1, 'N/A'), + (101, 2, 'O', 25000.00, '1995-03-05', '2-HIGH', 'Clerk#0002', 2, 'N/A'), + (1, 3, 'O', 10000.00, '1994-01-10', 'Priority 1', 'Clerk 1', 1, 'Order Comment 1'), + (2, 4, 'O', 20000.00, '1994-06-15', 'Priority 2', 'Clerk 2', 1, 'Order Comment 2');`, + `INSERT INTO lineitem (L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT) VALUES + (100, 200, 300, 1, 10, 5000.00, 0.05, 0.10, 'N', 'O', '1995-03-15', '1995-03-14', '1995-03-16', 'DELIVER IN PERSON', 'TRUCK', 'Urgent delivery'), + (100, 201, 301, 2, 20, 10000.00, 0.10, 0.10, 'R', 'F', '1995-03-17', '1995-03-15', '1995-03-18', 'NONE', 'MAIL', 'Handle with care'), + (101, 202, 302, 1, 30, 15000.00, 0.00, 0.10, 'A', 'F', '1995-03-20', '1995-03-18', '1995-03-21', 'TAKE BACK RETURN', 'SHIP', 'Standard delivery'), + (101, 203, 303, 2, 40, 10000.00, 0.20, 0.10, 'N', 'O', '1995-03-22', '1995-03-20', '1995-03-23', 'DELIVER IN PERSON', 'RAIL', 'Expedite'), + (1, 101, 1, 1, 5, 5000.00, 0.1, 0.05, 'N', 'O', '1994-01-12', '1994-01-11', '1994-01-13', 'Deliver in person','TRUCK', 'Lineitem Comment 1'), + (2, 102, 2, 1, 3, 15000.00, 0.2, 0.05, 'R', 'F', '1994-06-17', '1994-06-15', '1994-06-18', 'Leave at front door','AIR', 'Lineitem Comment 2'), + (11, 100, 2, 1, 30, 10000.00, 0.05, 0.07, 'A', 'F', '1998-07-21', '1998-07-22', '1998-07-23', 'DELIVER IN PERSON', 'TRUCK', 'N/A'), + (12, 101, 3, 1, 50, 15000.00, 0.10, 0.08, 'N', 'O', '1998-08-10', '1998-08-11', '1998-08-12', 'NONE', 'AIR', 'N/A'), + (13, 102, 4, 1, 70, 21000.00, 0.02, 0.04, 'R', 'F', '1998-06-30', '1998-07-01', '1998-07-02', 'TAKE BACK RETURN', 'MAIL', 'N/A'), + (14, 103, 5, 1, 90, 30000.00, 0.15, 0.10, 'A', 'O', '1998-05-15', '1998-05-16', '1998-05-17', 'DELIVER IN PERSON', 'RAIL', 'N/A'), + (15, 104, 2, 1, 45, 45000.00, 0.20, 0.15, 'N', 'F', '1998-07-15', '1998-07-16', '1998-07-17', 'NONE', 'SHIP', 'N/A');`, + } + + for _, query := range insertQueries { + mcmp.Exec(query) + } + + testcases := []struct { + name string + query string + }{ + { + name: "Q1", + query: `select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date_sub('1998-12-01', interval 108 day) +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus;`, + }, + } + + for _, testcase := range testcases { + mcmp.Run(testcase.name, func(mcmp *utils.MySQLCompare) { + mcmp.Exec(testcase.query) + }) + } +} diff --git a/go/test/endtoend/vtgate/queries/tpch/vschema.json b/go/test/endtoend/vtgate/queries/tpch/vschema.json new file mode 100644 index 00000000000..8cdf236e4e1 --- /dev/null +++ b/go/test/endtoend/vtgate/queries/tpch/vschema.json @@ -0,0 +1,121 @@ +{ + "sharded": true, + "foreignKeyMode": "unspecified", + "vindexes": { + "hash": { + "type": "hash" + } + }, + "tables": { + "basic": { + "name": "basic", + "column_vindexes": [ + { + "columns": [ + "a" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "customer": { + "name": "customer", + "column_vindexes": [ + { + "columns": [ + "C_CUSTKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "lineitem": { + "name": "lineitem", + "column_vindexes": [ + { + "columns": [ + "L_ORDERKEY", + "L_LINENUMBER" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "nation": { + "name": "nation", + "column_vindexes": [ + { + "columns": [ + "N_NATIONKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "orders": { + "name": "orders", + "column_vindexes": [ + { + "columns": [ + "O_ORDERKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "part": { + "name": "part", + "column_vindexes": [ + { + "columns": [ + "P_PARTKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "partsupp": { + "name": "partsupp", + "column_vindexes": [ + { + "columns": [ + "PS_PARTKEY", + "PS_SUPPKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "region": { + "name": "region", + "column_vindexes": [ + { + "columns": [ + "R_REGIONKEY" + ], + "type": "hash", + "name": "hash" + } + ] + }, + "supplier": { + "name": "supplier", + "column_vindexes": [ + { + "columns": [ + "S_SUPPKEY" + ], + "type": "hash", + "name": "hash" + } + ] + } + } +} \ No newline at end of file diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index c0b628b1aa8..21d13119804 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -198,7 +198,7 @@ func (c *compiler) compileToNumeric(ct ctype, offset int, fallback sqltypes.Type return ctype{Type: sqltypes.Int64, Flag: ct.Flag, Col: collationNumeric} } c.asm.Convert_Td(offset) - return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Size: ct.Size} + return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Size: ct.Size + decimalSizeBase, Scale: ct.Size} } c.asm.Convert_Tf(offset) return ctype{Type: sqltypes.Float64, Flag: ct.Flag, Col: collationNumeric} @@ -281,15 +281,21 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { if sqltypes.IsDecimal(ct.Type) { return ct } + var scale int32 + var size int32 switch ct.Type { case sqltypes.Int64: c.asm.Convert_id(offset) case sqltypes.Uint64: c.asm.Convert_ud(offset) + case sqltypes.Datetime, sqltypes.Time: + scale = ct.Size + size = ct.Size + decimalSizeBase + fallthrough default: c.asm.Convert_xd(offset, 0, 0) } - return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric} + return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Scale: scale, Size: size} } func (c *compiler) compileToDate(doct ctype, offset int) ctype { diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index a9ecd8f977e..3d5283db415 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -671,6 +671,10 @@ func TestCompilerSingle(t *testing.T) { expression: `1 * unix_timestamp(from_unixtime(time '31:34:58.123'))`, result: `DECIMAL(313458.123)`, }, + { + expression: `1 * unix_timestamp(time('1.0000'))`, + result: `DECIMAL(1698098401.0000)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 5d82d279d69..026892fc0ac 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -116,6 +116,7 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { } c.asm.Add_dd() ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) case sqltypes.Float64: if swap { @@ -170,6 +171,7 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(lt, 2) c.asm.Sub_dd() ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Uint64: @@ -188,6 +190,7 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(lt, 2) c.asm.Sub_dd() ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Float64: @@ -204,6 +207,7 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(rt, 1) c.asm.Sub_dd() ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) } } @@ -269,6 +273,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { } c.asm.Mul_dd() ct.Type = sqltypes.Decimal + ct.Size = lt.Size + rt.Size ct.Scale = lt.Scale + rt.Scale } @@ -309,6 +314,7 @@ func (op *opArithDiv) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() + ct.Size = lt.Size + divPrecisionIncrement ct.Scale = lt.Scale + divPrecisionIncrement } c.asm.jumpDestination(skip1, skip2) @@ -438,6 +444,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() @@ -455,6 +462,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Size = max(lt.Size, rt.Size) ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index 06e135c317c..8663370f819 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -74,9 +74,11 @@ func (c *Column) typeof(env *ExpressionEnv) (ctype, error) { } return ctype{ - Type: field.Type, - Col: typedCoercionCollation(field.Type, collations.ID(field.Charset)), - Flag: f, + Type: field.Type, + Col: typedCoercionCollation(field.Type, collations.ID(field.Charset)), + Flag: f, + Size: int32(field.ColumnLength), + Scale: int32(field.Decimals), }, nil } if c.Offset < len(env.Row) { diff --git a/go/vt/vtgate/evalengine/expr_env_test.go b/go/vt/vtgate/evalengine/expr_env_test.go new file mode 100644 index 00000000000..f75cc6f1376 --- /dev/null +++ b/go/vt/vtgate/evalengine/expr_env_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtenv" +) + +// TestExpressionEnvTypeOf tests the functionality of the TypeOf method on ExpressionEnv +func TestExpressionEnvTypeOf(t *testing.T) { + sumCol := &Column{ + Type: sqltypes.Unknown, + Offset: 0, + Original: &sqlparser.Sum{ + Arg: sqlparser.NewColName("l_discount"), + }, + dynamicTypeOffset: 0, + } + countCol := &Column{ + Type: sqltypes.Unknown, + Offset: 1, + Original: &sqlparser.Count{ + Args: sqlparser.Exprs{ + sqlparser.NewColName("l_discount"), + }, + }, + dynamicTypeOffset: 1, + } + + tests := []struct { + name string + env *ExpressionEnv + expr Expr + wantedScale int32 + wantedType sqltypes.Type + }{ + { + name: "Decimal divided by integer", + env: &ExpressionEnv{ + Fields: []*querypb.Field{ + { + Name: "avg_disc", + Type: querypb.Type_DECIMAL, + ColumnLength: 39, + Decimals: 2, + }, + { + Name: "count(l_discount)", + Type: querypb.Type_INT64, + ColumnLength: 21, + }, + }, + sqlmode: 3, + }, + expr: &UntypedExpr{ + env: vtenv.NewTestEnv(), + mu: sync.Mutex{}, + collation: 255, + typed: nil, + needTypes: []typedIR{sumCol, countCol}, + ir: &ArithmeticExpr{ + Op: &opArithDiv{}, + BinaryExpr: BinaryExpr{ + Left: sumCol, + Right: countCol, + }, + }, + }, + wantedScale: 6, + wantedType: sqltypes.Decimal, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.env.TypeOf(tt.expr) + require.NoError(t, err) + require.EqualValues(t, tt.wantedType, got.Type()) + require.EqualValues(t, tt.wantedScale, got.Scale()) + }) + } +} diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1303ac7614d..c102f5e5ef5 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -300,12 +300,16 @@ func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, error) { var f typeFlag + var size int32 + var scale int32 for i, tt := range args { f |= nullableFlags(tt.Flag) + size = max(size, tt.Size) + scale = max(scale, tt.Scale) c.compileToDecimal(tt, len(args)-i) } c.asm.Fn_MULTICMP_d(len(args), call.cmp < 0) - return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric}, nil + return ctype{Type: sqltypes.Decimal, Flag: f, Col: collationNumeric, Size: size, Scale: scale}, nil } func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index cbc1613f5fe..fe8b7d3770f 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -31,6 +31,12 @@ var SystemTime = time.Now const maxTimePrec = datetime.DefaultPrecision +// The length of a datetime converted to a numerical value is always 14 characters, +// see for example "20240404102732". We also have a `.` since we know it's a decimal +// and then additionally the number of decimals behind the dot. So total is always +// the input datetime size + 15. +const decimalSizeBase = 15 + type ( builtinNow struct { CallExpr @@ -428,7 +434,10 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { switch n.Type { case sqltypes.Datetime, sqltypes.Date: prec = n.Size - case sqltypes.Decimal, sqltypes.Time: + case sqltypes.Decimal: + prec = n.Scale + c.asm.Convert_xDT(3, -1, false) + case sqltypes.Time: prec = n.Size c.asm.Convert_xDT(3, -1, false) case sqltypes.VarChar, sqltypes.VarBinary: @@ -1520,13 +1529,28 @@ func (call *builtinTime) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) + var prec int32 switch arg.Type { case sqltypes.Time: + case sqltypes.Datetime, sqltypes.Date: + prec = arg.Size + c.asm.Convert_xT(1, -1) + case sqltypes.Decimal: + prec = arg.Scale + c.asm.Convert_xT(1, -1) + case sqltypes.VarChar, sqltypes.VarBinary: + if lit, ok := call.Arguments[0].(*Literal); ok && !arg.isHexOrBitLiteral() { + if t := evalToTime(lit.inner, -1); t != nil { + prec = int32(t.prec) + } + } + c.asm.Convert_xT(1, -1) default: + prec = maxTimePrec c.asm.Convert_xT(1, -1) } c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.Time, Col: collationBinary, Flag: arg.Flag | flagNullable}, nil + return ctype{Type: sqltypes.Time, Col: collationBinary, Flag: arg.Flag | flagNullable, Size: prec}, nil } func dateTimeUnixTimestamp(env *ExpressionEnv, date eval) evalNumeric { @@ -1612,7 +1636,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { if arg.Size == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } - return ctype{Type: sqltypes.Decimal, Size: arg.Size, Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Decimal, Size: decimalSizeBase + arg.Size, Scale: arg.Size, Col: collationNumeric, Flag: arg.Flag}, nil case sqltypes.Date, sqltypes.Int64, sqltypes.Uint64: return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil case sqltypes.VarChar, sqltypes.VarBinary: @@ -1624,12 +1648,12 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { if dt.prec == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } - return ctype{Type: sqltypes.Decimal, Size: int32(dt.prec), Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Decimal, Size: decimalSizeBase + int32(dt.prec), Scale: int32(dt.prec), Col: collationNumeric, Flag: arg.Flag}, nil } } fallthrough default: - return ctype{Type: sqltypes.Decimal, Size: maxTimePrec, Col: collationNumeric, Flag: arg.Flag}, nil + return ctype{Type: sqltypes.Decimal, Size: decimalSizeBase + maxTimePrec, Scale: maxTimePrec, Col: collationNumeric, Flag: arg.Flag}, nil } } diff --git a/test/config.json b/test/config.json index 0324ee54ef1..26efdee1f36 100644 --- a/test/config.json +++ b/test/config.json @@ -581,6 +581,15 @@ "RetryMax": 2, "Tags": ["upgrade_downgrade_query_serving_queries"] }, + "vtgate_queries_tpch": { + "File": "unused.go", + "Args": ["vitess.io/vitess/go/test/endtoend/vtgate/queries/tpch", "-timeout", "20m"], + "Command": [], + "Manual": false, + "Shard": "vtgate_queries", + "RetryMax": 2, + "Tags": ["upgrade_downgrade_query_serving_queries"] + }, "vtgate_queries_subquery": { "File": "unused.go", "Args": ["vitess.io/vitess/go/test/endtoend/vtgate/queries/subquery", "-timeout", "20m"],