Skip to content

Commit 949487e

Browse files
authored
Merge branch 'main' into pre-commit-ci-update-config
2 parents 24080cb + 49feb18 commit 949487e

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

apax/md/function_transformations.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
import jax.numpy as jnp
55

66

7-
@dataclasses.dataclass
8-
class FunctionTransformation:
9-
def apply(self, model):
10-
raise NotImplementedError
11-
12-
137
def make_biased_energy_force_fn(bias_fn):
148
def biased_energy_force_fn(positions, Z, idx, box, offsets):
159
bias_and_grad_fn = jax.value_and_grad(bias_fn, has_aux=True)
@@ -29,7 +23,8 @@ def biased_energy_force_fn(positions, Z, idx, box, offsets):
2923
return biased_energy_force_fn
3024

3125

32-
class UncertaintyDrivenDynamics(FunctionTransformation):
26+
@dataclasses.dataclass
27+
class UncertaintyDrivenDynamics:
3328
"""
3429
UDD requires an uncertainty aware model.
3530
It drives the dynamics towards higher uncertainty regions
@@ -67,7 +62,8 @@ def udd_energy(positions, Z, idx, box, offsets):
6762
return udd_energy_force
6863

6964

70-
class GaussianAcceleratedMolecularDynamics(FunctionTransformation):
65+
@dataclasses.dataclass
66+
class GaussianAcceleratedMolecularDynamics:
7167
"""
7268
Applies a boost potential to the system that pulls it towards a target energy.
7369
https://pubs.acs.org/doi/10.1021/acs.jctc.5b00436

0 commit comments

Comments
 (0)