Skip to content

Commit

Permalink
Merge pull request #371 from apax-hub/new_config
Browse files Browse the repository at this point in the history
Updates the configuration files and introduces new default values.
Fixes minor bugs
  • Loading branch information
Tetracarbonylnickel authored Nov 22, 2024
2 parents a98d873 + a650472 commit 1782291
Show file tree
Hide file tree
Showing 20 changed files with 687 additions and 386 deletions.
4 changes: 2 additions & 2 deletions apax/cli/templates/md_config_minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ ensemble:
tau: 100

duration: <DURATION> # fs
n_inner: 100 # compiled innner steps
n_inner: 500 # compiled innner steps
sampling_rate: 10 # dump interval
buffer_size: 100
buffer_size: 2500
dr_threshold: 0.5 # Neighborlist skin
extra_capacity: 0

Expand Down
53 changes: 30 additions & 23 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
n_epochs: <NUMBER OF EPOCHS>
seed: 1
patience: null
n_jitted_steps: 1
data_parallel: True
weight_average: null

Expand All @@ -24,7 +23,7 @@ data:
n_train: 1000
n_valid: 100

batch_size: 32
batch_size: 4
valid_batch_size: 100

shift_method: "per_element_regression_shift"
Expand All @@ -39,30 +38,37 @@ data:
model:
name: gmnn
basis:
name: gaussian
n_basis: 7
r_max: 6.0
r_min: 0.5
name: bessel
n_basis: 16
r_max: 5.0

ensemble: null
# if you would like to train model ensembles, this can be achieved with
# the following example.
# if you would like to use emirical repulsion corrections
# with the following example.
# empirical_corrections:
# - name: exponential
# r_max: 1.5

# if you would like to train model ensembles, this can be
# achieved with the following example.
# Hint: loss type hase to be changed to a probabalistic loss like nll or crps
# ensemble:
# kind: full
# kind: shallow
# n_members: N

n_radial: 5
n_contr: 8
nn: [512, 512]
nn: [256, 256]

calc_stress: true
calc_stress: false

w_init: normal
w_init: lecun
b_init: zeros
descriptor_dtype: fp64
descriptor_dtype: fp32
readout_dtype: fp32
scale_shift_dtype: fp32
scale_shift_dtype: fp64
emb_init: uniform
use_ntk: false

loss:
- name: energy
Expand All @@ -86,20 +92,21 @@ metrics:
optimizer:
name: adam
kwargs: {}
emb_lr: 0.03
nn_lr: 0.03
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
emb_lr: 0.001
nn_lr: 0.001
scale_lr: 0.0001
shift_lr: 0.003
zbl_lr: 0.0001
schedule:
name: linear
transition_begin: 0
end_value: 1e-6
name: cyclic_cosine
period: 40
decay_factor: 0.93

callbacks:
- name: csv

checkpoints:
ckpt_interval: 1
ckpt_interval: 500
# The options below are used for transfer learning
base_model_checkpoint: null
reset_layers: []
Expand Down
2 changes: 1 addition & 1 deletion apax/cli/templates/train_config_minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ data:

n_train: 1000
n_valid: 100
batch_size: 32
batch_size: 4
valid_batch_size: 100

metrics:
Expand Down
8 changes: 4 additions & 4 deletions apax/config/lr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"):
Parameters
----------
period: int = 20
period: int = 40
Length of a cycle in epochs.
decay_factor: NonNegativeFloat = 1.0
decay_factor: NonNegativeFloat = 0.93
Factor by which to decrease the LR after each cycle.
1.0 means no decrease.
"""

name: Literal["cyclic_cosine"]
period: int = 20
decay_factor: NonNegativeFloat = 1.0
period: int = 40
decay_factor: NonNegativeFloat = 0.93
8 changes: 4 additions & 4 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
| Time step in fs.
duration : float, required
| Total simulation time in fs.
n_inner : int, default = 100
n_inner : int, default = 500
| Number of compiled simulation steps (i.e. number of iterations of the
| `jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate : int, default = 10
| Interval between saving frames.
buffer_size : int, default = 100
buffer_size : int, default = 2500
| Number of collected frames to be dumped at once.
dr_threshold : float, default = 0.5
| Skin of the neighborlist.
Expand Down Expand Up @@ -273,9 +273,9 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
)

duration: PositiveFloat
n_inner: PositiveInt = 100
n_inner: PositiveInt = 500
sampling_rate: PositiveInt = 10
buffer_size: PositiveInt = 100
buffer_size: PositiveInt = 2500
dr_threshold: PositiveFloat = 0.5
extra_capacity: NonNegativeInt = 0

Expand Down
31 changes: 18 additions & 13 deletions apax/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class BesselBasisConfig(BaseModel, extra="forbid"):
Parameters
----------
n_basis : PositiveInt, default = 7
n_basis : PositiveInt, default = 16
Number of uncontracted basis functions.
r_max : PositiveFloat, default = 6.0
r_max : PositiveFloat, default = 5.0
Cutoff radius of the descriptor.
"""

name: Literal["bessel"] = "bessel"
n_basis: PositiveInt = 7
r_max: PositiveFloat = 6.0
n_basis: PositiveInt = 16
r_max: PositiveFloat = 5.0


BasisConfig = Union[GaussianBasisConfig, BesselBasisConfig]
Expand Down Expand Up @@ -84,6 +84,11 @@ class ShallowEnsembleConfig(BaseModel, extra="forbid"):
If set to an integer, the jacobian of ensemble energies wrt. to positions will be computed
in chunks of that size. This sacrifices some performance for the possibility to use relatively
large ensemble sizes.
Hint
----------
Loss type hase to be changed to a probabalistic loss like 'nll' or 'crps'
"""

kind: Literal["shallow"] = "shallow"
Expand All @@ -101,12 +106,12 @@ class Correction(BaseModel, extra="forbid"):

class ZBLRepulsion(Correction, extra="forbid"):
name: Literal["zbl"]
r_max: NonNegativeFloat = 2.0
r_max: NonNegativeFloat = 1.5


class ExponentialRepulsion(Correction, extra="forbid"):
name: Literal["exponential"]
r_max: NonNegativeFloat = 2.0
r_max: NonNegativeFloat = 1.5


EmpiricalCorrection = Union[ZBLRepulsion, ExponentialRepulsion]
Expand All @@ -120,31 +125,31 @@ class BaseModelConfig(BaseModel, extra="forbid"):
----------
basis : BasisConfig, default = GaussianBasisConfig()
Configuration for primitive basis funtions.
nn : List[PositiveInt], default = [512, 512]
nn : List[PositiveInt], default = [256, 256]
Number of hidden layers and units in those layers.
w_init : Literal["normal", "lecun"], default = "normal"
w_init : Literal["normal", "lecun"], default = "lecun"
Initialization scheme for the neural network weights.
b_init : Literal["normal", "zeros"], default = "normal"
b_init : Literal["normal", "zeros"], default = "zeros"
Initialization scheme for the neural network biases.
use_ntk : bool, default = True
use_ntk : bool, default = False
Whether or not to use NTK parametrization.
ensemble : Optional[EnsembleConfig], default = None
What kind of model ensemble to use (optional).
use_zbl : bool, default = False
Whether to include the ZBL correction.
calc_stress : bool, default = False
Whether to calculate stress during model evaluation.
descriptor_dtype : Literal["fp32", "fp64"], default = "fp64"
descriptor_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for descriptor calculations.
readout_dtype : Literal["fp32", "fp64"], default = "fp32"
Data type for readout calculations.
scale_shift_dtype : Literal["fp32", "fp64"], default = "fp32"
scale_shift_dtype : Literal["fp32", "fp64"], default = "fp64"
Data type for scale and shift parameters.
"""

basis: BasisConfig = Field(BesselBasisConfig(name="bessel"), discriminator="name")

nn: List[PositiveInt] = [128, 128]
nn: List[PositiveInt] = [256, 256]
w_init: Literal["normal", "lecun"] = "lecun"
b_init: Literal["normal", "zeros"] = "zeros"
use_ntk: bool = False
Expand Down
31 changes: 19 additions & 12 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ class DataConfig(BaseModel, extra="forbid"):
| dict of property name, shape (ragged or fixed) pairs. Currently unused.
energy_regularisation :
| Magnitude of the regularization in the per-element energy regression.
pos_unit : str, default = "Ang"
unit of length
energy_unit : str, default = "eV"
unit of energy
"""

directory: str
Expand Down Expand Up @@ -210,16 +213,20 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
----------
name : str, default = "adam"
Name of the optimizer. Can be any `optax` optimizer.
emb_lr : NonNegativeFloat, default = 0.02
emb_lr : NonNegativeFloat, default = 0.001
Learning rate of the elemental embedding contraction coefficients.
nn_lr : NonNegativeFloat, default = 0.03
nn_lr : NonNegativeFloat, default = 0.001
Learning rate of the neural network parameters.
scale_lr : NonNegativeFloat, default = 0.001
scale_lr : NonNegativeFloat, default = 0.0001
Learning rate of the elemental output scaling factors.
shift_lr : NonNegativeFloat, default = 0.05
shift_lr : NonNegativeFloat, default = 0.003
Learning rate of the elemental output shifts.
zbl_lr : NonNegativeFloat, default = 0.001
zbl_lr : NonNegativeFloat, default = 0.0001
Learning rate of the ZBL correction parameters.
rep_scale_lr : NonNegativeFloat, default = 0.001
LR for the length scale of thes exponential repulsion potential.
rep_prefactor_lr : NonNegativeFloat, default = 0.0001
LR for the strength of the exponential repulsion potential.
gradient_clipping: NonNegativeFloat, default = 1000.0
Per element Gradient clipping value.
Default is so high that it effectively disabled.
Expand All @@ -230,11 +237,11 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
"""

name: str = "adam"
emb_lr: NonNegativeFloat = 0.02
nn_lr: NonNegativeFloat = 0.03
scale_lr: NonNegativeFloat = 0.001
shift_lr: NonNegativeFloat = 0.05
zbl_lr: NonNegativeFloat = 0.001
emb_lr: NonNegativeFloat = 0.001
nn_lr: NonNegativeFloat = 0.001
scale_lr: NonNegativeFloat = 0.0001
shift_lr: NonNegativeFloat = 0.003
zbl_lr: NonNegativeFloat = 0.0001
rep_scale_lr: NonNegativeFloat = 0.001
rep_prefactor_lr: NonNegativeFloat = 0.0001

Expand Down Expand Up @@ -362,7 +369,7 @@ class CheckpointConfig(BaseModel, extra="forbid"):
reset_layers: List of layer names for which the parameters will be reinitialized.
"""

ckpt_interval: PositiveInt = 1
ckpt_interval: PositiveInt = 500
base_model_checkpoint: Optional[str] = None
reset_layers: List[str] = []

Expand Down
4 changes: 1 addition & 3 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from apax.config import Config, MDConfig, parse_config
from apax.config.md_config import Integrator
from apax.md.ase_calc import make_ensemble, maybe_vmap
from apax.md.constraints import Constraint, ConstraintBase, FixAtoms
from apax.md.constraints import Constraint, ConstraintBase
from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks
from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint
from apax.md.md_checkpoint import load_md_state
Expand Down Expand Up @@ -255,8 +255,6 @@ def run_sim(
dynamics_checks,
)

constraints = [FixAtoms(indices=[6, 8])]

apply_constraints = create_constraint_function(
constraints,
state,
Expand Down
3 changes: 3 additions & 0 deletions apax/nodes/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from apax.optimizer.get_optimizer import get_opt

__all__ = ["get_opt"]
Loading

0 comments on commit 1782291

Please sign in to comment.