Skip to content

Commit

Permalink
Merge pull request deepchem#3596 from shaipranesh2/energy_fucntion
Browse files Browse the repository at this point in the history
adding energy function for ferminet
  • Loading branch information
rbharath authored Oct 6, 2023
2 parents 1e19e7a + 1195ddb commit 039c6a7
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 20 deletions.
15 changes: 15 additions & 0 deletions deepchem/models/tests/test_ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def test_FerminetMode_pretrain():
mol = FerminetModel(H2_molecule, spin=0, ion_charge=0)
mol.train(nb_epoch=3)
assert mol.loss_value <= torch.tensor(1.0)


@pytest.mark.dqc
def test_FerminetMode_energy():
# Test for the init function of FerminetModel class
H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
# Testing ionic initialization
mol = FerminetModel(H2_molecule, spin=0, ion_charge=0)
mol.train(nb_epoch=50)
mol.model.forward(mol.molecule.x)
energy = mol.model.calculate_electron_electron(
) - mol.model.calculate_electron_nuclear(
) + mol.model.nuclear_nuclear_potential + mol.model.calculate_kinetic_energy(
)
assert energy <= torch.tensor(1.0)
95 changes: 90 additions & 5 deletions deepchem/models/torch_models/ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self,
Modulelist containing the ferminet electron feature layer
ferminet_layer_envelope: torch.nn.ModuleList
Modulelist containing the ferminet envelope electron feature layer
nuclear_nuclear_potential: torch.Tensor
Torch tensor containing the inter-nuclear potential in the molecular system
"""
super(Ferminet, self).__init__()
if len(n_one) != len(n_two):
Expand All @@ -92,6 +94,8 @@ def __init__(self,
self.ferminet_layer_envelope: torch.nn.ModuleList = torch.nn.ModuleList(
)
self.running_diff: torch.Tensor = torch.zeros(self.batch_size)
self.nuclear_nuclear_potential: torch.Tensor = self.calculate_nuclear_nuclear(
)

self.ferminet_layer.append(
FerminetElectronFeature(self.n_one, self.n_two,
Expand All @@ -118,12 +122,13 @@ def forward(self, input: np.ndarray) -> torch.Tensor:
contains the wavefunction - 'psi' value. It is in the shape (batch_size), where each row corresponds to the solution of one of the batches
"""
# creating one and two electron features
eps = torch.tensor(1e-36)
self.input = torch.from_numpy(input)
self.input.requires_grad = True
self.input = self.input.reshape((self.batch_size, -1, 3))
two_electron_vector = self.input.unsqueeze(1) - self.input.unsqueeze(2)
two_electron_distance = torch.norm(two_electron_vector,
dim=3).unsqueeze(3)
two_electron_distance = torch.linalg.norm(two_electron_vector + eps,
dim=3).unsqueeze(3)
two_electron = torch.cat((two_electron_vector, two_electron_distance),
dim=3)
two_electron = torch.reshape(
Expand All @@ -132,7 +137,7 @@ def forward(self, input: np.ndarray) -> torch.Tensor:

one_electron_vector = self.input.unsqueeze(
1) - self.nucleon_pos.unsqueeze(1)
one_electron_distance = torch.norm(one_electron_vector, dim=3)
one_electron_distance = torch.linalg.norm(one_electron_vector, dim=3)
one_electron = torch.cat(
(one_electron_vector, one_electron_distance.unsqueeze(-1)), dim=3)
one_electron = torch.reshape(one_electron.permute(0, 2, 1, 3),
Expand All @@ -141,9 +146,9 @@ def forward(self, input: np.ndarray) -> torch.Tensor:

one_electron, _ = self.ferminet_layer[0].forward(
one_electron.to(torch.float32), two_electron.to(torch.float32))
psi, self.psi_up, self.psi_down = self.ferminet_layer_envelope[
self.psi, self.psi_up, self.psi_down = self.ferminet_layer_envelope[
0].forward(one_electron, one_electron_vector_permuted)
return psi
return self.psi

def loss(self,
psi_up_mo: List[Optional[np.ndarray]] = [None],
Expand All @@ -169,6 +174,86 @@ def loss(self,
self.psi_up, psi_up_mo_torch.float()) + criterion(
self.psi_down, psi_down_mo_torch.float())

def calculate_nuclear_nuclear(self,) -> torch.Tensor:
"""
Function to calculate where only the nucleus terms are involved and does not change when new electrons are sampled.
atom-atom potential term = Zi*Zj/|Ri-Rj|, where Zi, Zj are the nuclear charges and Ri, Rj are nuclear coordinates
Returns:
--------
A torch tensor of a scalar value containing the nuclear-nuclear potential term (does not change for the molecule system with sampling of electrons)
"""

potential = torch.nan_to_num(
(self.nuclear_charge * 1 /
torch.cdist(self.nucleon_pos.float(), self.nucleon_pos.float()) *
self.nuclear_charge.unsqueeze(1)),
posinf=0.0,
neginf=0.0)
potential = torch.sum(potential) / 2
return potential

def calculate_electron_nuclear(self,) -> torch.Tensor:
"""
Function to calculate the expected electron-nuclear potential term per batch
nuclear-electron potential term = Zi/|Ri-rj|, rj is the electron coordinates, Ri is the nuclear coordinates, Zi is the nuclear charge
Returns:
--------
A torch tensor of a scalar value containing the electron-nuclear potential term.
"""

potential = torch.sum(
(1 / torch.cdist(self.input.float(), self.nucleon_pos.float())) *
self.nuclear_charge) / 2
return (potential / self.batch_size)

def calculate_electron_electron(self,):
"""
Function to calculate the expected electron-nuclear potential term per batch
nuclear-electron potential term = 1/|ri-rj|, ri, rj is the electron coordinates
Returns:
--------
A torch tensor of a scalar value containing the electron-electron potential term.
"""
potential = torch.sum(
torch.nan_to_num(
(1 / torch.cdist(self.input.float(), self.input.float())),
posinf=0.0,
neginf=0.0)) / 2
return (potential / self.batch_size)

def calculate_kinetic_energy(self,):
"""
Function to calculate the expected kinetic energy term per batch
It is calculated via:
\sum_{ri}^{}[(\pdv[]{log|\Psi|}{(ri)})^2 + \pdv[2]{log|\Psi|}{(ri)}]
Returns:
--------
A torch tensor of a scalar value containing the electron-electron potential term.
"""
log_probability = torch.log(torch.abs(self.psi))
jacobian = list(
map(
lambda x: torch.autograd.grad(x, self.input, create_graph=True)[
0], log_probability))
jacobian_square = list(
map(lambda x: torch.sum(torch.pow(x, 2)), jacobian))
jacobian_square_sum = torch.tensor(0.0)
hessian = torch.tensor(0.0)
for i in range(self.batch_size):
jacobian_square_sum = jacobian_square_sum + jacobian_square[i]
for j in range(self.total_electron):
for k in range(3):
hessian = hessian + torch.autograd.grad(
jacobian[i][i][j][k], self.input,
create_graph=True)[0][i][j][k]
kinetic_energy = -1 * 0.5 * (jacobian_square_sum +
hessian) / (self.batch_size)
return kinetic_energy


class FerminetModel(TorchModel):
"""A deep-learning based Variational Monte Carlo method [1]_ for calculating the ab-initio
Expand Down
42 changes: 27 additions & 15 deletions deepchem/models/torch_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5213,15 +5213,15 @@ def __init__(self, n_one: List[int], n_two: List[int], no_of_atoms: int,
# Initializing the first layer (first layer has different dims than others)
self.v.append(
nn.Linear(8 + 3 * 4 * self.no_of_atoms, self.n_one[0], bias=True))
#filling the weights with 2.5e-7 for faster convergence
self.v[0].weight.data.fill_(2.5e-7)
self.v[0].bias.data.fill_(2.5e-7)
#filling the weights with 1e-3 for faster convergence
self.v[0].weight.data.fill_(1e-3)
self.v[0].bias.data.fill_(1e-3)
self.v[0].weight.data = self.v[0].weight.data
self.v[0].bias.data = self.v[0].bias.data

self.w.append(nn.Linear(4, self.n_two[0], bias=True))
self.w[0].weight.data.fill_(2.5e-7)
self.w[0].bias.data.fill_(2.5e-7)
self.w[0].weight.data.fill_(1e-3)
self.w[0].bias.data.fill_(1e-3)
self.w[0].weight.data = self.w[0].weight.data
self.w[0].bias.data = self.w[0].bias.data

Expand All @@ -5230,18 +5230,27 @@ def __init__(self, n_one: List[int], n_two: List[int], no_of_atoms: int,
nn.Linear(3 * self.n_one[i - 1] + 2 * self.n_two[i - 1],
n_one[i],
bias=True))
self.v[i].weight.data.fill_(2.5e-7)
self.v[i].bias.data.fill_(2.5e-7)
self.v[i].weight.data.fill_(1e-3)
self.v[i].bias.data.fill_(1e-3)
self.v[i].weight.data = self.v[i].weight.data
self.v[i].bias.data = self.v[i].bias.data

self.w.append(nn.Linear(self.n_two[i - 1], self.n_two[i],
bias=True))
self.w[i].weight.data.fill_(2.5e-7)
self.w[i].weight.data.fill_(1e-3)
self.w[i].weight.data = self.w[i].weight.data
self.w[i].bias.data.fill_(2.5e-7)
self.w[i].bias.data.fill_(1e-3)
self.w[i].bias.data = self.w[i].bias.data

self.projection_module = nn.ModuleList()
self.projection_module.append(
nn.Linear(
4 * self.no_of_atoms,
n_one[0],
bias=False,
))
self.projection_module.append(nn.Linear(4, n_two[0], bias=False))

def forward(self, one_electron: torch.Tensor, two_electron: torch.Tensor):
"""
Parameters
Expand Down Expand Up @@ -5282,9 +5291,12 @@ def forward(self, one_electron: torch.Tensor, two_electron: torch.Tensor):
dim=1)
if l == 0 or (self.n_one[l] != self.n_one[l - 1]) or (
self.n_two[l] != self.n_two[l - 1]):
one_electron_tmp[:, i, :] = torch.tanh(self.v[l](f))
one_electron_tmp[:, i, :] = torch.tanh(
self.v[l](f)) + self.projection_module[0](
one_electron[:, i, :])
two_electron_tmp[:, i, :, :] = torch.tanh(self.w[l](
two_electron[:, i, :, :]))
two_electron[:, i, :, :])) + self.projection_module[1](
two_electron[:, i, :, :])
else:
one_electron_tmp[:, i, :] = torch.tanh(
self.v[l](f)) + one_electron[:, i, :]
Expand Down Expand Up @@ -5370,16 +5382,16 @@ def __init__(self, n_one: List[int], n_two: List[int], total_electron: int,
for j in range(self.total_electron):
self.envelope_w.append(
torch.nn.init.uniform(torch.empty(n_one[-1], 1),
b=2.5e-7).squeeze(-1))
b=1e-3).squeeze(-1))
self.envelope_g.append(
torch.nn.init.uniform(torch.empty(1), b=2.5e-7).squeeze(0))
torch.nn.init.uniform(torch.empty(1), b=1e-3).squeeze(0))
for k in range(self.no_of_atoms):
self.sigma.append(
torch.nn.init.uniform(torch.empty(self.no_of_atoms, 1),
b=2.5e-7).squeeze(0))
b=1e-3).squeeze(0))
self.pi.append(
torch.nn.init.uniform(torch.empty(self.no_of_atoms, 1),
b=2.5e-7).squeeze(0))
b=1e-3).squeeze(0))

def forward(self, one_electron: torch.Tensor,
one_electron_vector_permuted: torch.Tensor):
Expand Down

0 comments on commit 039c6a7

Please sign in to comment.