Skip to content

Commit

Permalink
Merge pull request #84 from dmejiar/master
Browse files Browse the repository at this point in the history
Update mGGA implementation to new API and add unrestricted case
  • Loading branch information
wavefunction91 authored Nov 20, 2023
2 parents 48deaa8 + 326532a commit d4956d1
Show file tree
Hide file tree
Showing 16 changed files with 1,183 additions and 37 deletions.
4 changes: 4 additions & 0 deletions include/gauxc/xc_integrator/integrator_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class XCIntegratorFactory {
std::shared_ptr<functional_type> func,
std::shared_ptr<LoadBalancer> lb ) {

// Early check for MGGAs and Device
if( ex_ == ExecutionSpace::Device && func->is_mgga() )
GAUXC_GENERIC_EXCEPTION("DEVICE IS NOT READY FOR MGGA");

// Create Local Work Driver
auto lwd = LocalWorkDriverFactory::make_local_work_driver( ex_,
lwd_kernel_, local_work_settings_ );
Expand Down
121 changes: 117 additions & 4 deletions src/xc_integrator/local_work_driver/host/local_host_work_driver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ void LocalHostWorkDriver::eval_collocation_hessian( size_t npts, size_t nshells,

}

// Collocation 3rd
void LocalHostWorkDriver::eval_collocation_der3( size_t npts, size_t nshells, size_t nbe,
const double* pts, const BasisSet<double>& basis, const int32_t* shell_list,
double* basis_eval, double* dbasis_x_eval, double* dbasis_y_eval,
double* dbasis_z_eval, double* d2basis_xx_eval, double* d2basis_xy_eval,
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval, double* d3basis_xxx_eval, double* d3basis_xxy_eval,
double* d3basis_xxz_eval, double* d3basis_xyy_eval, double* d3basis_xyz_eval,
double* d3basis_xzz_eval, double* d3basis_yyy_eval, double* d3basis_yyz_eval,
double* d3basis_yzz_eval, double* d3basis_zzz_eval) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_collocation_der3(npts, nshells, nbe, pts, basis, shell_list, basis_eval,
dbasis_x_eval, dbasis_y_eval, dbasis_z_eval, d2basis_xx_eval, d2basis_xy_eval,
d2basis_xz_eval, d2basis_yy_eval, d2basis_yz_eval, d2basis_zz_eval,
d3basis_xxx_eval, d3basis_xxy_eval, d3basis_xxz_eval, d3basis_xyy_eval,
d3basis_xyz_eval, d3basis_xzz_eval, d3basis_yyy_eval, d3basis_yyz_eval,
d3basis_yzz_eval, d3basis_zzz_eval);

}


// X matrix (fac * P * B)
void LocalHostWorkDriver::eval_xmat( size_t npts, size_t nbf, size_t nbe,
Expand All @@ -89,7 +110,6 @@ void LocalHostWorkDriver::eval_xmat( size_t npts, size_t nbf, size_t nbe,

}


void LocalHostWorkDriver::eval_exx_fmat( size_t npts, size_t nbf, size_t nbe_bra,
size_t nbe_ket, const submat_map_t& submat_map_bra,
const submat_map_t& submat_map_ket, const double* P, size_t ldp,
Expand Down Expand Up @@ -176,6 +196,40 @@ void LocalHostWorkDriver::eval_uvvar_gga_uks( size_t npts, size_t nbe,

}

// U/VVar MGGA(density, grad, gamma, tau, lapl)
void LocalHostWorkDriver::eval_uvvar_mgga_rks( size_t npts, size_t nbe,
const double* basis_eval, const double* dbasis_x_eval,
const double* dbasis_y_eval, const double* dbasis_z_eval, const double* lbasis_eval,
const double* X, size_t ldx, const double* mmat_x, const double* mmat_y, const double* mmat_z,
size_t ldm, double* den_eval, double* dden_x_eval, double* dden_y_eval,
double* dden_z_eval, double* gamma, double* tau, double* lapl ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_uvvar_mgga_rks(npts, nbe, basis_eval, dbasis_x_eval, dbasis_y_eval,
dbasis_z_eval, lbasis_eval, X, ldx, mmat_x, mmat_y, mmat_z, ldm, den_eval, dden_x_eval, dden_y_eval, dden_z_eval,
gamma, tau, lapl);

}


// U/VVar MGGA(density, grad, gamma, tau, lapl)
void LocalHostWorkDriver::eval_uvvar_mgga_uks( size_t npts, size_t nbe,
const double* basis_eval, const double* dbasis_x_eval,
const double* dbasis_y_eval, const double* dbasis_z_eval, const double* lbasis_eval,
const double* Xs, size_t ldxs, const double* Xz, size_t ldxz,
const double* mmat_xs, const double* mmat_ys, const double* mmat_zs, size_t ldms,
const double* mmat_xz, const double* mmat_yz, const double* mmat_zz, size_t ldmz,
double* den_eval, double* dden_x_eval, double* dden_y_eval,
double* dden_z_eval, double* gamma, double* tau, double* lapl ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_uvvar_mgga_uks(npts, nbe, basis_eval, dbasis_x_eval, dbasis_y_eval,
dbasis_z_eval, lbasis_eval, Xs, ldxs, Xz, ldxz, mmat_xs, mmat_ys, mmat_zs, ldms,
mmat_xz, mmat_yz, mmat_zz, ldmz, den_eval, dden_x_eval, dden_y_eval, dden_z_eval,
gamma, tau, lapl);

}

// Eval Z Matrix LDA VXC
void LocalHostWorkDriver::eval_zmat_lda_vxc_rks( size_t npts, size_t nbe,
const double* vrho, const double* basis_eval, double* Z, size_t ldz ) {
Expand Down Expand Up @@ -210,7 +264,6 @@ void LocalHostWorkDriver::eval_zmat_gga_vxc_rks( size_t npts, size_t nbe,

}


void LocalHostWorkDriver::eval_zmat_gga_vxc_uks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval,
Expand All @@ -225,14 +278,74 @@ void LocalHostWorkDriver::eval_zmat_gga_vxc_uks( size_t npts, size_t nbe,

}


// Eval Z Matrix MGGA VXC
void LocalHostWorkDriver::eval_zmat_mgga_vxc_rks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* vlapl,
const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval, const double* dden_x_eval,
const double* dden_y_eval, const double* dden_z_eval, double* Z, size_t ldz ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_zmat_mgga_vxc_rks(npts, nbe, vrho, vgamma, vlapl, basis_eval, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, lbasis_eval, dden_x_eval, dden_y_eval, dden_z_eval,
Z, ldz);

}


// Eval Z Matrix MGGA VXC
void LocalHostWorkDriver::eval_zmat_mgga_vxc_uks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* vlapl,
const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval, const double* dden_x_eval,
const double* dden_y_eval, const double* dden_z_eval, double* Zs, size_t ldzs,
double* Zz, size_t ldzz) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_zmat_mgga_vxc_uks(npts, nbe, vrho, vgamma, vlapl, basis_eval, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, lbasis_eval, dden_x_eval, dden_y_eval, dden_z_eval,
Zs, ldzs, Zz, ldzz);

}


// Eval M Matrix MGGA VXC
void LocalHostWorkDriver::eval_mmat_mgga_vxc_rks( size_t npts, size_t nbe,
const double* vtau, const double* vlapl,
const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_x, double* mmat_y, double* mmat_z, size_t ldm ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_mmat_mgga_vxc_rks(npts, nbe, vtau, vlapl, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, mmat_x, mmat_y, mmat_z, ldm);

}


// Eval M Matrix MGGA VXC
void LocalHostWorkDriver::eval_mmat_mgga_vxc_uks( size_t npts, size_t nbe,
const double* vtau, const double* vlapl,
const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_xs, double* mmat_ys, double* mmat_zs, size_t ldms,
double* mmat_xz, double* mmat_yz, double* mmat_zz, size_t ldmz ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_mmat_mgga_vxc_uks(npts, nbe, vtau, vlapl, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, mmat_xs, mmat_ys, mmat_zs, ldms, mmat_xz, mmat_yz,
mmat_zz, ldmz );

}

// Increment VXC by Z
void LocalHostWorkDriver::inc_vxc( size_t npts, size_t nbf, size_t nbe,
const double* basis_eval, const submat_map_t& submat_map, const double* Z,
size_t ldz, double* VXC, size_t ldvxc, double* scr ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->inc_vxc(npts, nbf, nbe, basis_eval, submat_map, Z, ldz, VXC, ldvxc,
scr);
pimpl_->inc_vxc(npts, nbf, nbe, basis_eval, submat_map, Z, ldz, VXC, ldvxc, scr);

}

Expand Down
135 changes: 133 additions & 2 deletions src/xc_integrator/local_work_driver/host/local_host_work_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,46 @@ class LocalHostWorkDriver : public LocalWorkDriver {
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval );

/** Evaluation the collocation matrix + gradient + hessian + 3rd derivatives
*
* @param[in] npts Same as `eval_collocation`
* @param[in] nshells Same as `eval_collocation`
* @param[in] nbe Same as `eval_collocation`
* @param[in] pts Same as `eval_collocation`
* @param[in] basis Same as `eval_collocation`
* @param[in] shell_list Same as `eval_collocation`
*
* @param[out] basis_eval Same as `eval_collocation`
* @param[out] dbasis_x_eval Same as `eval_collocation_gradient`
* @param[out] dbasis_y_eval Same as `eval_collocation_gradient`
* @param[out] dbasis_z_eval Same as `eval_collocation_gradient`
* @param[out] d2basis_xx_eval Derivative of `basis_eval` wrt x+x (same dimensions)
* @param[out] d2basis_xy_eval Derivative of `basis_eval` wrt x+y (same dimensions)
* @param[out] d2basis_xz_eval Derivative of `basis_eval` wrt x+z (same dimensions)
* @param[out] d2basis_yy_eval Derivative of `basis_eval` wrt y+y (same dimensions)
* @param[out] d2basis_yz_eval Derivative of `basis_eval` wrt y+z (same dimensions)
* @param[out] d2basis_zz_eval Derivative of `basis_eval` wrt z+z (same dimensions)
* @param[out] d3basis_xxx_eval Derivative of `basis_eval` wrt x+x+x (same dimensions)
* @param[out] d3basis_xxy_eval Derivative of `basis_eval` wrt x+x+y (same dimensions)
* @param[out] d3basis_xxz_eval Derivative of `basis_eval` wrt x+x+z (same dimensions)
* @param[out] d3basis_xyy_eval Derivative of `basis_eval` wrt x+y+y (same dimensions)
* @param[out] d3basis_xyz_eval Derivative of `basis_eval` wrt x+y+z (same dimensions)
* @param[out] d3basis_xzz_eval Derivative of `basis_eval` wrt x+z+z (same dimensions)
* @param[out] d3basis_yyy_eval Derivative of `basis_eval` wrt y+y+y (same dimensions)
* @param[out] d3basis_yyz_eval Derivative of `basis_eval` wrt y+y+z (same dimensions)
* @param[out] d3basis_yzz_eval Derivative of `basis_eval` wrt y+z+z (same dimensions)
* @param[out] d3basis_zzz_eval Derivative of `basis_eval` wrt z+z+z (same dimensions)
*/
void eval_collocation_der3( size_t npts, size_t nshells, size_t nbe,
const double* pts, const BasisSet<double>& basis, const int32_t* shell_list,
double* basis_eval, double* dbasis_x_eval, double* dbasis_y_eval,
double* dbasis_z_eval, double* d2basis_xx_eval, double* d2basis_xy_eval,
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval, double* d3basis_xxx_eval, double* d3basis_xxy_eval,
double* d3basis_xxz_eval, double* d3basis_xyy_eval, double* d3basis_xyz_eval,
double* d3basis_xzz_eval, double* d3basis_yyy_eval, double* d3basis_yyz_eval,
double* d3basis_yzz_eval, double* d3basis_zzz_eval);

/** Evaluate the compressed "X" matrix = fac * P * B
*
* @param[in] npts The number of points in the collocation matrix
Expand Down Expand Up @@ -238,6 +278,49 @@ class LocalHostWorkDriver : public LocalWorkDriver {
const double* Xz, size_t ldxz, double* den_eval,
double* dden_x_eval, double* dden_y_eval, double* dden_z_eval, double* gamma );

/** Evaluate the U and V variavles for RKS MGGA
*
* U = rho + gradient + tau + lapl
* V = rho + gamma + tau + lapl
*
* @param[in] npts Same as `eval_uvvar_lda`
* @param[in] nbe Same as `eval_uvvar_lda`
* @param[in] basis_eval Same as `eval_uvvar_lda`
* @param[in] dbasis_x_eval Derivative of `basis_eval` wrt x (same dims)
* @param[in] dbasis_y_eval Derivative of `basis_eval` wrt y (same dims)
* @param[in] dbasis_z_eval Derivative of `basis_eval` wrt z (same dims)
* @param[in] lbasis_eval Laplacian of `basis_eval` (same dims)
* @param[in] X Same as `eval_uvvar_lda`
* @param[in] ldx Same as `eval_uvvar_lda`
* @param[in] mmat_x
* @param[in] mmat_y
* @param[in] mmat_z
* @param[in] ldm
* @param[out] den_eval Same as `eval_uvvar_lda`
* @param[out] dden_x_eval Derivative of `den_eval` wrt x (npts)
* @param[out] dden_y_eval Derivative of `den_eval` wrt y (npts)
* @param[out] dden_z_eval Derivative of `den_eval` wrt z (npts)
* @param[out] gamma |grad rho|^2 (npts)
* @param[out] tau
* @param[out] lapl
*
*/
void eval_uvvar_mgga_rks( size_t npts, size_t nbe, const double* basis_eval,
const double* dbasis_x_eavl, const double* dbasis_y_eval,
const double* dbasis_z_eval, const double* lbasis_eval,
const double* X, size_t ldx, const double* mmat_x,
const double* mmat_y, const double* mmat_z, size_t ldm, double* den_eval,
double* dden_x_eval, double* dden_y_eval, double* dden_z_eval, double* gamma,
double* tau, double* lapl);
void eval_uvvar_mgga_uks( size_t npts, size_t nbe, const double* basis_eval,
const double* dbasis_x_eavl, const double* dbasis_y_eval,
const double* dbasis_z_eval, const double* lbasis_eval,
const double* Xs, size_t ldxs, const double* Xz, size_t ldxz,
const double* mmat_xs, const double* mmat_ys, const double* mmat_zs, size_t ldms,
const double* mmat_xz, const double* mmat_yz, const double* mmat_zz, size_t ldmz,
double* den_eval, double* dden_x_eval, double* dden_y_eval, double* dden_z_eval,
double* gamma, double* tau, double* lapl);

/** Evaluate the VXC Z Matrix for RKS LDA
*
* Z(mu,i) = 0.5 * vrho(i) * B(mu, i)
Expand Down Expand Up @@ -294,9 +377,57 @@ class LocalHostWorkDriver : public LocalWorkDriver {
double* Zs, size_t ldzs, double* Zz, size_t ldzz );


/** Evaluate the VXC Z Matrix for RKS MGGA
*
* Z(mu,i) = 0.5 * vrho(i) * B(mu, i) +
* 2.0 * vgamma(i) * (grad B(mu,i)) . (grad rho(i)) +
* 0.5 * vlapl(i) * lapl B(mu, i)
*
* TODO: Need to add an API for UKS/GKS
*
* @param[in] npts Same as `eval_zmat_lda_vxc`
* @param[in] nbe Same as `eval_zmat_lda_vxc`
* @param[in] vrho Same as `eval_zmat_lda_vxc`
* @param[in] vgamma Derivative of the XC functional wrt gamma scaled by quad weights (npts)
* @param[in] basis_eval Same as `eval_zmat_lda_vxc`
* @param[in] dbasis_x_eval Derivative of `basis_eval` wrt x (same dims)
* @param[in] dbasis_y_eval Derivative of `basis_eval` wrt y (same dims)
* @param[in] dbasis_z_eval Derivative of `basis_eval` wrt z (same dims)
* @param[in] lbasis_eval Laplacian of `basis_eval` (same dims)
* @param[in] dden_x_eval Derivative of rho wrt x (npts)
* @param[in] dden_y_eval Derivative of rho wrt y (npts)
* @param[in] dden_z_eval Derivative of rho wrt z (npts)
* @param[out] Z Same as `eval_zmat_lda_vxc`
* @param[in] ldz Same as `eval_zmat_lda_vxc`
*
*/
void eval_zmat_mgga_vxc_rks( size_t npts, size_t nbe, const double* vrho,
const double* vgamma, const double* vlapl, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval,
const double* dden_x_eval, const double* dden_y_eval, const double* dden_z_eval,
double* Z, size_t ldz );
void eval_zmat_mgga_vxc_uks( size_t npts, size_t nbe, const double* vrho,
const double* vgamma, const double* vlapl, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval,
const double* dden_x_eval, const double* dden_y_eval, const double* dden_z_eval,
double* Zs, size_t ldzs, double* Zz, size_t ldzz );
void eval_mmat_mgga_vxc_rks( size_t npts, size_t nbe, const double* vtau,
const double* vlapl, const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_x, double* mmat_y, double* mmat_z,
size_t ldm);
void eval_mmat_mgga_vxc_uks( size_t npts, size_t nbe, const double* vtau,
const double* vlapl, const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_xs, double* mmat_ys, double* mmat_zs,
size_t ldms, double* mmat_xz, double* mmat_yz, double* mmat_zz, size_t ldmz);



/** Increment VXC integrand given Z / Collocation (RKS LDA+GGA)
*
* VXC += Z**H * B + h.c.
* VXC += M**H . dB + h.c.
*
* Only updates lower triangle
*
Expand All @@ -313,8 +444,8 @@ class LocalHostWorkDriver : public LocalWorkDriver {
*
*/
void inc_vxc( size_t npts, size_t nbf, size_t nbe, const double* basis_eval,
const submat_map_t& submat_map, const double* Z, size_t ldz, double* VXC,
size_t ldvxc, double* scr );
const submat_map_t& submat_map, const double* Z, size_t ldz,
double* VXC, size_t ldvxc, double* scr );

private:

Expand Down
Loading

0 comments on commit d4956d1

Please sign in to comment.