Skip to content

Commit

Permalink
Fix Scale and length handling in CASE and JOIN bind variables (#15787)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
GuptaManan100 and dbussink authored Apr 25, 2024
1 parent 716fc12 commit 7ca2b81
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 26 deletions.
8 changes: 8 additions & 0 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,14 @@ type Value struct {
n NumberType
}

func (v *Value) Size() int32 {
return 0
}

func (v *Value) Scale() int32 {
return 0
}

func (v *Value) MarshalDate() string {
if d, ok := v.Date(); ok {
return d.ToStdTime(time.Local).Format("2006-01-02")
Expand Down
17 changes: 15 additions & 2 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package union
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"github.com/stretchr/testify/require"
)

func start(t *testing.T) (utils.MySQLCompare, func()) {
Expand Down Expand Up @@ -161,6 +161,19 @@ group by
order by
value desc;`,
},
{
name: "Q14 without decimal literal",
query: `select sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
}

for _, testcase := range testcases {
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"bytes"
"context"
"fmt"
"strings"
Expand Down Expand Up @@ -61,7 +62,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
result := &sqltypes.Result{}
if len(lresult.Rows) == 0 && wantfields {
for k, col := range jn.Vars {
joinVars[k] = bindvarForType(lresult.Fields[col].Type)
joinVars[k] = bindvarForType(lresult.Fields[col])
}
rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
if err != nil {
Expand Down Expand Up @@ -95,19 +96,21 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func bindvarForType(t querypb.Type) *querypb.BindVariable {
func bindvarForType(field *querypb.Field) *querypb.BindVariable {
bv := &querypb.BindVariable{
Type: t,
Type: field.Type,
Value: nil,
}
switch t {
switch field.Type {
case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16,
querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64:
bv.Value = []byte("0")
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
bv.Value = []byte("0e0")
case querypb.Type_DECIMAL:
bv.Value = []byte("0.0")
size := max(1, int(field.ColumnLength-field.Decimals))
scale := max(1, int(field.Decimals))
bv.Value = append(append(bytes.Repeat([]byte{'0'}, size), byte('.')), bytes.Repeat([]byte{'0'}, scale)...)
default:
return sqltypes.NullBindVariable
}
Expand Down
20 changes: 14 additions & 6 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type typeAggregation struct {
blob uint16
total uint16

nullable bool
nullable bool
scale, size int32
}

type TypeAggregator struct {
Expand All @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error {
return nil
}

ta.types.addNullable(typ.typ, typ.nullable)
ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale)
if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil {
return err
}
Expand Down Expand Up @@ -95,20 +96,25 @@ func (ta *typeAggregation) empty() bool {
func (ta *typeAggregation) addEval(e eval) {
var t sqltypes.Type
var f typeFlag
var size, scale int32
switch e := e.(type) {
case nil:
t = sqltypes.Null
ta.nullable = true
case *evalBytes:
t = sqltypes.Type(e.tt)
f = e.flag
size = e.Size()
scale = e.Scale()
default:
t = e.SQLType()
size = e.Size()
scale = e.Scale()
}
ta.add(t, f)
ta.add(t, f, size, scale)
}

func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) {
var flag typeFlag
if typ == sqltypes.HexVal || typ == sqltypes.HexNum {
typ = sqltypes.Binary
Expand All @@ -117,13 +123,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
if nullable {
flag |= flagNullable
}
ta.add(typ, flag)
ta.add(typ, flag, size, scale)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) {
if f&flagNullable != 0 {
ta.nullable = true
}
ta.size = max(ta.size, size)
ta.scale = max(ta.scale, scale)
switch tt {
case sqltypes.Float32, sqltypes.Float64:
ta.double++
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() {
}, "CMPFLAG NE [NULL]")
}

func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) {
func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, size, scale int32, cc collations.TypedCollation, allowZeroDate bool) {
elseOffset := 0
if hasElse {
elseOffset = 1
Expand All @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, size, scale, cc.Collation, env.now, allowZeroDate)
goto done
}
}
if elseOffset != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, size, scale, cc.Collation, env.now, allowZeroDate)
} else {
env.vm.stack[env.vm.sp-stackDepth] = nil
}
Expand Down
30 changes: 30 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) {
values []sqltypes.Value
result string
collation collations.ID
typeWanted evalengine.Type
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -675,6 +677,28 @@ func TestCompilerSingle(t *testing.T) {
expression: `1 * unix_timestamp(time('1.0000'))`,
result: `DECIMAL(1698098401.0000)`,
},
{
expression: `(case
when 'PROMOTION' like 'PROMO%'
then 0.01
else 0
end) * 0.01`,
result: `DECIMAL(0.0001)`,
typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil),
},
{
expression: `case when true then 0.02 else 1.000 end`,
result: `DECIMAL(0.02)`,
},
{
expression: `case
when false
then timestamp'2023-10-24 12:00:00.123456'
else timestamp'2023-10-24 12:00:00'
end`,
result: `DATETIME("2023-10-24 12:00:00.000000")`,
typeWanted: evalengine.NewTypeEx(sqltypes.Datetime, collations.CollationBinaryID, false, 6, 0, nil),
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -715,6 +739,12 @@ func TestCompilerSingle(t *testing.T) {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

if tc.typeWanted.Type() != sqltypes.Unknown {
typ, err := env.TypeOf(converted)
require.NoError(t, err)
require.True(t, tc.typeWanted.Equal(&typ))
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
res, err := env.Evaluate(converted)
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool {
type eval interface {
ToRawBytes() []byte
SQLType() sqltypes.Type
Size() int32
Scale() int32
}

type hashable interface {
Expand Down Expand Up @@ -170,7 +172,7 @@ func evalIsTruthy(e eval) boolean {
}
}

func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) {
func evalCoerce(e eval, typ sqltypes.Type, size, scale int32, col collations.ID, now time.Time, allowZero bool) (eval, error) {
if e == nil {
return nil, nil
}
Expand All @@ -181,7 +183,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all
// if we have an explicit VARCHAR coercion, always force it so the collation is replaced in the target
return evalToVarchar(e, col, false)
}
if e.SQLType() == typ {
if e.SQLType() == typ && e.Size() == size && e.Scale() == scale {
// nothing to be done here
return e, nil
}
Expand All @@ -204,9 +206,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all
case sqltypes.Date:
return evalToDate(e, now, allowZero), nil
case sqltypes.Datetime, sqltypes.Timestamp:
return evalToDateTime(e, -1, now, allowZero), nil
return evalToDateTime(e, int(size), now, allowZero), nil
case sqltypes.Time:
return evalToTime(e, -1), nil
return evalToTime(e, int(size)), nil
default:
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String())
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type {
return sqltypes.Type(e.tt)
}

func (e *evalBytes) Size() int32 {
return 0
}

func (e *evalBytes) Scale() int32 {
return 0
}

func (e *evalBytes) ToRawBytes() []byte {
return e.bytes
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type {
return sqltypes.Enum
}

func (e *evalEnum) Size() int32 {
return 0
}

func (e *evalEnum) Scale() int32 {
return 0
}

func valueIdx(values *EnumSetValues, value string) int {
if values == nil {
return -1
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/eval_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type {
return sqltypes.Int64
}

func (e *evalInt64) Size() int32 {
return 0
}

func (e *evalInt64) Scale() int32 {
return 0
}

func (e *evalInt64) ToRawBytes() []byte {
return strconv.AppendInt(nil, e.i, 10)
}
Expand Down Expand Up @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type {
return sqltypes.Uint64
}

func (e *evalUint64) Size() int32 {
return 0
}

func (e *evalUint64) Scale() int32 {
return 0
}

func (e *evalUint64) ToRawBytes() []byte {
return strconv.AppendUint(nil, e.u, 10)
}
Expand Down Expand Up @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type {
return sqltypes.Float64
}

func (e *evalFloat) Size() int32 {
return 0
}

func (e *evalFloat) Scale() int32 {
return 0
}

func (e *evalFloat) ToRawBytes() []byte {
return format.FormatFloat(e.f)
}
Expand Down Expand Up @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type {
return sqltypes.Decimal
}

func (e *evalDecimal) Size() int32 {
return e.length
}

func (e *evalDecimal) Scale() int32 {
return -e.dec.Exponent()
}

func (e *evalDecimal) ToRawBytes() []byte {
return e.dec.FormatMySQL(e.length)
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type {
return sqltypes.Set
}

func (e *evalSet) Size() int32 {
return 0
}

func (e *evalSet) Scale() int32 {
return 0
}

func evalSetBits(values *EnumSetValues, value string) uint64 {
if values != nil && len(*values) > 64 {
// This never would happen as MySQL limits SET
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type {
return e.t
}

func (e *evalTemporal) Size() int32 {
return int32(e.prec)
}

func (e *evalTemporal) Scale() int32 {
return 0
}

func (e *evalTemporal) toInt64() int64 {
switch e.SQLType() {
case sqltypes.Date:
Expand Down
Loading

0 comments on commit 7ca2b81

Please sign in to comment.