Skip to content

Commit

Permalink
Closes Bears-R-Us#3631: refactor-utilMsg-to-remove-registerND-annotat…
Browse files Browse the repository at this point in the history
…ion (Bears-R-Us#3681)

Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts authored Aug 21, 2024
1 parent 2efe473 commit db1d851
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 177 deletions.
10 changes: 5 additions & 5 deletions arkouda/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def clip(a: Array, a_min, a_max, /) -> Array:
return Array._new(
create_pdarray(
generic_msg(
cmd=f"clip{a.ndim}D",
cmd=f"clip<{a.dtype},{a.ndim}>",
args={
"name": a._array,
"x": a._array,
"min": a_min,
"max": a_max,
},
Expand Down Expand Up @@ -111,9 +111,9 @@ def diff(a: Array, /, n: int = 1, axis: int = -1, prepend=None, append=None) ->
return Array._new(
create_pdarray(
generic_msg(
cmd=f"diff{a.ndim}D",
cmd=f"diff<{a.dtype},{a.ndim}>",
args={
"name": a_._array,
"x": a_._array,
"n": n,
"axis": axis,
},
Expand Down Expand Up @@ -176,7 +176,7 @@ def pad(
return Array._new(
create_pdarray(
generic_msg(
cmd=f"pad{array.ndim}D",
cmd=f"pad<{array.dtype},{array.ndim}>",
args={
"name": array._array,
"padWidthBefore": tuple(pad_widths_b),
Expand Down
265 changes: 101 additions & 164 deletions src/UtilMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ module UtilMsg {
use Logging;
use Message;
use AryUtil;
use List;
use BigInteger;

use MultiTypeSymEntry;
use MultiTypeSymbolTable;
Expand All @@ -21,127 +23,81 @@ module UtilMsg {
see: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
*/
@arkouda.registerND
proc clipMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();
@arkouda.registerCommand()
proc clip(const ref x: [?d] ?t, min: real, max: real): [] t throws
where (t == int) || (t == real) || (t == uint(8)) || (t == uint(64)) {

const name = msgArgs.getValueOf("name"),
min = msgArgs.get("min"),
max = msgArgs.get("max"),
rname = st.nextName();
var y = makeDistArray(d, t);

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);
const minVal = min: t,
maxVal = max: t;

proc doClip(type t): MsgTuple throws {
const minVal = min.getScalarValue(t),
maxVal = max.getScalarValue(t);

const e = toSymEntry(gEnt, t, nd);
var c = st.addEntry(rname, (...e.tupShape), t);

forall i in e.a.domain {
if e.a[i] < minVal then
c.a[i] = minVal;
else if e.a[i] > maxVal then
c.a[i] = maxVal;
forall i in d {
if x[i] < minVal then
y[i] = minVal;
else if x[i] > maxVal then
y[i] = maxVal;
else
c.a[i] = e.a[i];
y[i] = x[i];
}
return y;
}

const repMsg = "created " + st.attrib(rname);
uLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select gEnt.dtype {
when DType.Int64 do return doClip(int);
when DType.UInt8 do return doClip(uint(8));
when DType.UInt64 do return doClip(uint);
when DType.Float64 do return doClip(real);
when DType.Bool do return doClip(bool);
otherwise {
const errorMsg = notImplementedError(pn,gEnt.dtype);
uLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
proc clip(const ref x: [?d] ?t, min: real, max: real): [d] t throws
where (t != int) && (t != real) && (t != uint(8)) && (t != uint(64)){
throw new Error("clip does not support dtype %s".format(t:string));
}

/*
Compute the n'th order discrete difference along a given axis
see: https://numpy.org/doc/stable/reference/generated/numpy.diff.html
*/
@arkouda.registerND
proc diffMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();

const name = msgArgs.getValueOf("name"),
n = msgArgs.get("n").getIntValue(),
axis = msgArgs.get("axis").getPositiveIntValue(nd),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc doDiff(type t): MsgTuple throws {
const e = toSymEntry(gEnt, t, nd);

if n == 1 {
// 1st order difference: requires no temporary storage
const outDom = subDomain(e.tupShape, axis, 1);
var d = st.addEntry(rname, (...outDom.shape), t);

forall axisSliceIdx in domOffAxis(e.a.domain, axis) {
const slice = domOnAxis(outDom, tuplify(axisSliceIdx), axis);
for i in slice {
var idxp = tuplify(i);
idxp[axis] += 1;
d.a[i] = e.a[idxp] - e.a[i];
}

@arkouda.registerCommand()
proc diff(x: [?d] ?t, n: int, axis: int): [] t throws
where (t == real) || (t == int) || (t == uint(8)) || (t == uint(64)){

const outDom = subDomain(x.shape, axis, n);
if n == 1 {
// 1st order difference: requires no temporary storage
var y = makeDistArray(outDom, t);
for axisSliceIdx in domOffAxis(d, axis) {
const slice = domOnAxis(outDom, tuplify(axisSliceIdx), axis);
for i in slice {
var idxp = tuplify(i);
idxp[axis] += 1;
y[i] = x[idxp] - x[i];
}
} else {
// n'th order difference: requires 2 temporary arrays
var d1 = makeDistArray(e.a);

{
var d2 = makeDistArray(e.a.domain, e.a.eltType);
for m in 1..n {
d1 <=> d2;
const diffSubDom = subDomain(e.tupShape, axis, m);

forall axisSliceIdx in domOffAxis(e.a.domain, axis) {
const slice = domOnAxis(diffSubDom, tuplify(axisSliceIdx), axis);

for i in slice {
var idxp = tuplify(i);
idxp[axis] += 1;
d1[i] = d2[idxp] - d2[i];
}
}
return y;
} else {
// n'th order difference: requires 2 temporary arrays
var d1 = makeDistArray(x);
{
var d2 = makeDistArray(d, t);
for m in 1..n {
d1 <=> d2;
const diffSubDom = subDomain(x.shape, axis, m);

forall axisSliceIdx in domOffAxis(d, axis) {
const slice = domOnAxis(diffSubDom, tuplify(axisSliceIdx), axis);

for i in slice {
var idxp = tuplify(i);
idxp[axis] += 1;
d1[i] = d2[idxp] - d2[i];
}
}
} // d2 deinit here

const outDom = subDomain(e.tupShape, axis, n),
d = createSymEntry(d1[outDom]);
st.addEntry(rname, d);
}

const repMsg = "created " + st.attrib(rname);
uLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}
} // d2 deinit here
return d1[outDom];
}
}

select gEnt.dtype {
when DType.Int64 do return doDiff(int);
when DType.UInt8 do return doDiff(uint(8));
when DType.UInt64 do return doDiff(uint);
when DType.Float64 do return doDiff(real);
otherwise {
const errorMsg = notImplementedError(pn,gEnt.dtype);
uLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
proc diff(x: [?d] ?t, n: int, axis: int): [d] t throws
where (t != real) && (t != int) && (t != uint(8)) && (t != uint(64)){
throw new Error("diff does not support dtype %s".format(t:string));
}

// helper to create a domain that's 'n' elements smaller in the 'axis' dimension
Expand All @@ -165,71 +121,52 @@ module UtilMsg {
Implements the 'constant' mode of numpy.pad: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
*/
@arkouda.registerND
proc padMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws {
param pn = Reflection.getRoutineName();

const name = msgArgs.getValueOf("name"),
padWidthBefore = msgArgs.get("padWidthBefore").getTuple(nd),
padWidthAfter = msgArgs.get("padWidthAfter").getTuple(nd),
padValsBefore = msgArgs.get("padValsBefore"),
padValsAfter = msgArgs.get("padValsAfter"),
rname = st.nextName();

var gEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(name, st);

proc doPad(type t): MsgTuple throws {
const e = toSymEntry(gEnt, t, nd);

const pvb = padValsBefore.toScalarArray(t, nd),
pva = padValsAfter.toScalarArray(t, nd);

// compute the padded shape
var outShape: nd*int;
for i in 0..<nd do outShape[i] = padWidthBefore[i] + e.tupShape[i] + padWidthAfter[i];

var p = st.addEntry(rname, (...outShape), t);

// copy the original array into the padded array
const dOffset = e.a.domain.translate(padWidthBefore);
p.a[dOffset] = e.a;

// starting with the last dimension, pad the array (i.e., dimension 0 overwrites dimension 1 in the corners, etc.)
for rank in 0..<nd {
var beforeSlice, afterSlice: nd*range;
for i in 0..<nd {
// TODO: compute the exact slice for each pad-section so these assignments can be done
// in parallel and to avoid accessing the corners of the array unnecessarily (which
// could result in additional comm for large pad widths)
if i == rank {
beforeSlice[i] = 0..<padWidthBefore[i];
afterSlice[i] = (outShape[i]-padWidthAfter[i])..<outShape[i];
} else {
beforeSlice[i] = 0..<outShape[i];
afterSlice[i] = 0..<outShape[i];
}
@arkouda.instantiateAndRegister
proc pad(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype == int) || (array_dtype == uint(8)) || (array_dtype == uint(64)) || (array_dtype == real) || (array_dtype == bool) {

const e = st[msgArgs["name"]]: SymEntry(array_dtype, array_nd),
padWidthBefore = msgArgs["padWidthBefore"].toScalarTuple(int, array_nd),
padWidthAfter = msgArgs["padWidthAfter"].toScalarTuple(int, array_nd),
padValsBefore = msgArgs["padValsBefore"].toScalarArray(array_dtype, array_nd),
padValsAfter = msgArgs["padValsAfter"].toScalarArray(array_dtype, array_nd);

// compute the padded shape
var outShape: array_nd*int;
for i in 0..<array_nd do outShape[i] = padWidthBefore[i] + e.tupShape[i] + padWidthAfter[i];

var paddedArray = makeDistArray((...outShape), array_dtype);

// copy the original array into the padded array
const dOffset = e.a.domain.translate(padWidthBefore);
paddedArray[dOffset] = e.a;

// starting with the last dimension, pad the array (i.e., dimension 0 overwrites dimension 1 in the corners, etc.)
for rank in 0..<array_nd {
var beforeSlice, afterSlice: array_nd*range;
for i in 0..<array_nd {
// TODO: compute the exact slice for each pad-section so these assignments can be done
// in parallel and to avoid accessing the corners of the array unnecessarily (which
// could result in additional comm for large pad widths)
if i == rank {
beforeSlice[i] = 0..<padWidthBefore[i];
afterSlice[i] = (outShape[i]-padWidthAfter[i])..<outShape[i];
} else {
beforeSlice[i] = 0..<outShape[i];
afterSlice[i] = 0..<outShape[i];
}

p.a[(...beforeSlice)] = pvb[rank];
p.a[(...afterSlice)] = pva[rank];
}

const repMsg = "created " + st.attrib(rname);
uLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
paddedArray[(...beforeSlice)] = padValsBefore[rank];
paddedArray[(...afterSlice)] = padValsAfter[rank];
}

select gEnt.dtype {
when DType.Int64 do return doPad(int);
when DType.UInt8 do return doPad(uint(8));
when DType.UInt64 do return doPad(uint);
when DType.Float64 do return doPad(real);
when DType.Bool do return doPad(bool);
otherwise {
const errorMsg = notImplementedError(pn,gEnt.dtype);
uLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
}
return st.insert(new shared SymEntry(paddedArray));
}

proc pad(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where (array_dtype != int) && (array_dtype != uint(8)) && (array_dtype != uint(64)) && (array_dtype != real) && (array_dtype != bool) {
throw new Error("pad does not support dtype %s".format(array_dtype:string));
}

}
Loading

0 comments on commit db1d851

Please sign in to comment.