Skip to content

Commit

Permalink
fix vgamma memory allocation bug and implement eval_zmat_gga_uks
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikael Alexander Kovtun committed Dec 5, 2023
1 parent 4035d97 commit e92977a
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 41 deletions.
111 changes: 110 additions & 1 deletion src/xc_integrator/local_work_driver/device/cuda/kernels/zmat_vxc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void zmat_gga_vxc_rks( size_t ntasks,



template<int den_selector>
template<density_id den_selector>
__global__ void zmat_lda_vxc_uks_kernel( size_t ntasks,
XCDeviceTask* tasks_device ) {

Expand All @@ -166,6 +166,7 @@ __global__ void zmat_lda_vxc_uks_kernel( size_t ntasks,

const auto* basis_eval_device = task.bf;


auto* z_matrix_device = task.zmat;

const int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -215,5 +216,113 @@ void zmat_lda_vxc_uks( size_t ntasks,




template<density_id den_selector>
__global__ void zmat_gga_vxc_uks_kernel( size_t ntasks,
XCDeviceTask* tasks_device ) {

const int batch_idx = blockIdx.z;
if( batch_idx >= ntasks ) return;

auto& task = tasks_device[ batch_idx ];
const auto npts = task.npts;
const auto nbf = task.bfn_screening.nbe;

const double* vrho_pos_device = task.vrho_pos;
const double* vrho_neg_device = task.vrho_neg;

const auto* den_pos_x_eval_device = task.dden_posx;
const auto* den_pos_y_eval_device = task.dden_posy;
const auto* den_pos_z_eval_device = task.dden_posz;
const auto* den_neg_x_eval_device = task.dden_negx;
const auto* den_neg_y_eval_device = task.dden_negy;
const auto* den_neg_z_eval_device = task.dden_negz;


const auto* basis_eval_device = task.bf;
const auto* dbasis_x_eval_device = task.dbfx;
const auto* dbasis_y_eval_device = task.dbfy;
const auto* dbasis_z_eval_device = task.dbfz;

auto* z_matrix_device = task.zmat;

const int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
const int tid_y = blockIdx.y * blockDim.y + threadIdx.y;

if( tid_x < npts and tid_y < nbf ) {

const size_t ibfoff = tid_y * npts + tid_x;

const double factp = 0.25 * vrho_pos_device[tid_x];
const double factm = 0.25 * vrho_neg_device[tid_x];

const auto gga_fact_pp = task.vgamma_pp[ tid_x ];
const auto gga_fact_pm = task.vgamma_pm[ tid_x ];
const auto gga_fact_mm = task.vgamma_mm[ tid_x ];

const auto gga_fact_1 = 0.5*(gga_fact_pp + gga_fact_pm + gga_fact_mm);
const auto gga_fact_2 = 0.5*(gga_fact_pp - gga_fact_mm);
const auto gga_fact_3 = 0.5*(gga_fact_pp - gga_fact_pm + gga_fact_mm);

if constexpr ( den_selector == DEN_S ) {
const auto x_fact = gga_fact_1 * den_pos_x_eval_device[ ibfoff ] + gga_fact_2 * den_pos_x_eval_device[ ibfoff ];
const auto y_fact = gga_fact_1 * den_pos_y_eval_device[ ibfoff ] + gga_fact_2 * den_pos_y_eval_device[ ibfoff ];
const auto z_fact = gga_fact_1 * den_pos_z_eval_device[ ibfoff ] + gga_fact_2 * den_pos_z_eval_device[ ibfoff ];

z_matrix_device[ ibfoff ] = x_fact * dbasis_x_eval_device[ ibfoff ]
+ y_fact * dbasis_y_eval_device[ ibfoff ]
+ z_fact * dbasis_z_eval_device[ ibfoff ]
+ (factp + factm) * basis_eval_device[ ibfoff ];

}
if constexpr ( den_selector == DEN_Z ) {
const auto x_fact = gga_fact_3 * den_neg_x_eval_device[ ibfoff ] + gga_fact_2 * den_neg_x_eval_device[ ibfoff ];
const auto y_fact = gga_fact_3 * den_neg_y_eval_device[ ibfoff ] + gga_fact_2 * den_neg_y_eval_device[ ibfoff ];
const auto z_fact = gga_fact_3 * den_neg_z_eval_device[ ibfoff ] + gga_fact_2 * den_neg_z_eval_device[ ibfoff ];

z_matrix_device[ ibfoff ] = x_fact * dbasis_x_eval_device[ ibfoff ]
+ y_fact * dbasis_y_eval_device[ ibfoff ]
+ z_fact * dbasis_z_eval_device[ ibfoff ]
+ (factp - factm) * basis_eval_device[ ibfoff ];
}






}


}



void zmat_gga_vxc_uks( size_t ntasks,
int32_t max_nbf,
int32_t max_npts,
XCDeviceTask* tasks_device,
density_id sel,
device_queue queue ) {

cudaStream_t stream = queue.queue_as<util::cuda_stream>() ;


dim3 threads(cuda::warp_size,cuda::max_warps_per_thread_block,1);
dim3 blocks( util::div_ceil( max_npts, threads.x ),
util::div_ceil( max_nbf, threads.y ),
ntasks );

if ( sel == DEN_S ) zmat_gga_vxc_uks_kernel<DEN_S><<< blocks, threads, 0, stream >>>( ntasks, tasks_device );
else if ( sel == DEN_Z ) zmat_gga_vxc_uks_kernel<DEN_Z><<< blocks, threads, 0, stream >>>( ntasks, tasks_device );

}







}

Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ FWD_TO_PIMPL(eval_zmat_lda_vxc_rks) // Eval Z Matrix LDA VXC
FWD_TO_PIMPL(eval_zmat_gga_vxc_rks) // Eval Z Matrix GGA VXC

FWD_TO_PIMPL_DEN_ID(eval_zmat_lda_vxc_uks) // Eval Z Matrix LDA VXC
FWD_TO_PIMPL(eval_zmat_gga_vxc_uks) // Eval Z Matrix GGA VXC
FWD_TO_PIMPL_DEN_ID(eval_zmat_gga_vxc_uks) // Eval Z Matrix GGA VXC

FWD_TO_PIMPL(eval_zmat_lda_vxc_gks) // Eval Z Matrix LDA VXC
FWD_TO_PIMPL(eval_zmat_gga_vxc_gks) // Eval Z Matrix GGA VXC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class LocalDeviceWorkDriver : public LocalWorkDriver {
void eval_zmat_gga_vxc_rks( XCDeviceData* );

void eval_zmat_lda_vxc_uks( XCDeviceData*, density_id );
void eval_zmat_gga_vxc_uks( XCDeviceData* );
void eval_zmat_gga_vxc_uks( XCDeviceData*, density_id );

void eval_zmat_lda_vxc_gks( XCDeviceData* );
void eval_zmat_gga_vxc_gks( XCDeviceData* );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct LocalDeviceWorkDriverPIMPL {
virtual void eval_zmat_lda_vxc_rks( XCDeviceData* ) = 0;
virtual void eval_zmat_gga_vxc_rks( XCDeviceData* ) = 0;
virtual void eval_zmat_lda_vxc_uks( XCDeviceData*, density_id ) = 0;
virtual void eval_zmat_gga_vxc_uks( XCDeviceData* ) = 0;
virtual void eval_zmat_gga_vxc_uks( XCDeviceData*, density_id ) = 0;
virtual void eval_zmat_lda_vxc_gks( XCDeviceData* ) = 0;
virtual void eval_zmat_gga_vxc_gks( XCDeviceData* ) = 0;
virtual void inc_exc( XCDeviceData* ) = 0;
Expand Down
30 changes: 25 additions & 5 deletions src/xc_integrator/local_work_driver/device/scheme1_base.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,24 @@ void AoSScheme1Base::eval_zmat_lda_vxc_uks( XCDeviceData* _data, density_id den_

}

void AoSScheme1Base::eval_zmat_gga_vxc_uks( XCDeviceData* ){
void AoSScheme1Base::eval_zmat_gga_vxc_uks( XCDeviceData* _data, density_id den_select ){

GAUXC_GENERIC_EXCEPTION("UKS NOT YET IMPLEMENTED FOR DEVICE");
auto* data = dynamic_cast<Data*>(_data);
if( !data ) GAUXC_BAD_LWD_DATA_CAST();

if( not data->device_backend_ ) GAUXC_UNINITIALIZED_DEVICE_BACKEND();

auto& tasks = data->host_device_tasks;
const auto ntasks = tasks.size();
size_t nbe_max = 0, npts_max = 0;
for( auto& task : tasks ) {
nbe_max = std::max( nbe_max, task.bfn_screening.nbe );
npts_max = std::max( npts_max, task.npts );
}

auto aos_stack = data->aos_stack;
zmat_lda_vxc_uks( ntasks, nbe_max, npts_max, aos_stack.device_tasks, den_select,
data->device_backend_->queue() );

}

Expand Down Expand Up @@ -442,8 +457,9 @@ void AoSScheme1Base::inc_nel( XCDeviceData* _data ){

const bool is_RKS = data->allocated_terms.ks_scheme == RKS;
const bool is_UKS = data->allocated_terms.ks_scheme == UKS;
const bool is_den = data->allocated_terms.den;

if( is_RKS )
if( is_RKS or is_den )
gdot( data->device_backend_->master_blas_handle(), data->total_npts_task_batch,
base_stack.weights_device, 1, base_stack.den_eval_device, 1,
static_stack.acc_scr_device, static_stack.nel_device );
Expand Down Expand Up @@ -599,6 +615,7 @@ void AoSScheme1Base::eval_kern_exc_vxc_lda( const functional_type& func,

const bool is_RKS = data->allocated_terms.ks_scheme == RKS;
const bool is_UKS = data->allocated_terms.ks_scheme == UKS;
const bool is_excgrad = data->allocated_terms.exc_grad;

const size_t npts = data->total_npts_task_batch ;

Expand All @@ -617,7 +634,7 @@ void AoSScheme1Base::eval_kern_exc_vxc_lda( const functional_type& func,
base_stack.den_eval_device, base_stack.eps_eval_device,
base_stack.vrho_eval_device, data->device_backend_->queue() );

if( is_RKS ) {
if( is_RKS or is_excgrad ) {
hadamard_product( data->device_backend_->master_blas_handle(), data->total_npts_task_batch, 1,
base_stack.weights_device, 1, base_stack.vrho_eval_device, 1 );
}
Expand Down Expand Up @@ -653,6 +670,7 @@ void AoSScheme1Base::eval_kern_exc_vxc_gga( const functional_type& func,

const bool is_RKS = data->allocated_terms.ks_scheme == RKS;
const bool is_UKS = data->allocated_terms.ks_scheme == UKS;
const bool is_excgrad = data->allocated_terms.exc_grad;

const size_t npts = data->total_npts_task_batch ;

Expand All @@ -676,7 +694,7 @@ void AoSScheme1Base::eval_kern_exc_vxc_gga( const functional_type& func,
base_stack.eps_eval_device, base_stack.vrho_eval_device,
base_stack.vgamma_eval_device, data->device_backend_->queue() );

if( is_RKS ) {
if( is_RKS or is_excgrad ) {
hadamard_product( data->device_backend_->master_blas_handle(), data->total_npts_task_batch, 1,
base_stack.weights_device, 1, base_stack.vrho_eval_device, 1 );
hadamard_product( data->device_backend_->master_blas_handle(), data->total_npts_task_batch, 1,
Expand Down Expand Up @@ -883,6 +901,8 @@ void AoSScheme1Base::inc_vxc( XCDeviceData* _data, density_id den_selector){
case DEN_Z:
vxc_ptr = static_stack.vxc_z_device;
break;
default:
GAUXC_GENERIC_EXCEPTION( "inc_vxc called with invalid density selected" );
}
sym_task_inc_potential( ntasks, aos_stack.device_tasks,
vxc_ptr, nbf, submat_block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AoSScheme1Base : public detail::LocalDeviceWorkDriverPIMPL {
void eval_uvvar_lda( XCDeviceData*, integrator_term_tracker ) override final;
void eval_uvvar_gga( XCDeviceData*, integrator_term_tracker ) override final;
void eval_zmat_lda_vxc_uks( XCDeviceData*, density_id ) override final;
void eval_zmat_gga_vxc_uks( XCDeviceData* ) override final;
void eval_zmat_gga_vxc_uks( XCDeviceData*, density_id ) override final;

void eval_zmat_lda_vxc_gks( XCDeviceData* ) override final;
void eval_zmat_gga_vxc_gks( XCDeviceData* ) override final;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,6 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
else if (is_uks) device_data.send_static_data_density_basis( Ps, ldps, Pz, ldpz, basis );
//if (is_gks) device_data.send_static_data_density_basis( Ps, ldps, Pz, ldpz, Px, ldpx, Py, ldpy, basis );

// for debugging
auto* data = dynamic_cast<XCDeviceStackData*>(&device_data);
auto base_stack = data->base_stack;
auto static_stack = data->static_stack;

// Processes batches in groups that saturate available device memory
if( func.is_lda() ) enabled_terms.xc_approx = integrator_xc_approx::LDA;
Expand Down Expand Up @@ -364,15 +360,13 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
}
if (is_uks) {
// Evaluate Scalar Z matrix
//if( func.is_gga() ) lwd->eval_zmat_gga_vxc_uks( &device_data, DEN_S );
if( func.is_gga() ) GAUXC_GENERIC_EXCEPTION("UKS GGA eval_zmat NYI");
if( func.is_gga() ) lwd->eval_zmat_gga_vxc_uks( &device_data, DEN_S );
else lwd->eval_zmat_lda_vxc_uks( &device_data, DEN_S );
// Increment Scalar VXC
lwd->inc_vxc( &device_data, DEN_S );

// Repeat for Z VXC
//if( func.is_gga() ) lwd->eval_zmat_gga_vxc_uks( &device_data, DEN_Z );
if( func.is_gga() ) GAUXC_GENERIC_EXCEPTION("UKS GGA eval_zmat NYI");
if( func.is_gga() ) lwd->eval_zmat_gga_vxc_uks( &device_data, DEN_Z );
else lwd->eval_zmat_lda_vxc_uks( &device_data, DEN_Z );
lwd->inc_vxc( &device_data, DEN_Z );

Expand Down
39 changes: 23 additions & 16 deletions src/xc_integrator/xc_data/device/xc_device_aos_data.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -466,21 +466,19 @@ void XCDeviceAoSData::pack_and_send(
buffer_adaptor xmat_dy_mem( aos_stack.xmat_dy_device, total_nbe_bfn_npts );
buffer_adaptor xmat_dz_mem( aos_stack.xmat_dz_device, total_nbe_bfn_npts );

size_t den_vrho_eval_npts = total_npts;
if (terms.ks_scheme == UKS ) {
// Use den_eval_device to store interleaved density before eval_kern_exc_vxc
den_vrho_eval_npts *= 2; }
int den_fac = (terms.ks_scheme == UKS) ? 2 : 1;
int gamma_fac = (terms.ks_scheme == UKS) ? 3 : 1;

buffer_adaptor den_mem ( base_stack.den_eval_device, den_vrho_eval_npts );

buffer_adaptor dden_x_mem( base_stack.den_x_eval_device, total_npts );
buffer_adaptor dden_y_mem( base_stack.den_y_eval_device, total_npts );
buffer_adaptor dden_z_mem( base_stack.den_z_eval_device, total_npts );

buffer_adaptor eps_mem( base_stack.eps_eval_device, total_npts );
buffer_adaptor gamma_mem( base_stack.gamma_eval_device, total_npts );
buffer_adaptor vrho_mem( base_stack.vrho_eval_device, den_vrho_eval_npts );
buffer_adaptor vgamma_mem( base_stack.vgamma_eval_device, total_npts );
buffer_adaptor den_mem ( base_stack.den_eval_device, total_npts * den_fac );
buffer_adaptor eps_mem ( base_stack.eps_eval_device, total_npts );
buffer_adaptor gamma_mem ( base_stack.gamma_eval_device, total_npts * gamma_fac );
buffer_adaptor vrho_mem ( base_stack.vrho_eval_device, total_npts * den_fac );
buffer_adaptor vgamma_mem ( base_stack.vgamma_eval_device, total_npts * gamma_fac );

// UKS
buffer_adaptor den_pos_mem( base_stack.den_pos_eval_device, total_npts );
Expand All @@ -498,6 +496,9 @@ void XCDeviceAoSData::pack_and_send(
buffer_adaptor gamma_pp_mem( base_stack.gamma_pp_eval_device, total_npts );
buffer_adaptor gamma_pm_mem( base_stack.gamma_pm_eval_device, total_npts );
buffer_adaptor gamma_mm_mem( base_stack.gamma_mm_eval_device, total_npts );
buffer_adaptor vgamma_pp_mem( base_stack.vgamma_pp_eval_device, total_npts );
buffer_adaptor vgamma_pm_mem( base_stack.vgamma_pm_eval_device, total_npts );
buffer_adaptor vgamma_mm_mem( base_stack.vgamma_mm_eval_device, total_npts );

for( auto& task : host_device_tasks ) {
const auto npts = task.npts;
Expand Down Expand Up @@ -591,13 +592,6 @@ void XCDeviceAoSData::pack_and_send(
}
}

// Allocate UKS specific tasks
if(terms.ks_scheme == UKS) {
task.den_pos = den_pos_mem.aligned_alloc<double>( npts, csl);
task.den_neg = den_neg_mem.aligned_alloc<double>( npts, csl);
task.vrho_pos = vrho_pos_mem.aligned_alloc<double>( npts, csl);
task.vrho_neg = vrho_neg_mem.aligned_alloc<double>( npts, csl);
}


task.gamma =
Expand All @@ -610,6 +604,19 @@ void XCDeviceAoSData::pack_and_send(
task.vgamma =
vgamma_mem.aligned_alloc<double>( reqt.grid_vgamma_size(npts), csl);

// Allocate UKS specific tasks
if(terms.ks_scheme == UKS) {
task.den_pos = den_pos_mem.aligned_alloc<double>( npts, csl);
task.den_neg = den_neg_mem.aligned_alloc<double>( npts, csl);
task.vrho_pos = vrho_pos_mem.aligned_alloc<double>( npts, csl);
task.vrho_neg = vrho_neg_mem.aligned_alloc<double>( npts, csl);
if (reqt.grid_vgamma ) {
task.vgamma_pp = vgamma_pp_mem.aligned_alloc<double>( npts, csl);
task.vgamma_pm = vgamma_pm_mem.aligned_alloc<double>( npts, csl);
task.vgamma_mm = vgamma_mm_mem.aligned_alloc<double>( npts, csl);
}
}

// EXX Specific
task.fmat = fmat_mem.aligned_alloc<double>(
reqt.task_fmat_size(nbe_cou,npts), csl);
Expand Down
13 changes: 6 additions & 7 deletions src/xc_integrator/xc_data/device/xc_device_stack_data.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(
if( reqt.grid_den ) { // Density
if( is_den ) base_stack.den_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_rks ) base_stack.den_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.den_pos_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.den_eval_device = mem.aligned_alloc<double>(2*msz, aln, csl);
base_stack.den_pos_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.den_neg_eval_device = mem.aligned_alloc<double>(msz, aln, csl); }

}
Expand All @@ -719,14 +720,11 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(
base_stack.den_neg_z_eval_device = mem.aligned_alloc<double>(msz, aln, csl); }
}

if( is_uks and reqt.grid_den ) { // Interleaved density storage
if( not is_gga ) base_stack.den_eval_device = mem.aligned_alloc<double>(2 * msz, aln, csl);
else base_stack.den_eval_device = mem.aligned_alloc<double>(8 * msz, aln, csl); //GGA
}

if( reqt.grid_gamma ) { // Gamma
if( is_rks ) base_stack.gamma_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.gamma_pp_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.gamma_eval_device = mem.aligned_alloc<double>(3 * msz, aln, csl);
base_stack.gamma_pp_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.gamma_pm_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.gamma_mm_eval_device = mem.aligned_alloc<double>(msz, aln, csl); }
}
Expand All @@ -739,7 +737,8 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(

if( reqt.grid_vgamma ) { // Vgamma
if( is_rks ) base_stack.vgamma_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.vgamma_pp_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.vgamma_eval_device = mem.aligned_alloc<double>(3*msz, aln, csl);
base_stack.vgamma_pp_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.vgamma_pm_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.vgamma_mm_eval_device = mem.aligned_alloc<double>(msz, aln, csl); }
}
Expand Down
3 changes: 3 additions & 0 deletions src/xc_integrator/xc_data/device/xc_device_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct XCDeviceTask {
double* gamma_pp = nullptr;
double* gamma_pm = nullptr;
double* gamma_mm = nullptr;
double* vgamma_pp = nullptr;
double* vgamma_pm = nullptr;
double* vgamma_mm = nullptr;

int32_t iParent = -1;
double dist_nearest = 0.;
Expand Down

0 comments on commit e92977a

Please sign in to comment.