Skip to content

Commit

Permalink
Small refactor of cuda vvar kernel to support any grid/block dims
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanstocks00 committed Oct 9, 2024
1 parent eeff105 commit 2089af6
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/xc_integrator/local_work_driver/device/cuda/kernels/uvvars.cu
Original file line number Diff line number Diff line change
Expand Up @@ -614,17 +614,23 @@ __global__ void eval_vvar_kern( size_t ntasks,

const auto* den_basis_prod_device = task.zmat;

const int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
const int tid_y = blockIdx.y * blockDim.y + threadIdx.y;

register double den_reg = 0.;

if( tid_x < nbf and tid_y < npts ) {
int start_y = blockIdx.y * blockDim.y + threadIdx.y;

for (int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
tid_x < nbf;
tid_x += blockDim.x * gridDim.x ) {

for (int tid_y = start_y;
tid_y < npts;
tid_y += blockDim.y * gridDim.y ) {

const double* bf_col = basis_eval_device + tid_x*npts;
const double* db_col = den_basis_prod_device + tid_x*npts;
const double* bf_col = basis_eval_device + tid_x*npts;
const double* db_col = den_basis_prod_device + tid_x*npts;

den_reg = bf_col[ tid_y ] * db_col[ tid_y ];
den_reg += bf_col[ tid_y ] * db_col[ tid_y ];
}

}

Expand All @@ -634,8 +640,8 @@ __global__ void eval_vvar_kern( size_t ntasks,
den_reg = cuda::warp_reduce_sum<warp_size>( den_reg );


if( threadIdx.x == 0 and tid_y < npts ) {
atomicAdd( den_eval_device + tid_y, den_reg );
if( threadIdx.x == 0 and start_y < npts ) {
atomicAdd( den_eval_device + start_y, den_reg );
}


Expand Down

0 comments on commit 2089af6

Please sign in to comment.