Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #3720: Update SetMsg to use the new message framework #3774

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions arkouda/array_api/set_functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from .array_object import Array

from typing import NamedTuple, cast

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarrayclass import create_pdarray, create_pdarrays

from .array_object import Array


class UniqueAllResult(NamedTuple):
Expand Down Expand Up @@ -33,21 +33,21 @@ def unique_all(x: Array, /) -> UniqueAllResult:
- the inverse indices that reconstruct `x` from the unique values
- the counts of each unique value
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueAll{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueAll<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueAllResult(
values=arrays[0],
indices=arrays[1],
inverse_indices=arrays[2],
counts=arrays[3],
values=Array._new(arrays[0]),
indices=Array._new(arrays[1]),
inverse_indices=Array._new(arrays[2]),
counts=Array._new(arrays[3]),
)


Expand All @@ -57,19 +57,19 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
- the unique values in `x`
- the counts of each unique value
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueCounts{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueCounts<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueCountsResult(
values=arrays[0],
counts=arrays[1],
values=Array._new(arrays[0]),
counts=Array._new(arrays[1]),
)


Expand All @@ -79,19 +79,19 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
- the unique values in `x`
- the inverse indices that reconstruct `x` from the unique values
"""
resp = cast(
str,
generic_msg(
cmd=f"uniqueInverse{x.ndim}D",
args={"name": x._array},
),
arrays = create_pdarrays(
cast(
str,
generic_msg(
cmd=f"uniqueInverse<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
)

arrays = [Array._new(create_pdarray(r)) for r in resp.split("+")]

return UniqueInverseResult(
values=arrays[0],
inverse_indices=arrays[1],
values=Array._new(arrays[0]),
inverse_indices=Array._new(arrays[1]),
)


Expand All @@ -104,7 +104,7 @@ def unique_values(x: Array, /) -> Array:
cast(
str,
generic_msg(
cmd=f"uniqueValues{x.ndim}D",
cmd=f"uniqueValues<{x.dtype},{x.ndim}>",
args={"name": x._array},
),
)
Expand Down
211 changes: 72 additions & 139 deletions src/SetMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,155 +11,88 @@ module SetMsg {
use RadixSortLSD;
use Unique;
use Reflection;
use BigInteger;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
const sLogger = new Logger(logLevel, logChannel);
@arkouda.instantiateAndRegister
proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype != BigInteger.bigint) && (array_dtype != uint(8))
{
const name = msgArgs["name"],
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

@arkouda.registerND
proc uniqueValuesMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name"),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const eUnique = uniqueFromSorted(eSorted, needCounts=false);

st.addEntry(rname, createSymEntry(eUnique));

const repMsg = "created " + st.attrib(rname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
const eSorted = radixSortLSD_keys(eFlat);
const eUnique = uniqueFromSorted(eSorted, needCounts=false);

return st.insert(new shared SymEntry(eUnique));
}

proc uniqueValues(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_values does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueCountsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
cname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const (eUnique, eCounts) = uniqueFromSorted(eSorted);

st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(cname, createSymEntry(eCounts));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(cname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const eSorted = radixSortLSD_keys(eFlat);
const (eUnique, eCounts) = uniqueFromSorted(eSorted);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(eCounts)),
]);
}

proc uniqueCounts(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_counts does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueInverseMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
uname = st.nextName(),
iname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, _, inv) = uniqueSortWithInverse(eFlat);
st.addEntry(uname, createSymEntry(eUnique));
st.addEntry(iname, createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));

const repMsg = "created " + st.attrib(uname) + "+created " + st.attrib(iname);
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, _, inv) = uniqueSortWithInverse(eFlat);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))),
]);
}

proc uniqueInverse(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_inverse does not support the %s dtype".format(array_dtype:string));
}

@arkouda.registerND
proc uniqueAllMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.instantiateAndRegister
proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws {
const name = msgArgs.getValueOf("name"),
rnames = for 0..<4 do st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc getUniqueVals(type t): MsgTuple throws {
const eIn = toSymEntry(gEnt, t, nd),
eFlat = if nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true);
st.addEntry(rnames[0], createSymEntry(eUnique));
st.addEntry(rnames[1], createSymEntry(eIndices));
st.addEntry(rnames[2], createSymEntry(if nd == 1 then inv else unflatten(inv, eIn.a.shape)));
st.addEntry(rnames[3], createSymEntry(eCounts));

const repMsg = try! "+".join([rn in rnames] "created " + st.attrib(rn));
sLogger.info(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return getUniqueVals(int);
// when DType.UInt8 do return getUniqueVals(uint(8));
when DType.UInt64 do return getUniqueVals(uint);
when DType.Float64 do return getUniqueVals(real);
when DType.Bool do return getUniqueVals(bool);
otherwise {
var errorMsg = notImplementedError(getRoutineName(),gEnt.dtype);
sLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
eIn = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
eFlat = if array_nd == 1 then eIn.a else flatten(eIn.a);

const (eUnique, eCounts, inv, eIndices) = uniqueSortWithInverse(eFlat, needIndices=true);

return MsgTuple.fromResponses([
st.insert(new shared SymEntry(eUnique)),
st.insert(new shared SymEntry(eIndices)),
st.insert(new shared SymEntry(if array_nd == 1 then inv else unflatten(inv, eIn.a.shape))),
st.insert(new shared SymEntry(eCounts)),
]);
}

proc uniqueAll(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == BigInteger.bigint) || (array_dtype == uint(8))
{
return MsgTuple.error("unique_all does not support the %s dtype".format(array_dtype:string));
}
}
Loading
Loading