From 92fdfe81f260b230753717339719002c89578e1d Mon Sep 17 00:00:00 2001 From: ajpotts Date: Wed, 7 Aug 2024 12:10:01 -0400 Subject: [PATCH] Closes #3526: refactor argsortMsg to remove registerND annotation (#3602) Co-authored-by: Amanda Potts --- arkouda/sorting.py | 29 +++++-- src/ArgSortMsg.chpl | 140 ++++++++++++-------------------- src/registry/Commands.chpl | 26 ++++++ tests/deprecated/string_test.py | 4 +- 4 files changed, 102 insertions(+), 97 deletions(-) diff --git a/arkouda/sorting.py b/arkouda/sorting.py index 18eaf57ec9..7cd2f15bb1 100644 --- a/arkouda/sorting.py +++ b/arkouda/sorting.py @@ -71,15 +71,26 @@ def argsort( return zeros(0, dtype=pda.dtype) if isinstance(pda, pdarray) and pda.dtype == bigint: return coargsort(pda.bigint_to_uint_arrays(), algorithm) - repMsg = generic_msg( - cmd=f"argsort{pda.ndim}D", - args={ - "name": pda.entry.name if isinstance(pda, Strings) else pda.name, - "algoName": algorithm.name, - "objType": pda.objType, - "axis": axis, - }, - ) + + if isinstance(pda, Strings): + repMsg = generic_msg( + cmd="argsortStrings", + args={ + "name": pda.entry.name, + "algoName": algorithm.name, + }, + ) + else: + repMsg = generic_msg( + cmd=f"argsort<{pda.dtype.name},1>", + args={ + "name": pda.name, + "algoName": algorithm.name, + "objType": pda.objType, + "axis": axis, + }, + ) + return create_pdarray(cast(str, repMsg)) diff --git a/src/ArgSortMsg.chpl b/src/ArgSortMsg.chpl index 417f1ba544..d952fd54e6 100644 --- a/src/ArgSortMsg.chpl +++ b/src/ArgSortMsg.chpl @@ -6,6 +6,7 @@ module ArgSortMsg { use ServerConfig; + use MsgProcessing; use CTypes; use Time; @@ -30,6 +31,7 @@ module ArgSortMsg use ServerErrors; use Logging; use Message; + use BigInteger; private config const logLevel = ServerConfig.logLevel; private config const logChannel = ServerConfig.logChannel; @@ -62,6 +64,24 @@ module ArgSortMsg }; config const defaultSortAlgorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD; + proc getSortingAlgoritm(algoName:string) throws{ + var algorithm = defaultSortAlgorithm; + if algoName != "" { + try { + return algoName: SortingAlgorithm; + } catch { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algoName), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } + return algorithm; + } + // proc DefaultComparator.keyPart(x: _tuple, i:int) where !isHomogeneousTuple(x) && // (isInt(x(0)) || isUint(x(0)) || isReal(x(0))) { @@ -404,95 +424,41 @@ module ArgSortMsg } /* argsort takes pdarray and returns an index vector iv which sorts the array */ - @arkouda.registerND - proc argsortMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { - param pn = Reflection.getRoutineName(); - var repMsg: string; // response message - const name = msgArgs.getValueOf("name"); - const algoName = msgArgs.getValueOf("algoName"); - const axis = msgArgs.get("axis").getIntValue(); - var algorithm: SortingAlgorithm = defaultSortAlgorithm; - - if algoName != "" { - try { - algorithm = algoName: SortingAlgorithm; - } catch { - throw getErrorWithContext( - msg="Unrecognized sorting algorithm: %s".format(algoName), - lineNumber=getLineNumber(), - routineName=getRoutineName(), - moduleName=getModuleName(), - errorClass="NotImplementedError" - ); - } - } - // get next symbol name - const ivname = st.nextName(); - - asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "cmd: %s name: %s ivname: %s".format(cmd, name, ivname)); - - var objtype = msgArgs.getValueOf("objType").toUpper(): ObjType; - select objtype { - when ObjType.PDARRAY { - var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st); - // check and throw if over memory limit - overMemLimit(radixSortLSD_memEst(gEnt.size, gEnt.itemsize)); - - select (gEnt.dtype) { - when (DType.Int64) { - var e = toSymEntry(gEnt,int, nd); - var iv = argsortDefault(e.a, algorithm=algorithm, axis); - st.addEntry(ivname, createSymEntry(iv)); - } - when (DType.UInt64) { - var e = toSymEntry(gEnt,uint, nd); - var iv = argsortDefault(e.a, algorithm=algorithm, axis); - st.addEntry(ivname, createSymEntry(iv)); - } - when (DType.Float64) { - var e = toSymEntry(gEnt, real, nd); - var iv = argsortDefault(e.a, axis=axis); - st.addEntry(ivname, createSymEntry(iv)); - } - when (DType.Bool) { - var e = toSymEntry(gEnt,bool, nd); - var int_ea = makeDistArray(e.a:int); - var iv = argsortDefault(int_ea, algorithm=algorithm, axis); - st.addEntry(ivname, createSymEntry(iv)); - } - otherwise { - var errorMsg = notImplementedError(pn,gEnt.dtype); - asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } - } - when ObjType.STRINGS { - if nd != 1 { - const errorMsg = "argsort only supports 1D strings"; - asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - var strings = getSegString(name, st); - // check and throw if over memory limit - overMemLimit((8 * strings.size * 8) - + (2 * here.maxTaskPar * numLocales * 2**16 * 8)); - var iv = strings.argsort(); - st.addEntry(ivname, createSymEntry(iv)); - } - otherwise { - var errorMsg = notImplementedError(pn, objtype: string); - asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + @arkouda.instantiateAndRegister + proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype != BigInteger.bigint) && (array_dtype != uint(8)) + { + const name = msgArgs["name"], + algoName = msgArgs["algoName"].toScalar(string), + algorithm = getSortingAlgoritm(algoName), + axis = msgArgs["axis"].toScalar(int), + symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd), + vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype); + + const iv = argsortDefault(vals, algorithm=algorithm, axis); + return st.insert(new shared SymEntry(iv)); + } + + proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws + where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8)) + { + return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string)); + } - repMsg = "created " + st.attrib(ivname); - asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); + proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { + const name = msgArgs["name"].toScalar(string), + strings = getSegString(name, st), + algoName = msgArgs["algoName"].toScalar(string), + algorithm = getSortingAlgoritm(algoName); + + // check and throw if over memory limit + overMemLimit((8 * strings.size * 8) + + (2 * here.maxTaskPar * numLocales * 2**16 * 8)); + const iv = strings.argsort(); + return st.insert(new shared SymEntry(iv)); } use CommandMap; + registerFunction("argsortStrings", argsortStrings, getModuleName()); registerFunction("coargsort", coargsortMsg, getModuleName()); -} +} \ No newline at end of file diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index 15f98df535..b6ce450ad9 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -24,6 +24,32 @@ param regConfig = """ } """; +import ArgSortMsg; + +proc ark_argsort_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=int, array_nd=1); +registerFunction('argsort', ark_argsort_int_1, 'ArgSortMsg', 428); + +proc ark_argsort_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=uint, array_nd=1); +registerFunction('argsort', ark_argsort_uint_1, 'ArgSortMsg', 428); + +proc ark_argsort_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1); +registerFunction('argsort', ark_argsort_uint8_1, 'ArgSortMsg', 428); + +proc ark_argsort_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=real, array_nd=1); +registerFunction('argsort', ark_argsort_real_1, 'ArgSortMsg', 428); + +proc ark_argsort_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=bool, array_nd=1); +registerFunction('argsort', ark_argsort_bool_1, 'ArgSortMsg', 428); + +proc ark_argsort_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); +registerFunction('argsort', ark_argsort_bigint_1, 'ArgSortMsg', 428); + import CastMsg; proc ark_cast_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do diff --git a/tests/deprecated/string_test.py b/tests/deprecated/string_test.py index 14801e304d..79f142cca4 100644 --- a/tests/deprecated/string_test.py +++ b/tests/deprecated/string_test.py @@ -885,7 +885,9 @@ def test_get_fixes(self): self.assertListEqual(["c", "d", "i"], p.to_list()) def test_encoding(self): - idna_strings = ak.array(["Bücher.example", "ドメイン.テスト", "домен.испытание", "Königsgäßchen"]) + idna_strings = ak.array( + ["Bücher.example", "ドメイン.テスト", "домен.испытание", "Königsgäßchen"] + ) expected = ak.array( [ "xn--bcher-kva.example",