diff --git a/filters/standard_filters.go b/filters/standard_filters.go index 9d03a0a..e367f7d 100644 --- a/filters/standard_filters.go +++ b/filters/standard_filters.go @@ -3,6 +3,7 @@ package filters import ( "encoding/json" + "errors" "fmt" "html" "math" @@ -17,6 +18,8 @@ import ( "github.com/osteele/tuesday" ) +var errDivisionByZero = errors.New("division by zero") + // A FilterDictionary holds filters. type FilterDictionary interface { AddFilter(string, any) @@ -100,14 +103,42 @@ func AddStandardFilters(fd FilterDictionary) { //nolint: gocyclo fd.AddFilter("times", func(a, b float64) float64 { return a * b }) - fd.AddFilter("divided_by", func(a float64, b any) any { + fd.AddFilter("divided_by", func(a float64, b any) (any, error) { + divInt := func(a, b int64) (int64, error) { + if b == 0 { + return 0, errDivisionByZero + } + return a / b, nil + } + divFloat := func(a, b float64) (float64, error) { + if b == 0 { + return 0, errDivisionByZero + } + return a / b, nil + } switch q := b.(type) { - case int, int16, int32, int64: - return int(a) / q.(int) - case float32, float64: - return a / b.(float64) + case int: + return divInt(int64(a), int64(q)) + case int8: + return divInt(int64(a), int64(q)) + case int16: + return divInt(int64(a), int64(q)) + case int32: + return divInt(int64(a), int64(q)) + case int64: + return divInt(int64(a), q) + case uint8: + return divInt(int64(a), int64(q)) + case uint16: + return divInt(int64(a), int64(q)) + case uint32: + return divInt(int64(a), int64(q)) + case float32: + return divFloat(a, float64(q)) + case float64: + return divFloat(a, q) default: - return nil + return nil, fmt.Errorf("invalid divisor: '%v'", b) } }) fd.AddFilter("round", func(n float64, places func(int) int) float64 { diff --git a/filters/standard_filters_test.go b/filters/standard_filters_test.go index 194cf2b..6a05fa1 100644 --- a/filters/standard_filters_test.go +++ b/filters/standard_filters_test.go @@ -184,7 +184,6 @@ Liquid" | slice: 2, 4`, "quid"}, {`5 | divided_by: 3`, 1}, {`20 | divided_by: 7`, 2}, {`20 | divided_by: 7.0`, 2.857142857142857}, - {`20 | divided_by: 's'`, nil}, {`1.2 | round`, 1.0}, {`2.7 | round`, 3.0}, @@ -197,6 +196,14 @@ Liquid" | slice: 2, 4`, "quid"}, {`"1" | type`, `string`}, } +var filterErrorTests = []struct { + in string + error string +}{ + {`20 | divided_by: 's'`, `error applying filter "divided_by" ("invalid divisor: 's'")`}, + {`20 | divided_by: 0`, `error applying filter "divided_by" ("division by zero")`}, +} + var filterTestBindings = map[string]any{ "empty_array": []any{}, "empty_map": map[string]any{}, @@ -272,7 +279,14 @@ func TestFilters(t *testing.T) { t.Run(fmt.Sprintf("%02d", i+1), func(t *testing.T) { actual, err := expressions.EvaluateString(test.in, context) require.NoErrorf(t, err, test.in) - require.Equalf(t, test.expected, actual, test.in) + require.EqualValuesf(t, test.expected, actual, test.in) + }) + } + + for i, test := range filterErrorTests { + t.Run(fmt.Sprintf("%02d", i+len(filterTests)+1), func(t *testing.T) { + _, err := expressions.EvaluateString(test.in, context) + require.EqualErrorf(t, err, test.error, test.in) }) } }