Skip to content

Commit

Permalink
Store Decimal precision and size while normalising (#15785)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
GuptaManan100 and dbussink authored Apr 25, 2024
1 parent 7ca2b81 commit 8c146a9
Show file tree
Hide file tree
Showing 16 changed files with 249 additions and 15 deletions.
10 changes: 10 additions & 0 deletions go/mysql/datetime/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package datetime

import (
"strings"
"time"
)

Expand Down Expand Up @@ -287,3 +288,12 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l
const (
durationPerDay = 24 * time.Hour
)

// SizeAndScaleFromString
func SizeFromString(s string) int32 {
idx := strings.LastIndex(s, ".")
if idx == -1 {
return 0
}
return int32(len(s[idx+1:]))
}
77 changes: 77 additions & 0 deletions go/mysql/datetime/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
Copyright 2024 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package datetime

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSizeFromString(t *testing.T) {
testcases := []struct {
value string
sizeExpected int32
}{
{
value: "2020-01-01 00:00:00",
sizeExpected: 0,
},
{
value: "2020-01-01 00:00:00.1",
sizeExpected: 1,
},
{
value: "2020-01-01 00:00:00.12",
sizeExpected: 2,
},
{
value: "2020-01-01 00:00:00.123",
sizeExpected: 3,
},
{
value: "2020-01-01 00:00:00.123456",
sizeExpected: 6,
},
{
value: "00:00:00",
sizeExpected: 0,
},
{
value: "00:00:00.1",
sizeExpected: 1,
},
{
value: "00:00:00.12",
sizeExpected: 2,
},
{
value: "00:00:00.123",
sizeExpected: 3,
},
{
value: "00:00:00.123456",
sizeExpected: 6,
},
}
for _, testcase := range testcases {
t.Run(testcase.value, func(t *testing.T) {
siz := SizeFromString(testcase.value)
assert.EqualValues(t, testcase.sizeExpected, siz)
})
}
}
45 changes: 45 additions & 0 deletions go/mysql/decimal/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,52 @@ func TestDecimal_Cmp1(t *testing.T) {
a := New(123, 3)
b := New(-1234, 2)
assert.Equal(t, 1, a.Cmp(b))
}

func TestSizeAndScaleFromString(t *testing.T) {
testcases := []struct {
value string
sizeExpected int32
scaleExpected int32
}{
{
value: "0.00003",
sizeExpected: 6,
scaleExpected: 5,
},
{
value: "-0.00003",
sizeExpected: 6,
scaleExpected: 5,
},
{
value: "12.00003",
sizeExpected: 7,
scaleExpected: 5,
},
{
value: "-12.00003",
sizeExpected: 7,
scaleExpected: 5,
},
{
value: "1000003",
sizeExpected: 7,
scaleExpected: 0,
},
{
value: "-1000003",
sizeExpected: 7,
scaleExpected: 0,
},
}
for _, testcase := range testcases {
t.Run(testcase.value, func(t *testing.T) {
siz, scale := SizeAndScaleFromString(testcase.value)
assert.EqualValues(t, testcase.sizeExpected, siz)
assert.EqualValues(t, testcase.scaleExpected, scale)
})
}
}

func TestDecimal_Cmp2(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions go/mysql/decimal/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"math"
"math/big"
"math/bits"
"strings"

"vitess.io/vitess/go/mysql/fastparse"
)
Expand Down Expand Up @@ -71,6 +72,20 @@ func parseDecimal64(s []byte) (Decimal, error) {
}, nil
}

// SizeAndScaleFromString gets the size and scale for the decimal value without needing to parse it.
func SizeAndScaleFromString(s string) (int32, int32) {
switch s[0] {
case '+', '-':
s = s[1:]
}
totalLen := len(s)
idx := strings.Index(s, ".")
if idx == -1 {
return int32(totalLen), 0
}
return int32(totalLen - 1), int32(totalLen - 1 - idx)
}

func NewFromMySQL(s []byte) (Decimal, error) {
var original = s
var neg bool
Expand Down
6 changes: 3 additions & 3 deletions go/test/endtoend/vtgate/queries/normalize/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ func TestNormalizeAllFields(t *testing.T) {
defer conn.Close()

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
vtgateVersion, err := cluster.GetMajorVersion("vtgate")
require.NoError(t, err)
if vtgateVersion < 19 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)`
if vtgateVersion < 20 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
}
selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
Expand Down
22 changes: 22 additions & 0 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ order by
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
{
name: "Q14 without case",
query: `select 100.00 * sum(l_extendedprice * (1 - l_discount)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
{
name: "Q14",
query: `select 100.00 * sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
Expand Down
5 changes: 3 additions & 2 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2317,8 +2317,9 @@ type (

// Argument represents bindvariable expression
Argument struct {
Name string
Type sqltypes.Type
Name string
Type sqltypes.Type
Size, Scale int32
}

// NullVal represents a NULL value.
Expand Down
2 changes: 2 additions & 0 deletions go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,15 @@ func (node *Argument) Format(buf *TrackedBuffer) {
// For bind variables that are statically typed, emit their type as an adjacent comment.
// This comment will be ignored by older versions of Vitess (and by MySQL) but will provide
// type safety when using the query as a cache key.
buf.astPrintf(node, " /* %s */", node.Type.String())
buf.astPrintf(node, " /* %s", node.Type.String())
if node.Size != 0 || node.Scale != 0 {
buf.astPrintf(node, "(%d", node.Size)
if node.Scale != 0 {
buf.astPrintf(node, ",%d", node.Scale)
}
buf.WriteString(")")
}
buf.WriteString(" */")
}
}

Expand Down
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"strconv"
"strings"

"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/decimal"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -562,6 +564,20 @@ func NewTypedArgument(in string, t sqltypes.Type) *Argument {
return &Argument{Name: in, Type: t}
}

func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) {
arg := &Argument{Name: in, Type: lit.SQLType()}
switch arg.Type {
case sqltypes.Decimal:
siz, scale := decimal.SizeAndScaleFromString(lit.Val)
arg.Scale = scale
arg.Size = siz
case sqltypes.Datetime, sqltypes.Time:
siz := datetime.SizeFromString(lit.Val)
arg.Size = siz
}
return arg, nil
}

// NewListArg builds a new ListArg.
func NewListArg(in string) ListArg {
return ListArg(in)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 18 additions & 3 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
}

// Modify the AST node to a bindvar.
cursor.Replace(NewTypedArgument(bvname, node.SQLType()))
arg, err := NewTypedArgumentFromLiteral(bvname, node)
if err != nil {
nz.err = err
return
}
cursor.Replace(arg)
}

// convertLiteral converts an Literal without the dedup.
Expand All @@ -224,7 +229,12 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {

bvname := nz.reserved.nextUnusedVar()
nz.bindVars[bvname] = bval
cursor.Replace(NewTypedArgument(bvname, node.SQLType()))
arg, err := NewTypedArgumentFromLiteral(bvname, node)
if err != nil {
nz.err = err
return
}
cursor.Replace(arg)
}

// convertComparison attempts to convert IN clauses to
Expand Down Expand Up @@ -268,7 +278,12 @@ func (nz *normalizer) parameterize(left, right Expr) Expr {
return nil
}
bvname := nz.decideBindVarName(lit, col, bval)
return NewTypedArgument(bvname, lit.SQLType())
arg, err := NewTypedArgumentFromLiteral(bvname, lit)
if err != nil {
nz.err = err
return nil
}
return arg
}

func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string {
Expand Down
18 changes: 16 additions & 2 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,28 @@ func TestNormalize(t *testing.T) {
}, {
// float val
in: "select * from t where foobar = 1.2",
outstmt: "select * from t where foobar = :foobar /* DECIMAL */",
outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.DecimalBindVariable("1.2"),
},
}, {
// datetime val
in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* TIME(6) */",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")),
},
}, {
// multiple vals
in: "select * from t where foo = 1.2 and bar = 2",
outstmt: "select * from t where foo = :foo /* DECIMAL */ and bar = :bar /* INT64 */",
outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.DecimalBindVariable("1.2"),
"bar": sqltypes.Int64BindVariable(2),
Expand Down
Loading

0 comments on commit 8c146a9

Please sign in to comment.