From e2722c0e6d21001b087da3660787fe56c2ad2858 Mon Sep 17 00:00:00 2001 From: Ben McDonald <46734217+bmcdonald3@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:34:37 -0700 Subject: [PATCH 1/4] Switch byte calculation to batches (#3770) --- src/parquet/ReadParquet.cpp | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/parquet/ReadParquet.cpp b/src/parquet/ReadParquet.cpp index 5fee7edfab..2ab2958ca9 100644 --- a/src/parquet/ReadParquet.cpp +++ b/src/parquet/ReadParquet.cpp @@ -440,19 +440,32 @@ int64_t cpp_getStringColumnNumBytes(const char* filename, const char* colname, v static_cast(column_reader.get()); int64_t numRead = 0; - while (ba_reader->HasNext() && numRead < numElems) { - parquet::ByteArray value; - (void)ba_reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read); - if(values_read > 0) { - offsets[i] = value.len + 1; - byteSize += value.len + 1; - numRead += values_read; - } else { - offsets[i] = 1; - byteSize+=1; - numRead+=1; + + int totalProcessed = 0; + std::vector values(batchSize); + while (ba_reader->HasNext() && totalProcessed < numElems) { + if((numElems - totalProcessed) < batchSize) // adjust batchSize if needed + batchSize = numElems - totalProcessed; + std::vector definition_levels(batchSize,-1); + (void)ba_reader->ReadBatch(batchSize, definition_levels.data(), nullptr, values.data(), &values_read); + totalProcessed += values_read; + int j = 0; + int numProcessed = 0; + while(j < batchSize) { + if(definition_levels[j] == 1 || definition_levels[j] == 3) { + offsets[i] = values[numProcessed].len + 1; + byteSize += values[numProcessed].len + 1; + numProcessed++; + i++; + } else if(definition_levels[j] == 0) { + offsets[i] = 1; + byteSize+=1; + i++; + } else { + j = batchSize; // exit condition + } + j++; } - i++; } } return byteSize; From 8fc7bf86001f6eee2ecff72e21ef5d6893213c00 Mon Sep 17 00:00:00 2001 From: Ben McDonald <46734217+bmcdonald3@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:35:02 -0700 Subject: [PATCH 2/4] Add batch string read (#3768) --- src/ParquetMsg.chpl | 4 ++-- src/parquet/ReadParquet.cpp | 40 +++++++++++++++++++++++++------------ src/parquet/ReadParquet.h | 4 ++-- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/ParquetMsg.chpl b/src/ParquetMsg.chpl index 160160bedc..5803dc827c 100644 --- a/src/ParquetMsg.chpl +++ b/src/ParquetMsg.chpl @@ -172,7 +172,7 @@ module ParquetMsg { } proc readStrFilesByName(ref A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string) throws { - extern proc c_readStrColumnByName(filename, arr_chpl, colname, batchSize, errMsg): int; + extern proc c_readStrColumnByName(filename, arr_chpl, colname, numElems, batchSize, errMsg): int; var (subdoms, length) = getSubdomains(sizes); coforall loc in A.targetLocales() do on loc { @@ -188,7 +188,7 @@ module ParquetMsg { var col: [filedom] t; if c_readStrColumnByName(filename.localize().c_str(), c_ptrTo(col), - dsetname.localize().c_str(), + dsetname.localize().c_str(), filedom.size, batchSize, c_ptrTo(pqErr.errMsg)) == ARROWERROR { pqErr.parquetError(getLineNumber(), getRoutineName(), getModuleName()); } diff --git a/src/parquet/ReadParquet.cpp b/src/parquet/ReadParquet.cpp index 2ab2958ca9..244276d1b1 100644 --- a/src/parquet/ReadParquet.cpp +++ b/src/parquet/ReadParquet.cpp @@ -98,7 +98,7 @@ int64_t readColumnIrregularBitWidth(void* chpl_arr, int64_t startIdx, std::share return i; } -int cpp_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t batchSize, char** errMsg) { +int cpp_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t batchSize, char** errMsg) { try { int64_t ty = cpp_getType(filename, colname, errMsg); @@ -131,23 +131,37 @@ int cpp_readStrColumnByName(const char* filename, void* chpl_arr, const char* co column_reader = row_group_reader->Column(idx); if(ty == ARROWSTRING) { - int16_t definition_level; // nullable type and only reading single records in batch auto chpl_ptr = (unsigned char*)chpl_arr; parquet::ByteArrayReader* reader = static_cast(column_reader.get()); - while (reader->HasNext()) { - parquet::ByteArray value; - (void)reader->ReadBatch(1, &definition_level, nullptr, &value, &values_read); - // if values_read is 0, that means that it was a null value - if(values_read > 0) { - for(int j = 0; j < value.len; j++) { - chpl_ptr[i] = value.ptr[j]; + int totalProcessed = 0; + std::vector values(batchSize); + while (reader->HasNext() && totalProcessed < numElems) { + std::vector definition_levels(batchSize,-1); + if((numElems - totalProcessed) < batchSize) // adjust batchSize if needed + batchSize = numElems - totalProcessed; + + (void)reader->ReadBatch(batchSize, definition_levels.data(), nullptr, values.data(), &values_read); + totalProcessed += values_read; + int j = 0; + int numProcessed = 0; + while(j < batchSize) { + if(definition_levels[j] == 1) { + for(int k = 0; k < values[numProcessed].len; k++) { + chpl_ptr[i] = values[numProcessed].ptr[k]; + i++; + } + i++; // skip one space so the strings are null terminated with a 0 + numProcessed++; + } else if(definition_levels[j] == 0) { i++; + } else { + j = batchSize; // exit loop, not read } + j++; } - i++; // skip one space so the strings are null terminated with a 0 - } + } } } return 0; @@ -744,8 +758,8 @@ int64_t cpp_getListColumnSize(const char* filename, const char* colname, void* c } extern "C" { - int c_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t batchSize, char** errMsg) { - return cpp_readStrColumnByName(filename, chpl_arr, colname, batchSize, errMsg); + int c_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t batchSize, char** errMsg) { + return cpp_readStrColumnByName(filename, chpl_arr, colname, numElems, batchSize, errMsg); } int c_readColumnByName(const char* filename, void* chpl_arr, bool* where_null_chpl, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, int64_t byteLength, bool hasNonFloatNulls, char** errMsg) { diff --git a/src/parquet/ReadParquet.h b/src/parquet/ReadParquet.h index 6a90ee07e3..c449204bf9 100644 --- a/src/parquet/ReadParquet.h +++ b/src/parquet/ReadParquet.h @@ -15,9 +15,9 @@ #include extern "C" { #endif - int c_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t batchSize, char** errMsg); + int c_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t batchSize, char** errMsg); - int cpp_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t batchSize, char** errMsg); + int cpp_readStrColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t batchSize, char** errMsg); int c_readColumnByName(const char* filename, void* chpl_arr, bool* where_null_chpl, const char* colname, int64_t numElems, int64_t startIdx, From e00f10b11480daee4dff26a885fada58f190576f Mon Sep 17 00:00:00 2001 From: tess <48131946+stress-tess@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:02:57 -0400 Subject: [PATCH 3/4] Closes #3720: Update `SetMsg` to use the new message framework (#3774) This PR (Closes #3720) updates `SetMsg` to use the new message framework and modifies the test to run with less than 3 dims Co-authored-by: Tess Hayes --- arkouda/array_api/set_functions.py | 72 +++++----- src/SetMsg.chpl | 211 ++++++++++------------------- src/registry/Commands.chpl | 98 ++++++++++++++ tests/array_api/set_functions.py | 19 ++- 4 files changed, 218 insertions(+), 182 deletions(-) diff --git a/arkouda/array_api/set_functions.py b/arkouda/array_api/set_functions.py index 34034846f3..b826ef2077 100644 --- a/arkouda/array_api/set_functions.py +++ b/arkouda/array_api/set_functions.py @@ -1,11 +1,11 @@ from __future__ import annotations -from .array_object import Array - from typing import NamedTuple, cast from arkouda.client import generic_msg -from arkouda.pdarrayclass import create_pdarray +from arkouda.pdarrayclass import create_pdarray, create_pdarrays + +from .array_object import Array class UniqueAllResult(NamedTuple): @@ -33,21 +33,21 @@ def unique_all(x: Array, /) -> UniqueAllResult: - the inverse indices that reconstruct `x` from the unique values - the counts of each unique value """ - resp = cast( - str, - generic_msg( - cmd=f"uniqueAll{x.ndim}D", - args={"name": x._array}, - ), + arrays = create_pdarrays( + cast( + str, + generic_msg( + cmd=f"uniqueAll<{x.dtype},{x.ndim}>", + args={"name": x._array}, + ), + ) ) - arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")] - return UniqueAllResult( - values=arrays[0], - indices=arrays[1], - inverse_indices=arrays[2], - counts=arrays[3], + values=Array._new(arrays[0]), + indices=Array._new(arrays[1]), + inverse_indices=Array._new(arrays[2]), + counts=Array._new(arrays[3]), ) @@ -57,19 +57,19 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: - the unique values in `x` - the counts of each unique value """ - resp = cast( - str, - generic_msg( - cmd=f"uniqueCounts{x.ndim}D", - args={"name": x._array}, - ), + arrays = create_pdarrays( + cast( + str, + generic_msg( + cmd=f"uniqueCounts<{x.dtype},{x.ndim}>", + args={"name": x._array}, + ), + ) ) - arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")] - return UniqueCountsResult( - values=arrays[0], - counts=arrays[1], + values=Array._new(arrays[0]), + counts=Array._new(arrays[1]), ) @@ -79,19 +79,19 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: - the unique values in `x` - the inverse indices that reconstruct `x` from the unique values """ - resp = cast( - str, - generic_msg( - cmd=f"uniqueInverse{x.ndim}D", - args={"name": x._array}, - ), + arrays = create_pdarrays( + cast( + str, + generic_msg( + cmd=f"uniqueInverse<{x.dtype},{x.ndim}>", + args={"name": x._array}, + ), + ) ) - arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")] - return UniqueInverseResult( - values=arrays[0], - inverse_indices=arrays[1], + values=Array._new(arrays[0]), + inverse_indices=Array._new(arrays[1]), ) @@ -104,7 +104,7 @@ def unique_values(x: Array, /) -> Array: cast( str, generic_msg( - cmd=f"uniqueValues{x.ndim}D", + cmd=f"uniqueValues<{x.dtype},{x.ndim}>", args={"name": x._array}, ), ) diff --git a/src/SetMsg.chpl b/src/SetMsg.chpl index ab293075e6..9ca939e6d4 100644 --- a/src/SetMsg.chpl +++ b/src/SetMsg.chpl @@ -11,155 +11,88 @@ module SetMsg { use RadixSortLSD; use Unique; use Reflection; + use BigInteger; - private config const logLevel = ServerConfig.logLevel; - private config const logChannel = ServerConfig.logChannel; - const sLogger = new Logger(logLevel, logChannel); + @arkouda.instantiateAndRegister + proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype != BigInteger.bigint) && (array_dtype != uint(8)) + { + const name = msgArgs["name"], + eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), + eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a); - @arkouda.registerND - proc uniqueValuesMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - param pn = Reflection.getRoutineName(); - const name = msgArgs.getValueOf("name"), - rname = st.nextName(); - - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st); - - proc getUniqueVals(type t): MsgTuple throws { - const eIn = toSymEntry(gEnt, t, nd), - eFlat = if nd == 1 then eIn.a else flatten(eIn.a); - - const eSorted = radixSortLSD_keys(eFlat); - const eUnique = uniqueFromSorted(eSorted, needCounts=false); - - st.addEntry(rname, createSymEntry(eUnique)); - - const repMsg = "created " + st.attrib(rname); - sLogger.info(getModuleName(),pn,getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - - select gEnt.dtype { - when DType.Int64 do return getUniqueVals(int); - // when DType.UInt8 do return getUniqueVals(uint(8)); - when DType.UInt64 do return getUniqueVals(uint); - when DType.Float64 do return getUniqueVals(real); - when DType.Bool do return getUniqueVals(bool); - otherwise { - var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype); - sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + const eSorted = radixSortLSD_keys(eFlat); + const eUnique = uniqueFromSorted(eSorted, needCounts=false); + + return st.insert(new shared SymEntry(eUnique)); + } + + proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) + { + return MsgTuple.error("unique_values does not support the %s dtype".format(array_dtype:string)); } - @arkouda.registerND - proc uniqueCountsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - param pn = Reflection.getRoutineName(); + @arkouda.instantiateAndRegister + proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws { const name = msgArgs.getValueOf("name"), - uname = st.nextName(), - cname = st.nextName(); - - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st); - - proc getUniqueVals(type t): MsgTuple throws { - const eIn = toSymEntry(gEnt, t, nd), - eFlat = if nd == 1 then eIn.a else flatten(eIn.a); - - const eSorted = radixSortLSD_keys(eFlat); - const (eUnique, eCounts) = uniqueFromSorted(eSorted); - - st.addEntry(uname, createSymEntry(eUnique)); - st.addEntry(cname, createSymEntry(eCounts)); - - const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(cname); - sLogger.info(getModuleName(),pn,getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - - select gEnt.dtype { - when DType.Int64 do return getUniqueVals(int); - // when DType.UInt8 do return getUniqueVals(uint(8)); - when DType.UInt64 do return getUniqueVals(uint); - when DType.Float64 do return getUniqueVals(real); - when DType.Bool do return getUniqueVals(bool); - otherwise { - var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype); - sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), + eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a); + + const eSorted = radixSortLSD_keys(eFlat); + const (eUnique, eCounts) = uniqueFromSorted(eSorted); + + return MsgTuple.fromResponses([ + st.insert(new shared SymEntry(eUnique)), + st.insert(new shared SymEntry(eCounts)), + ]); + } + + proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) + { + return MsgTuple.error("unique_counts does not support the %s dtype".format(array_dtype:string)); } - @arkouda.registerND - proc uniqueInverseMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - param pn = Reflection.getRoutineName(); + @arkouda.instantiateAndRegister + proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws { const name = msgArgs.getValueOf("name"), - uname = st.nextName(), - iname = st.nextName(); - - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st); - - proc getUniqueVals(type t): MsgTuple throws { - const eIn = toSymEntry(gEnt, t, nd), - eFlat = if nd == 1 then eIn.a else flatten(eIn.a); - - const (eUnique, _, inv) = uniqueSortWithInverse(eFlat); - st.addEntry(uname, createSymEntry(eUnique)); - st.addEntry(iname, createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape))); - - const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(iname); - sLogger.info(getModuleName(),pn,getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - - select gEnt.dtype { - when DType.Int64 do return getUniqueVals(int); - // when DType.UInt8 do return getUniqueVals(uint(8)); - when DType.UInt64 do return getUniqueVals(uint); - when DType.Float64 do return getUniqueVals(real); - when DType.Bool do return getUniqueVals(bool); - otherwise { - var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype); - sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), + eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a); + + const (eUnique, _, inv) = uniqueSortWithInverse(eFlat); + + return MsgTuple.fromResponses([ + st.insert(new shared SymEntry(eUnique)), + st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))), + ]); + } + + proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) + { + return MsgTuple.error("unique_inverse does not support the %s dtype".format(array_dtype:string)); } - @arkouda.registerND - proc uniqueAllMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - param pn = Reflection.getRoutineName(); + @arkouda.instantiateAndRegister + proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws { const name = msgArgs.getValueOf("name"), - rnames = for 0..<4 do st.nextName(); - - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st); - - proc getUniqueVals(type t): MsgTuple throws { - const eIn = toSymEntry(gEnt, t, nd), - eFlat = if nd == 1 then eIn.a else flatten(eIn.a); - - const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true); - st.addEntry(rnames[0], createSymEntry(eUnique)); - st.addEntry(rnames[1], createSymEntry(eIndices)); - st.addEntry(rnames[2], createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape))); - st.addEntry(rnames[3], createSymEntry(eCounts)); - - const repMsg = try! "+".join([rn in rnames] "created " + st.attrib(rn)); - sLogger.info(getModuleName(),pn,getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - - select gEnt.dtype { - when DType.Int64 do return getUniqueVals(int); - // when DType.UInt8 do return getUniqueVals(uint(8)); - when DType.UInt64 do return getUniqueVals(uint); - when DType.Float64 do return getUniqueVals(real); - when DType.Bool do return getUniqueVals(bool); - otherwise { - var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype); - sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), + eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a); + + const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true); + + return MsgTuple.fromResponses([ + st.insert(new shared SymEntry(eUnique)), + st.insert(new shared SymEntry(eIndices)), + st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))), + st.insert(new shared SymEntry(eCounts)), + ]); + } + + proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) + { + return MsgTuple.error("unique_all does not support the %s dtype".format(array_dtype:string)); } } diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index 931074e4b5..b4d63e1498 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -1760,6 +1760,104 @@ proc ark_nonzero_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrow return ReductionMsg.nonzero(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); registerFunction('nonzero', ark_nonzero_bigint_1, 'ReductionMsg', 325); +import SetMsg; + +proc ark_uniqueValues_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_int_1, 'SetMsg', 17); + +proc ark_uniqueValues_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_uint_1, 'SetMsg', 17); + +proc ark_uniqueValues_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_uint8_1, 'SetMsg', 17); + +proc ark_uniqueValues_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_real_1, 'SetMsg', 17); + +proc ark_uniqueValues_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_bool_1, 'SetMsg', 17); + +proc ark_uniqueValues_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueValues(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('uniqueValues', ark_uniqueValues_bigint_1, 'SetMsg', 17); + +proc ark_uniqueCounts_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_int_1, 'SetMsg', 37); + +proc ark_uniqueCounts_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_uint_1, 'SetMsg', 37); + +proc ark_uniqueCounts_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_uint8_1, 'SetMsg', 37); + +proc ark_uniqueCounts_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_real_1, 'SetMsg', 37); + +proc ark_uniqueCounts_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_bool_1, 'SetMsg', 37); + +proc ark_uniqueCounts_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueCounts(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('uniqueCounts', ark_uniqueCounts_bigint_1, 'SetMsg', 37); + +proc ark_uniqueInverse_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_int_1, 'SetMsg', 58); + +proc ark_uniqueInverse_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_uint_1, 'SetMsg', 58); + +proc ark_uniqueInverse_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_uint8_1, 'SetMsg', 58); + +proc ark_uniqueInverse_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_real_1, 'SetMsg', 58); + +proc ark_uniqueInverse_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_bool_1, 'SetMsg', 58); + +proc ark_uniqueInverse_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueInverse(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('uniqueInverse', ark_uniqueInverse_bigint_1, 'SetMsg', 58); + +proc ark_uniqueAll_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_int_1, 'SetMsg', 78); + +proc ark_uniqueAll_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_uint_1, 'SetMsg', 78); + +proc ark_uniqueAll_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_uint8_1, 'SetMsg', 78); + +proc ark_uniqueAll_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_real_1, 'SetMsg', 78); + +proc ark_uniqueAll_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_bool_1, 'SetMsg', 78); + +proc ark_uniqueAll_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return SetMsg.uniqueAll(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('uniqueAll', ark_uniqueAll_bigint_1, 'SetMsg', 78); + import StatsMsg; proc ark_reg_mean_generic(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype_0, param array_nd_0: int): MsgTuple throws { diff --git a/tests/array_api/set_functions.py b/tests/array_api/set_functions.py index df1c54c67e..60916897cf 100644 --- a/tests/array_api/set_functions.py +++ b/tests/array_api/set_functions.py @@ -1,5 +1,3 @@ -import json - import numpy as np import pytest @@ -10,7 +8,16 @@ s = SEED -def randArr(shape): +def ret_shapes(): + shapes = [1000] + if pytest.max_rank > 1: + shapes.append((20, 50)) + if pytest.max_rank > 2: + shapes.append((2, 10, 50)) + return shapes + + +def rand_arr(shape): global s s += 2 return xp.asarray(ak.randint(0, 100, shape, dtype=ak.int64, seed=s)) @@ -18,11 +25,9 @@ def randArr(shape): class TestSetFunction: - @pytest.mark.skip_if_max_rank_less_than(3) def test_set_functions(self): - - for shape in [(1000), (20, 50), (2, 10, 50)]: - r = randArr(shape) + for shape in ret_shapes(): + r = rand_arr(shape) ua = xp.unique_all(r) uc = xp.unique_counts(r) From 8d994e65f72c35a6d48225d9aa4940090d60788f Mon Sep 17 00:00:00 2001 From: ajpotts Date: Tue, 17 Sep 2024 08:24:17 -0400 Subject: [PATCH 4/4] Closes #3771: register_commands.py to handle generic scalar type (#3772) Co-authored-by: Amanda Potts --- registration-config.json | 10 ++++ src/registry/Commands.chpl | 10 ++++ src/registry/register_commands.py | 90 ++++++++++++++++--------------- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/registration-config.json b/registration-config.json index 44afced3bf..44a28371fa 100644 --- a/registration-config.json +++ b/registration-config.json @@ -10,6 +10,16 @@ "bool", "bigint" ] + }, + "scalar": { + "dtype": [ + "int", + "uint", + "uint(8)", + "real", + "bool", + "bigint" + ] } } } diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index b4d63e1498..97a32a8f26 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -19,6 +19,16 @@ param regConfig = """ "bool", "bigint" ] + }, + "scalar": { + "dtype": [ + "int", + "uint", + "uint(8)", + "real", + "bool", + "bigint" + ] } } } diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index 6834d9f3c5..103291bd2c 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -1,7 +1,8 @@ -import chapel -import sys -import json import itertools +import json +import sys + +import chapel DEFAULT_MODS = ["MsgProcessing", "GenSymIO"] @@ -210,6 +211,7 @@ def info_tuple(formal): gen_formals.append(formal_info) else: con_formals.append(formal_info) + return con_formals, gen_formals @@ -220,9 +222,7 @@ def clean_stamp_name(name): return name.translate(str.maketrans("[](),=", "______")) -def stamp_generic_command( - generic_proc_name, prefix, module_name, formals, line_num, is_user_proc -): +def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc): """ Create code to stamp out and register a generic command using a generic procedure, and a set values for its generic formals. @@ -295,9 +295,7 @@ def parse_param_class_value(value): if isinstance(value, list): for v in value: if not isinstance(v, (int, float, str)): - raise ValueError( - f"Invalid parameter value type ({type(v)}) in list '{value}'" - ) + raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'") return value elif isinstance(value, int): return [ @@ -313,9 +311,7 @@ def parse_param_class_value(value): if isinstance(vals, list): return vals else: - raise ValueError( - f"Could not create a list of parameter values from '{value}'" - ) + raise ValueError(f"Could not create a list of parameter values from '{value}'") else: raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'") @@ -353,9 +349,7 @@ def generic_permutations(config, gen_formals): + "please check the 'parameter_classes' field in the configuration file" ) - to_permute[formal_name] = parse_param_class_value( - config["parameter_classes"][pclass][pname] - ) + to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname]) return permutations(to_permute) @@ -446,6 +440,28 @@ def unpack_scalar_arg(arg_name, arg_type): return f"\tvar {arg_name} = {ARGS_FORMAL_NAME}['{arg_name}'].toScalar({arg_type});" +def unpack_scalar_arg_with_generic(arg_name, array_count): + """ + Generate the code to unpack a scalar argument + + 'scalar_count' is used to generate unique names when + a procedure has multiple array-symbol formals + + Example: + ``` + var x = msgArgs['x'].toScalar(scalar_dtype_0); + ``` + + Returns the chapel code, and the specifications for the + 'dtype' and type-constructor arguments + """ + dtype_arg_name = "scalar_dtype_" + str(array_count) + return ( + unpack_scalar_arg(arg_name, dtype_arg_name), + [(dtype_arg_name, "type", None, None)], + ) + + def unpack_tuple_arg(arg_name, tuple_size, scalar_type): """ Generate the code to unpack a tuple argument @@ -492,8 +508,7 @@ def gen_signature(user_proc_name, generic_args=None): if generic_args: name = "ark_reg_" + user_proc_name + "_generic" arg_strings = [ - f"{kind} {name}: {ft}" if ft else f"{kind} {name}" - for name, kind, ft, _ in generic_args + f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args ] proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}" else: @@ -511,11 +526,13 @@ def gen_arg_unpacking(formals): unpack_lines = [] generic_args = [] array_arg_counter = 0 + scalar_arg_counter = 0 array_domain_queries = {} array_dtype_queries = {} for fname, fintent, ftype, finfo in formals: + if ftype in chapel_scalar_types: unpack_lines.append(unpack_scalar_arg(fname, ftype)) elif ftype == "": @@ -556,12 +573,14 @@ def gen_arg_unpacking(formals): unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype)) else: if ftype in array_dtype_queries.keys(): - unpack_lines.append( - unpack_scalar_arg(fname, array_dtype_queries[ftype]) - ) + + unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype])) else: # TODO: fully handle generic user-defined types - unpack_lines.append(unpack_user_symbol(fname, ftype)) + code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter) + unpack_lines.append(code) + generic_args += scalar_args + scalar_arg_counter += 1 return ("\n".join(unpack_lines), generic_args) @@ -652,14 +671,10 @@ def gen_command_proc(name, return_type, formals, mod_name): arg_unpack, command_formals = gen_arg_unpacking(formals) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) - fn_call, result_name = gen_user_function_call( - name, [f[0] for f in formals], mod_name, return_type - ) + fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type) # get the names of the array-elt-type queries in the formals - array_etype_queries = [ - f[3][1] for f in formals if (f[2] == "" and f[3] is not None) - ] + array_etype_queries = [f[3][1] for f in formals if (f[2] == "" and f[3] is not None)] # assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference # or if it is a `SymEntry` type-constructor call @@ -678,30 +693,22 @@ def gen_command_proc(name, return_type, formals, mod_name): ) ) returns_array = ( - return_type - and isinstance(return_type, chapel.BracketLoop) - and return_type.is_maybe_array_type() + return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type() ) if returns_array: - symbol_creation, result_name = gen_symbol_creation( - ARRAY_ENTRY_CLASS_NAME, result_name - ) + symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name) else: symbol_creation = "" response = gen_response(result_name, returns_symbol or returns_array) - command_proc = "\n".join( - [signature, arg_unpack, fn_call, symbol_creation, response, "}"] - ) + command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"]) return (command_proc, cmd_name, is_generic_command, command_formals) -def stamp_out_command( - config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc -): +def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc): """ Yield instantiations of a generic command with using the values from the configuration file @@ -723,9 +730,7 @@ def stamp_out_command( formal_perms = generic_permutations(config, formals) for fp in formal_perms: - stamp = stamp_generic_command( - name, cmd_prefix, mod_name, fp, line_num, is_user_proc - ) + stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc) yield stamp @@ -782,6 +787,7 @@ def register_commands(config, source_files): (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc( name, fn.return_type(), con_formals, mod_name ) + file_stamps.append(cmd_proc) count += 1