Skip to content

Commit

Permalink
fix factorial2
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Dec 9, 2023
1 parent 7b9e026 commit a378c00
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/example/single_point/h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# define the wave function
wf = SlaterJastrow(mol, kinetic='jacobi',
configs='ground_state', jastrow=jastrow).gto2sto()
configs='ground_state', jastrow=jastrow) #.gto2sto()

# sampler
sampler = Metropolis(nwalkers=1000, nstep=1000, step_size=0.25,
Expand Down
18 changes: 13 additions & 5 deletions qmctorch/wavefunction/orbitals/norm_orbital.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np

from scipy.special import factorial2

def atomic_orbital_norm(basis):
"""Computes the norm of the atomic orbitals
Expand Down Expand Up @@ -141,16 +141,24 @@ def norm_gaussian_cartesian(a, b, c, exp):
torch.tensor: normalization factor
"""

from scipy.special import factorial2 as f2

pref = torch.as_tensor((2 * exp / np.pi) ** (0.75))
am1 = (2 * a - 1).astype("int")
x = (4 * exp) ** (a / 2) / torch.sqrt(torch.as_tensor(f2(am1)))

bm1 = (2 * b - 1).astype("int")
y = (4 * exp) ** (b / 2) / torch.sqrt(torch.as_tensor(f2(bm1)))

cm1 = (2 * c - 1).astype("int")
z = (4 * exp) ** (c / 2) / torch.sqrt(torch.as_tensor(f2(cm1)))
z = (4 * exp) ** (c / 2) / torch.sqrt(torch.as_tensor(f2(cm1)))

return (pref * x * y * z).type(torch.get_default_dtype())

def f2(x):
"""Returns the f2 of x with f2(x<1) = 1 as implemented in scipy 1.10.
"""
# compute the x!!
out = factorial2(x)

# set all the elements lower than 1 to 1
out[out<1] = 1
return out

0 comments on commit a378c00

Please sign in to comment.