From e4bf8fb2e9e25dc3380fd9f7c2c2f75c6512da5f Mon Sep 17 00:00:00 2001 From: drculhane Date: Wed, 18 Sep 2024 12:40:30 -0400 Subject: [PATCH] Rebased --- src/EfuncMsg.chpl | 346 +++++++++++++++++++++++++++--------- tests/numpy/numeric_test.py | 142 ++++++--------- tests/where_test.py | 17 +- 3 files changed, 325 insertions(+), 180 deletions(-) diff --git a/src/EfuncMsg.chpl b/src/EfuncMsg.chpl index 8d9f3c8cf3..9e2f6bdcd4 100644 --- a/src/EfuncMsg.chpl +++ b/src/EfuncMsg.chpl @@ -17,6 +17,8 @@ module EfuncMsg use UniqueMsg; use AryUtil; + use CommAggregation; + private config const logLevel = ServerConfig.logLevel; private config const logChannel = ServerConfig.logChannel; const eLogger = new Logger(logLevel, logChannel); @@ -848,114 +850,218 @@ module EfuncMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ + + // Presently, only two functions are implemented in efunc3vvMsg: where and putmask. + // This makes it an excellent candidate for rewriting to move both where and putmask + // into functions using the new interface. + @arkouda.registerND proc efunc3vvMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string; // response message - // split request into fields + var repMsg: string; var rname = st.nextName(); var efunc = msgArgs.getValueOf("func"); var g1: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("condition"), st); var g2: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("a"), st); var g3: borrowed GenSymEntry = getGenericTypedArrayEntry(msgArgs.getValueOf("b"), st); - if !((g1.shape == g2.shape) && (g2.shape == g3.shape)) { + if g1.shape != g2.shape { // both where and putmask require condition's shape to match 1st data input's var errorMsg = "shape mismatch in arguments to "+pn; eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); return new MsgTuple(errorMsg, MsgType.ERROR); } - select (g1.dtype, g2.dtype, g3.dtype) { - when (DType.Bool, DType.Int64, DType.Int64) { - var e1 = toSymEntry(g1, bool, nd); - var e2 = toSymEntry(g2, int, nd); - var e3 = toSymEntry(g3, int, nd); - select efunc { - when "where" { - var a = where_helper(e1.a, e2.a, e3.a, 0); - st.addEntry(rname, new shared SymEntry(a)); + if g1.dtype != DType.Bool { + var errorMsg = "condition must be of type Bool in "+pn ; + eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); + } + var e1 = toSymEntry(g1, bool, nd); + select efunc { + when "where" { + if g2.shape != g3.shape { // where requires all inputs to be of same shape + var errorMsg = "shape mismatch in arguments to "+pn; + eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); + } + select (g2.dtype,g3.dtype) { // where allows any combination of numerical types as inputs + when (DType.Float64, DType.Float64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, real, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Float64, DType.Int64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, int, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Float64, DType.UInt64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, uint, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Float64, DType.Bool) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, bool, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Int64, DType.Float64) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, real, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Int64, DType.Int64) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, int, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Int64, DType.UInt64) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, uint, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Int64, DType.Bool) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, bool, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.UInt64, DType.Float64) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, real, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.UInt64, DType.Int64) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, int, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.UInt64, DType.UInt64) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, uint, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.UInt64, DType.Bool) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, bool, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Bool, DType.Float64) { + var e2 = toSymEntry(g2, bool, nd); + var e3 = toSymEntry(g3, real, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Bool, DType.Int64) { + var e2 = toSymEntry(g2, bool, nd); + var e3 = toSymEntry(g3, int, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Bool, DType.UInt64) { + var e2 = toSymEntry(g2, bool, nd); + var e3 = toSymEntry(g3, uint, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + when (DType.Bool, DType.Bool) { + var e2 = toSymEntry(g2, bool, nd); + var e3 = toSymEntry(g3, bool, nd); + var a = where_helper_3vv(e1.a, e2.a, e3.a); + st.addEntry(rname, new shared SymEntry(a)); + } + otherwise { + var errorMsg = "arg types incompatible with where in "+pn ; + eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); + } + } // end of "select" + repMsg = "created " + st.attrib(rname); + eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } // end of when where + + when "putmask" { // putmask only requires original values and mask to be ofsame shape, + select (g2.dtype,g3.dtype) { // and allows all of the data type combinations herein. + when (DType.Float64, DType.Float64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, real, nd); + putmask_helper(e1.a,e2.a,e3.a); } - otherwise { - var errorMsg = notImplementedError(pn,efunc,g1.dtype, - g2.dtype,g3.dtype); - eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } - } - when (DType.Bool, DType.UInt64, DType.UInt64) { - var e1 = toSymEntry(g1, bool, nd); - var e2 = toSymEntry(g2, uint, nd); - var e3 = toSymEntry(g3, uint, nd); - select efunc { - when "where" { - var a = where_helper(e1.a, e2.a, e3.a, 0); - st.addEntry(rname, new shared SymEntry(a)); + when (DType.Float64, DType.Int64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, int, nd); + putmask_helper(e1.a,e2.a,e3.a); } - otherwise { - var errorMsg = notImplementedError(pn,efunc,g1.dtype, - g2.dtype,g3.dtype); - eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } - } - when (DType.Bool, DType.Float64, DType.Float64) { - var e1 = toSymEntry(g1, bool, nd); - var e2 = toSymEntry(g2, real, nd); - var e3 = toSymEntry(g3, real, nd); - select efunc { - when "where" { - var a = where_helper(e1.a, e2.a, e3.a, 0); - st.addEntry(rname, new shared SymEntry(a)); + when (DType.Float64, DType.UInt64) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, uint, nd); + putmask_helper(e1.a,e2.a,e3.a); } - otherwise { - var errorMsg = notImplementedError(pn,efunc,g1.dtype, - g2.dtype,g3.dtype); - eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + when (DType.Float64, DType.Bool) { + var e2 = toSymEntry(g2, real, nd); + var e3 = toSymEntry(g3, bool, nd); + putmask_helper(e1.a,e2.a,e3.a); } - } - } - when (DType.Bool, DType.Bool, DType.Bool) { - var e1 = toSymEntry(g1, bool, nd); - var e2 = toSymEntry(g2, bool, nd); - var e3 = toSymEntry(g3, bool, nd); - select efunc { - when "where" { - var a = where_helper(e1.a, e2.a, e3.a, 0); - st.addEntry(rname, new shared SymEntry(a)); + when (DType.Int64, DType.Int64) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, int, nd); + putmask_helper(e1.a,e2.a,e3.a); + } + when (DType.Int64, DType.UInt64) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, uint, nd); + putmask_helper(e1.a,e2.a,e3.a); + } + when (DType.Int64, DType.Bool) { + var e2 = toSymEntry(g2, int, nd); + var e3 = toSymEntry(g3, bool, nd); + putmask_helper(e1.a,e2.a,e3.a); + } + when (DType.UInt64, DType.UInt64) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, uint, nd); + putmask_helper(e1.a,e2.a,e3.a); + } + when (DType.UInt64, DType.Bool) { + var e2 = toSymEntry(g2, uint, nd); + var e3 = toSymEntry(g3, bool, nd); + putmask_helper(e1.a,e2.a,e3.a); + } + when (DType.Bool, DType.Bool) { + var e2 = toSymEntry(g2, bool, nd); + var e3 = toSymEntry(g3, bool, nd); + putmask_helper(e1.a,e2.a,e3.a); } otherwise { - var errorMsg = notImplementedError(pn,efunc,g1.dtype, - g2.dtype,g3.dtype); - eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + var errorMsg = "arg types incompatible with putmask in "+pn ; + eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); return new MsgTuple(errorMsg, MsgType.ERROR); - } - } - } - otherwise { - var errorMsg = notImplementedError(pn,efunc,g1.dtype,g2.dtype,g3.dtype); - eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } // end of select + return MsgTuple.success(); // putmask does not return a new pdarray; it writes in place + } // end of when putmask + otherwise { // neither where nor putmask + var errorMsg = notImplementedError(pn,efunc,g1.dtype,g2.dtype,g3.dtype); + eLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); } - } + } // end of select efunc repMsg = "created " + st.attrib(rname); eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); return new MsgTuple(repMsg, MsgType.NORMAL); - } + } // end of efunc3vvMsg - /* - vector = efunc(vector, vector, scalar) - - :arg reqMsg: request containing (cmd,efunc,name1,name2,dtype,value) - :type reqMsg: string - - :arg st: SymTab to act on - :type st: borrowed SymTab - - :returns: (MsgTuple) - :throws: `UndefinedSymbolError(name)` - */ @arkouda.registerND proc efunc3vsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { param pn = Reflection.getRoutineName(); @@ -1283,6 +1389,26 @@ module EfuncMsg return new MsgTuple(repMsg, MsgType.NORMAL); } + // The 3vv version of where_helper has been pulled out as a separate function, + // to ease the transition to matching the data types as allowed by numpy. + proc where_helper_3vv(cond:[?D] bool, A:[D] ?ta, B:[D] ?tb) : [D] throws { + type resultType = compute_result_type(ta,tb) ; + var C = makeDistArray(D, resultType); + forall (ch, a, b, c) in zip(cond, A, B, C) { + c = if ch then a:resultType else b:resultType; + } + return C; + } + + // This proc gives the return type for ak.where, given the input types. + // cribbed from LinalgMsg.chpl, with uint(8) changed to uint(64) + proc compute_result_type (type t1, type t2) type { + if t1 == real || t2 == real then return real; + if t1 == int || t2 == int then return int; + if t1 == uint(64) || t2 == uint(64) then return uint(64); + return bool; + } + /* The 'where' function takes a boolean array and two other arguments A and B, and returns an array with A where the boolean is true and B where it is false. A and B can be vectors or scalars. @@ -1302,6 +1428,7 @@ module EfuncMsg :arg kind: :type kind: param */ + proc where_helper(cond:[?D] bool, A:[D] ?t, B:[D] t, param kind):[D] t throws where (kind == 0) { var C = makeDistArray(D, t); forall (ch, a, b, c) in zip(cond, A, B, C) { @@ -1375,4 +1502,57 @@ module EfuncMsg } return C; } +//} +//} + proc putmask_helper (mask : [?D1] bool, ref A : [D1] ?ta, Values : [?D2] ?tv) { + +// Note, added casting to the output, to accomodate the various combinations of +// types that are allowed by np.putmask. Invalid pairings of types ta, tv must be +// caught and flagged by the calling function. This proc does no type checking. + + if A.size == Values.size { // then their distributions match + forall element in A.domain { + if mask[element] then A[element] = Values[element]:ta ; + } + + } else if A.size < Values.size { // then prune Values so the distributions match + var Pruned_Values = Values[0..A.size-1] ; + forall element in A.domain { + if mask[element] then A[element] = Pruned_Values[element]:ta ; + } + +// For the remaining cases, we know Values.size < A.size + + } else { + var A_subset_size = (A.size/numLocales):int ; + if A.size%numLocales == 0 then A_subset_size += 1 ; + +// If Values.size < A_subset_size, then we may as well make a copy of Values on each locale + + if Values.size < A_subset_size { + coforall loc in Locales do on loc { + var local_Values : [0..Values.size-1] tv = Values ; + // var local_Values = Values ; // per Tess's suggestion + forall element in A.localSubdomain() { + if mask[element] then A[element] = (local_Values[element%local_Values.size]):ta ; + } + } + +// Finally, we have the case where Values.size > A_subset_size, which means Values is distributed +// but probably not to our liking. Use an aggregator to get the right values in the right places. + + } else { + coforall loc in Locales do on loc { + var local_Values : [0..A.localSubdomain().size-1] tv; + forall (element, item) in zip (A.localSubdomain(),local_Values) with (var agg = newSrcAggregator(tv)) { + agg.copy(item,Values[element%Values.size]) ; + } + forall (element, item) in zip (A.localSubdomain(),local_Values) { + if mask[element] then A[element] = item:ta ; + } + } + } + } + return ; + } } diff --git a/tests/numpy/numeric_test.py b/tests/numpy/numeric_test.py index b2735aea76..7829ba8b02 100644 --- a/tests/numpy/numeric_test.py +++ b/tests/numpy/numeric_test.py @@ -18,6 +18,17 @@ YES_NO = [True, False] VOWELS_AND_SUCH = ["a", "e", "i", "o", "u", "AB", 47, 2, 3.14159] +ALLOWED_PUTMASK_PAIRS = [ + (ak.float64, ak.float64), + (ak.float64, ak.int64), + (ak.float64, ak.uint64), + (ak.float64, ak.bool_), + (ak.int64, ak.int64), + (ak.int64, ak.bool_), + (ak.uint64, ak.uint64), + (ak.uint64, ak.bool_), + (ak.bool_, ak.bool_), +] # There are many ways to create a vector of alternating values. # This is a fairly fast and fairly straightforward approach. @@ -533,8 +544,6 @@ def test_value_counts(self, num_type): assert ak.array([100]) == result[1] def test_value_counts_error(self): - pda = ak.linspace(1, 10, 10) - with pytest.raises(TypeError): ak.value_counts([0]) @@ -801,53 +810,63 @@ def test_clip(self, prob_size): @pytest.mark.parametrize("prob_size", pytest.prob_size) def test_putmask(self, prob_size): - for data_type in INT_FLOAT: + for d1, d2 in ALLOWED_PUTMASK_PAIRS: # three things to test: values same size as data - nda = np.random.randint(0, 10, prob_size).astype(data_type) - result = nda.copy() - np.putmask(result, result > 5, result**2) + nda = np.random.randint(0, 10, prob_size).astype(d1) pda = ak.array(nda) - ak.putmask(pda, pda > 5, pda**2) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) + nda2 = (nda**2).astype(d2) + pda2 = ak.array(nda2) + hold_that_thought = nda.copy() + np.putmask(nda, nda > 5, nda2) + ak.putmask(pda, pda > 5, pda2) + assert np.allclose(nda, pda.to_ndarray()) - # values shorter than data + # values potentially much shorter than data - result = nda.copy() + nda = hold_that_thought.copy() pda = ak.array(nda) - values = np.arange(3).astype(data_type) - np.putmask(result, result > 5, values) - ak.putmask(pda, pda > 5, ak.array(values)) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) + npvalues = np.arange(3).astype(d2) + akvalues = ak.array(npvalues) + np.putmask(nda, nda > 5, npvalues) + ak.putmask(pda, pda > 5, akvalues) + assert np.allclose(nda, pda.to_ndarray()) + + # values shorter than data, but likely not to fit on one locale in a multi-locale test + + nda = hold_that_thought.copy() + pda = ak.array(nda) + npvalues = np.arange(prob_size // 2 + 1).astype(d2) + akvalues = ak.array(npvalues) + np.putmask(nda, nda > 5, npvalues) + ak.putmask(pda, pda > 5, akvalues) + assert np.allclose(nda, pda.to_ndarray()) # values longer than data - result = nda.copy() + nda = hold_that_thought.copy() pda = ak.array(nda) - values = np.arange(prob_size + 1).astype(data_type) - np.putmask(result, result > 5, values) - ak.putmask(pda, pda > 5, ak.array(values)) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) + npvalues = np.arange(prob_size + 1000).astype(d2) + akvalues = ak.array(npvalues) + np.putmask(nda, nda > 5, npvalues) + ak.putmask(pda, pda > 5, akvalues) + assert np.allclose(nda, pda.to_ndarray()) - # finally try to raise the error + # finally try to raise errors - pda = ak.random.randint(0, 10, 10).astype(ak.float64) - values = np.arange(10) - with pytest.raises(TypeError): - ak.putmask(pda, pda > 3, values) + pda = ak.random.randint(0, 10, 10).astype(ak.float64) + mask = ak.array([True]) # wrong size error + values = ak.arange(10).astype(ak.float64) + with pytest.raises(RuntimeError): + ak.putmask(pda, mask, values) + + for d2, d1 in ALLOWED_PUTMASK_PAIRS: + if d1 != d2: # wrong types error + pda = ak.arange(0, 10, prob_size).astype(d1) + pda2 = (10 - pda).astype(d2) + with pytest.raises(RuntimeError): + ak.putmask(pda, pda > 5, pda2) # In the tests below, the rationale for using size = math.sqrt(prob_size) is that # the resulting matrices are on the order of size*size. @@ -1159,54 +1178,3 @@ def test_clip(self, prob_size): assert np.allclose( np.clip(nd_arry, lo, None), ak.clip(ak_arry, aklo, None).to_ndarray() ) - - @pytest.mark.parametrize("prob_size", pytest.prob_size) - def test_putmask(self, prob_size): - - for data_type in INT_FLOAT: - - # three things to test: values same size as data - - nda = np.random.randint(0, 10, prob_size).astype(data_type) - result = nda.copy() - np.putmask(result, result > 5, result**2) - pda = ak.array(nda) - ak.putmask(pda, pda > 5, pda**2) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) - - # values shorter than data - - result = nda.copy() - pda = ak.array(nda) - values = np.arange(3).astype(data_type) - np.putmask(result, result > 5, values) - ak.putmask(pda, pda > 5, ak.array(values)) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) - - # values longer than data - - result = nda.copy() - pda = ak.array(nda) - values = np.arange(prob_size + 1).astype(data_type) - np.putmask(result, result > 5, values) - ak.putmask(pda, pda > 5, ak.array(values)) - assert ( - np.all(result == pda.to_ndarray()) - if data_type == ak.int64 - else np.allclose(result, pda.to_ndarray()) - ) - - # finally try to raise the error - - pda = ak.random.randint(0, 10, 10).astype(ak.float64) - values = np.arange(10) - with pytest.raises(TypeError): - ak.putmask(pda, pda > 3, values) diff --git a/tests/where_test.py b/tests/where_test.py index 89dbaec979..ba856e0c77 100644 --- a/tests/where_test.py +++ b/tests/where_test.py @@ -9,16 +9,19 @@ warnings.simplefilter("always", UserWarning) +# TODO: Parametrize test_where class TestWhere: @pytest.mark.parametrize("size", pytest.prob_size) def test_where(self, size): npA = { + "uint64": np.random.randint(0, 10, size), "int64": np.random.randint(0, 10, size), "float64": np.random.randn(size), "bool": np.random.randint(0, 2, size, dtype="bool"), } akA = {k: ak.array(v) for k, v in npA.items()} npB = { + "uint64": np.random.randint(0, 10, size), "int64": np.random.randint(10, 20, size), "float64": np.random.randn(size) + 10, "bool": np.random.randint(0, 2, size, dtype="bool"), @@ -26,18 +29,12 @@ def test_where(self, size): akB = {k: ak.array(v) for k, v in npB.items()} npCond = np.random.randint(0, 2, size, dtype="bool") akCond = ak.array(npCond) - scA = {"int64": 42, "float64": 2.71828, "bool": True} - scB = {"int64": -1, "float64": 3.14159, "bool": False} dtypes = set(npA.keys()) - for dtype in dtypes: - for (ak1, ak2), (np1, np2) in zip( - product((akA, scA), (akB, scB)), - product((npA, scA), (npB, scB)), - ): - akres = ak.where(akCond, ak1[dtype], ak2[dtype]).to_ndarray() - npres = np.where(npCond, np1[dtype], np2[dtype]) - assert np.allclose(akres, npres, equal_nan=True) + for (dtype1,dtype2) in zip(dtypes,dtypes): + akres = ak.where(akCond, akA[dtype1], akB[dtype2]).to_ndarray() + npres = np.where(npCond, npA[dtype1], npB[dtype2]) + assert np.allclose(akres, npres, equal_nan=True) def test_error_handling(self): with pytest.raises(TypeError):