Skip to content

Commit

Permalink
bugfixes and added awareness of job type to memory manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikael Alexander Kovtun committed Dec 4, 2023
1 parent a786f97 commit 4035d97
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 28 deletions.
25 changes: 16 additions & 9 deletions src/xc_integrator/local_work_driver/device/cuda/kernels/uvvars.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,24 @@ void eval_uvvars_lda( size_t ntasks, int32_t nbf_max, int32_t npts_max, integrat
dim3 blocks( util::div_ceil( nbf_max, threads.x ),
util::div_ceil( npts_max, threads.y ),
ntasks );
switch ( enabled_terms.ks_scheme ) {
case RKS:
if( enabled_terms.den ) {
eval_uvars_lda_rks_kernel<<< blocks, threads, 0, stream >>>( ntasks, device_tasks );
break;
case UKS:
eval_uvars_lda_uks_kernel<<< blocks, threads, 0, stream >>>( ntasks, device_tasks );
break;
default:
GAUXC_GENERIC_EXCEPTION( "Unexpected KS scheme when attempting to evaluate UV vars" );
}

else {
switch ( enabled_terms.ks_scheme ) {
case RKS:
eval_uvars_lda_rks_kernel<<< blocks, threads, 0, stream >>>( ntasks, device_tasks );
break;
case UKS:
eval_uvars_lda_uks_kernel<<< blocks, threads, 0, stream >>>( ntasks, device_tasks );
break;
case GKS:
GAUXC_GENERIC_EXCEPTION( "Device GKS NYI!" );
break;
default:
GAUXC_GENERIC_EXCEPTION( "Unexpected KS scheme when attempting to evaluate UV vars" );
}
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ void AoSScheme1Base::eval_uvvar_lda( XCDeviceData* _data, integrator_term_tracke

// Zero density
auto base_stack = data->base_stack;
if (en_terms.ks_scheme == RKS )
if (en_terms.ks_scheme == RKS or en_terms.den )
data->device_backend_->set_zero_async_master_queue( data->total_npts_task_batch, base_stack.den_eval_device, "Den Zero" );


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
// Processes batches in groups that saturadate available device memory
integrator_term_tracker enabled_terms;
enabled_terms.exc_grad = true;
enabled_terms.ks_scheme = RKS;
if( func.is_lda() ) enabled_terms.xc_approx = integrator_xc_approx::LDA;
else if( func.is_gga() ) enabled_terms.xc_approx = integrator_xc_approx::GGA;
else GAUXC_GENERIC_EXCEPTION("XC Approx NYI");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
device_data.zero_exc_vxc_integrands(enabled_terms);


if( func.is_gga() and is_uks ) GAUXC_GENERIC_EXCEPTION("UKS GGA NYI");

auto task_it = task_begin;
while( task_it != task_end ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ void ShellBatchedReplicatedXCDeviceIntegrator<ValueType>::
this->timer_.time_op("XCIntegrator.DeviceAlloc",
[&](){ return lwd->create_device_data(rt); });

if(this->func_->is_gga()) GAUXC_GENERIC_EXCEPTION( "GGA+UKS NYI!" );

// Generate incore integrator instance, transfer ownership of LWD
incore_integrator_type incore_integrator( this->func_, this->load_balancer_,
Expand Down
5 changes: 3 additions & 2 deletions src/xc_integrator/xc_data/device/xc_device_aos_data.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,9 @@ void XCDeviceAoSData::pack_and_send(
buffer_adaptor xmat_dz_mem( aos_stack.xmat_dz_device, total_nbe_bfn_npts );

size_t den_vrho_eval_npts = total_npts;
if (terms.ks_scheme == UKS ) // Use den_eval_device to store interleaved density before eval_kern_exc_vxc
den_vrho_eval_npts *= 2;
if (terms.ks_scheme == UKS ) {
// Use den_eval_device to store interleaved density before eval_kern_exc_vxc
den_vrho_eval_npts *= 2; }

buffer_adaptor den_mem ( base_stack.den_eval_device, den_vrho_eval_npts );

Expand Down
35 changes: 23 additions & 12 deletions src/xc_integrator/xc_data/device/xc_device_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct integrator_term_tracker {
}
};

#define PRDVL(pred,val) (pred) ? (val) : 0ul;
#define PRDVL(pred,val) (pred) ? (val) : 0ul

struct required_term_storage {
bool grid_points = false;
Expand All @@ -78,31 +78,37 @@ struct required_term_storage {


// Reference flags for memory management use
integrator_ks_scheme ref = _UNDEF_SCHEME;
integrator_term_tracker ref_tracker;

inline size_t grid_den_size(size_t npts){
// grid_den_size takes into account the size of the interleaved density sent to ExchCXX in the cases of UKS/GKS (hence the * 2)
return PRDVL(grid_den and ref == RKS, npts)
+ PRDVL(grid_den and ref == UKS and not grid_den_grad, 2 * npts) // LDA
+ PRDVL(grid_den and ref == UKS and grid_den_grad, 5 * npts); // GGA
if( grid_den ) {
if( ref_tracker.den )
return npts;
else if (ref_tracker.ks_scheme == RKS)
return npts;
else if (ref_tracker.ks_scheme == UKS)
return 2 * npts;
}
}
inline size_t grid_den_grad_size(size_t npts){
return PRDVL(grid_den_grad, 3 * npts);
return PRDVL(grid_den_grad and ref_tracker.ks_scheme == RKS, 3 * npts)
+ PRDVL(grid_den_grad and ref_tracker.ks_scheme == UKS, 6 * npts);
}
inline size_t grid_gamma_size(size_t npts){
return PRDVL(grid_gamma and ref == RKS, npts)
+ PRDVL(grid_gamma and ref == UKS, 3 * npts);
return PRDVL(grid_gamma and ref_tracker.ks_scheme == RKS, npts)
+ PRDVL(grid_gamma and ref_tracker.ks_scheme == UKS, 3 * npts);
}
inline size_t grid_eps_size(size_t npts){
return PRDVL(grid_eps, npts);
}
inline size_t grid_vrho_size(size_t npts){
return PRDVL(grid_vrho and ref == RKS, npts)
+ PRDVL(grid_vrho and ref == UKS, 2 * npts);
return PRDVL(grid_vrho and ref_tracker.ks_scheme == RKS, npts)
+ PRDVL(grid_vrho and ref_tracker.ks_scheme == UKS, 2 * npts);
}
inline size_t grid_vgamma_size(size_t npts){
return PRDVL(grid_vgamma and ref == RKS, npts)
+ PRDVL(grid_vgamma and ref == UKS, 3 * npts);
return PRDVL(grid_vgamma and ref_tracker.ks_scheme == RKS, npts)
+ PRDVL(grid_vgamma and ref_tracker.ks_scheme == UKS, 3 * npts);
}


Expand Down Expand Up @@ -254,9 +260,14 @@ struct required_term_storage {

// Allocated terms for XC calculations
const bool is_xc = tracker.exc_vxc or tracker.exc_grad;

ref_tracker = tracker;

if(is_xc) {
if( tracker.xc_approx == _UNDEF_APPROX )
GAUXC_GENERIC_EXCEPTION("NO XC APPROX SET");
if( tracker.ks_scheme == _UNDEF_SCHEME )
GAUXC_GENERIC_EXCEPTION("NO KS SCHEME SET");
//const bool is_lda = is_xc and tracker.xc_approx == LDA;
const bool is_gga = is_xc and tracker.xc_approx == GGA;
const bool is_grad = tracker.exc_grad;
Expand Down
9 changes: 7 additions & 2 deletions src/xc_integrator/xc_data/device/xc_device_stack_data.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,14 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(
required_term_storage reqt(terms);
const size_t msz = total_npts_task_batch;
const size_t aln = 256;


// Below is only true if terms.exc_vxc is true
const bool is_rks = terms.ks_scheme == RKS;
const bool is_uks = terms.ks_scheme == UKS;
const bool is_gks = terms.ks_scheme == GKS;
const bool is_gga = reqt.grid_den_grad;

const bool is_den = terms.den;

// Grid Points
if( reqt.grid_points ) {
Expand All @@ -697,6 +701,7 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(

// Grid function evaluations
if( reqt.grid_den ) { // Density
if( is_den ) base_stack.den_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_rks ) base_stack.den_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
if( is_uks ) { base_stack.den_pos_eval_device = mem.aligned_alloc<double>(msz, aln, csl);
base_stack.den_neg_eval_device = mem.aligned_alloc<double>(msz, aln, csl); }
Expand All @@ -715,7 +720,7 @@ XCDeviceStackData::device_buffer_t XCDeviceStackData::allocate_dynamic_stack(
}

if( is_uks and reqt.grid_den ) { // Interleaved density storage
if( not reqt.grid_den_grad ) base_stack.den_eval_device = mem.aligned_alloc<double>(2 * msz, aln, csl);
if( not is_gga ) base_stack.den_eval_device = mem.aligned_alloc<double>(2 * msz, aln, csl);
else base_stack.den_eval_device = mem.aligned_alloc<double>(8 * msz, aln, csl); //GGA
}

Expand Down

0 comments on commit 4035d97

Please sign in to comment.