Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The macros to convert trans/notrans etc were not correct for use inline #121

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 25 additions & 36 deletions src/dplasmaaux_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,32 @@
*/
#if !defined(CUBLAS_H_)
#include <cublas_v2.h>
#endif /* !defined(CUBLAS_V2_H_) */

#define dplasma_cublas_side(side) \
assert( (side == dplasmaRight) || (side == dplasmaLeft) ); \
side = (side == dplasmaRight) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;


#define dplasma_cublas_diag(diag) \
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) ); \
diag = (diag == dplasmaNonUnit) ? CUBLAS_DIAG_NON_UNIT : CUBLAS_DIAG_UNIT;

#define dplasma_cublas_fill(fill) \
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) ); \
fill = (fill == dplasmaLower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;

#if defined(PRECISION_z) || defined(PRECISION_c)
#define dplasma_cublas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) ); \
switch(trans){ \
case dplasmaNoTrans: \
trans = CUBLAS_OP_N; \
break; \
case dplasmaTrans: \
trans = CUBLAS_OP_T; \
break; \
case dplasmaConjTrans: \
trans = CUBLAS_OP_C; \
break; \
default: \
trans = CUBLAS_OP_N; \
break; \
}
#include "dplasma/constants.h"

static inline cublasSideMode_t dplasma_cublas_side(int side) {
assert( (side == dplasmaRight) || (side == dplasmaLeft) );
return (side == dplasmaRight) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
}

static inline cublasDiagType_t dplasma_cublas_diag(int diag) {
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) );
return (diag == dplasmaNonUnit) ? CUBLAS_DIAG_NON_UNIT : CUBLAS_DIAG_UNIT;
}

static inline cublasFillMode_t dplasma_cublas_fill(int fill) {
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) );
return (fill == dplasmaLower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
}

static inline cublasOperation_t dplasma_cublas_op(int trans) {
#if defined(PRECISION_d) || defined(PRECISION_s)
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) );
#else
#define dplasma_cublas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) ); \
trans = (trans == dplasmaNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
#endif /* PRECISION_z || PRECISION_c */
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) );
#endif /* PRECISION_d || PRECISION_s */
return (trans == dplasmaConjTrans) ? CUBLAS_OP_C: ((trans == dplasmaTrans) ? CUBLAS_OP_T : CUBLAS_OP_N);
}
#endif /* !defined(CUBLAS_V2_H_) */

extern parsec_info_id_t dplasma_dtd_cuda_infoid;
extern parsec_info_id_t dplasma_dtd_cuda_workspace_infoid;
Expand Down
50 changes: 20 additions & 30 deletions src/dplasmaaux_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,31 @@
#include <hipsolver/hipsolver.h>
#include <rocsolver/rocsolver.h>

#define dplasma_hipblas_side(side) \
assert( (side == dplasmaRight) || (side == dplasmaLeft) ); \
side = (side == dplasmaRight) ? HIPBLAS_SIDE_RIGHT : HIPBLAS_SIDE_LEFT;
#include "dplasma/constants.h"

static inline hipblasSideMode_t dplasma_hipblas_side(int side) {
assert( (side == dplasmaRight) || (side == dplasmaLeft) );
return (side == dplasmaRight) ? HIPBLAS_SIDE_RIGHT : HIPBLAS_SIDE_LEFT;
}

#define dplasma_hipblas_diag(diag) \
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) ); \
diag = (diag == dplasmaNonUnit) ? HIPBLAS_DIAG_NON_UNIT : HIPBLAS_DIAG_UNIT;
static inline hipblasDiagType_t dplasma_hipblas_diag(int diag) {
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) );
return (diag == dplasmaNonUnit) ? HIPBLAS_DIAG_NON_UNIT : HIPBLAS_DIAG_UNIT;
}

#define dplasma_hipblas_fill(fill) \
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) ); \
fill = (fill == dplasmaLower) ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
static inline hipblasFillMode_t dplasma_hipblas_fill(int fill) {
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) );
return (fill == dplasmaLower) ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
}

#if defined(PRECISION_z) || defined(PRECISION_c)
#define dplasma_hipblas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) ); \
switch(trans){ \
case dplasmaNoTrans: \
trans = HIPBLAS_OP_N; \
break; \
case dplasmaTrans: \
trans = HIPBLAS_OP_T; \
break; \
case dplasmaConjTrans: \
trans = HIPBLAS_OP_C; \
break; \
default: \
trans = HIPBLAS_OP_N; \
break; \
}
static inline hipblasOperation_t dplasma_hipblas_op(int trans) {
#if defined(PRECISION_d) || defined(PRECISION_s)
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) );
#else
#define dplasma_hipblas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) ); \
trans = (trans == dplasmaNoTrans) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
#endif /* PRECISION_z || PRECISION_c */
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) );
#endif /* PRECISION_d || PRECISION_s */
return (trans == dplasmaConjTrans) ? HIPBLAS_OP_C: ((trans == dplasmaTrans) ? HIPBLAS_OP_T : HIPBLAS_OP_N);
}

extern parsec_info_id_t dplasma_dtd_hip_infoid;

Expand Down
5 changes: 1 addition & 4 deletions src/dtd_wrappers/zgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ parsec_core_zgemm_cuda(parsec_device_gpu_module_t* gpu_device,
double betag = beta;
#endif

dplasma_cublas_op(transA);
dplasma_cublas_op(transB);

#if defined(PARSEC_DEBUG_NOISIER)
{
char tmp[MAX_TASK_STRLEN];
Expand All @@ -80,7 +77,7 @@ parsec_core_zgemm_cuda(parsec_device_gpu_module_t* gpu_device,

parsec_cuda_exec_stream_t* cuda_stream = (parsec_cuda_exec_stream_t*)gpu_stream;
cublasSetStream( handles->cublas_handle, cuda_stream->cuda_stream );
status = cublasZgemm(handles->cublas_handle, transA, transB,
status = cublasZgemm(handles->cublas_handle, dplasma_cublas_op(transA), dplasma_cublas_op(transB),
n, m, k,
&alphag, (cuDoubleComplex*)Ag, lda,
(cuDoubleComplex*)Bg, ldb,
Expand Down
5 changes: 1 addition & 4 deletions src/dtd_wrappers/zherk.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ parsec_core_zherk_cuda(parsec_device_gpu_module_t* gpu_device,
Ag = parsec_dtd_get_dev_ptr(this_task, 0);
Cg = parsec_dtd_get_dev_ptr(this_task, 1);

dplasma_cublas_op(trans);
dplasma_cublas_fill(uplo);

handles = parsec_info_get(&gpu_stream->infos, dplasma_dtd_cuda_infoid);

#if defined(PARSEC_DEBUG_NOISIER)
Expand All @@ -68,7 +65,7 @@ parsec_core_zherk_cuda(parsec_device_gpu_module_t* gpu_device,

parsec_cuda_exec_stream_t* cuda_stream = (parsec_cuda_exec_stream_t*)gpu_stream;
cublasSetStream( handles->cublas_handle, cuda_stream->cuda_stream );
status = cublasZherk(handles->cublas_handle, uplo, trans,
status = cublasZherk(handles->cublas_handle, dplasma_cublas_fill(uplo), dplasma_cublas_op(trans),
m, n,
&alpha, (cuDoubleComplex*)Ag, lda,
&beta, (cuDoubleComplex*)Cg, ldc );
Expand Down
Loading