Skip to content

Commit

Permalink
Better support for floats (#102)
Browse files Browse the repository at this point in the history
Fixed protocols for Less then and Greater then.

This was causing an issue where mixed floats/ints in columns would not
sort correctly.
  • Loading branch information
scudette authored Jun 5, 2024
1 parent 520bec2 commit e681962
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 17 deletions.
42 changes: 33 additions & 9 deletions protocols/protocol_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ func (self AddDispatcher) Copy() AddDispatcher {
append([]AddProtocol{}, self.impl...)}
}

// Adding protocol

// LHS RHS
// int int -> lhs + rhs
// int float -> float(lhs) + rhs
// float int -> lhs + float(rhs)
// float float -> lhs + rhs

// We dont handle any other additions with ints here.
func intAdd(lhs int64, b types.Any) (types.Any, bool) {
switch b.(type) {
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, _ := utils.ToInt64(b)
return lhs + rhs, true

case float64, float32:
rhs, _ := utils.ToFloat(b)
return float64(lhs) + rhs, true
}

// We dont handle any other additions here
return &types.Null{}, false
}

func (self AddDispatcher) Add(scope types.Scope, a types.Any, b types.Any) types.Any {
a = maybeReduce(a)
b = maybeReduce(b)
Expand All @@ -44,19 +68,19 @@ func (self AddDispatcher) Add(scope types.Scope, a types.Any, b types.Any) types
case types.Null, *types.Null, nil:
return &types.Null{}

case float64:
b_float, ok := utils.ToFloat(b)
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
lhs, ok := utils.ToInt64(t)
if ok {
return t + b_float
res, ok := intAdd(lhs, b)
if ok {
return res
}
}
}

// Maybe its an integer.
a_int, ok := utils.ToInt64(a)
if ok {
b_int, ok := utils.ToInt64(b)
case float64:
b_float, ok := utils.ToFloat(b)
if ok {
return a_int + b_int
return t + b_float
}
}

Expand Down
6 changes: 6 additions & 0 deletions protocols/protocol_eq.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ func (self EqDispatcher) Eq(scope types.Scope, a types.Any, b types.Any) bool {
return t == rhs
}

case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
lhs, ok := utils.ToInt64(a)
if ok {
return intEq(lhs, b)
}

case bool:
rhs, ok := b.(bool)
if ok {
Expand Down
64 changes: 60 additions & 4 deletions protocols/protocol_gt.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,39 @@ func (self GtDispatcher) Copy() GtDispatcher {
append([]GtProtocol{}, self.impl...)}
}

func intGt(lhs int64, b types.Any) bool {
switch b.(type) {
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, _ := utils.ToInt64(b)
return lhs > rhs
case float64, float32:
rhs, _ := utils.ToFloat(b)
return float64(lhs) > rhs
}
return false
}

func (self GtDispatcher) Gt(scope types.Scope, a types.Any, b types.Any) bool {
a = maybeReduce(a)
b = maybeReduce(b)

switch t := a.(type) {
case types.Null, *types.Null, nil:
return false

case string:
rhs, ok := b.(string)
if ok {
return t > rhs
}

// If it is integer like, coerce to int.
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
lhs, ok := utils.ToInt64(t)
if ok {
return intGt(lhs, b)
}

case float64:
rhs, ok := utils.ToFloat(b)
if ok {
Expand All @@ -52,11 +74,45 @@ func (self GtDispatcher) Gt(scope types.Scope, a types.Any, b types.Any) bool {
}
}

lhs, ok := utils.ToInt64(a)
if ok {
rhs, ok := utils.ToInt64(b)
switch t := b.(type) {
case types.Null, *types.Null, nil:
return false

case string:
lhs, ok := a.(string)
if ok {
return lhs > t
}

// If it is integer like, coerce to int.
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, ok := utils.ToInt64(t)
if ok {
if intLt(rhs, a) {
return false
}
if intEq(rhs, a) {
return false
}
return true
}

case float64:
lhs, ok := utils.ToFloat(a)
if ok {
return lhs > t
}

case time.Time:
lhs, ok := toTime(a)
if ok {
return t.Before(*lhs)
}

case *time.Time:
lhs, ok := toTime(a)
if ok {
return lhs > rhs
return t.Before(*lhs)
}
}

Expand Down
76 changes: 72 additions & 4 deletions protocols/protocol_lt.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@ func (self LtDispatcher) Copy() LtDispatcher {
append([]LtProtocol{}, self.impl...)}
}

// Comparison table
// LHS RHS -> Promoted
// int int -> lhs < rhs
// int float -> float(lhs) < rhs
// float int -> lhs < float(rhs)
// float float -> lhs < lhs

func intLt(lhs int64, b types.Any) bool {
switch b.(type) {
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, _ := utils.ToInt64(b)
return lhs < rhs
case float64, float32:
rhs, _ := utils.ToFloat(b)
return float64(lhs) < rhs
}
return false
}

func intEq(lhs int64, b types.Any) bool {
switch t := b.(type) {
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, _ := utils.ToInt64(b)
return lhs == rhs
case float64, float32:
rhs, _ := utils.ToFloat(b)
return float64(lhs) == rhs
case bool:
return lhs != 0 == t
}
return false
}

func (self LtDispatcher) Lt(scope types.Scope, a types.Any, b types.Any) bool {
a = maybeReduce(a)
b = maybeReduce(b)
Expand All @@ -36,6 +69,13 @@ func (self LtDispatcher) Lt(scope types.Scope, a types.Any, b types.Any) bool {
return t < rhs
}

// If it is integer like, coerce to int.
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
lhs, ok := utils.ToInt64(t)
if ok {
return intLt(lhs, b)
}

case float64:
rhs, ok := utils.ToFloat(b)
if ok {
Expand All @@ -55,11 +95,39 @@ func (self LtDispatcher) Lt(scope types.Scope, a types.Any, b types.Any) bool {
}
}

lhs, ok := utils.ToInt64(a)
if ok {
rhs, ok := utils.ToInt64(b)
switch t := b.(type) {
case types.Null, *types.Null, nil:
return false

// If it is integer like, coerce to int.
case int, int8, int16, int32, int64, uint8, uint16, uint32, uint64:
rhs, ok := utils.ToInt64(t)
if ok {
if intGt(rhs, a) {
return false
}
if intEq(rhs, a) {
return false
}
return true
}

case float64:
lhs, ok := utils.ToFloat(a)
if ok {
return lhs < t
}

case time.Time:
lhs, ok := toTime(a)
if ok {
return t.After(*lhs)
}

case *time.Time:
lhs, ok := toTime(a)
if ok {
return lhs < rhs
return t.After(*lhs)
}
}

Expand Down
18 changes: 18 additions & 0 deletions vfilter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ var execTestsSerialization = []execTest{
// Comparing int to float
{"1 = 1.0", true},
{"1.0 = 1", true},
{"1.1 = 1", false},
{"1 = 1.1", false},
{"1 = 'foo'", false},

// Floats do not compare with integers properly.
Expand All @@ -136,6 +138,21 @@ var execTestsSerialization = []execTest{
{"2 < 1", false},
{"2 < 1.5", false},

// Floats
{"2.1 < three_int64", true},
{"2.1 < 2.5", true},
{"3.5 < three_int64", false},

{"three_int64 < 3.6", true},
{"three_int64 < 2.1", false},

{"2.1 > three_int64", false},
{"2.1 > 2.5", false},
{"3.5 > three_int64", true},

{"three_int64 > 3.6", false},
{"three_int64 > 2.1", true},

// Non matching types
{"2 > 'hello'", false},
{"2 < 'hello'", false},
Expand Down Expand Up @@ -390,6 +407,7 @@ func (self SetEnvFunction) Info(scope types.Scope, type_map *TypeMap) *FunctionI
func makeScope() types.Scope {
env := ordereddict.NewDict().
Set("const_foo", 1).
Set("three_int64", int64(3)).
Set("my_list_obj", ordereddict.NewDict().
Set("my_list", []interface{}{
1, 2, 3,
Expand Down

0 comments on commit e681962

Please sign in to comment.