Skip to content

Commit e73017b

Browse files
committed
datetime: obey the evalengine's environment time
Signed-off-by: Vicent Marti <vmg@strn.cat>
1 parent 5af661e commit e73017b

File tree

12 files changed

+84
-77
lines changed

12 files changed

+84
-77
lines changed

go/mysql/datetime/datetime.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ func (t Time) FormatDecimal() decimal.Decimal {
9494
return dec
9595
}
9696

97-
func (t Time) ToDateTime() (out DateTime) {
98-
return NewDateTimeFromStd(t.ToStdTime(time.Local))
97+
func (t Time) ToDateTime(now time.Time) (out DateTime) {
98+
return NewDateTimeFromStd(t.ToStdTime(now))
9999
}
100100

101101
func (t Time) IsZero() bool {
@@ -421,9 +421,9 @@ func (t Time) toStdTime(year int, month time.Month, day int, loc *time.Location)
421421
return time.Date(year, month, day, hours, minutes, secs, nsecs, loc)
422422
}
423423

424-
func (t Time) ToStdTime(loc *time.Location) (out time.Time) {
425-
year, month, day := time.Now().Date()
426-
return t.toStdTime(year, month, day, loc)
424+
func (t Time) ToStdTime(now time.Time) (out time.Time) {
425+
year, month, day := now.Date()
426+
return t.toStdTime(year, month, day, now.Location())
427427
}
428428

429429
func (t Time) AddInterval(itv *Interval, stradd bool) (Time, uint8, bool) {
@@ -444,20 +444,20 @@ func (d Date) ToStdTime(loc *time.Location) (out time.Time) {
444444
return time.Date(d.Year(), time.Month(d.Month()), d.Day(), 0, 0, 0, 0, loc)
445445
}
446446

447-
func (dt DateTime) ToStdTime(loc *time.Location) time.Time {
447+
func (dt DateTime) ToStdTime(now time.Time) time.Time {
448448
zerodate := dt.Date.IsZero()
449449
zerotime := dt.Time.IsZero()
450450

451451
switch {
452452
case zerodate && zerotime:
453453
return time.Time{}
454454
case zerodate:
455-
return dt.Time.ToStdTime(loc)
455+
return dt.Time.ToStdTime(now)
456456
case zerotime:
457-
return dt.Date.ToStdTime(loc)
457+
return dt.Date.ToStdTime(now.Location())
458458
default:
459459
year, month, day := dt.Date.Year(), time.Month(dt.Date.Month()), dt.Date.Day()
460-
return dt.Time.toStdTime(year, month, day, loc)
460+
return dt.Time.toStdTime(year, month, day, now.Location())
461461
}
462462
}
463463

@@ -527,7 +527,10 @@ func (dt DateTime) Compare(dt2 DateTime) int {
527527
// if we're comparing a time to a datetime, we need to normalize them
528528
// both into datetimes; this normalization is not trivial because negative
529529
// times result in a date change, so let the standard library handle this
530-
return dt.ToStdTime(time.Local).Compare(dt2.ToStdTime(time.Local))
530+
531+
// Using the current time is OK here since the comparison is relative
532+
now := time.Now()
533+
return dt.ToStdTime(now).Compare(dt2.ToStdTime(now))
531534
}
532535
if cmp := dt.Date.Compare(dt2.Date); cmp != 0 {
533536
return cmp
@@ -559,9 +562,10 @@ func (dt DateTime) Round(p int) (r DateTime) {
559562
r = dt
560563
if n == 1e9 {
561564
r.Time.nanosecond = 0
562-
return NewDateTimeFromStd(r.ToStdTime(time.Local).Add(time.Second))
565+
r.addInterval(&Interval{timeparts: timeparts{sec: 1}, unit: IntervalSecond})
566+
} else {
567+
r.Time.nanosecond = uint32(n)
563568
}
564-
r.Time.nanosecond = uint32(n)
565569
return r
566570
}
567571

go/mysql/json/parser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ func (v *Value) MarshalDate() string {
678678

679679
func (v *Value) MarshalDateTime() string {
680680
if dt, ok := v.DateTime(); ok {
681-
return dt.ToStdTime(time.Local).Format("2006-01-02 15:04:05.000000")
681+
return dt.ToStdTime(time.Now()).Format("2006-01-02 15:04:05.000000")
682682
}
683683
return ""
684684
}

go/vt/vtgate/evalengine/compiler_asm.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
529529
end := env.vm.sp - elseOffset
530530
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
531531
if env.vm.stack[sp].(*evalInt64).i != 0 {
532-
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation)
532+
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now)
533533
goto done
534534
}
535535
}
536536
if elseOffset != 0 {
537-
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation)
537+
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now)
538538
} else {
539539
env.vm.stack[env.vm.sp-stackDepth] = nil
540540
}
@@ -1110,7 +1110,7 @@ func (asm *assembler) Convert_xD(offset int) {
11101110
// Need to explicitly check here or we otherwise
11111111
// store a nil wrapper in an interface vs. a direct
11121112
// nil.
1113-
d := evalToDate(env.vm.stack[env.vm.sp-offset])
1113+
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
11141114
if d == nil {
11151115
env.vm.stack[env.vm.sp-offset] = nil
11161116
} else {
@@ -1125,7 +1125,7 @@ func (asm *assembler) Convert_xD_nz(offset int) {
11251125
// Need to explicitly check here or we otherwise
11261126
// store a nil wrapper in an interface vs. a direct
11271127
// nil.
1128-
d := evalToDate(env.vm.stack[env.vm.sp-offset])
1128+
d := evalToDate(env.vm.stack[env.vm.sp-offset], env.now)
11291129
if d == nil || d.isZero() {
11301130
env.vm.stack[env.vm.sp-offset] = nil
11311131
} else {
@@ -1140,7 +1140,7 @@ func (asm *assembler) Convert_xDT(offset, prec int) {
11401140
// Need to explicitly check here or we otherwise
11411141
// store a nil wrapper in an interface vs. a direct
11421142
// nil.
1143-
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
1143+
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
11441144
if dt == nil {
11451145
env.vm.stack[env.vm.sp-offset] = nil
11461146
} else {
@@ -1155,7 +1155,7 @@ func (asm *assembler) Convert_xDT_nz(offset, prec int) {
11551155
// Need to explicitly check here or we otherwise
11561156
// store a nil wrapper in an interface vs. a direct
11571157
// nil.
1158-
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec)
1158+
dt := evalToDateTime(env.vm.stack[env.vm.sp-offset], prec, env.now)
11591159
if dt == nil || dt.isZero() {
11601160
env.vm.stack[env.vm.sp-offset] = nil
11611161
} else {
@@ -4252,7 +4252,7 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
42524252
}
42534253

42544254
tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
4255-
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{})
4255+
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
42564256
env.vm.sp--
42574257
return 1
42584258
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
@@ -4274,7 +4274,7 @@ func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col col
42744274
goto baddate
42754275
}
42764276

4277-
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col)
4277+
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, col, env.now)
42784278
env.vm.sp--
42794279
return 1
42804280

go/vt/vtgate/evalengine/compiler_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ func TestCompilerSingle(t *testing.T) {
457457
},
458458
}
459459

460+
tz, _ := time.LoadLocation("Europe/Madrid")
461+
460462
for _, tc := range testCases {
461463
t.Run(tc.expression, func(t *testing.T) {
462464
expr, err := sqlparser.ParseExpr(tc.expression)
@@ -478,6 +480,7 @@ func TestCompilerSingle(t *testing.T) {
478480
}
479481

480482
env := evalengine.EmptyExpressionEnv()
483+
env.SetTime(time.Date(2023, 10, 24, 12, 0, 0, 0, tz))
481484
env.Row = tc.values
482485

483486
expected, err := env.Evaluate(evalengine.Deoptimize(converted))

go/vt/vtgate/evalengine/eval.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package evalengine
1818

1919
import (
2020
"strconv"
21+
"time"
2122
"unicode/utf8"
2223

2324
"vitess.io/vitess/go/hack"
@@ -167,7 +168,7 @@ func evalIsTruthy(e eval) boolean {
167168
}
168169
}
169170

170-
func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
171+
func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time) (eval, error) {
171172
if e == nil {
172173
return nil, nil
173174
}
@@ -199,9 +200,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID) (eval, error) {
199200
case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64:
200201
return evalToInt64(e).toUint64(), nil
201202
case sqltypes.Date:
202-
return evalToDate(e), nil
203+
return evalToDate(e, now), nil
203204
case sqltypes.Datetime, sqltypes.Timestamp:
204-
return evalToDateTime(e, -1), nil
205+
return evalToDateTime(e, -1, now), nil
205206
case sqltypes.Time:
206207
return evalToTime(e, -1), nil
207208
default:
@@ -329,7 +330,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
329330
return nil, err
330331
}
331332
// Separate return here to avoid nil wrapped in interface type
332-
d := evalToDate(e)
333+
d := evalToDate(e, time.Now())
333334
if d == nil {
334335
return nil, nil
335336
}
@@ -340,7 +341,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
340341
return nil, err
341342
}
342343
// Separate return here to avoid nil wrapped in interface type
343-
dt := evalToDateTime(e, -1)
344+
dt := evalToDateTime(e, -1, time.Now())
344345
if dt == nil {
345346
return nil, nil
346347
}

go/vt/vtgate/evalengine/eval_temporal.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package evalengine
22

33
import (
4+
"time"
5+
46
"vitess.io/vitess/go/hack"
57
"vitess.io/vitess/go/mysql/collations"
68
"vitess.io/vitess/go/mysql/datetime"
@@ -92,12 +94,12 @@ func (e *evalTemporal) toJSON() *evalJSON {
9294
}
9395
}
9496

95-
func (e *evalTemporal) toDateTime(l int) *evalTemporal {
97+
func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal {
9698
switch e.SQLType() {
9799
case sqltypes.Datetime, sqltypes.Date:
98100
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)}
99101
case sqltypes.Time:
100-
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(), prec: uint8(l)}
102+
return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)}
101103
default:
102104
panic("unreachable")
103105
}
@@ -118,15 +120,15 @@ func (e *evalTemporal) toTime(l int) *evalTemporal {
118120
}
119121
}
120122

121-
func (e *evalTemporal) toDate() *evalTemporal {
123+
func (e *evalTemporal) toDate(now time.Time) *evalTemporal {
122124
switch e.SQLType() {
123125
case sqltypes.Datetime:
124126
dt := datetime.DateTime{Date: e.dt.Date}
125127
return &evalTemporal{t: sqltypes.Date, dt: dt}
126128
case sqltypes.Date:
127129
return e
128130
case sqltypes.Time:
129-
dt := e.dt.Time.ToDateTime()
131+
dt := e.dt.Time.ToDateTime(now)
130132
dt.Time = datetime.Time{}
131133
return &evalTemporal{t: sqltypes.Date, dt: dt}
132134
default:
@@ -138,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
138140
return e.dt.IsZero()
139141
}
140142

141-
func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation) eval {
143+
func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
142144
var tmp *evalTemporal
143145
var ok bool
144146

@@ -150,7 +152,7 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
150152
tmp = &evalTemporal{t: e.t}
151153
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
152154
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
153-
tmp = e.toDateTime(int(e.prec))
155+
tmp = e.toDateTime(int(e.prec), now)
154156
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
155157
}
156158
if !ok {
@@ -324,10 +326,10 @@ func evalToTime(e eval, l int) *evalTemporal {
324326
return nil
325327
}
326328

327-
func evalToDateTime(e eval, l int) *evalTemporal {
329+
func evalToDateTime(e eval, l int, now time.Time) *evalTemporal {
328330
switch e := e.(type) {
329331
case *evalTemporal:
330-
return e.toDateTime(precision(l, int(e.prec)))
332+
return e.toDateTime(precision(l, int(e.prec)), now)
331333
case *evalBytes:
332334
if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() {
333335
return newEvalDateTime(t, l)
@@ -371,10 +373,10 @@ func evalToDateTime(e eval, l int) *evalTemporal {
371373
return nil
372374
}
373375

374-
func evalToDate(e eval) *evalTemporal {
376+
func evalToDate(e eval, now time.Time) *evalTemporal {
375377
switch e := e.(type) {
376378
case *evalTemporal:
377-
return e.toDate()
379+
return e.toDate(now)
378380
case *evalBytes:
379381
if t, _ := datetime.ParseDate(e.string()); !t.IsZero() {
380382
return newEvalDate(t)

go/vt/vtgate/evalengine/expr_convert.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,12 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
125125
case p > 6:
126126
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
127127
}
128-
if dt := evalToDateTime(e, c.Length); dt != nil {
128+
if dt := evalToDateTime(e, c.Length, env.now); dt != nil {
129129
return dt, nil
130130
}
131131
return nil, nil
132132
case "DATE":
133-
if d := evalToDate(e); d != nil {
133+
if d := evalToDate(e, env.now); d != nil {
134134
return d, nil
135135
}
136136
return nil, nil

go/vt/vtgate/evalengine/expr_env.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ func (env *ExpressionEnv) TypeOf(expr Expr, fields []*querypb.Field) (sqltypes.T
9999
return ty, f, nil
100100
}
101101

102+
func (env *ExpressionEnv) SetTime(now time.Time) {
103+
// This function is called only once by NewExpressionEnv to ensure that all expressions in the same
104+
// ExpressionEnv evaluate NOW() and similar SQL functions to the same value.
105+
env.now = now
106+
if tz := env.currentTimezone(); tz != nil {
107+
env.now = env.now.In(tz)
108+
}
109+
}
110+
102111
// EmptyExpressionEnv returns a new ExpressionEnv with no bind vars or row
103112
func EmptyExpressionEnv() *ExpressionEnv {
104113
return NewExpressionEnv(context.Background(), nil, nil)
@@ -108,14 +117,6 @@ func EmptyExpressionEnv() *ExpressionEnv {
108117
func NewExpressionEnv(ctx context.Context, bindVars map[string]*querypb.BindVariable, vc VCursor) *ExpressionEnv {
109118
env := &ExpressionEnv{BindVars: bindVars, vc: vc}
110119
env.user = callerid.ImmediateCallerIDFromContext(ctx)
111-
112-
// The current time for this ExpressionEnv is set only once, during creation.
113-
// This is to ensure that all expressions in the same ExpressionEnv evaluate NOW()
114-
// and similar SQL functions to the same value.
115-
env.now = time.Now()
116-
117-
if tz := env.currentTimezone(); tz != nil {
118-
env.now = env.now.In(tz)
119-
}
120+
env.SetTime(time.Now())
120121
return env
121122
}

go/vt/vtgate/evalengine/expr_logical.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) {
520520
return nil, nil
521521
}
522522
t, _ := c.typeof(env, nil)
523-
return evalCoerce(result, t, ca.result().Collation)
523+
return evalCoerce(result, t, ca.result().Collation, env.now)
524524
}
525525

526526
func (c *CaseExpr) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {

0 commit comments

Comments
 (0)