Skip to content

Commit

Permalink
fixed eigh solver for uhf (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 authored Jan 30, 2024
1 parent cbaa083 commit ab2bbda
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions gpu4pyscf/scf/uhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def spin_square(mo, s=1):
r'''Spin square and multiplicity of UHF determinant
Detailed derivataion please refers to the cpu pyscf.
'''
mo_a, mo_b = mo
nocc_a = mo_a.shape[1]
Expand Down Expand Up @@ -111,7 +111,7 @@ class UHF(uhf.UHF):

DIIS = diis.SCF_DIIS
get_jk = _get_jk
_eigh = hf.RHF._eigh
_eigh = staticmethod(hf.eigh)
scf = kernel = RHF.kernel
get_fock = get_fock
get_hcore = hf.RHF.get_hcore
Expand Down Expand Up @@ -145,12 +145,12 @@ def make_rdm1(self, mo_coeff=None, mo_occ=None, **kwargs):
if mo_occ is None:
mo_occ = self.mo_occ
return make_rdm1(mo_coeff, mo_occ, **kwargs)

def eig(self, fock, s):
e_a, c_a = self._eigh(fock[0], s)
e_b, c_b = self._eigh(fock[1], s)
return cupy.array((e_a,e_b)), cupy.array((c_a,c_b))

def get_veff(self, mol=None, dm=None, dm_last=None, vhf_last=0, hermi=1):
if mol is None: mol = self.mol
if dm is None: dm = self.make_rdm1()
Expand All @@ -166,7 +166,7 @@ def get_veff(self, mol=None, dm=None, dm_last=None, vhf_last=0, hermi=1):
vhf = vj[0] + vj[1] - vk
vhf += vhf_last
return vhf

def spin_square(self, mo_coeff=None, s=None):
if mo_coeff is None:
mo_coeff = (self.mo_coeff[0][:,self.mo_occ[0]>0],
Expand Down

0 comments on commit ab2bbda

Please sign in to comment.