Skip to content

Commit

Permalink
Closes #3526: refactor argsortMsg to remove registerND annotation (#3602
Browse files Browse the repository at this point in the history
)

Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts committed Aug 7, 2024
1 parent 9853fa0 commit 92fdfe8
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 97 deletions.
29 changes: 20 additions & 9 deletions arkouda/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,26 @@ def argsort(
return zeros(0, dtype=pda.dtype)
if isinstance(pda, pdarray) and pda.dtype == bigint:
return coargsort(pda.bigint_to_uint_arrays(), algorithm)
repMsg = generic_msg(
cmd=f"argsort{pda.ndim}D",
args={
"name": pda.entry.name if isinstance(pda, Strings) else pda.name,
"algoName": algorithm.name,
"objType": pda.objType,
"axis": axis,
},
)

if isinstance(pda, Strings):
repMsg = generic_msg(
cmd="argsortStrings",
args={
"name": pda.entry.name,
"algoName": algorithm.name,
},
)
else:
repMsg = generic_msg(
cmd=f"argsort<{pda.dtype.name},1>",
args={
"name": pda.name,
"algoName": algorithm.name,
"objType": pda.objType,
"axis": axis,
},
)

return create_pdarray(cast(str, repMsg))


Expand Down
140 changes: 53 additions & 87 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module ArgSortMsg
{
use ServerConfig;

use MsgProcessing;
use CTypes;

use Time;
Expand All @@ -30,6 +31,7 @@ module ArgSortMsg
use ServerErrors;
use Logging;
use Message;
use BigInteger;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand Down Expand Up @@ -62,6 +64,24 @@ module ArgSortMsg
};
config const defaultSortAlgorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD;

proc getSortingAlgoritm(algoName:string) throws{
var algorithm = defaultSortAlgorithm;
if algoName != "" {
try {
return algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
return algorithm;
}

// proc DefaultComparator.keyPart(x: _tuple, i:int) where !isHomogeneousTuple(x) &&
// (isInt(x(0)) || isUint(x(0)) || isReal(x(0))) {

Expand Down Expand Up @@ -404,95 +424,41 @@ module ArgSortMsg
}

/* argsort takes pdarray and returns an index vector iv which sorts the array */
@arkouda.registerND
proc argsortMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
var repMsg: string; // response message
const name = msgArgs.getValueOf("name");
const algoName = msgArgs.getValueOf("algoName");
const axis = msgArgs.get("axis").getIntValue();
var algorithm: SortingAlgorithm = defaultSortAlgorithm;

if algoName != "" {
try {
algorithm = algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
// get next symbol name
const ivname = st.nextName();

asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"cmd: %s name: %s ivname: %s".format(cmd, name, ivname));

var objtype = msgArgs.getValueOf("objType").toUpper(): ObjType;
select objtype {
when ObjType.PDARRAY {
var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);
// check and throw if over memory limit
overMemLimit(radixSortLSD_memEst(gEnt.size, gEnt.itemsize));

select (gEnt.dtype) {
when (DType.Int64) {
var e = toSymEntry(gEnt,int, nd);
var iv = argsortDefault(e.a, algorithm=algorithm, axis);
st.addEntry(ivname, createSymEntry(iv));
}
when (DType.UInt64) {
var e = toSymEntry(gEnt,uint, nd);
var iv = argsortDefault(e.a, algorithm=algorithm, axis);
st.addEntry(ivname, createSymEntry(iv));
}
when (DType.Float64) {
var e = toSymEntry(gEnt, real, nd);
var iv = argsortDefault(e.a, axis=axis);
st.addEntry(ivname, createSymEntry(iv));
}
when (DType.Bool) {
var e = toSymEntry(gEnt,bool, nd);
var int_ea = makeDistArray(e.a:int);
var iv = argsortDefault(int_ea, algorithm=algorithm, axis);
st.addEntry(ivname, createSymEntry(iv));
}
otherwise {
var errorMsg = notImplementedError(pn,gEnt.dtype);
asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
}
when ObjType.STRINGS {
if nd != 1 {
const errorMsg = "argsort only supports 1D strings";
asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
var strings = getSegString(name, st);
// check and throw if over memory limit
overMemLimit((8 * strings.size * 8)
+ (2 * here.maxTaskPar * numLocales * 2**16 * 8));
var iv = strings.argsort();
st.addEntry(ivname, createSymEntry(iv));
}
otherwise {
var errorMsg = notImplementedError(pn, objtype: string);
asLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
@arkouda.instantiateAndRegister
proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype != BigInteger.bigint) && (array_dtype != uint(8))
{
const name = msgArgs["name"],
algoName = msgArgs["algoName"].toScalar(string),
algorithm = getSortingAlgoritm(algoName),
axis = msgArgs["axis"].toScalar(int),
symEntry = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
vals = if (array_dtype == bool) then (symEntry.a:int) else (symEntry.a: array_dtype);

const iv = argsortDefault(vals, algorithm=algorithm, axis);
return st.insert(new shared SymEntry(iv));
}

proc argsort(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("argsort does not support the %s dtype".format(array_dtype:string));
}

repMsg = "created " + st.attrib(ivname);
asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
proc argsortStrings(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const name = msgArgs["name"].toScalar(string),
strings = getSegString(name, st),
algoName = msgArgs["algoName"].toScalar(string),
algorithm = getSortingAlgoritm(algoName);

// check and throw if over memory limit
overMemLimit((8 * strings.size * 8)
+ (2 * here.maxTaskPar * numLocales * 2**16 * 8));
const iv = strings.argsort();
return st.insert(new shared SymEntry(iv));
}

use CommandMap;
registerFunction("argsortStrings", argsortStrings, getModuleName());
registerFunction("coargsort", coargsortMsg, getModuleName());
}
}
26 changes: 26 additions & 0 deletions src/registry/Commands.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ param regConfig = """
}
""";

import ArgSortMsg;

proc ark_argsort_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=int, array_nd=1);
registerFunction('argsort<int64,1>', ark_argsort_int_1, 'ArgSortMsg', 428);

proc ark_argsort_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=uint, array_nd=1);
registerFunction('argsort<uint64,1>', ark_argsort_uint_1, 'ArgSortMsg', 428);

proc ark_argsort_uint8_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=uint(8), array_nd=1);
registerFunction('argsort<uint8,1>', ark_argsort_uint8_1, 'ArgSortMsg', 428);

proc ark_argsort_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=real, array_nd=1);
registerFunction('argsort<float64,1>', ark_argsort_real_1, 'ArgSortMsg', 428);

proc ark_argsort_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=bool, array_nd=1);
registerFunction('argsort<bool,1>', ark_argsort_bool_1, 'ArgSortMsg', 428);

proc ark_argsort_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
return ArgSortMsg.argsort(cmd, msgArgs, st, array_dtype=bigint, array_nd=1);
registerFunction('argsort<bigint,1>', ark_argsort_bigint_1, 'ArgSortMsg', 428);

import CastMsg;

proc ark_cast_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do
Expand Down
4 changes: 3 additions & 1 deletion tests/deprecated/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,9 @@ def test_get_fixes(self):
self.assertListEqual(["c", "d", "i"], p.to_list())

def test_encoding(self):
idna_strings = ak.array(["Bücher.example", "ドメイン.テスト", "домен.испытание", "Königsgäßchen"])
idna_strings = ak.array(
["Bücher.example", "ドメイン.テスト", "домен.испытание", "Königsgäßchen"]
)
expected = ak.array(
[
"xn--bcher-kva.example",
Expand Down

0 comments on commit 92fdfe8

Please sign in to comment.