Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .flake8

This file was deleted.

4 changes: 2 additions & 2 deletions erbs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import warnings

from erbs.utils import setup_ase
import jax

from erbs.utils import setup_ase

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
jax.config.update("jax_enable_x64", True)

setup_ase()

warnings.filterwarnings(action="ignore", category=FutureWarning, module=r"jax.*scatter")

2 changes: 1 addition & 1 deletion erbs/ase_main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
This submodule contains the PLUMED ASE calculator from the main branch of the ASE repository.
PyPi projects cannot have GitLab/GitHub dependencies, hence the code for this calculator was copied here.
On the next ASE release, this submodule will be removed.
"""
"""
86 changes: 48 additions & 38 deletions erbs/ase_main/calculators/plumed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from os.path import exists

import numpy as np
from ase.calculators.calculator import Calculator, all_changes
from ase.io.trajectory import Trajectory
from ase.parallel import broadcast
from ase.parallel import world
import numpy as np
from os.path import exists
from ase.units import fs, mol, kJ, nm
from ase.parallel import broadcast, world
from ase.units import fs, kJ, mol, nm


def restart_from_trajectory(prev_traj, *args, prev_steps=None, atoms=None,
**kwargs):
def restart_from_trajectory(prev_traj, *args, prev_steps=None, atoms=None, **kwargs):
"""This function helps the user to restart a plumed simulation
from a trajectory file.

Expand Down Expand Up @@ -44,10 +43,20 @@ def restart_from_trajectory(prev_traj, *args, prev_steps=None, atoms=None,


class Plumed(Calculator):
implemented_properties = ['energy', 'forces']

def __init__(self, calc, input, timestep, atoms=None, kT=1., log='',
restart=False, use_charge=False, update_charge=False):
implemented_properties = ["energy", "forces"]

def __init__(
self,
calc,
input,
timestep,
atoms=None,
kT=1.0,
log="",
restart=False,
use_charge=False,
update_charge=False,
):
"""
Plumed calculator is used for simulations of enhanced sampling methods
with the open-source code PLUMED (plumed.org).
Expand Down Expand Up @@ -109,8 +118,10 @@ def __init__(self, calc, input, timestep, atoms=None, kT=1., log='',
from plumed import Plumed as pl

if atoms is None:
raise TypeError('plumed calculator has to be defined with the \
object atoms inside.')
raise TypeError(
"plumed calculator has to be defined with the \
object atoms inside."
)

self.istep = 0
Calculator.__init__(self, atoms=atoms)
Expand All @@ -124,7 +135,7 @@ def __init__(self, calc, input, timestep, atoms=None, kT=1., log='',
natoms = len(atoms.get_positions())
self.plumed = pl()

''' Units setup
""" Units setup
warning: inputs and outputs of plumed will still be in
plumed units.

Expand All @@ -133,14 +144,14 @@ def __init__(self, calc, input, timestep, atoms=None, kT=1., log='',
nm to Angstrom
ps to ASE time units
ASE and plumed - charge unit is in e units
ASE and plumed - mass unit is in a.m.u units '''
ASE and plumed - mass unit is in a.m.u units """

ps = 1000 * fs
self.plumed.cmd("setMDEnergyUnits", mol / kJ)
self.plumed.cmd("setMDLengthUnits", 1 / nm)
self.plumed.cmd("setMDTimeUnits", 1 / ps)
self.plumed.cmd("setMDChargeUnits", 1.)
self.plumed.cmd("setMDMassUnits", 1.)
self.plumed.cmd("setMDChargeUnits", 1.0)
self.plumed.cmd("setMDMassUnits", 1.0)

self.plumed.cmd("setNatoms", natoms)
self.plumed.cmd("setMDEngine", "ASE")
Expand All @@ -154,17 +165,17 @@ def __init__(self, calc, input, timestep, atoms=None, kT=1., log='',
self.atoms = atoms

def _get_name(self):
return f'{self.calc.name}+Plumed'
return f"{self.calc.name}+Plumed"

def calculate(self, atoms=None, properties=['energy', 'forces'],
system_changes=all_changes):
def calculate(
self, atoms=None, properties=["energy", "forces"], system_changes=all_changes
):
Calculator.calculate(self, atoms, properties, system_changes)

comp = self.compute_energy_and_forces(self.atoms.get_positions(),
self.istep)
comp = self.compute_energy_and_forces(self.atoms.get_positions(), self.istep)
energy, forces = comp
self.istep += 1
self.results['energy'], self. results['forces'] = energy, forces
self.results["energy"], self.results["forces"] = energy, forces

def compute_energy_and_forces(self, pos, istep):
unbiased_energy = self.calc.get_potential_energy(self.atoms)
Expand All @@ -183,11 +194,10 @@ def compute_bias(self, pos, istep, unbiased_energy):
self.plumed.cmd("setStep", istep)

if self.use_charge:
if 'charges' in self.calc.implemented_properties and \
self.update_charge:
if "charges" in self.calc.implemented_properties and self.update_charge:
charges = self.calc.get_charges(atoms=self.atoms.copy())

elif self.atoms.has('initial_charges') and not self.update_charge:
elif self.atoms.has("initial_charges") and not self.update_charge:
charges = self.atoms.get_initial_charges()

else:
Expand Down Expand Up @@ -215,11 +225,11 @@ def compute_bias(self, pos, istep, unbiased_energy):
return [energy_bias, forces_bias]

def write_plumed_files(self, images):
""" This function computes what is required in
"""This function computes what is required in
plumed input for some trajectory.

The outputs are saved in the typical files of
plumed such as COLVAR, HILLS """
plumed such as COLVAR, HILLS"""
for i, image in enumerate(images):
pos = image.get_positions()
self.compute_energy_and_forces(pos, i)
Expand All @@ -231,25 +241,25 @@ def read_plumed_files(self, file_name=None):
read_files[file_name] = np.loadtxt(file_name, unpack=True)
else:
for line in self.input:
if line.find('FILE') != -1:
ini = line.find('FILE')
end = line.find(' ', ini)
if line.find("FILE") != -1:
ini = line.find("FILE")
end = line.find(" ", ini)
if end == -1:
file_name = line[ini + 5:]
file_name = line[ini + 5 :]
else:
file_name = line[ini + 5:end]
file_name = line[ini + 5 : end]
read_files[file_name] = np.loadtxt(file_name, unpack=True)

if len(read_files) == 0:
if exists('COLVAR'):
read_files['COLVAR'] = np.loadtxt('COLVAR', unpack=True)
if exists('HILLS'):
read_files['HILLS'] = np.loadtxt('HILLS', unpack=True)
if exists("COLVAR"):
read_files["COLVAR"] = np.loadtxt("COLVAR", unpack=True)
if exists("HILLS"):
read_files["HILLS"] = np.loadtxt("HILLS", unpack=True)
assert not len(read_files) == 0, "There are not files for reading"
return read_files

def __enter__(self):
return self

def __exit__(self, *args):
self.plumed.finalize()
self.plumed.finalize()
3 changes: 1 addition & 2 deletions erbs/bias/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from erbs.bias.energy_function_factory import (OPESExploreFactory,
MetaDFactory)
from erbs.bias.energy_function_factory import MetaDFactory, OPESExploreFactory
from erbs.bias.potential import ERBS

__all__ = [
Expand Down
6 changes: 1 addition & 5 deletions erbs/bias/energy_function_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def __init__(self, T=300, dE=1.2, a=0.3, compression_threshold=0.4) -> None:
self.compression_threshold = compression_threshold

def create(self, cv_fn, dim_reduction_fn):

def energy_fn(positions, Z, neighbor, box, offsets, bias_state):

g_ref = bias_state.g
norm = bias_state.normalisation
cov = bias_state.cov
Expand All @@ -31,6 +29,7 @@ def energy_fn(positions, Z, neighbor, box, offsets, bias_state):

total_bias = jnp.sum(kde_ij)
return total_bias

return energy_fn


Expand All @@ -44,11 +43,8 @@ def __init__(self, T=300, dE=1.2, a=0.3, compression_threshold=0.4) -> None:
self.gamma = self.dE * self.beta
self.compression_threshold = compression_threshold


def create(self, cv_fn, dim_reduction_fn):

def energy_fn(positions, Z, neighbor, box, offsets, bias_state):

g_ref = bias_state.g
norm = bias_state.normalisation
cov = bias_state.cov
Expand Down
47 changes: 23 additions & 24 deletions erbs/bias/kernel.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import jax.numpy as jnp
import numpy as np


def gaussian(g_diff, k, a):
x = jnp.sum((g_diff) ** 2)
U = k * jnp.exp(-x / (a * 2.0))
return U


def diag_gaussian(gdiff, k, cov):
x = jnp.dot(gdiff.T, gdiff / cov)
U = k * jnp.exp(-x/2)
U = k * jnp.exp(-x / 2)
return U


def chunked_sum_of_kernels(X, k, cov, chunk_size: int|None=None):
def chunked_sum_of_kernels(X, k, cov, chunk_size: int | None = None):
if chunk_size is None:
gdiff = X[:, None, :] - X[None, :, :]
cov_broadcast = cov[:, None, :]
x = np.einsum('ijd,ijd->ij', gdiff, gdiff / cov_broadcast)
x = np.einsum("ijd,ijd->ij", gdiff, gdiff / cov_broadcast)
G_skk = k * np.exp(-0.5 * x)

return np.sum(G_skk)
Expand All @@ -40,7 +42,7 @@ def chunked_sum_of_kernels(X, k, cov, chunk_size: int|None=None):
# TODO CHECK CHUNKING FOR COV for chunksize = 1

gdiff = Xi[:, None, :] - Xj[None, :, :]

x = np.einsum("bcj, bcj -> bc", gdiff, gdiff / cov_chunk)
G_skk_chunk = k_chunk * np.exp(-x / 2.0)
G_skk += np.sum(G_skk_chunk)
Expand All @@ -49,28 +51,25 @@ def chunked_sum_of_kernels(X, k, cov, chunk_size: int|None=None):


def global_mc_normalisation(g_ref, height, cov):

G_skk = chunked_sum_of_kernels(g_ref, height, cov)
mc_norm = G_skk / g_ref.shape[0]

return mc_norm



def mc_normalisation(cluster_models, cluster_idxs, g_ref, height, var):

total_n_clusters = 0
elements = sorted(list(cluster_models.keys()))
mc_norm = np.zeros(np.max(cluster_idxs) + 1) # Z in the paper
elements = sorted(cluster_models.keys())
mc_norm = np.zeros(np.max(cluster_idxs) + 1) # Z in the paper

for element in elements:
current_n_clusters = 0
for cluster in range(cluster_models[element].n_clusters):
current_n_clusters += 1
cluster_with_offset = cluster + total_n_clusters
g_filtered = g_ref[cluster_idxs==cluster_with_offset]
height_filtered = height[cluster_idxs==cluster_with_offset]
var_filtered = var[cluster_idxs==cluster_with_offset]
g_filtered = g_ref[cluster_idxs == cluster_with_offset]
height_filtered = height[cluster_idxs == cluster_with_offset]
var_filtered = var[cluster_idxs == cluster_with_offset]

G_skk = chunked_sum_of_kernels(g_filtered, height_filtered, var_filtered)

Expand All @@ -80,27 +79,27 @@ def mc_normalisation(cluster_models, cluster_idxs, g_ref, height, var):
return mc_norm



def distances(P1, p2):
dv = P1 - p2[None, :]
d = np.linalg.norm(dv, axis=1)
return d


def mahalanobis(P1, Var1, p2):
arg = (P1 - p2[None, :])**2 / Var1
arg = (P1 - p2[None, :]) ** 2 / Var1
d = np.sqrt(np.sum(arg, axis=1))
return d


def combine_kernels(h1, h2, p1, p2, var1, var2):
h = h1 + h2
p = (h1 * p1 + h2 * p2) / h
var = (h1 * (var1 + p1**2) + h2 * (var2 + p2**2))/h - p**2
var = (h1 * (var1 + p1**2) + h2 * (var2 + p2**2)) / h - p**2

return p, h, var


def compress(g, cov, h, thresh=0.8):

gc = np.full(g.shape, 1000.0)
covc = np.full(g.shape, 0.0)
hc = np.full((g.shape[0], 1), 0.0)
Expand All @@ -118,7 +117,7 @@ def compress(g, cov, h, thresh=0.8):
h2 = h[ii]
cov2 = cov[ii]

dists = mahalanobis(P1,Cov1, p2)
dists = mahalanobis(P1, Cov1, p2)
dmin = np.min(dists)
idx = np.argmin(dists)

Expand All @@ -137,15 +136,15 @@ def compress(g, cov, h, thresh=0.8):
gc[idx] = pnew
hc[idx] = hnew
covc[idx] = covnew

gc = gc[:n_compressed]
covc = covc[:n_compressed]
hc = hc[:n_compressed]
return gc, covc, hc


def incremental_compress(gc, covc, hc, gnew, covnew, hnew, thresh=0.8):
dists = mahalanobis(gc,covc, gnew)
dists = mahalanobis(gc, covc, gnew)

dmin = np.min(dists)
idx = np.argmin(dists)
Expand All @@ -160,10 +159,10 @@ def incremental_compress(gc, covc, hc, gnew, covnew, hnew, thresh=0.8):
gc[idx] = pnew
hc[idx] = hnew
covc[idx] = covnew

else:
gc = np.append(gc, gnew[None,:], axis=0)
hc = np.append(hc, hnew[None,:], axis=0)
covc = np.append(covc, covnew[None,:], axis=0)
gc = np.append(gc, gnew[None, :], axis=0)
hc = np.append(hc, hnew[None, :], axis=0)
covc = np.append(covc, covnew[None, :], axis=0)

return gc, covc, hc
Loading