Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mGGA implementation to new API and add unrestricted case #84

Merged
merged 13 commits into from
Nov 20, 2023
4 changes: 4 additions & 0 deletions include/gauxc/xc_integrator/integrator_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class XCIntegratorFactory {
std::shared_ptr<functional_type> func,
std::shared_ptr<LoadBalancer> lb ) {

// Early check for MGGAs and Device
if( ex_ == ExecutionSpace::Device && func->is_mgga() )
GAUXC_GENERIC_EXCEPTION("DEVICE IS NOT READY FOR MGGA");

// Create Local Work Driver
auto lwd = LocalWorkDriverFactory::make_local_work_driver( ex_,
lwd_kernel_, local_work_settings_ );
Expand Down
121 changes: 117 additions & 4 deletions src/xc_integrator/local_work_driver/host/local_host_work_driver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ void LocalHostWorkDriver::eval_collocation_hessian( size_t npts, size_t nshells,

}

// Collocation 3rd
void LocalHostWorkDriver::eval_collocation_der3( size_t npts, size_t nshells, size_t nbe,
wavefunction91 marked this conversation as resolved.
Show resolved Hide resolved
const double* pts, const BasisSet<double>& basis, const int32_t* shell_list,
double* basis_eval, double* dbasis_x_eval, double* dbasis_y_eval,
double* dbasis_z_eval, double* d2basis_xx_eval, double* d2basis_xy_eval,
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval, double* d3basis_xxx_eval, double* d3basis_xxy_eval,
double* d3basis_xxz_eval, double* d3basis_xyy_eval, double* d3basis_xyz_eval,
double* d3basis_xzz_eval, double* d3basis_yyy_eval, double* d3basis_yyz_eval,
double* d3basis_yzz_eval, double* d3basis_zzz_eval) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_collocation_der3(npts, nshells, nbe, pts, basis, shell_list, basis_eval,
dbasis_x_eval, dbasis_y_eval, dbasis_z_eval, d2basis_xx_eval, d2basis_xy_eval,
d2basis_xz_eval, d2basis_yy_eval, d2basis_yz_eval, d2basis_zz_eval,
d3basis_xxx_eval, d3basis_xxy_eval, d3basis_xxz_eval, d3basis_xyy_eval,
d3basis_xyz_eval, d3basis_xzz_eval, d3basis_yyy_eval, d3basis_yyz_eval,
d3basis_yzz_eval, d3basis_zzz_eval);

}


// X matrix (fac * P * B)
void LocalHostWorkDriver::eval_xmat( size_t npts, size_t nbf, size_t nbe,
Expand All @@ -89,7 +110,6 @@ void LocalHostWorkDriver::eval_xmat( size_t npts, size_t nbf, size_t nbe,

}


void LocalHostWorkDriver::eval_exx_fmat( size_t npts, size_t nbf, size_t nbe_bra,
size_t nbe_ket, const submat_map_t& submat_map_bra,
const submat_map_t& submat_map_ket, const double* P, size_t ldp,
Expand Down Expand Up @@ -176,6 +196,40 @@ void LocalHostWorkDriver::eval_uvvar_gga_uks( size_t npts, size_t nbe,

}

// U/VVar MGGA(density, grad, gamma, tau, lapl)
void LocalHostWorkDriver::eval_uvvar_mgga_rks( size_t npts, size_t nbe,
const double* basis_eval, const double* dbasis_x_eval,
const double* dbasis_y_eval, const double* dbasis_z_eval, const double* lbasis_eval,
const double* X, size_t ldx, const double* mmat_x, const double* mmat_y, const double* mmat_z,
size_t ldm, double* den_eval, double* dden_x_eval, double* dden_y_eval,
double* dden_z_eval, double* gamma, double* tau, double* lapl ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_uvvar_mgga_rks(npts, nbe, basis_eval, dbasis_x_eval, dbasis_y_eval,
dbasis_z_eval, lbasis_eval, X, ldx, mmat_x, mmat_y, mmat_z, ldm, den_eval, dden_x_eval, dden_y_eval, dden_z_eval,
gamma, tau, lapl);

}


// U/VVar MGGA(density, grad, gamma, tau, lapl)
void LocalHostWorkDriver::eval_uvvar_mgga_uks( size_t npts, size_t nbe,
const double* basis_eval, const double* dbasis_x_eval,
const double* dbasis_y_eval, const double* dbasis_z_eval, const double* lbasis_eval,
const double* Xs, size_t ldxs, const double* Xz, size_t ldxz,
const double* mmat_xs, const double* mmat_ys, const double* mmat_zs, size_t ldms,
const double* mmat_xz, const double* mmat_yz, const double* mmat_zz, size_t ldmz,
double* den_eval, double* dden_x_eval, double* dden_y_eval,
double* dden_z_eval, double* gamma, double* tau, double* lapl ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_uvvar_mgga_uks(npts, nbe, basis_eval, dbasis_x_eval, dbasis_y_eval,
dbasis_z_eval, lbasis_eval, Xs, ldxs, Xz, ldxz, mmat_xs, mmat_ys, mmat_zs, ldms,
mmat_xz, mmat_yz, mmat_zz, ldmz, den_eval, dden_x_eval, dden_y_eval, dden_z_eval,
gamma, tau, lapl);

}

// Eval Z Matrix LDA VXC
void LocalHostWorkDriver::eval_zmat_lda_vxc_rks( size_t npts, size_t nbe,
const double* vrho, const double* basis_eval, double* Z, size_t ldz ) {
Expand Down Expand Up @@ -210,7 +264,6 @@ void LocalHostWorkDriver::eval_zmat_gga_vxc_rks( size_t npts, size_t nbe,

}


void LocalHostWorkDriver::eval_zmat_gga_vxc_uks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval,
Expand All @@ -225,14 +278,74 @@ void LocalHostWorkDriver::eval_zmat_gga_vxc_uks( size_t npts, size_t nbe,

}


// Eval Z Matrix MGGA VXC
void LocalHostWorkDriver::eval_zmat_mgga_vxc_rks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* vlapl,
const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval, const double* dden_x_eval,
const double* dden_y_eval, const double* dden_z_eval, double* Z, size_t ldz ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_zmat_mgga_vxc_rks(npts, nbe, vrho, vgamma, vlapl, basis_eval, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, lbasis_eval, dden_x_eval, dden_y_eval, dden_z_eval,
Z, ldz);

}


// Eval Z Matrix MGGA VXC
void LocalHostWorkDriver::eval_zmat_mgga_vxc_uks( size_t npts, size_t nbe,
const double* vrho, const double* vgamma, const double* vlapl,
const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval, const double* dden_x_eval,
const double* dden_y_eval, const double* dden_z_eval, double* Zs, size_t ldzs,
double* Zz, size_t ldzz) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_zmat_mgga_vxc_uks(npts, nbe, vrho, vgamma, vlapl, basis_eval, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, lbasis_eval, dden_x_eval, dden_y_eval, dden_z_eval,
Zs, ldzs, Zz, ldzz);

}


// Eval M Matrix MGGA VXC
void LocalHostWorkDriver::eval_mmat_mgga_vxc_rks( size_t npts, size_t nbe,
const double* vtau, const double* vlapl,
const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_x, double* mmat_y, double* mmat_z, size_t ldm ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_mmat_mgga_vxc_rks(npts, nbe, vtau, vlapl, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, mmat_x, mmat_y, mmat_z, ldm);

}


// Eval M Matrix MGGA VXC
void LocalHostWorkDriver::eval_mmat_mgga_vxc_uks( size_t npts, size_t nbe,
const double* vtau, const double* vlapl,
const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_xs, double* mmat_ys, double* mmat_zs, size_t ldms,
double* mmat_xz, double* mmat_yz, double* mmat_zz, size_t ldmz ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->eval_mmat_mgga_vxc_uks(npts, nbe, vtau, vlapl, dbasis_x_eval,
dbasis_y_eval, dbasis_z_eval, mmat_xs, mmat_ys, mmat_zs, ldms, mmat_xz, mmat_yz,
mmat_zz, ldmz );

}

// Increment VXC by Z
void LocalHostWorkDriver::inc_vxc( size_t npts, size_t nbf, size_t nbe,
const double* basis_eval, const submat_map_t& submat_map, const double* Z,
size_t ldz, double* VXC, size_t ldvxc, double* scr ) {

throw_if_invalid_pimpl(pimpl_);
pimpl_->inc_vxc(npts, nbf, nbe, basis_eval, submat_map, Z, ldz, VXC, ldvxc,
scr);
pimpl_->inc_vxc(npts, nbf, nbe, basis_eval, submat_map, Z, ldz, VXC, ldvxc, scr);

}

Expand Down
135 changes: 133 additions & 2 deletions src/xc_integrator/local_work_driver/host/local_host_work_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,46 @@ class LocalHostWorkDriver : public LocalWorkDriver {
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval );

/** Evaluation the collocation matrix + gradient + hessian + 3rd derivatives
*
* @param[in] npts Same as `eval_collocation`
* @param[in] nshells Same as `eval_collocation`
* @param[in] nbe Same as `eval_collocation`
* @param[in] pts Same as `eval_collocation`
* @param[in] basis Same as `eval_collocation`
* @param[in] shell_list Same as `eval_collocation`
*
* @param[out] basis_eval Same as `eval_collocation`
* @param[out] dbasis_x_eval Same as `eval_collocation_gradient`
* @param[out] dbasis_y_eval Same as `eval_collocation_gradient`
* @param[out] dbasis_z_eval Same as `eval_collocation_gradient`
* @param[out] d2basis_xx_eval Derivative of `basis_eval` wrt x+x (same dimensions)
* @param[out] d2basis_xy_eval Derivative of `basis_eval` wrt x+y (same dimensions)
* @param[out] d2basis_xz_eval Derivative of `basis_eval` wrt x+z (same dimensions)
* @param[out] d2basis_yy_eval Derivative of `basis_eval` wrt y+y (same dimensions)
* @param[out] d2basis_yz_eval Derivative of `basis_eval` wrt y+z (same dimensions)
* @param[out] d2basis_zz_eval Derivative of `basis_eval` wrt z+z (same dimensions)
* @param[out] d3basis_xxx_eval Derivative of `basis_eval` wrt x+x+x (same dimensions)
* @param[out] d3basis_xxy_eval Derivative of `basis_eval` wrt x+x+y (same dimensions)
* @param[out] d3basis_xxz_eval Derivative of `basis_eval` wrt x+x+z (same dimensions)
* @param[out] d3basis_xyy_eval Derivative of `basis_eval` wrt x+y+y (same dimensions)
* @param[out] d3basis_xyz_eval Derivative of `basis_eval` wrt x+y+z (same dimensions)
* @param[out] d3basis_xzz_eval Derivative of `basis_eval` wrt x+z+z (same dimensions)
* @param[out] d3basis_yyy_eval Derivative of `basis_eval` wrt y+y+y (same dimensions)
* @param[out] d3basis_yyz_eval Derivative of `basis_eval` wrt y+y+z (same dimensions)
* @param[out] d3basis_yzz_eval Derivative of `basis_eval` wrt y+z+z (same dimensions)
* @param[out] d3basis_zzz_eval Derivative of `basis_eval` wrt z+z+z (same dimensions)
*/
void eval_collocation_der3( size_t npts, size_t nshells, size_t nbe,
const double* pts, const BasisSet<double>& basis, const int32_t* shell_list,
double* basis_eval, double* dbasis_x_eval, double* dbasis_y_eval,
double* dbasis_z_eval, double* d2basis_xx_eval, double* d2basis_xy_eval,
double* d2basis_xz_eval, double* d2basis_yy_eval, double* d2basis_yz_eval,
double* d2basis_zz_eval, double* d3basis_xxx_eval, double* d3basis_xxy_eval,
double* d3basis_xxz_eval, double* d3basis_xyy_eval, double* d3basis_xyz_eval,
double* d3basis_xzz_eval, double* d3basis_yyy_eval, double* d3basis_yyz_eval,
double* d3basis_yzz_eval, double* d3basis_zzz_eval);

/** Evaluate the compressed "X" matrix = fac * P * B
*
* @param[in] npts The number of points in the collocation matrix
Expand Down Expand Up @@ -238,6 +278,49 @@ class LocalHostWorkDriver : public LocalWorkDriver {
const double* Xz, size_t ldxz, double* den_eval,
double* dden_x_eval, double* dden_y_eval, double* dden_z_eval, double* gamma );

/** Evaluate the U and V variavles for RKS MGGA
*
* U = rho + gradient + tau + lapl
* V = rho + gamma + tau + lapl
*
* @param[in] npts Same as `eval_uvvar_lda`
* @param[in] nbe Same as `eval_uvvar_lda`
* @param[in] basis_eval Same as `eval_uvvar_lda`
* @param[in] dbasis_x_eval Derivative of `basis_eval` wrt x (same dims)
* @param[in] dbasis_y_eval Derivative of `basis_eval` wrt y (same dims)
* @param[in] dbasis_z_eval Derivative of `basis_eval` wrt z (same dims)
* @param[in] lbasis_eval Laplacian of `basis_eval` (same dims)
* @param[in] X Same as `eval_uvvar_lda`
* @param[in] ldx Same as `eval_uvvar_lda`
* @param[in] mmat_x
* @param[in] mmat_y
* @param[in] mmat_z
* @param[in] ldm
* @param[out] den_eval Same as `eval_uvvar_lda`
* @param[out] dden_x_eval Derivative of `den_eval` wrt x (npts)
* @param[out] dden_y_eval Derivative of `den_eval` wrt y (npts)
* @param[out] dden_z_eval Derivative of `den_eval` wrt z (npts)
* @param[out] gamma |grad rho|^2 (npts)
* @param[out] tau
* @param[out] lapl
*
*/
void eval_uvvar_mgga_rks( size_t npts, size_t nbe, const double* basis_eval,
const double* dbasis_x_eavl, const double* dbasis_y_eval,
const double* dbasis_z_eval, const double* lbasis_eval,
const double* X, size_t ldx, const double* mmat_x,
const double* mmat_y, const double* mmat_z, size_t ldm, double* den_eval,
double* dden_x_eval, double* dden_y_eval, double* dden_z_eval, double* gamma,
double* tau, double* lapl);
void eval_uvvar_mgga_uks( size_t npts, size_t nbe, const double* basis_eval,
const double* dbasis_x_eavl, const double* dbasis_y_eval,
const double* dbasis_z_eval, const double* lbasis_eval,
const double* Xs, size_t ldxs, const double* Xz, size_t ldxz,
const double* mmat_xs, const double* mmat_ys, const double* mmat_zs, size_t ldms,
const double* mmat_xz, const double* mmat_yz, const double* mmat_zz, size_t ldmz,
double* den_eval, double* dden_x_eval, double* dden_y_eval, double* dden_z_eval,
double* gamma, double* tau, double* lapl);

/** Evaluate the VXC Z Matrix for RKS LDA
*
* Z(mu,i) = 0.5 * vrho(i) * B(mu, i)
Expand Down Expand Up @@ -294,9 +377,57 @@ class LocalHostWorkDriver : public LocalWorkDriver {
double* Zs, size_t ldzs, double* Zz, size_t ldzz );


/** Evaluate the VXC Z Matrix for RKS MGGA
*
* Z(mu,i) = 0.5 * vrho(i) * B(mu, i) +
* 2.0 * vgamma(i) * (grad B(mu,i)) . (grad rho(i)) +
* 0.5 * vlapl(i) * lapl B(mu, i)
*
* TODO: Need to add an API for UKS/GKS
*
* @param[in] npts Same as `eval_zmat_lda_vxc`
* @param[in] nbe Same as `eval_zmat_lda_vxc`
* @param[in] vrho Same as `eval_zmat_lda_vxc`
* @param[in] vgamma Derivative of the XC functional wrt gamma scaled by quad weights (npts)
* @param[in] basis_eval Same as `eval_zmat_lda_vxc`
* @param[in] dbasis_x_eval Derivative of `basis_eval` wrt x (same dims)
* @param[in] dbasis_y_eval Derivative of `basis_eval` wrt y (same dims)
* @param[in] dbasis_z_eval Derivative of `basis_eval` wrt z (same dims)
* @param[in] lbasis_eval Laplacian of `basis_eval` (same dims)
* @param[in] dden_x_eval Derivative of rho wrt x (npts)
* @param[in] dden_y_eval Derivative of rho wrt y (npts)
* @param[in] dden_z_eval Derivative of rho wrt z (npts)
* @param[out] Z Same as `eval_zmat_lda_vxc`
* @param[in] ldz Same as `eval_zmat_lda_vxc`
*
*/
void eval_zmat_mgga_vxc_rks( size_t npts, size_t nbe, const double* vrho,
const double* vgamma, const double* vlapl, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval,
const double* dden_x_eval, const double* dden_y_eval, const double* dden_z_eval,
double* Z, size_t ldz );
void eval_zmat_mgga_vxc_uks( size_t npts, size_t nbe, const double* vrho,
const double* vgamma, const double* vlapl, const double* basis_eval,
const double* dbasis_x_eval, const double* dbasis_y_eval, const double* dbasis_z_eval,
const double* lbasis_eval,
const double* dden_x_eval, const double* dden_y_eval, const double* dden_z_eval,
double* Zs, size_t ldzs, double* Zz, size_t ldzz );
void eval_mmat_mgga_vxc_rks( size_t npts, size_t nbe, const double* vtau,
const double* vlapl, const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_x, double* mmat_y, double* mmat_z,
size_t ldm);
void eval_mmat_mgga_vxc_uks( size_t npts, size_t nbe, const double* vtau,
const double* vlapl, const double* dbasis_x_eval, const double* dbasis_y_eval,
const double* dbasis_z_eval, double* mmat_xs, double* mmat_ys, double* mmat_zs,
size_t ldms, double* mmat_xz, double* mmat_yz, double* mmat_zz, size_t ldmz);



/** Increment VXC integrand given Z / Collocation (RKS LDA+GGA)
*
* VXC += Z**H * B + h.c.
* VXC += M**H . dB + h.c.
*
* Only updates lower triangle
*
Expand All @@ -313,8 +444,8 @@ class LocalHostWorkDriver : public LocalWorkDriver {
*
*/
void inc_vxc( size_t npts, size_t nbf, size_t nbe, const double* basis_eval,
const submat_map_t& submat_map, const double* Z, size_t ldz, double* VXC,
size_t ldvxc, double* scr );
const submat_map_t& submat_map, const double* Z, size_t ldz,
double* VXC, size_t ldvxc, double* scr );

private:

Expand Down
Loading