Skip to content

Commit

Permalink
add AD support for ROHF (#11)
Browse files Browse the repository at this point in the history
* add ROHF
* update codecov
* add environment.yml
* update README
  • Loading branch information
fishjojo authored Mar 11, 2023
1 parent 5128e15 commit a5d4e7e
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: test
run: ./.github/workflows/run_test.sh
- name: Upload coverage to codecov
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v3
with:
token: ${{secrets.CODECOV_TOKEN}}
files: ./pyscfad/coverage.xml
6 changes: 3 additions & 3 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ valid-metaclass-classmethod-first-arg=mcs

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException
overgeneral-exceptions=builtins.StandardError,
builtins.Exception,
builtins.BaseException

18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ pip install 'pyscf-properties @ git+https://github.com/fishjojo/properties.git@a
pip install pyscfad
```

* To install the development version of pyscfad, use the following command instead:
---
* To install the development version, use the following command instead:
```
pip install git+https://github.com/fishjojo/pyscfad.git
```
---

* To Install from source and install dependencies manually
* The dependencies can be installed via a predefined conda environment
```
conda env create -f environment.yml
conda activate pyscfad_env
```

* Alternatively, the dependencies can be installed from source
```
pip install numpy scipy h5py
pip install jax jaxlib jaxopt
Expand All @@ -51,7 +57,7 @@ export PYTHONPATH=$HOME/pyscf:$PYTHONPATH
```

---
* Running pyscfad inside a docker container:
* One can also run PySCFAD inside a docker container:
```
docker pull fishjojo/pyscfad:latest
docker run -rm -t -i fishjojo/pyscfad:latest /bin/bash
Expand All @@ -60,7 +66,9 @@ docker run -rm -t -i fishjojo/pyscfad:latest /bin/bash
Running examples
----------------

* Add the following lines to the PySCF configure file ($HOME/.pyscf\_conf.py)
* In order to perform AD calculations,
the following lines need to be added to
the PySCF configure file($HOME/.pyscf\_conf.py)
```
pyscfad = True
pyscf_numpy_backend = 'jax'
Expand Down
11 changes: 11 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
comment: false
coverage:
status:
project:
default:
threshold: 1%
target: 80%
patch:
default:
threshold: 1%
informational: true
14 changes: 14 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pyscfad_env
channels:
- defaults
dependencies:
- python=3.9
- pip=23.0
- pip:
- numpy>=1.17
- scipy
- jax>=0.1.65,<0.3.14
- jaxlib>=0.1.65,<0.3.14
- jaxopt>=0.2
- -e git+https://github.com/fishjojo/pyscf.git@ad#egg=pyscf
- -e git+https://github.com/fishjojo/properties.git@ad#egg=pyscf-properties
34 changes: 34 additions & 0 deletions examples/scf/03-rohf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pyscf
from pyscfad import gto, scf

"""
Analytic nuclear gradient for ROHF computed by auto-differentiation
Reference results from PySCF:
converged SCF energy = -75.578154312784
--------------- ROHF gradients ---------------
x y z
0 O 0.0000000000 -0.0000000000 -0.0023904882
1 H 0.0000000000 -0.0432752607 0.0011952441
2 H -0.0000000000 0.0432752607 0.0011952441
----------------------------------------------
"""

mol = gto.Mole()
mol.atom = '''
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161'''
mol.basis = '631g'
mol.charge = 1
mol.spin = 1 # = 2S = spin_up - spin_down
mol.verbose = 5
mol.build()

mf = scf.ROHF(mol)
mf.kernel()

jac = mf.energy_grad()
print(f'Nuclaer gradient:\n{jac.coords}')
print(f'Gradient wrt basis exponents:\n{jac.exp}')
print(f'Gradient wrt basis contraction coefficients:\n{jac.ctr_coeff}')
10 changes: 5 additions & 5 deletions pyscfad/gto/mole.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def __init__(self, **kwargs):
self.exp = None
self.ctr_coeff = None
self.r0 = None
gto.Mole.__init__(self, **kwargs)
super().__init__(**kwargs)

def atom_coords(self, unit='Bohr'):
if self.coords is None:
return gto.Mole.atom_coords(self, unit)
return super().atom_coords(unit)
else:
if unit[:3].upper() == 'ANG':
return self.coords * param.BOHR
Expand All @@ -81,7 +81,7 @@ def build(self, *args, **kwargs):
trace_ctr_coeff = kwargs.pop('trace_ctr_coeff', True)
trace_r0 = kwargs.pop('trace_r0', False)

gto.Mole.build(self, *args, **kwargs)
super().build(*args, **kwargs)

if trace_coords:
self.coords = np.asarray(self.atom_coords())
Expand All @@ -99,8 +99,8 @@ def intor(self, intor, comp=None, hermi=0, aosym='s1', out=None,
shls_slice=None, grids=None):
if (self.coords is None and self.exp is None
and self.ctr_coeff is None and self.r0 is None):
return gto.Mole.intor(self, intor, comp=comp, hermi=hermi,
aosym=aosym, out=out, shls_slice=shls_slice)
return super().intor(intor, comp=comp, hermi=hermi,
aosym=aosym, out=out, shls_slice=shls_slice)
else:
return moleintor.getints(self, intor, shls_slice,
comp, hermi, aosym, out=None)
4 changes: 4 additions & 0 deletions pyscfad/scf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pyscfad.scf import hf
from pyscfad.scf import uhf
from pyscfad.scf import rohf

def RHF(mol, **kwargs):
return hf.RHF(mol, **kwargs)

def UHF(mol, **kwargs):
return uhf.UHF(mol, **kwargs)

def ROHF(mol, **kwargs):
return rohf.ROHF(mol, **kwargs)
12 changes: 11 additions & 1 deletion pyscfad/scf/hf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import partial
from functools import partial, reduce
import numpy
from jax import jacrev, jacfwd
from jaxopt import linear_solve
Expand Down Expand Up @@ -138,6 +138,8 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None,
fock = mf.get_fock(h1e, s1e, vhf, dm)
mo_energy, mo_coeff = mf.eig(fock, s1e)
mo_occ = mf.get_occ(stop_grad(mo_energy), stop_grad(mo_coeff))
# hack for ROHF
mo_energy = getattr(mo_energy, 'mo_energy', mo_energy)
return scf_conv, e_tot, mo_energy, mo_coeff, mo_occ

if isinstance(mf.diis, lib.diis.DIIS):
Expand Down Expand Up @@ -204,6 +206,9 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None,
if dump_chk:
mf.dump_chk(locals())

# hack for ROHF
mo_energy = getattr(mo_energy, 'mo_energy', mo_energy)

log.timer('scf_cycle', *cput0)
del log
# A post-processing hook before return
Expand Down Expand Up @@ -235,6 +240,11 @@ def _dot_eri_dm_nosymm(eri, dm, with_j, with_k):
vk = vk.reshape(dm.shape)
return vj, vk

def level_shift(s, d, f, factor):
dm_vir = s - reduce(np.dot, (s, d, s))
return f + dm_vir * factor


@util.pytree_node(Traced_Attributes, num_args=1)
class SCF(pyscf_hf.SCF):
'''
Expand Down
Loading

0 comments on commit a5d4e7e

Please sign in to comment.