Skip to content

Commit

Permalink
Merge pull request #82 from abouteiller/update/new-recursive
Browse files Browse the repository at this point in the history
Update for the new Parsec recursive API
  • Loading branch information
abouteiller authored Sep 12, 2023
2 parents d165b7e + 39da035 commit 749c912
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 27 deletions.
31 changes: 27 additions & 4 deletions src/zgeqrf.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ extern "C" %{

#include "parsec/data_dist/matrix/subtile.h"
#include "parsec/recursive.h"
static void zgeqrt_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void zunmqr_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void ztsqrt_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void ztsmqr_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);

#if defined(DPLASMA_HAVE_CUDA)
#include "cores/dplasma_zcores.h"
Expand Down Expand Up @@ -150,7 +154,7 @@ BODY [type=RECURSIVE]

/* recursive call */
parsec_recursivecall((parsec_task_t*)this_task,
parsec_zgeqrt, dplasma_zgeqrfr_geqrt_Destruct,
parsec_zgeqrt, zgeqrt_recursive_cb,
2, small_descA, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -246,7 +250,7 @@ BODY [type=RECURSIVE]

/* recursive call */
parsec_recursivecall((parsec_task_t*)this_task,
parsec_zunmqr_panel, dplasma_zgeqrfr_unmqr_Destruct,
parsec_zunmqr_panel, zunmqr_recursive_cb,
3, small_descA, small_descC, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -374,7 +378,7 @@ BODY [type=RECURSIVE]

/* recursive call */
parsec_recursivecall((parsec_task_t*)this_task,
parsec_ztsqrt, dplasma_zgeqrfr_tsqrt_Destruct,
parsec_ztsqrt, ztsqrt_recursive_cb,
3, small_descA1, small_descA2, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -542,7 +546,7 @@ BODY [type=RECURSIVE]

/* recursive call */
parsec_recursivecall((parsec_task_t*)this_task,
parsec_ztsmqr, dplasma_zgeqrfr_tsmqr_Destruct,
parsec_ztsmqr, ztsmqr_recursive_cb,
4, small_descA1, small_descA2, small_descV, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -580,3 +584,22 @@ BODY

}
END

extern "C" %{
static void zgeqrt_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgeqrfr_geqrt_Destruct(tp);
}
static void zunmqr_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgeqrfr_unmqr_Destruct(tp);
}
static void ztsqrt_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgeqrfr_tsqrt_Destruct(tp);
}
static void ztsmqr_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgeqrfr_tsmqr_Destruct(tp);
}
%}
39 changes: 27 additions & 12 deletions src/zpotrf_L.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ extern "C" %{

#include "parsec/data_dist/matrix/subtile.h"
#include "parsec/recursive.h"
static void zpotrf_L_update_INFO(parsec_taskpool_t* _tp, const parsec_recursive_callback_t* data);
static void zpotrf_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void zgemm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void zherk_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void ztrsm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);

/* Define the different shapes this JDF is using */
#define DEFAULT 0
Expand Down Expand Up @@ -123,10 +126,10 @@ BODY [type=RECURSIVE]
smallnb, smallnb, 0, 0, tempkm, tempkm );
small_descT->mat = T;

parsec_zpotrf = dplasma_zpotrf_New(uplo, (parsec_tiled_matrix_t *)small_descT, &info );
parsec_zpotrf = dplasma_zpotrf_New(uplo, (parsec_tiled_matrix_t *)small_descT, (int*)&info );

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zpotrf, zpotrf_L_update_INFO,
parsec_zpotrf, zpotrf_recursive_cb,
1, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -267,7 +270,7 @@ BODY [type=RECURSIVE]
(parsec_tiled_matrix_t *)small_descC );

parsec_recursivecall((parsec_task_t*)this_task,
parsec_ztrsm, dplasma_ztrsm_Destruct,
parsec_ztrsm, ztrsm_recursive_cb,
2, small_descT, small_descC );

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -371,7 +374,7 @@ BODY [type=RECURSIVE]
(double)1.0, (parsec_tiled_matrix_t*) small_descT);

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zherk, dplasma_zherk_Destruct,
parsec_zherk, zherk_recursive_cb,
2, small_descA, small_descT);
return PARSEC_HOOK_RETURN_ASYNC;
}
Expand Down Expand Up @@ -451,7 +454,7 @@ BODY [type=CUDA

cublasStatus_t status;
dplasma_cuda_handles_t *handles;

assert( ldam_A <= descA->mb );
assert( ldan_B <= descA->mb );
assert( ldam_C <= descA->mb );
Expand Down Expand Up @@ -502,7 +505,7 @@ BODY [type=RECURSIVE]
(parsec_tiled_matrix_t *)small_descC);

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zgemm, dplasma_zgemm_Destruct,
parsec_zgemm, zgemm_recursive_cb,
3, small_descA, small_descB, small_descC );

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -566,16 +569,28 @@ static int64_t zgemm_time_estimate(const parsec_task_t *task, parsec_device_modu
* As we are handling the diagonal tiles recursively, we have to scale the INFO in case of errors
* to reflect the position in the global matrix and not on the current tile.
*/
static void zpotrf_L_update_INFO(parsec_taskpool_t* _tp, const parsec_recursive_callback_t* data)
static void zpotrf_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data)
{
parsec_zpotrf_L_taskpool_t* tp = (parsec_zpotrf_L_taskpool_t*)_tp;
parsec_zpotrf_L_taskpool_t* tppo = (parsec_zpotrf_L_taskpool_t*)tp;
__parsec_zpotrf_L_potrf_zpotrf_task_t* task = (__parsec_zpotrf_L_potrf_zpotrf_task_t*)data->task;

if( (0 < task->locals.info.value) && (0 == *tp->_g_INFO) ) {
if( (0 < task->locals.info.value) && (0 == *tppo->_g_INFO) ) {
/* we need to scale the INFO according to the parent taskpool tile size */
*tp->_g_INFO = task->locals.info.value + task->locals.k.value * ((parsec_zpotrf_L_taskpool_t*)task->taskpool)->_g_descA->nb;
*tppo->_g_INFO = task->locals.info.value + task->locals.k.value * ((parsec_zpotrf_L_taskpool_t*)task->taskpool)->_g_descA->nb;
}
dplasma_zpotrf_Destruct(_tp);
dplasma_zpotrf_Destruct(tp);
}
static void zgemm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgemm_Destruct(tp);
}
static void zherk_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zherk_Destruct(tp);
}
static void ztrsm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_ztrsm_Destruct(tp);
}

%}
Expand Down
37 changes: 26 additions & 11 deletions src/zpotrf_U.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ extern "C" %{

#include "parsec/data_dist/matrix/subtile.h"
#include "parsec/recursive.h"
static void zpotrf_U_update_INFO(parsec_taskpool_t* _tp, const parsec_recursive_callback_t* data);
static void zpotrf_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void zgemm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void zherk_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);
static void ztrsm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data);

/* Define the different shapes this JDF is using */
#define DEFAULT 0
Expand Down Expand Up @@ -122,10 +125,10 @@ BODY [type=RECURSIVE]
smallnb, smallnb, 0, 0, tempkn, tempkn );
small_descT->mat = T;

parsec_zpotrf = dplasma_zpotrf_New(uplo, (parsec_tiled_matrix_t *)small_descT, &info );
parsec_zpotrf = dplasma_zpotrf_New(uplo, (parsec_tiled_matrix_t *)small_descT, (int*)&info );

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zpotrf, zpotrf_U_update_INFO,
parsec_zpotrf, zpotrf_recursive_cb,
1, small_descT);

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -268,7 +271,7 @@ BODY [type=RECURSIVE]
(parsec_tiled_matrix_t *)small_descC );

parsec_recursivecall((parsec_task_t*)this_task,
parsec_ztrsm, dplasma_ztrsm_Destruct,
parsec_ztrsm, ztrsm_recursive_cb,
2, small_descT, small_descC );

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -374,7 +377,7 @@ BODY [type=RECURSIVE]
(double)1.0, (parsec_tiled_matrix_t*) small_descT);

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zherk, dplasma_zherk_Destruct,
parsec_zherk, zherk_recursive_cb,
2, small_descA, small_descT);
return PARSEC_HOOK_RETURN_ASYNC;
}
Expand Down Expand Up @@ -514,7 +517,7 @@ BODY [type=RECURSIVE]
(parsec_tiled_matrix_t *)small_descC);

parsec_recursivecall((parsec_task_t*)this_task,
parsec_zgemm, dplasma_zgemm_Destruct,
parsec_zgemm, zgemm_recursive_cb,
3, small_descA, small_descB, small_descC );

return PARSEC_HOOK_RETURN_ASYNC;
Expand Down Expand Up @@ -578,16 +581,28 @@ static int64_t zgemm_time_estimate(const parsec_task_t *task, parsec_device_modu
* As we are handling the diagonal tiles recursively, we have to scale the INFO in case of errors
* to reflect the position in the global matrix and not on the current tile.
*/
static void zpotrf_U_update_INFO(parsec_taskpool_t* _tp, const parsec_recursive_callback_t* data)
static void zpotrf_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data)
{
parsec_zpotrf_U_taskpool_t* tp = (parsec_zpotrf_U_taskpool_t*)_tp;
parsec_zpotrf_U_taskpool_t* tppo = (parsec_zpotrf_U_taskpool_t*)tp;
__parsec_zpotrf_U_potrf_zpotrf_task_t* task = (__parsec_zpotrf_U_potrf_zpotrf_task_t*)data->task;

if( (0 < task->locals.info.value) && (0 == *tp->_g_INFO) ) {
if( (0 < task->locals.info.value) && (0 == *tppo->_g_INFO) ) {
/* we need to scale the INFO according to the parent taskpool tile size */
*tp->_g_INFO = task->locals.info.value + task->locals.k.value * ((parsec_zpotrf_U_taskpool_t*)task->taskpool)->_g_descA->nb;
*tppo->_g_INFO = task->locals.info.value + task->locals.k.value * ((parsec_zpotrf_U_taskpool_t*)task->taskpool)->_g_descA->nb;
}
dplasma_zpotrf_Destruct(_tp);
dplasma_zpotrf_Destruct(tp);
}
static void zgemm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zgemm_Destruct(tp);
}
static void zherk_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_zherk_Destruct(tp);
}
static void ztrsm_recursive_cb(parsec_taskpool_t* tp, const parsec_recursive_callback_t* data) {
(void)data;
dplasma_ztrsm_Destruct(tp);
}

%}
Expand Down

0 comments on commit 749c912

Please sign in to comment.