diff --git a/PROTO_tests/tests/numeric_test.py b/PROTO_tests/tests/numeric_test.py index 6bca490267..05ef9d1406 100644 --- a/PROTO_tests/tests/numeric_test.py +++ b/PROTO_tests/tests/numeric_test.py @@ -88,13 +88,16 @@ def test_seeded_rng_general(self, prob_size): def test_cast(self, prob_size, cast_to): arrays = { ak.int64: ak.randint(-(2**48), 2**48, prob_size), + ak.uint64: ak.randint(0, 2**48, prob_size, dtype=ak.uint64), ak.float64: ak.randint(0, 1, prob_size, dtype=ak.float64), ak.bool: ak.randint(0, 2, prob_size, dtype=ak.bool), + ak.str_: ak.cast(ak.randint(0, 2**48, prob_size), "str"), } for t1, orig in arrays.items(): - if t1 == ak.float64 and cast_to == ak.bigint: + if (t1 == ak.float64 and cast_to == ak.bigint) or (t1 == ak.str_ and cast_to == ak.bool): # we don't support casting a float to a bigint + # we do support str to bool, but it's expected to contain "true/false" not numerics continue other = ak.cast(orig, cast_to) assert orig.size == other.size diff --git a/PROTO_tests/tests/pdarray_creation_test.py b/PROTO_tests/tests/pdarray_creation_test.py index f145e60aa2..d14db2e532 100644 --- a/PROTO_tests/tests/pdarray_creation_test.py +++ b/PROTO_tests/tests/pdarray_creation_test.py @@ -29,20 +29,18 @@ class TestPdarrayCreation: @pytest.mark.parametrize("dtype", DTYPES) def test_array_creation(self, dtype): - # TODO - remove the 'if' below (to make everything that follows unconditional) after #2645 is complete - if dtype != str: - fixed_size = 100 - for pda in [ - ak.array(ak.ones(fixed_size, int), dtype), - ak.array(np.ones(fixed_size), dtype), - ak.array(list(range(fixed_size)), dtype=dtype), - ak.array((range(fixed_size)), dtype), - ak.array(deque(range(fixed_size)), dtype), - ak.array([f"{i}" for i in range(fixed_size)], dtype=dtype), - ]: - assert isinstance(pda, ak.pdarray if dtype != str else ak.Strings) - assert len(pda) == fixed_size - assert dtype == pda.dtype + fixed_size = 100 + for pda in [ + ak.array(ak.ones(fixed_size, int), dtype), + ak.array(np.ones(fixed_size), dtype), + ak.array(list(range(fixed_size)), dtype=dtype), + ak.array((range(fixed_size)), dtype), + ak.array(deque(range(fixed_size)), dtype), + ak.array([f"{i}" for i in range(fixed_size)], dtype=dtype), + ]: + assert isinstance(pda, ak.pdarray if dtype != str else ak.Strings) + assert len(pda) == fixed_size + assert dtype == pda.dtype @pytest.mark.parametrize("size", pytest.prob_size) def test_large_array_creation(self, size): diff --git a/src/CastMsg.chpl b/src/CastMsg.chpl index a06d6a433f..3cf0fec81d 100644 --- a/src/CastMsg.chpl +++ b/src/CastMsg.chpl @@ -76,13 +76,13 @@ module CastMsg { when (DType.UInt64, "uint64") { return new MsgTuple(castGenSymEntry(gse, st, uint, uint), MsgType.NORMAL); } - when (DType.UInt64, "float") { + when (DType.UInt64, "float64") { return new MsgTuple(castGenSymEntry(gse, st, uint, real), MsgType.NORMAL); } when (DType.UInt64, "bool") { return new MsgTuple(castGenSymEntry(gse, st, uint, bool), MsgType.NORMAL); } - when (DType.UInt64, "string") { + when (DType.UInt64, "str") { return new MsgTuple(castGenSymEntryToString(gse, st, uint), MsgType.NORMAL); } when (DType.UInt64, "bigint") { @@ -168,6 +168,13 @@ module CastMsg { when "bigint" { return new MsgTuple(castStringToBigInt(strings, st, errors), MsgType.NORMAL); } + when "str" { + const oname = st.nextName(); + const vname = st.nextName(); + var offsets = st.addEntry(oname, createSymEntry(strings.offsets.a)); + var values = st.addEntry(vname, createSymEntry(strings.values.a)); + return new MsgTuple("created " + st.attrib(oname) + "+created " + st.attrib(vname), MsgType.NORMAL); + } otherwise { var errorMsg = notImplementedError(pn,"str",":",targetDtype); castLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);