Skip to content

Commit

Permalink
Closes #2830 Implement division and floor division for int64 and uint…
Browse files Browse the repository at this point in the history
…64 dtypes. (#2847)

* binops and operatorMsg changes to allow division and floor division for int and uint combinations. Also adds testing

* fixing run errors

* fixing uint-uint addition bug

---------

Co-authored-by: jaketrookman <jaketrookman@users.noreply.github.com>
  • Loading branch information
jaketrookman and jaketrookman authored Nov 15, 2023
1 parent f359644 commit aa3afc4
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 26 deletions.
33 changes: 32 additions & 1 deletion PROTO_tests/tests/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_concatenation(self, dtype):

def test_max_bits_concatenation(self):
# reproducer for issue #2802
concatenated = ak.concatenate([ak.arange(5, max_bits=3), ak.arange(2**200 - 1, 2**200 + 4)])
concatenated = ak.concatenate([ak.arange(5, max_bits=3), ak.arange(2 ** 200 - 1, 2 ** 200 + 4)])
assert concatenated.max_bits == 3
assert [0, 1, 2, 3, 4, 7, 0, 1, 2, 3] == concatenated.to_list()

Expand All @@ -240,6 +240,37 @@ def test_uint_bool_binops(self):
ak_bool = ak_uint % 2 == 0
assert (ak_uint + ak_bool).to_list() == (ak.arange(10) + ak_bool).to_list()

def test_int_uint_binops(self):
np_int = np.arange(-5, 5)
ak_int = ak.array(np_int)

np_uint = np.arange(2**64 - 10, 2**64, dtype=np.uint64)
ak_uint = ak.array(np_uint)

# Vector-Vector Case (Division and Floor Division)
assert np.allclose((ak_uint / ak_uint).to_ndarray(), np_uint / np_uint, equal_nan=True)
assert np.allclose((ak_int / ak_uint).to_ndarray(), np_int / np_uint, equal_nan=True)
assert np.allclose((ak_uint / ak_int).to_ndarray(), np_uint / np_int, equal_nan=True)
assert np.allclose((ak_uint // ak_uint).to_ndarray(), np_uint // np_uint, equal_nan=True)
assert np.allclose((ak_int // ak_uint).to_ndarray(), np_int // np_uint, equal_nan=True)
assert np.allclose((ak_uint // ak_int).to_ndarray(), np_uint // np_int, equal_nan=True)

# Scalar-Vector Case (Division and Floor Division)
assert np.allclose((ak_uint[0] / ak_uint).to_ndarray(), np_uint[0] / np_uint, equal_nan=True)
assert np.allclose((ak_int[0] / ak_uint).to_ndarray(), np_int[0] / np_uint, equal_nan=True)
assert np.allclose((ak_uint[0] / ak_int).to_ndarray(), np_uint[0] / np_int, equal_nan=True)
assert np.allclose((ak_uint[0] // ak_uint).to_ndarray(), np_uint[0] // np_uint, equal_nan=True)
assert np.allclose((ak_int[0] // ak_uint).to_ndarray(), np_int[0] // np_uint, equal_nan=True)
assert np.allclose((ak_uint[0] // ak_int).to_ndarray(), np_uint[0] // np_int, equal_nan=True)

# Vector-Scalar Case (Division and Floor Division)
assert np.allclose((ak_uint / ak_uint[0]).to_ndarray(), np_uint / np_uint[0], equal_nan=True)
assert np.allclose((ak_int / ak_uint[0]).to_ndarray(), np_int / np_uint[0], equal_nan=True)
assert np.allclose((ak_uint / ak_int[0]).to_ndarray(), np_uint / np_int[0], equal_nan=True)
assert np.allclose((ak_uint // ak_uint[0]).to_ndarray(), np_uint // np_uint[0], equal_nan=True)
assert np.allclose((ak_int // ak_uint[0]).to_ndarray(), np_int // np_uint[0], equal_nan=True)
assert np.allclose((ak_uint // ak_int[0]).to_ndarray(), np_uint // np_int[0], equal_nan=True)

def test_float_uint_binops(self):
# Test fix for issue #1620
np_uint = make_np_arrays(10, "uint64")
Expand Down
25 changes: 25 additions & 0 deletions src/BinOp.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ module BinOp
when "-" {
e.a = l.a:real - r.a:real;
}
when "/" { // truediv
e.a = l.a:real / r.a:real;
}
when "//" { // floordiv
ref ea = e.a;
var la = l.a:real;
var ra = r.a:real;
[(ei,li,ri) in zip(ea,la,ra)] ei = floorDivisionHelper(li, ri);
}
otherwise {
var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -717,6 +726,14 @@ module BinOp
when "-" {
e.a = l.a: real - val: real;
}
when "/" { // truediv
e.a = l.a: real / val: real;
}
when "//" { // floordiv
ref ea = e.a;
var la = l.a;
[(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li, val:real);
}
otherwise {
var errorMsg = notImplementedError(pn,l.dtype,op,dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -1075,6 +1092,14 @@ module BinOp
when "-" {
e.a = val:real - r.a:real;
}
when "/" { // truediv
e.a = val:real / r.a:real;
}
when "//" { // floordiv
ref ea = e.a;
var ra = r.a;
[(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(val:real, ri);
}
otherwise {
var errorMsg = notImplementedError(pn,dtype,op,r.dtype);
omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down
82 changes: 58 additions & 24 deletions src/OperatorMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ module OperatorMsg
boolOps.add("==");
boolOps.add("!=");

var realOps: set(string);
realOps.add("+");
realOps.add("-");
realOps.add("/");
realOps.add("//");

select (left.dtype, right.dtype) {
when (DType.Int64, DType.Int64) {
var l = toSymEntry(left,int);
Expand Down Expand Up @@ -211,22 +217,27 @@ module OperatorMsg
var e = st.addEntry(rname, l.size, bool);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
var e = st.addEntry(rname, l.size, uint);
return doBinOpvv(l, r, e, op, rname, pn, st);
if op == "/"{
var e = st.addEntry(rname, l.size, real);
return doBinOpvv(l, r, e, op, rname, pn, st);
} else {
var e = st.addEntry(rname, l.size, uint);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
}
when (DType.UInt64, DType.Int64) {
var l = toSymEntry(left,uint);
var r = toSymEntry(right,int);
if boolOps.contains(op) {
var e = st.addEntry(rname, l.size, bool);
return doBinOpvv(l, r, e, op, rname, pn, st);
return doBinOpvv(l, r , e, op, rname, pn, st);
}
// + and - both result in real outputs to match NumPy
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, l.size, real);
return doBinOpvv(l, r, e, op, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, l.size, uint);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
Expand All @@ -238,11 +249,12 @@ module OperatorMsg
var e = st.addEntry(rname, l.size, bool);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, l.size, real);
return doBinOpvv(l, r, e, op, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, l.size, int);
return doBinOpvv(l, r, e, op, rname, pn, st);
}
Expand Down Expand Up @@ -399,6 +411,12 @@ module OperatorMsg
boolOps.add("==");
boolOps.add("!=");

var realOps: set(string);
realOps.add("+");
realOps.add("-");
realOps.add("/");
realOps.add("//");

select (left.dtype, dtype) {
when (DType.Int64, DType.Int64) {
var l = toSymEntry(left,int);
Expand Down Expand Up @@ -545,8 +563,13 @@ module OperatorMsg
var e = st.addEntry(rname, l.size, bool);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
var e = st.addEntry(rname, l.size, uint);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
if op == "/"{
var e = st.addEntry(rname, l.size, real);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
} else {
var e = st.addEntry(rname, l.size, uint);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
}
when (DType.UInt64, DType.Int64) {
var l = toSymEntry(left,uint);
Expand All @@ -555,12 +578,12 @@ module OperatorMsg
var e = st.addEntry(rname, l.size, bool);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
// + and - both result in real outputs to match NumPy
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, l.size, real);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, l.size, uint);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
Expand All @@ -572,12 +595,12 @@ module OperatorMsg
var e = st.addEntry(rname, l.size, bool);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
// + and - both result in real outputs to match NumPy
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, l.size, real);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, l.size, int);
return doBinOpvs(l, val, e, op, dtype, rname, pn, st);
}
Expand Down Expand Up @@ -734,6 +757,12 @@ module OperatorMsg
boolOps.add("==");
boolOps.add("!=");

var realOps: set(string);
realOps.add("+");
realOps.add("-");
realOps.add("/");
realOps.add("//");

select (dtype, right.dtype) {
when (DType.Int64, DType.Int64) {
var val = value.getIntValue();
Expand Down Expand Up @@ -880,8 +909,13 @@ module OperatorMsg
var e = st.addEntry(rname, r.size, bool);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
var e = st.addEntry(rname, r.size, uint);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
if op == "/"{
var e = st.addEntry(rname, r.size, real);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
} else {
var e = st.addEntry(rname, r.size, uint);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
}
when (DType.UInt64, DType.Int64) {
var val = value.getUIntValue();
Expand All @@ -890,12 +924,12 @@ module OperatorMsg
var e = st.addEntry(rname, r.size, bool);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
// + and - both result in real outputs to match NumPy
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, r.size, real);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, r.size, uint);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
Expand All @@ -907,12 +941,12 @@ module OperatorMsg
var e = st.addEntry(rname, r.size, bool);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
// + and - both result in real outputs to match NumPy
if op == "+" || op == "-" {
// +, -, /, // both result in real outputs to match NumPy
if realOps.contains(op) {
var e = st.addEntry(rname, r.size, real);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
} else {
// isn't + or -, so we can use LHS to determine type
// isn't +, -, /, // so we can use LHS to determine type
var e = st.addEntry(rname, r.size, int);
return doBinOpsv(val, r, e, op, dtype, rname, pn, st);
}
Expand Down
60 changes: 59 additions & 1 deletion tests/operator_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,12 @@ def testConcatenate(self):
)

pdaOne = ak.arange(5, max_bits=3)
pdaTwo = ak.arange(2**200 - 1, 2**200 + 4)
pdaTwo = ak.arange(2 ** 200 - 1, 2 ** 200 + 4)
concatenated = ak.concatenate([pdaOne, pdaTwo])
self.assertEqual(concatenated.max_bits, 3)
self.assertListEqual([0, 1, 2, 3, 4, 7, 0, 1, 2, 3], concatenated.to_list())


def test_invert(self):
ak_uint = ak.arange(10, dtype=ak.uint64)
inverted = ~ak_uint
Expand All @@ -312,6 +313,63 @@ def test_uint_bool_binops(self):
ak_bool = ak_uint % 2 == 0
self.assertListEqual((ak_uint + ak_bool).to_list(), (ak.arange(10) + ak_bool).to_list())

def test_int_uint_binops(self):
np_int = np.arange(-5, 5)
ak_int = ak.array(np_int)

np_uint = np.arange(2**64 - 10, 2**64, dtype=np.uint64)
ak_uint = ak.array(np_uint)

# Vector-Vector Case (Division and Floor Division)
self.assertTrue(np.allclose((ak_uint / ak_uint).to_ndarray(), np_uint / np_uint, equal_nan=True))
self.assertTrue(np.allclose((ak_int / ak_uint).to_ndarray(), np_int / np_uint, equal_nan=True))
self.assertTrue(np.allclose((ak_uint / ak_int).to_ndarray(), np_uint / np_int, equal_nan=True))
self.assertTrue(
np.allclose((ak_uint // ak_uint).to_ndarray(), np_uint // np_uint, equal_nan=True)
)
self.assertTrue(np.allclose((ak_int // ak_uint).to_ndarray(), np_int // np_uint, equal_nan=True))
self.assertTrue(np.allclose((ak_uint // ak_int).to_ndarray(), np_uint // np_int, equal_nan=True))

# Scalar-Vector Case (Division and Floor Division)
self.assertTrue(
np.allclose((ak_uint[0] / ak_uint).to_ndarray(), np_uint[0] / np_uint, equal_nan=True)
)
self.assertTrue(
np.allclose((ak_int[0] / ak_uint).to_ndarray(), np_int[0] / np_uint, equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint[0] / ak_int).to_ndarray(), np_uint[0] / np_int, equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint[0] // ak_uint).to_ndarray(), np_uint[0] // np_uint, equal_nan=True)
)
self.assertTrue(
np.allclose((ak_int[0] // ak_uint).to_ndarray(), np_int[0] // np_uint, equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint[0] // ak_int).to_ndarray(), np_uint[0] // np_int, equal_nan=True)
)

# Vector-Scalar Case (Division and Floor Division)
self.assertTrue(
np.allclose((ak_uint / ak_uint[0]).to_ndarray(), np_uint / np_uint[0], equal_nan=True)
)
self.assertTrue(
np.allclose((ak_int / ak_int[0]).to_ndarray(), np_int / np_int[0], equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint / ak_uint[0]).to_ndarray(), np_uint / np_uint[0], equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint // ak_uint[0]).to_ndarray(), np_uint // np_uint[0], equal_nan=True)
)
self.assertTrue(
np.allclose((ak_int // ak_uint[0]).to_ndarray(), np_int // np_uint[0], equal_nan=True)
)
self.assertTrue(
np.allclose((ak_uint // ak_int[0]).to_ndarray(), np_uint // np_int[0], equal_nan=True)
)

def test_float_uint_binops(self):
# Test fix for issue #1620
ak_uint = ak.array([5], dtype=ak.uint64)
Expand Down

0 comments on commit aa3afc4

Please sign in to comment.