Skip to content

Commit

Permalink
Closes #3656: Refactor LinalgMsg.transpose to use registerCommand (#3693
Browse files Browse the repository at this point in the history
)

* Closes #3656: Refactor LinalgMsg.transpose to use registerCommand

* remove doTranspose

* fix bug

---------

Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts authored Aug 28, 2024
1 parent 78633dd commit c36172a
Showing 1 changed file with 17 additions and 36 deletions.
53 changes: 17 additions & 36 deletions src/LinalgMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -285,45 +285,26 @@ module LinalgMsg {

// Transpose an array.

@arkouda.instantiateAndRegister
proc transpose(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd >= 2 {

const name = msgArgs.getValueOf("array");

linalgLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"cmd: %s dtype: %s".format(cmd,type2str(array_dtype)));

// get the input array, copy its shape

var eIn = st[name]: borrowed SymEntry(array_dtype, array_nd),
outShape = eIn.tupShape;

// switch the indices of the output shape

@arkouda.registerCommand
proc transpose(array: [?d] ?t): [] t throws
where d.rank >= 2 {
var outShape = array.shape;
outShape[outShape.size-2] <=> outShape[outShape.size-1];

// create the output array

var eOut = createSymEntry((...outShape), array_dtype);
doTranspose(eIn.a, eOut.a); // do the transpose

return st.insert(eOut);
}

proc transpose(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd < 2 {
return MsgTuple.error("Matrix transpose with arrays of dimension < 2 is not supported");
}

// TODO: performance improvements. Should use tiling to keep data local

proc doTranspose(ref A: [?D], ref B) {
forall idx in D {
var ret = makeDistArray((...outShape), t);

// // TODO: performance improvements. Should use tiling to keep data local
forall idx in d {
var bIdx = idx;
bIdx[D.rank-1] <=> bIdx[D.rank-2]; // bIdx is now the reverse of idx
B[bIdx] = A[idx]; // making B the transpose of A
bIdx[d.rank-1] <=> bIdx[d.rank-2]; // bIdx is now the reverse of idx
ret[bIdx] = array[idx]; // making B the transpose of A
}

return ret;
}

proc transpose(array: [?d] ?t): [d] t throws
where d.rank < 2 {
throw new Error("Matrix transpose with arrays of dimension < 2 is not supported");
}

/*
Expand Down

0 comments on commit c36172a

Please sign in to comment.