Skip to content

Commit

Permalink
Closes #2099: Bug in left and right shift by >=64 bits for int/uint (#…
Browse files Browse the repository at this point in the history
…2107)

* first set of changes

fix typecasting for unit case

initialize array variables

zipping wrong size

adding testing def

adding testing

fix testing typos

* fix testing error

all test cases for singletons

* random space fix

* fixing array tests

* CHanging lowest value to 62 to ensure a non 0 values for left and right shifting cases

adding right shift implementation and testing function, also corrected the size of the maxbit arrays

* fixing arrow direction

* array fix and test condensing

* black fix

---------

Co-authored-by: jaketrookman <jaketrookman@users.noreply.github.com>
  • Loading branch information
jaketrookman and jaketrookman authored Feb 7, 2023
1 parent 96fc3c9 commit cefaa93
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 12 deletions.
52 changes: 40 additions & 12 deletions src/BinOp.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,16 @@ module BinOp
[(ei,li,ri) in zip(ea,la,ra)] ei = if ri != 0 then li%ri else 0;
}
when "<<" {
e.a = l.a << r.a;
ref ea = e.a;
ref la = l.a;
ref ra = r.a;
[(ei,li,ri) in zip(ea,la,ra)] if ri < 64 then ei = li << ri;
}
when ">>" {
e.a = l.a >> r.a;
ref ea = e.a;
ref la = l.a;
ref ra = r.a;
[(ei,li,ri) in zip(ea,la,ra)] if ri < 64 then ei = li >> ri;
}
when "<<<" {
e.a = rotl(l.a, r.a);
Expand Down Expand Up @@ -242,10 +248,16 @@ module BinOp
(e.etype == uint && r.etype == int) {
select op {
when ">>" {
e.a = l.a >> r.a;
ref ea = e.a;
ref la = l.a;
ref ra = r.a;
[(ei,li,ri) in zip(ea,la,ra)] if ri < 64 then ei = li >> ri;
}
when "<<" {
e.a = l.a << r.a;
ref ea = e.a;
ref la = l.a;
ref ra = r.a;
[(ei,li,ri) in zip(ea,la,ra)] if ri < 64 then ei = li << ri;
}
when ">>>" {
e.a = rotr(l.a, r.a);
Expand Down Expand Up @@ -540,10 +552,14 @@ module BinOp
[(ei,li) in zip(ea,la)] ei = if val != 0 then li%val else 0;
}
when "<<" {
e.a = l.a << val;
if val < 64 {
e.a = l.a << val;
}
}
when ">>" {
e.a = l.a >> val;
if val < 64 {
e.a = l.a >> val;
}
}
when "<<<" {
e.a = rotl(l.a, val);
Expand Down Expand Up @@ -591,10 +607,14 @@ module BinOp
(e.etype == uint && val.type == int) {
select op {
when ">>" {
e.a = l.a >> val:l.etype;
if val < 64 {
e.a = l.a >> val:l.etype;
}
}
when "<<" {
e.a = l.a << val:l.etype;
if val < 64 {
e.a = l.a << val:l.etype;
}
}
when ">>>" {
e.a = rotr(l.a, val:l.etype);
Expand Down Expand Up @@ -857,10 +877,14 @@ module BinOp
[(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val%ri else 0;
}
when "<<" {
e.a = val << r.a;
ref ea = e.a;
ref ra = r.a;
[(ei,ri) in zip(ea,ra)] if ri < 64 then ei = val << ri;
}
when ">>" {
e.a = val >> r.a;
ref ea = e.a;
ref ra = r.a;
[(ei,ri) in zip(ea,ra)] if ri < 64 then ei = val >> ri;
}
when "<<<" {
e.a = rotl(val, r.a);
Expand Down Expand Up @@ -911,10 +935,14 @@ module BinOp
} else if (val.type == int && r.etype == uint) {
select op {
when ">>" {
e.a = val:uint >> r.a:uint;
ref ea = e.a;
ref ra = r.a;
[(ei,ri) in zip(ea,ra)] if ri:uint < 64 then ei = val:uint >> ri:uint;
}
when "<<" {
e.a = val:uint << r.a:uint;
ref ea = e.a;
ref ra = r.a;
[(ei,ri) in zip(ea,ra)] if ri:uint < 64 then ei = val:uint << ri:uint;
}
when ">>>" {
e.a = rotr(val:uint, r.a:uint);
Expand Down
50 changes: 50 additions & 0 deletions tests/operator_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,56 @@ def test_float_uint_binops(self):
self.assertTrue(np.allclose((ak_float**aku).to_ndarray(), np_float**npu, equal_nan=True))
self.assertTrue(np.allclose((aku**ak_float).to_ndarray(), npu**np_float, equal_nan=True))

def test_shift_binop(self):
# This tests for a bug when left shifting by a value >=64 bits for int/uint, Issue #2099
# Max bit value
maxbits = 2**63 - 1

# Value arrays
ak_uint = ak.array([maxbits, maxbits, maxbits, maxbits], dtype=ak.uint64)
np_uint = np.array([maxbits, maxbits, maxbits, maxbits], dtype=np.uint64)
ak_int = ak.array([maxbits, maxbits, maxbits, maxbits], dtype=ak.int64)
np_int = np.array([maxbits, maxbits, maxbits, maxbits], dtype=np.int64)

# Shifting value arrays
ak_uint_array = ak.array([62, 63, 64, 65], dtype=ak.uint64)
np_uint_array = np.array([62, 63, 64, 65], dtype=np.uint64)
ak_int_array = ak.array([62, 63, 64, 65], dtype=ak.int64)
np_int_array = np.array([62, 63, 64, 65], dtype=np.int64)

# Binopvs case
for i in range(62, 66):
# Left shift
self.assertTrue(np.allclose((ak_uint << i).to_ndarray(), np_uint << i))
self.assertTrue(np.allclose((ak_int << i).to_ndarray(), np_int << i))
# Right shift
self.assertTrue(np.allclose((ak_uint >> i).to_ndarray(), np_uint >> i))
self.assertTrue(np.allclose((ak_int >> i).to_ndarray(), np_int >> i))

# Binopsv case
# Left Shift
self.assertListEqual((maxbits << ak_uint_array).to_list(), (maxbits << np_uint_array).tolist())
self.assertListEqual((maxbits << ak_int_array).to_list(), (maxbits << np_int_array).tolist())
# Right Shift
self.assertListEqual((maxbits >> ak_uint_array).to_list(), (maxbits >> np_uint_array).tolist())
self.assertListEqual((maxbits >> ak_int_array).to_list(), (maxbits >> np_int_array).tolist())

# Binopvv case, Same type
# Left Shift
self.assertListEqual((ak_uint << ak_uint_array).to_list(), (np_uint << np_uint_array).tolist())
self.assertListEqual((ak_int << ak_int_array).to_list(), (np_int << np_int_array).tolist())
# Right Shift
self.assertListEqual((ak_uint >> ak_uint_array).to_list(), (np_uint >> np_uint_array).tolist())
self.assertListEqual((ak_int >> ak_int_array).to_list(), (np_int >> np_int_array).tolist())

# Binopvv case, Mixed type
# Left Shift
self.assertListEqual((ak_uint << ak_int_array).to_list(), (np_uint << np_uint_array).tolist())
self.assertListEqual((ak_int << ak_uint_array).to_list(), (np_int << np_int_array).tolist())
# Right shift
self.assertListEqual((ak_uint >> ak_int_array).to_list(), (np_uint >> np_uint_array).tolist())
self.assertListEqual((ak_int >> ak_uint_array).to_list(), (np_int >> np_int_array).tolist())

def test_concatenate_type_preservation(self):
# Test that concatenate preserves special pdarray types (IPv4, Datetime, BitVector, ...)
from arkouda.util import generic_concat as akuconcat
Expand Down

0 comments on commit cefaa93

Please sign in to comment.