Skip to content

Commit

Permalink
default jastrow
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Dec 5, 2023
1 parent fab56a9 commit e3900c9
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions qmctorch/wavefunction/slater_jastrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .. import log

from .wf_base import WaveFunction
from .jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from .jastrows.elec_elec.kernels import PadeJastrowKernel
from .jastrows.combine_jastrow import CombineJastrow
from .orbitals.atomic_orbitals import AtomicOrbitals
from .orbitals.atomic_orbitals_backflow import AtomicOrbitalsBackFlow
Expand All @@ -21,7 +23,7 @@ class SlaterJastrow(WaveFunction):
def __init__(
self,
mol,
jastrow=None,
jastrow='default',
backflow=None,
configs="ground_state",
kinetic="jacobi",
Expand All @@ -42,8 +44,8 @@ def __init__(
Args:
mol (Molecule): a QMCTorch molecule object
jastrow_kernel (JastrowKernelBase, optional) : Class that computes the jastrow kernels
backflow_kernel (BackFlowKernelBase, optional) : kernel function of the backflow transformation
jastrow (JastrowKernelBase, optional) : Class that computes the jastrow kernels
backflow (BackFlowKernelBase, optional) : kernel function of the backflow transformation
configs (str, optional): defines the CI configurations to be used. Defaults to 'ground_state'.
- ground_state : only the ground state determinant in the wave function
- single(n,m) : only single excitation with n electrons and m orbitals
Expand Down Expand Up @@ -181,18 +183,28 @@ def init_fc_layer(self):
def init_jastrow(self, jastrow):
"""Init the jastrow factor calculator"""

# if the jastrow is explicitly None we disable the factor
if jastrow is None:
self.jastrow = jastrow
self.use_jastrow = False

# otherwise we use the jastrow provided by the user
else:
self.use_jastrow = True

if isinstance(jastrow, list):
# create a simple Pade Jastrow factor as default
if jastrow == 'default':
self.jastrow = JastrowFactorElectronElectron(self.mol, PadeJastrowKernel)

elif isinstance(jastrow, list):
self.jastrow = CombineJastrow(jastrow)
else:

elif isinstance(jastrow, nn.Module):
self.jastrow = jastrow

else:
raise TypeError('Jastrow factor not supported.')

self.jastrow_type = self.jastrow.__repr__()
if self.cuda:
self.jastrow = self.jastrow.to(self.device)
Expand Down

0 comments on commit e3900c9

Please sign in to comment.