Skip to content

Commit

Permalink
Merge pull request #97 from jamslinger/division-by-zero
Browse files Browse the repository at this point in the history
Don't panic in filter divided_by on division by zero
  • Loading branch information
danog authored Oct 18, 2024
2 parents cc7bda4 + 63cddf1 commit 71b8fa2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
43 changes: 37 additions & 6 deletions filters/standard_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package filters

import (
"encoding/json"
"errors"
"fmt"
"html"
"math"
Expand All @@ -17,6 +18,8 @@ import (
"github.com/osteele/tuesday"
)

var errDivisionByZero = errors.New("division by zero")

// A FilterDictionary holds filters.
type FilterDictionary interface {
AddFilter(string, any)
Expand Down Expand Up @@ -100,14 +103,42 @@ func AddStandardFilters(fd FilterDictionary) { //nolint: gocyclo
fd.AddFilter("times", func(a, b float64) float64 {
return a * b
})
fd.AddFilter("divided_by", func(a float64, b any) any {
fd.AddFilter("divided_by", func(a float64, b any) (any, error) {
divInt := func(a, b int64) (int64, error) {
if b == 0 {
return 0, errDivisionByZero
}
return a / b, nil
}
divFloat := func(a, b float64) (float64, error) {
if b == 0 {
return 0, errDivisionByZero
}
return a / b, nil
}
switch q := b.(type) {
case int, int16, int32, int64:
return int(a) / q.(int)
case float32, float64:
return a / b.(float64)
case int:
return divInt(int64(a), int64(q))
case int8:
return divInt(int64(a), int64(q))
case int16:
return divInt(int64(a), int64(q))
case int32:
return divInt(int64(a), int64(q))
case int64:
return divInt(int64(a), q)
case uint8:
return divInt(int64(a), int64(q))
case uint16:
return divInt(int64(a), int64(q))
case uint32:
return divInt(int64(a), int64(q))
case float32:
return divFloat(a, float64(q))
case float64:
return divFloat(a, q)
default:
return nil
return nil, fmt.Errorf("invalid divisor: '%v'", b)
}
})
fd.AddFilter("round", func(n float64, places func(int) int) float64 {
Expand Down
18 changes: 16 additions & 2 deletions filters/standard_filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ Liquid" | slice: 2, 4`, "quid"},
{`5 | divided_by: 3`, 1},
{`20 | divided_by: 7`, 2},
{`20 | divided_by: 7.0`, 2.857142857142857},
{`20 | divided_by: 's'`, nil},

{`1.2 | round`, 1.0},
{`2.7 | round`, 3.0},
Expand All @@ -197,6 +196,14 @@ Liquid" | slice: 2, 4`, "quid"},
{`"1" | type`, `string`},
}

var filterErrorTests = []struct {
in string
error string
}{
{`20 | divided_by: 's'`, `error applying filter "divided_by" ("invalid divisor: 's'")`},
{`20 | divided_by: 0`, `error applying filter "divided_by" ("division by zero")`},
}

var filterTestBindings = map[string]any{
"empty_array": []any{},
"empty_map": map[string]any{},
Expand Down Expand Up @@ -272,7 +279,14 @@ func TestFilters(t *testing.T) {
t.Run(fmt.Sprintf("%02d", i+1), func(t *testing.T) {
actual, err := expressions.EvaluateString(test.in, context)
require.NoErrorf(t, err, test.in)
require.Equalf(t, test.expected, actual, test.in)
require.EqualValuesf(t, test.expected, actual, test.in)
})
}

for i, test := range filterErrorTests {
t.Run(fmt.Sprintf("%02d", i+len(filterTests)+1), func(t *testing.T) {
_, err := expressions.EvaluateString(test.in, context)
require.EqualErrorf(t, err, test.error, test.in)
})
}
}
Expand Down

0 comments on commit 71b8fa2

Please sign in to comment.