Skip to content

Commit

Permalink
Fixes #2644, #2645: Fix uint cast to str/float and add str cast to str (
Browse files Browse the repository at this point in the history
#2745)

This PR (fixes #2644 and fixes #2645) fixes uint casting to str and float and adds casting from str to str

Co-authored-by: Pierce Hayes <pierce314159@users.noreply.github.com>
  • Loading branch information
stress-tess and Pierce Hayes committed Sep 6, 2023
1 parent 6044082 commit 5d7f91f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
5 changes: 4 additions & 1 deletion PROTO_tests/tests/numeric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ def test_seeded_rng_general(self, prob_size):
def test_cast(self, prob_size, cast_to):
arrays = {
ak.int64: ak.randint(-(2**48), 2**48, prob_size),
ak.uint64: ak.randint(0, 2**48, prob_size, dtype=ak.uint64),
ak.float64: ak.randint(0, 1, prob_size, dtype=ak.float64),
ak.bool: ak.randint(0, 2, prob_size, dtype=ak.bool),
ak.str_: ak.cast(ak.randint(0, 2**48, prob_size), "str"),
}

for t1, orig in arrays.items():
if t1 == ak.float64 and cast_to == ak.bigint:
if (t1 == ak.float64 and cast_to == ak.bigint) or (t1 == ak.str_ and cast_to == ak.bool):
# we don't support casting a float to a bigint
# we do support str to bool, but it's expected to contain "true/false" not numerics
continue
other = ak.cast(orig, cast_to)
assert orig.size == other.size
Expand Down
26 changes: 12 additions & 14 deletions PROTO_tests/tests/pdarray_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,18 @@
class TestPdarrayCreation:
@pytest.mark.parametrize("dtype", DTYPES)
def test_array_creation(self, dtype):
# TODO - remove the 'if' below (to make everything that follows unconditional) after #2645 is complete
if dtype != str:
fixed_size = 100
for pda in [
ak.array(ak.ones(fixed_size, int), dtype),
ak.array(np.ones(fixed_size), dtype),
ak.array(list(range(fixed_size)), dtype=dtype),
ak.array((range(fixed_size)), dtype),
ak.array(deque(range(fixed_size)), dtype),
ak.array([f"{i}" for i in range(fixed_size)], dtype=dtype),
]:
assert isinstance(pda, ak.pdarray if dtype != str else ak.Strings)
assert len(pda) == fixed_size
assert dtype == pda.dtype
fixed_size = 100
for pda in [
ak.array(ak.ones(fixed_size, int), dtype),
ak.array(np.ones(fixed_size), dtype),
ak.array(list(range(fixed_size)), dtype=dtype),
ak.array((range(fixed_size)), dtype),
ak.array(deque(range(fixed_size)), dtype),
ak.array([f"{i}" for i in range(fixed_size)], dtype=dtype),
]:
assert isinstance(pda, ak.pdarray if dtype != str else ak.Strings)
assert len(pda) == fixed_size
assert dtype == pda.dtype

@pytest.mark.parametrize("size", pytest.prob_size)
def test_large_array_creation(self, size):
Expand Down
11 changes: 9 additions & 2 deletions src/CastMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ module CastMsg {
when (DType.UInt64, "uint64") {
return new MsgTuple(castGenSymEntry(gse, st, uint, uint), MsgType.NORMAL);
}
when (DType.UInt64, "float") {
when (DType.UInt64, "float64") {
return new MsgTuple(castGenSymEntry(gse, st, uint, real), MsgType.NORMAL);
}
when (DType.UInt64, "bool") {
return new MsgTuple(castGenSymEntry(gse, st, uint, bool), MsgType.NORMAL);
}
when (DType.UInt64, "string") {
when (DType.UInt64, "str") {
return new MsgTuple(castGenSymEntryToString(gse, st, uint), MsgType.NORMAL);
}
when (DType.UInt64, "bigint") {
Expand Down Expand Up @@ -168,6 +168,13 @@ module CastMsg {
when "bigint" {
return new MsgTuple(castStringToBigInt(strings, st, errors), MsgType.NORMAL);
}
when "str" {
const oname = st.nextName();
const vname = st.nextName();
var offsets = st.addEntry(oname, createSymEntry(strings.offsets.a));
var values = st.addEntry(vname, createSymEntry(strings.values.a));
return new MsgTuple("created " + st.attrib(oname) + "+created " + st.attrib(vname), MsgType.NORMAL);
}
otherwise {
var errorMsg = notImplementedError(pn,"str",":",targetDtype);
castLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down

0 comments on commit 5d7f91f

Please sign in to comment.