Skip to content

Commit

Permalink
Merge pull request #122 from abouteiller/feature/hip-trsm
Browse files Browse the repository at this point in the history
Add HIP TRSM
  • Loading branch information
abouteiller authored Sep 6, 2024
2 parents 17c6c95 + 9aff551 commit 599a680
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 1 deletion.
32 changes: 32 additions & 0 deletions src/ztrsm_LLN.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,n) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, k), CLEAN_NB(descB, n)); %}]
/* Execution space */
Expand Down Expand Up @@ -123,6 +124,37 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempmm = ((m) == (descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempnn = ((n) == (descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int lda = BLKLDD( descA, m );
int ldbk = BLKLDD( descB, k );
int ldb = BLKLDD( descB, m );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N,
tempmm, tempnn, descB->mb,
&mzone, (hipblasDoubleComplex*)C, lda,
(hipblasDoubleComplex*)D, ldbk,
&lalpha, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
Expand Down
32 changes: 32 additions & 0 deletions src/ztrsm_LLT.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,n) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, k), CLEAN_NB(descB, n)); %}]
/* Execution space */
Expand Down Expand Up @@ -124,6 +125,37 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int tempkm = ((k)==(0)) ? (descB->m-((descB->mt-1)*descB->mb)) : descB->mb;
int lda = BLKLDD( descA, (descB->mt-1)-k );
int ldb = BLKLDD( descB, (descB->mt-1)-k );
int ldbm = BLKLDD( descB, (descB->mt-1)-m );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
descB->mb, tempnn, tempkm,
&mzone, (hipblasDoubleComplex*)C, lda,
(hipblasDoubleComplex*)D, ldb,
&lalpha, (hipblasDoubleComplex*)E, ldbm );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
Expand Down
32 changes: 32 additions & 0 deletions src/ztrsm_LUN.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,n) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, k), CLEAN_NB(descB, n)); %}]
/* Execution Space */
Expand Down Expand Up @@ -123,6 +124,37 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int tempkm = ((k)==(0)) ? (descB->m-((descB->mt-1)*descB->mb)) : descB->mb;
int ldam = BLKLDD( descB, (descA.mt-1)-m );
int ldbm = BLKLDD( descB, (descB->mt-1)-m );
int ldb = BLKLDD( descB, (descB->mt-1)-k );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N,
descB->mb, tempnn, tempkm,
&mzone, (hipblasDoubleComplex*)C, ldam,
(hipblasDoubleComplex*)D, ldb,
&lalpha, (hipblasDoubleComplex*)E, ldbm );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
Expand Down
32 changes: 32 additions & 0 deletions src/ztrsm_LUT.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,n) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, k), CLEAN_NB(descB, n)); %}]
/* Execution space */
Expand Down Expand Up @@ -123,6 +124,37 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempmm = ((m) == (descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempnn = ((n) == (descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int ldak = BLKLDD( descA, k );
int ldbk = BLKLDD( descB, k );
int ldb = BLKLDD( descB, m );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
tempmm, tempnn, descB->mb,
&mzone, (hipblasDoubleComplex*)C, ldak,
(hipblasDoubleComplex*)D, ldbk,
&lalpha, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
Expand Down
31 changes: 31 additions & 0 deletions src/ztrsm_RLN.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,m) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, m), CLEAN_NB(descB, k)); %}]
/* Execution space */
Expand Down Expand Up @@ -122,6 +123,36 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempkn = ((k)==(0)) ? (descB->n-((descB->nt-1)*descB->nb)) : descB->nb;
int ldb = BLKLDD( descB, m );
int lda = BLKLDD( descA, (descB->nt-1)-k );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N,
tempmm, descB->nb, tempkn,
&mzone, (hipblasDoubleComplex*)C, ldb,
(hipblasDoubleComplex*)D, lda,
&lalpha, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
Expand Down
29 changes: 29 additions & 0 deletions src/ztrsm_RLT.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,m) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, m), CLEAN_NB(descB, k)); %}]
/* Execution space */
Expand Down Expand Up @@ -121,6 +122,34 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex zone = { 1., 0.};
hipDoubleComplex cdiv = hipCdiv(make_hipDoubleComplex(-1., 0.), make_hipDoubleComplex(creal(alpha), cimag(alpha)));
hipblasDoubleComplex minvalpha = { hipCreal(cdiv), hipCimag(cdiv) };
#else
dplasma_complex64_t zone = 1.;
dplasma_complex64_t minvalpha = ((dplasma_complex64_t)1.0)/alpha;
#endif

int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int ldb = BLKLDD( descB, m );
int ldan = BLKLDD( descA, n );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, dplasma_hipblas_op(trans),
tempmm, tempnn, descB->mb,
&minvalpha, (hipblasDoubleComplex*)C, ldb,
(hipblasDoubleComplex*)D, ldan,
&zone, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
Expand Down
32 changes: 31 additions & 1 deletion src/ztrsm_RUN.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,m) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, m), CLEAN_NB(descB, k)); %}]
/* Execution space */
Expand Down Expand Up @@ -109,7 +110,6 @@ BODY [type=CUDA]
int ldb = BLKLDD( descB, m );
int lda = BLKLDD( descA, k );


cublasStatus_t status;

cublasSetKernelStream( parsec_body.stream );
Expand All @@ -123,6 +123,36 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex mzone = {-1., 0.};
hipblasDoubleComplex lalpha = {1., 0.};
if(k == 0) {
lalpha.x = creal(alpha); lalpha.y = cimag(alpha);
}
#else
dplasma_complex64_t mzone = -1.;
dplasma_complex64_t lalpha = ((k)==(0)) ? (alpha) : (dplasma_complex64_t)(1.0);
#endif

int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
int ldb = BLKLDD( descB, m );
int lda = BLKLDD( descA, k );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N,
tempmm, tempnn, descB->mb,
&mzone, (hipblasDoubleComplex*)C, ldb,
(hipblasDoubleComplex*)D, lda,
&lalpha, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
Expand Down
29 changes: 29 additions & 0 deletions src/ztrsm_RUT.jdf
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ alpha [type = "dplasma_complex64_t"]
descA [type = "const parsec_tiled_matrix_t*"]
descB [type = "parsec_tiled_matrix_t*"]

hip_handles_infokey [type = "int" hidden = on default = "parsec_info_lookup(&parsec_per_stream_infos, \"DPLASMA::HIP::HANDLES\", NULL)" ]

ztrsm(k,m) [ flops = inline_c %{ return FLOPS_ZTRSM(side, CLEAN_MB(descB, m), CLEAN_NB(descB, k)); %}]
/* Execution space */
Expand Down Expand Up @@ -121,6 +122,34 @@ BODY [type=CUDA]
}
END

BODY [type=HIP]
{
#if defined(PRECISION_z) || defined(PRECISION_c)
hipblasDoubleComplex zone = { 1., 0.};
hipDoubleComplex cdiv = hipCdiv(make_hipDoubleComplex(-1., 0.), make_hipDoubleComplex(creal(alpha), cimag(alpha)));
hipblasDoubleComplex minvalpha = { hipCreal(cdiv), hipCimag(cdiv) };
#else
dplasma_complex64_t zone = 1.;
dplasma_complex64_t minvalpha = ((dplasma_complex64_t)1.0)/alpha;
#endif

int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
int tempkn = ((k)==(0)) ? (descB->n-((descB->nt-1)*descB->nb)) : descB->nb;
int ldan = BLKLDD( descA, (descB->nt-1)-n );
int ldb = BLKLDD( descB, m );

hipblasStatus_t status;
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
assert(NULL != handles);
status = hipblasZgemm( handles->hipblas_handle, HIPBLAS_OP_N, dplasma_hipblas_op(trans),
tempmm, descB->nb, tempkn,
&minvalpha, (hipblasDoubleComplex*)C, ldb,
(hipblasDoubleComplex*)D, ldan,
&zone, (hipblasDoubleComplex*)E, ldb );
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
}
END

BODY
{
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
Expand Down
2 changes: 2 additions & 0 deletions tools/PrecisionGenerator/subs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@
('float', 'double', 'cuCdivf', 'cuCdiv' ),
('float', 'double', 'hipCdivf', 'hipCdiv' ),
('', '', 'crealf', 'creal' ),
('', '', 'hipCrealf', 'hipCreal' ),
('', '', 'cimagf', 'cimag' ),
('', '', 'hipCimagf', 'hipCimag' ),
('', '', 'conjf', 'conj' ),
('', '', 'cuCfmaf', 'cuCfma' ),
('', '', 'hipCfmaf', 'hipCfma' ),
Expand Down

0 comments on commit 599a680

Please sign in to comment.