diff --git a/go/mysql/datetime/interval.go b/go/mysql/datetime/interval.go index 75e1ce7bb45..22993d5cf55 100644 --- a/go/mysql/datetime/interval.go +++ b/go/mysql/datetime/interval.go @@ -166,6 +166,54 @@ func (itv IntervalType) ToString() string { } } +// ParseIntervalType parses a string into an IntervalType. This is the inverse function of IntervalType.ToString(). +func ParseIntervalType(s string) IntervalType { + switch strings.ToLower(s) { + case "year": + return IntervalYear + case "quarter": + return IntervalQuarter + case "month": + return IntervalMonth + case "week": + return IntervalWeek + case "day": + return IntervalDay + case "hour": + return IntervalHour + case "minute": + return IntervalMinute + case "second": + return IntervalSecond + case "microsecond": + return IntervalMicrosecond + case "year_month": + return IntervalYearMonth + case "day_hour": + return IntervalDayHour + case "day_minute": + return IntervalDayMinute + case "day_second": + return IntervalDaySecond + case "hour_minute": + return IntervalHourMinute + case "hour_second": + return IntervalHourSecond + case "minute_second": + return IntervalMinuteSecond + case "day_microsecond": + return IntervalDayMicrosecond + case "hour_microsecond": + return IntervalHourMicrosecond + case "minute_microsecond": + return IntervalMinuteMicrosecond + case "second_microsecond": + return IntervalSecondMicrosecond + default: + return IntervalNone + } +} + func intervalSetYear(tp *Interval, val int) { tp.year = val } diff --git a/go/mysql/datetime/interval_test.go b/go/mysql/datetime/interval_test.go index 22b4617656b..f3343faf8f9 100644 --- a/go/mysql/datetime/interval_test.go +++ b/go/mysql/datetime/interval_test.go @@ -17,9 +17,12 @@ limitations under the License. package datetime import ( + "math" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/decimal" ) @@ -367,3 +370,47 @@ func TestInRange(t *testing.T) { assert.Equal(t, tc.wantInRange, got) } } + +func TestParseIntervalType(t *testing.T) { + intervals := []IntervalType{ + IntervalMicrosecond, + IntervalSecond, + IntervalMinute, + IntervalHour, + IntervalDay, + IntervalWeek, + IntervalMonth, + IntervalQuarter, + IntervalYear, + IntervalSecondMicrosecond, + IntervalMinuteMicrosecond, + IntervalMinuteSecond, + IntervalHourMicrosecond, + IntervalHourSecond, + IntervalHourMinute, + IntervalDayMicrosecond, + IntervalDaySecond, + IntervalDayMinute, + IntervalDayHour, + IntervalYearMonth, + } + for _, interval := range intervals { + s := interval.ToString() + t.Run(s, func(t *testing.T) { + require.NotEmpty(t, s) + require.NotEqual(t, "[unknown IntervalType]", s) + parsed := ParseIntervalType(s) + assert.NotEqual(t, IntervalNone, parsed) + assert.Equal(t, interval, parsed) + + parsed = ParseIntervalType(strings.ToUpper(s)) + assert.NotEqual(t, IntervalNone, parsed) + assert.Equal(t, interval, parsed) + }) + } + interval := IntervalType(math.MaxUint8) + s := interval.ToString() + assert.Equal(t, "[unknown IntervalType]", s) + parsed := ParseIntervalType(s) + assert.Equal(t, IntervalNone, parsed) +}