From aa4837f92efdc3c25419390d4341ff0d81f785f5 Mon Sep 17 00:00:00 2001 From: Zhiyi Li <86362692+zhiyil1230@users.noreply.github.com> Date: Fri, 13 Sep 2024 01:44:26 +0100 Subject: [PATCH] Finetune script (#13) * save initial files * add one more file * save current progress * tidy up, stress augmentation to fix * remove some redundant code * run tested fine wandb ok * remove ema * modify lr scheduler * some review comments * update version of black and rerun * remove e3nn dependency * Update finetune.py Co-authored-by: Ben Rhodes * Update finetune.py Co-authored-by: Ben Rhodes * Update finetune.py Co-authored-by: Ben Rhodes * Update finetune.py Co-authored-by: Ben Rhodes * save * second reviewcomments * update readme * revert epoch saving * update readme and tidy up wandb reporting * Update README.md Co-authored-by: Ben Rhodes * update readme --------- Co-authored-by: Ben Rhodes --- README.md | 21 +- finetune.py | 344 ++++++++++++++++++ internal/check.py | 11 +- orb_models/dataset/ase_dataset.py | 193 ++++++++++ orb_models/forcefield/atomic_system.py | 9 +- orb_models/forcefield/base.py | 12 +- orb_models/forcefield/calculator.py | 9 +- .../forcefield/featurization_utilities.py | 3 +- orb_models/forcefield/gns.py | 7 +- orb_models/forcefield/graph_regressor.py | 16 +- orb_models/forcefield/nn_util.py | 3 +- orb_models/forcefield/pretrained.py | 6 +- orb_models/forcefield/property_definitions.py | 5 +- orb_models/forcefield/rbf.py | 1 + orb_models/forcefield/reference_energies.py | 1 + orb_models/forcefield/segment_ops.py | 3 +- orb_models/utils.py | 328 +++++++++++++++++ pyproject.toml | 6 +- tests/conftest.py | 3 +- tests/test_base.py | 1 + tests/test_calculator.py | 4 +- tests/test_featurization_utilities.py | 1 + tests/test_segment_ops.py | 3 +- 23 files changed, 936 insertions(+), 54 deletions(-) create mode 100644 finetune.py create mode 100644 orb_models/dataset/ase_dataset.py create mode 100644 orb_models/utils.py diff --git a/README.md b/README.md index d4b9007..706d530 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ For more information on the models, please see the [MODELS.md](MODELS.md) file. import ase from ase.build import bulk -from orb_models.forcefield import pretrained -from orb_models.forcefield import atomic_system + +from orb_models.forcefield import atomic_system, pretrained from orb_models.forcefield.base import batch_graphs device = "cpu" # or device="cuda" @@ -66,10 +66,10 @@ atoms = atomic_system.atom_graphs_to_ase_atoms( ```python import ase from ase.build import bulk + from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator - device="cpu" # or device="cuda" orbff = pretrained.orb_v1(device=device) # or choose another model using ORB_PRETRAINED_MODELS[model_name]() calc = ORBCalculator(orbff, device=device) @@ -95,6 +95,21 @@ print("Optimized Energy:", atoms.get_potential_energy()) ``` +### Finetuning +You can finetune the model using your custom dataset. +The dataset should be an [ASE sqlite database](https://wiki.fysik.dtu.dk/ase/ase/db/db.html#module-ase.db.core). +```python +python finetune.py --dataset= --data_path= +``` +After the model is finetuned, checkpoints will, by default, be saved to the ckpts folder in the directory you ran the finetuning script from. + +You can use the new model and load the checkpoint by: +```python +from orb_models.forcefield import pretrained + +model = pretrained.orb_v1(weights_path=) +``` + ### Citing We are currently preparing a preprint for publication. diff --git a/finetune.py b/finetune.py new file mode 100644 index 0000000..226a190 --- /dev/null +++ b/finetune.py @@ -0,0 +1,344 @@ +"""Finetuning loop.""" + +import argparse +import logging +import os +from typing import Dict, Optional, Union + +import torch +import tqdm +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler + +import wandb +from orb_models import utils +from orb_models.dataset.ase_dataset import AseSqliteDataset +from orb_models.forcefield import base, pretrained +from wandb import wandb_run + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def finetune( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + dataloader: DataLoader, + lr_scheduler: Optional[_LRScheduler] = None, + num_steps: Optional[int] = None, + clip_grad: Optional[float] = None, + log_freq: float = 10, + device: torch.device = torch.device("cpu"), + epoch: int = 0, +): + """Train for a fixed number of steps. + + Args: + model: The model to optimize. + optimizer: The optimizer for the model. + dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed. + lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. + num_steps: The number of training steps to take. This is required for distributed training, + because controlling parallism is easier if all processes take exactly the same number of steps ( + this particularly applies when using dynamic batching). + clip_grad: Optional, the gradient clipping threshold. + log_freq: The logging frequency for step metrics. + device: The device to use for training. + epoch: The number of epochs the model has been fintuned. + + Returns + A dictionary of metrics. + """ + run: Optional[wandb_run.Run] = wandb.run + + if clip_grad is not None: + hook_handles = utils.gradient_clipping(model, clip_grad) + + metrics = utils.ScalarMetricTracker() + + # Set the model to "train" mode. + model.train() + + # Get tqdm for the training batches + batch_generator = iter(dataloader) + num_training_batches: Union[int, float] + if num_steps is not None: + num_training_batches = num_steps + else: + try: + num_training_batches = len(dataloader) + except TypeError: + raise ValueError("Dataloader has no length, you must specify num_steps.") + + batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches) + + i = 0 + batch_iterator = iter(batch_generator_tqdm) + while True: + if num_steps and i == num_steps: + break + + optimizer.zero_grad(set_to_none=True) + + step_metrics = { + "batch_size": 0.0, + "batch_num_edges": 0.0, + "batch_num_nodes": 0.0, + } + + # Reset metrics so that it reports raw values for each step but still do averages on + # the gradient accumulation. + if i % log_freq == 0: + metrics.reset() + + batch = next(batch_iterator) + batch = batch.to(device) + step_metrics["batch_size"] += len(batch.n_node) + step_metrics["batch_num_edges"] += batch.n_edge.sum() + step_metrics["batch_num_nodes"] += batch.n_node.sum() + + with torch.cuda.amp.autocast(enabled=False): + batch_outputs = model.loss(batch) + loss = batch_outputs.loss + metrics.update(batch_outputs.log) + if torch.isnan(loss): + raise ValueError("nan loss encountered") + loss.backward() + + optimizer.step() + + if lr_scheduler is not None: + lr_scheduler.step() + + metrics.update(step_metrics) + + if i != 0 and i % log_freq == 0: + metrics_dict = metrics.get_metrics() + if run is not None: + step = (epoch * num_training_batches) + i + if run.sweep_id is not None: + run.log( + {"loss": metrics_dict["loss"]}, + commit=False, + ) + run.log( + {"step": step}, + commit=False, + ) + run.log(utils.prefix_keys(metrics_dict, "finetune_step"), commit=True) + + # Finished a single full step! + i += 1 + + if clip_grad is not None: + for h in hook_handles: + h.remove() + + return metrics.get_metrics() + + +def build_train_loader( + dataset_path: str, + num_workers: int, + batch_size: int, + augmentation: Optional[bool] = True, + target_config: Optional[Dict] = None, + **kwargs, +) -> DataLoader: + """Builds the train dataloader from a config file. + + Args: + dataset_path: Dataset path. + num_workers: The number of workers for each dataset. + batch_size: The batch_size config for each dataset. + augmentation: If rotation augmentation is used. + target_config: The target config. + + Returns: + The train Dataloader. + """ + log_train = "Loading train datasets:\n" + dataset = AseSqliteDataset( + dataset_path, target_config=target_config, augmentation=augmentation, **kwargs + ) + + log_train += f"Total train dataset size: {len(dataset)} samples" + logging.info(log_train) + + sampler = RandomSampler(dataset) + + batch_sampler = BatchSampler( + sampler, + batch_size=batch_size, + drop_last=False, + ) + + train_loader: DataLoader = DataLoader( + dataset, + num_workers=num_workers, + worker_init_fn=utils.worker_init_fn, + collate_fn=base.batch_graphs, + batch_sampler=batch_sampler, + timeout=10 * 60 if num_workers > 0 else 0, + ) + return train_loader + + +def run(args): + """Training Loop. + + Args: + config (DictConfig): Config for training loop. + """ + device = utils.init_device(device_id=args.device_id) + utils.seed_everything(args.random_seed) + + # Make sure to use this flag for matmuls on A100 and H100 GPUs. + torch.set_float32_matmul_precision("high") + + # Instantiate model + model = pretrained.orb_v1(device=device) + for param in model.parameters(): + param.requires_grad = True + model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logging.info(f"Model has {model_params} trainable parameters.") + + # Move model to correct device. + model.to(device=device) + total_steps = args.max_epochs * args.num_steps + optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model) + + wandb_run = None + # Logger instantiation/configuration + if args.wandb: + logging.info("Instantiating WandbLogger.") + wandb_run = utils.init_wandb_from_config( + dataset=args.dataset, job_type="finetuning", entity=args.wandb_entity + ) + + wandb.define_metric("step") + wandb.define_metric("finetune_step/*", step_metric="step") + + loader_args = dict( + dataset_path=args.data_path, + num_workers=args.num_workers, + batch_size=args.batch_size, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + ) + train_loader = build_train_loader( + **loader_args, + augmentation=True, + ) + logging.info("Starting training!") + + num_steps = args.num_steps + + start_epoch = 0 + + for epoch in range(start_epoch, args.max_epochs): + print(f"Start epoch: {epoch} training...") + finetune( + model=model, + optimizer=optimizer, + dataloader=train_loader, + lr_scheduler=lr_scheduler, + clip_grad=args.gradient_clip_val, + device=device, + num_steps=num_steps, + epoch=epoch, + ) + + # Save checkpoint from last epoch + if epoch == args.max_epochs - 1: + # create ckpts folder if it does not exist + if not os.path.exists(args.checkpoint_path): + os.makedirs(args.checkpoint_path) + torch.save( + model.state_dict(), + os.path.join(args.checkpoint_path, f"checkpoint_epoch{epoch}.ckpt"), + ) + logging.info(f"Checkpoint saved to {args.checkpoint_path}") + + if wandb_run is not None: + wandb_run.finish() + + +def main(): + """Main.""" + parser = argparse.ArgumentParser( + description="Finetune orb model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--random_seed", default=1234, type=int, help="Random seed for finetuning." + ) + parser.add_argument( + "--device_id", default=0, type=int, help="GPU index to use if GPU is available." + ) + parser.add_argument( + "--wandb", + default=True, + action="store_true", + help="If the run is logged to Weights and Biases (requires installation).", + ) + parser.add_argument( + "--wandb_entity", + default="orbitalmaterials", + type=str, + help="Entity to log the run to in Weights and Biases.", + ) + parser.add_argument( + "--dataset", + default="mp-traj", + type=str, + help="Dataset name for wandb run logging.", + ) + parser.add_argument( + "--data_path", + default=os.path.join(os.getcwd(), "datasets/mptraj/finetune.db"), + type=str, + help="Dataset path to an ASE sqlite database (you must convert your data into this format).", + ) + parser.add_argument( + "--num_workers", + default=8, + type=int, + help="Number of cpu workers for the pytorch data loader.", + ) + parser.add_argument( + "--batch_size", default=100, type=int, help="Batch size for finetuning." + ) + parser.add_argument( + "--gradient_clip_val", default=0.5, type=float, help="Gradient clip value." + ) + parser.add_argument( + "--max_epochs", + default=50, + type=int, + help="Maximum number of epochs to finetune.", + ) + parser.add_argument( + "--num_steps", + default=100, + type=int, + help="Num steps of in each epoch.", + ) + parser.add_argument( + "--checkpoint_path", + default=os.path.join(os.getcwd(), "ckpts"), + type=str, + help="Path to save the model checkpoint.", + ) + parser.add_argument( + "--lr", + default=3e-04, + type=float, + help="Learning rate. 3e-4 is purely a sensible default; you may want to tune this for your problem.", + ) + args = parser.parse_args() + run(args) + + +if __name__ == "__main__": + main() diff --git a/internal/check.py b/internal/check.py index 6754977..0dd313e 100644 --- a/internal/check.py +++ b/internal/check.py @@ -1,14 +1,13 @@ """Integration tests to check compatibility of outputs with internal OM models.""" -import torch -import ase +import argparse -from orb_models.forcefield import pretrained -from orb_models.forcefield import atomic_system -from core.models import load +import ase +import torch from core.dataset import atomic_system as core_atomic_system +from core.models import load -import argparse +from orb_models.forcefield import atomic_system, pretrained def main(model: str, core_model: str): diff --git a/orb_models/dataset/ase_dataset.py b/orb_models/dataset/ase_dataset.py new file mode 100644 index 0000000..66f0a66 --- /dev/null +++ b/orb_models/dataset/ase_dataset.py @@ -0,0 +1,193 @@ +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import ase +import ase.db +import ase.db.row +import numpy as np +import torch +from ase.stress import voigt_6_to_full_3x3_stress +from torch.utils.data import Dataset + +from orb_models.forcefield import atomic_system, property_definitions +from orb_models.forcefield.base import AtomGraphs +from orb_models.utils import rand_matrix + + +class AseSqliteDataset(Dataset): + """AseSqliteDataset. + + A Pytorch Dataset for reading ASE Sqlite serialized Atoms objects. + + Args: + dataset_path: Local path to read. + system_config: A config for controlling how an atomic system is represented. + target_config: A config for regression/classification targets. + augmentation: If random rotation augmentation is used. + + Returns: + An AseSqliteDataset. + """ + + def __init__( + self, + dataset_path: Union[str, Path], + system_config: Optional[atomic_system.SystemConfig] = None, + target_config: Optional[Dict] = None, + augmentation: Optional[bool] = True, + ): + super().__init__() + self.augmentation = augmentation + self.path = dataset_path + self.db = ase.db.connect(str(self.path), serial=True, type="db") + + self.feature_config = system_config + if target_config is None: + target_config = { + "graph": ["energy", "stress"], + "node": ["forces"], + "edge": [], + } + self.target_config = target_config + + def __getitem__(self, idx) -> AtomGraphs: + """Fetch an item from the db. + + Args: + idx: An index to fetch from the db file and convert to an AtomGraphs. + + Returns: + A AtomGraphs object containing everything the model needs as input, + positions and atom types and other auxillary information, such as + fine tuning targets, or global graph features. + """ + # Sqlite db is 1 indexed. + row = self.db.get(idx + 1) + atoms = row.toatoms() + node_properties = property_definitions.get_property_from_row( + self.target_config["node"], row + ) + graph_property_dict = {} + for target_property in self.target_config["graph"]: + system_properties = property_definitions.get_property_from_row( + target_property, row + ) + graph_property_dict[target_property] = system_properties + extra_targets = { + "node": {"forces": node_properties}, + "edge": {}, + "graph": graph_property_dict, + } + if self.augmentation: + atoms, extra_targets = random_rotations_with_properties(atoms, extra_targets) # type: ignore + + atom_graph = atomic_system.ase_atoms_to_atom_graphs( + atoms, + system_id=idx, + brute_force_knn=False, + device=torch.device("cpu"), + ) + atom_graph = self._add_extra_targets(atom_graph, extra_targets) + + return atom_graph + + def get_atom(self, idx: int) -> ase.Atoms: + """Return the Atoms object for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms() + + def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]: + """Return the Atoms object plus a dict of metadata for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms(), row.data + + def __len__(self) -> int: + """Return the dataset length.""" + return len(self.db) + + def __repr__(self) -> str: + """String representation of class.""" + return f"AseSqliteDataset({self.path=})" + + def _add_extra_targets( + self, + atom_graph: AtomGraphs, + extra_targets: Dict[str, Dict], + ): + """Add extra features and targets to the AtomGraphs object. + + Args: + atom_graph: AtomGraphs object to add extra features and targets to. + extra_targets: Dictionary of extra targets to add. + """ + node_targets = ( + atom_graph.node_targets if atom_graph.node_targets is not None else {} + ) + node_targets = {**node_targets, **extra_targets["node"]} + + edge_targets = ( + atom_graph.edge_targets if atom_graph.edge_targets is not None else {} + ) + edge_targets = {**edge_targets, **extra_targets["edge"]} + + system_targets = ( + atom_graph.system_targets if atom_graph.system_targets is not None else {} + ) + system_targets = {**system_targets, **extra_targets["graph"]} + + return atom_graph._replace( + node_targets=node_targets if node_targets != {} else None, + edge_targets=edge_targets if edge_targets != {} else None, + system_targets=system_targets if system_targets != {} else None, + ) + + +def random_rotations_with_properties( + atoms: ase.Atoms, properties: dict +) -> Tuple[ase.Atoms, dict]: + """Randomly rotate atoms in ase.Atoms object. + + This exists to handle the case where we also need to rotate properties. + Currently we only ever do this for random rotations, but it could be extended. + + Args: + atoms (ase.Atoms): Atoms object to rotate. + properties (dict): Dictionary of properties to rotate. + """ + rand_rotation = rand_matrix(1)[0].numpy() + atoms.positions = atoms.positions @ rand_rotation + if atoms.cell is not None: + atoms.set_cell(atoms.cell.array @ rand_rotation) + + new_node_properties = {} + for key, v in properties["node"].items(): + if tuple(v.shape) == tuple(atoms.positions.shape): + new_node_properties[key] = v @ rand_rotation + else: + new_node_properties[key] = v + properties["node"] = new_node_properties + + if "stress" in properties["graph"]: + # Transformation rule of stress tensor + stress = properties["graph"]["stress"] + full_stress = voigt_6_to_full_3x3_stress(stress) + + # The featurization code adds a batch dimension, so we need to reshape + if full_stress.shape != (3, 3): + full_stress = full_stress.reshape(3, 3) + + transformed = np.dot(np.dot(rand_rotation, full_stress), rand_rotation.T) + # Back to voigt notation, and shape (1, 6) for consistency with batching + properties["graph"]["stress"] = torch.tensor( + [ + transformed[0, 0], + transformed[1, 1], + transformed[2, 2], + transformed[1, 2], + transformed[0, 2], + transformed[0, 1], + ], + dtype=torch.float32, + ).unsqueeze(0) + + return atoms, properties diff --git a/orb_models/forcefield/atomic_system.py b/orb_models/forcefield/atomic_system.py index f8cd017..848e82e 100644 --- a/orb_models/forcefield/atomic_system.py +++ b/orb_models/forcefield/atomic_system.py @@ -1,14 +1,13 @@ -from typing import Optional, List from dataclasses import dataclass +from typing import List, Optional import ase +import torch from ase import constraints from ase.calculators.singlepoint import SinglePointCalculator - from orb_models.forcefield import featurization_utilities from orb_models.forcefield.base import AtomGraphs -import torch @dataclass @@ -96,7 +95,9 @@ def ase_atoms_to_atom_graphs( ), system_id: Optional[int] = None, brute_force_knn: Optional[bool] = None, - device: Optional[torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + device: Optional[torch.device] = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ), ) -> AtomGraphs: """Generate AtomGraphs from an ase.Atoms object. diff --git a/orb_models/forcefield/base.py b/orb_models/forcefield/base.py index 0d28135..63b9422 100644 --- a/orb_models/forcefield/base.py +++ b/orb_models/forcefield/base.py @@ -2,16 +2,8 @@ from collections import defaultdict from copy import deepcopy -from typing import ( - Any, - Dict, - Union, - NamedTuple, - Mapping, - Optional, - List, - Sequence, -) +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union + import torch import tree diff --git a/orb_models/forcefield/calculator.py b/orb_models/forcefield/calculator.py index 0ea3b09..d785438 100644 --- a/orb_models/forcefield/calculator.py +++ b/orb_models/forcefield/calculator.py @@ -1,10 +1,9 @@ -from ase.calculators.calculator import Calculator, all_changes from typing import Optional + import torch -from orb_models.forcefield.atomic_system import ( - SystemConfig, - ase_atoms_to_atom_graphs, -) +from ase.calculators.calculator import Calculator, all_changes + +from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs from orb_models.forcefield.graph_regressor import GraphRegressor diff --git a/orb_models/forcefield/featurization_utilities.py b/orb_models/forcefield/featurization_utilities.py index ca3de1b..7ae4e6b 100644 --- a/orb_models/forcefield/featurization_utilities.py +++ b/orb_models/forcefield/featurization_utilities.py @@ -1,6 +1,6 @@ """Featurization utilities for molecular models.""" -from typing import Callable, Tuple, Union, Optional +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -9,7 +9,6 @@ from pynanoflann import KDTree as NanoKDTree from scipy.spatial import KDTree as SciKDTree - DistanceFeaturizer = Callable[[torch.Tensor], torch.Tensor] diff --git a/orb_models/forcefield/gns.py b/orb_models/forcefield/gns.py index bb3cd12..11adcb6 100644 --- a/orb_models/forcefield/gns.py +++ b/orb_models/forcefield/gns.py @@ -2,12 +2,13 @@ from collections import OrderedDict from typing import List, Literal + +import numpy as np import torch from torch import nn -import numpy as np -from orb_models.forcefield import base + +from orb_models.forcefield import base, segment_ops from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield import segment_ops _KEY = "feat" diff --git a/orb_models/forcefield/graph_regressor.py b/orb_models/forcefield/graph_regressor.py index 26a93e0..167523e 100644 --- a/orb_models/forcefield/graph_regressor.py +++ b/orb_models/forcefield/graph_regressor.py @@ -1,17 +1,14 @@ -from typing import Literal, Optional, Dict, Tuple, Union +from typing import Dict, Literal, Optional, Tuple, Union + +import numpy import torch import torch.nn as nn -import numpy -from orb_models.forcefield.property_definitions import ( - PROPERTIES, - PropertyDefinition, -) -from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES -from orb_models.forcefield import base +from orb_models.forcefield import base, segment_ops from orb_models.forcefield.gns import _KEY, MoleculeGNS from orb_models.forcefield.nn_util import build_mlp -from orb_models.forcefield import segment_ops +from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition +from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES global HAS_WARNED_FOR_TF32_MATMUL HAS_WARNED_FOR_TF32_MATMUL = False @@ -97,7 +94,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Normalize by running mean and std.""" if self.training: - # hack: call batch norm, but only to update a running mean/std self.bn(x.view(-1, 1)) return (x - self.bn.running_mean) / torch.sqrt(self.bn.running_var) # type: ignore diff --git a/orb_models/forcefield/nn_util.py b/orb_models/forcefield/nn_util.py index 48cde63..6b24cc1 100644 --- a/orb_models/forcefield/nn_util.py +++ b/orb_models/forcefield/nn_util.py @@ -1,8 +1,9 @@ """Shared neural net utility functions.""" +from typing import List, Optional, Type + import torch import torch.nn.functional as F -from typing import List, Optional, Type from torch import nn from torch.utils.checkpoint import checkpoint_sequential diff --git a/orb_models/forcefield/pretrained.py b/orb_models/forcefield/pretrained.py index 28f4c7a..ecb826e 100644 --- a/orb_models/forcefield/pretrained.py +++ b/orb_models/forcefield/pretrained.py @@ -1,15 +1,17 @@ # flake8: noqa: E501 from typing import Union + import torch from cached_path import cached_path + from orb_models.forcefield.featurization_utilities import get_device +from orb_models.forcefield.gns import MoleculeGNS from orb_models.forcefield.graph_regressor import ( EnergyHead, - NodeHead, GraphHead, GraphRegressor, + NodeHead, ) -from orb_models.forcefield.gns import MoleculeGNS from orb_models.forcefield.rbf import ExpNormalSmearing global HAS_MESSAGED_FOR_TF32_MATMUL diff --git a/orb_models/forcefield/property_definitions.py b/orb_models/forcefield/property_definitions.py index 495e720..0fd8bc5 100644 --- a/orb_models/forcefield/property_definitions.py +++ b/orb_models/forcefield/property_definitions.py @@ -1,7 +1,7 @@ """Classes that define prediction targets.""" from dataclasses import dataclass -from typing import Any, Callable, Dict, Literal, Tuple, Union, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import ase.data import ase.db @@ -97,6 +97,7 @@ def energy_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: internally generated datasets. """ extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("energy", 1)], "mp-traj-d3": [("energy", 1), ("data.d3.energy", 1)], "alexandria-d3": [("energy", 1), ("data.d3.energy", 1)], } @@ -125,6 +126,7 @@ def forces_row_fn(row: ase.db.row.AtomsRow, dataset: str): internally generated datasets. """ extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("forces", 1)], "mp-traj-d3": [("forces", 1), ("data.d3.forces", 1)], "alexandria-d3": [("forces", 1), ("data.d3.forces", 1)], } @@ -145,6 +147,7 @@ def forces_row_fn(row: ase.db.row.AtomsRow, dataset: str): def stress_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: """Extract stress data.""" extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("stress", 1)], "mp-traj-d3": [("stress", 1), ("data.d3.stress", 1)], "alexandria-d3": [("stress", 1), ("data.d3.stress", 1)], } diff --git a/orb_models/forcefield/rbf.py b/orb_models/forcefield/rbf.py index 61e7b40..266bc3b 100644 --- a/orb_models/forcefield/rbf.py +++ b/orb_models/forcefield/rbf.py @@ -1,4 +1,5 @@ import math + import torch diff --git a/orb_models/forcefield/reference_energies.py b/orb_models/forcefield/reference_energies.py index c3837d3..f3a75c7 100644 --- a/orb_models/forcefield/reference_energies.py +++ b/orb_models/forcefield/reference_energies.py @@ -1,4 +1,5 @@ from typing import NamedTuple + import numpy diff --git a/orb_models/forcefield/segment_ops.py b/orb_models/forcefield/segment_ops.py index 4af5ae9..c25efbc 100644 --- a/orb_models/forcefield/segment_ops.py +++ b/orb_models/forcefield/segment_ops.py @@ -1,6 +1,7 @@ -import torch from typing import Optional +import torch + TORCHINT = [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8] diff --git a/orb_models/utils.py b/orb_models/utils.py new file mode 100644 index 0000000..8b221c2 --- /dev/null +++ b/orb_models/utils.py @@ -0,0 +1,328 @@ +"""Experiment utilities.""" + +import math +import os +import random +import re +from collections import defaultdict +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar + +import numpy as np +import torch + +import wandb +from orb_models.forcefield import base +from wandb import wandb_run + +T = TypeVar("T") + + +def init_device(device_id: Optional[int] = None) -> torch.device: + """Initialize a device. + + Initializes a device based on the device id provided, + if not provided, it will use device_id = 0 if GPU is available. + """ + if not device_id: + device_id = 0 + if torch.cuda.is_available(): + device = f"cuda:{device_id}" + torch.cuda.set_device(device_id) + torch.cuda.empty_cache() + else: + device = "cpu" + return torch.device(device) + + +def worker_init_fn(id: int): + """Set seeds per worker, so augmentations etc are not duplicated across workers. + + Unused id arg is a requirement for the Dataloader interface. + + By default, each worker will have its PyTorch seed set to base_seed + worker_id, + where base_seed is a long generated by main process using its RNG + (thereby, consuming a RNG state mandatorily) or a specified generator. + However, seeds for other libraries may be duplicated upon initializing workers, + causing each worker to return identical random numbers. + + In worker_init_fn, you may access the PyTorch seed set for each worker with either + torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed + other libraries before data loading. + """ + uint64_seed = torch.initial_seed() + ss = np.random.SeedSequence([uint64_seed]) + np.random.seed(ss.generate_state(4)) + random.seed(uint64_seed) + + +def ensure_detached(x: base.Metric) -> base.Metric: + """Ensure that the tensor is detached and on the CPU.""" + if isinstance(x, torch.Tensor): + return x.detach() + return x + + +def to_item(x: base.Metric) -> base.Metric: + """Convert a tensor to a python scalar.""" + if isinstance(x, torch.Tensor): + return x.cpu().item() + return x + + +def prefix_keys( + dict_to_prefix: Dict[str, T], prefix: str, sep: str = "/" +) -> Dict[str, T]: + """Add a prefix to dictionary keys with a seperator.""" + return {f"{prefix}{sep}{k}": v for k, v in dict_to_prefix.items()} + + +def seed_everything(seed: int, rank: int = 0) -> None: + """Set the seed for all pseudo random number generators.""" + random.seed(seed + rank) + np.random.seed(seed + rank) + torch.manual_seed(seed + rank) + + +def init_wandb_from_config(dataset: str, job_type: str, entity: str) -> wandb_run.Run: + """Initialise wandb.""" + wandb.init( # type: ignore + job_type=job_type, + dir=os.path.join(os.getcwd(), "wandb"), + name=f"{dataset}-{job_type}", + project="orb-experiment", + entity=entity, + mode="online", + sync_tensorboard=False, + ) + assert wandb.run is not None + return wandb.run + + +class ScalarMetricTracker: + """Keep track of average scalar metric values.""" + + def __init__(self): + self.reset() + + def reset(self): + """Reset the AverageMetrics.""" + self.sums = defaultdict(float) + self.counts = defaultdict(int) + + def update(self, metrics: Mapping[str, base.Metric]) -> None: + """Update the metric counts with new values.""" + for k, v in metrics.items(): + if isinstance(v, torch.Tensor) and v.nelement() > 1: + continue # only track scalar metrics + if isinstance(v, torch.Tensor) and v.isnan().any(): + continue + self.sums[k] += ensure_detached(v) + self.counts[k] += 1 + + def get_metrics(self): + """Get the metric values, possibly reducing across gpu processes.""" + return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} + + +def gradient_clipping( + model: torch.nn.Module, clip_value: float +) -> List[torch.utils.hooks.RemovableHandle]: + """Add gradient clipping hooks to a model. + + This is the correct way to implement gradient clipping, because + gradients are clipped as gradients are computed, rather than after + all gradients are computed - this means expoding gradients are less likely, + because they are "caught" earlier. + + Args: + model: The model to add hooks to. + clip_value: The upper and lower threshold to clip the gradients to. + + Returns: + A list of handles to remove the hooks from the parameters. + """ + handles = [] + + def _clip(grad): + if grad is None: + return grad + return grad.clamp(min=-clip_value, max=clip_value) + + for parameter in model.parameters(): + if parameter.requires_grad: + h = parameter.register_hook(lambda grad: _clip(grad)) + handles.append(h) + + return handles + + +def get_optim( + lr: float, total_steps: int, model: torch.nn.Module +) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler]]: + """Configure optimizers, LR schedulers and EMA.""" + + # Initialize parameter groups + params = [] + + # Split parameters based on the regex + for name, param in model.named_parameters(): + if re.search(r"(.*bias|.*layer_norm.*|.*batch_norm.*)", name): + params.append({"params": param, "weight_decay": 0.0}) + else: + params.append({"params": param}) + + # Create the optimizer with the parameter groups + optimizer = torch.optim.Adam(params, lr=lr) + + # Create the learning rate scheduler + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, max_lr=lr, total_steps=total_steps, pct_start=0.05 + ) + + return optimizer, scheduler + + +def rand_angles(*shape, requires_grad=False, dtype=None, device=None): + r"""random rotation angles + + Parameters + ---------- + *shape : int + + Returns + ------- + alpha : `torch.Tensor` + tensor of shape :math:`(\mathrm{shape})` + + beta : `torch.Tensor` + tensor of shape :math:`(\mathrm{shape})` + + gamma : `torch.Tensor` + tensor of shape :math:`(\mathrm{shape})` + """ + alpha, gamma = 2 * math.pi * torch.rand(2, *shape, dtype=dtype, device=device) + beta = torch.rand(shape, dtype=dtype, device=device).mul(2).sub(1).acos() + alpha = alpha.detach().requires_grad_(requires_grad) + beta = beta.detach().requires_grad_(requires_grad) + gamma = gamma.detach().requires_grad_(requires_grad) + return alpha, beta, gamma + + +def matrix_x(angle: torch.Tensor) -> torch.Tensor: + r"""matrix of rotation around X axis + + Parameters + ---------- + angle : `torch.Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `torch.Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = torch.ones_like(angle) + z = torch.zeros_like(angle) + return torch.stack( + [ + torch.stack([o, z, z], dim=-1), + torch.stack([z, c, -s], dim=-1), + torch.stack([z, s, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_y(angle: torch.Tensor) -> torch.Tensor: + r"""matrix of rotation around Y axis + + Parameters + ---------- + angle : `torch.Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `torch.Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = torch.ones_like(angle) + z = torch.zeros_like(angle) + return torch.stack( + [ + torch.stack([c, z, s], dim=-1), + torch.stack([z, o, z], dim=-1), + torch.stack([-s, z, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_z(angle: torch.Tensor) -> torch.Tensor: + r"""matrix of rotation around Z axis + + Parameters + ---------- + angle : `torch.Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `torch.Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = torch.ones_like(angle) + z = torch.zeros_like(angle) + return torch.stack( + [ + torch.stack([c, -s, z], dim=-1), + torch.stack([s, c, z], dim=-1), + torch.stack([z, z, o], dim=-1), + ], + dim=-2, + ) + + +def angles_to_matrix(alpha, beta, gamma): + r"""conversion from angles to matrix + + Parameters + ---------- + alpha : `torch.Tensor` + tensor of shape :math:`(...)` + + beta : `torch.Tensor` + tensor of shape :math:`(...)` + + gamma : `torch.Tensor` + tensor of shape :math:`(...)` + + Returns + ------- + `torch.Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) + return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) + + +def rand_matrix(*shape, requires_grad=False, dtype=None, device=None): + r"""random rotation matrix + + Parameters + ---------- + *shape : int + + Returns + ------- + `torch.Tensor` + tensor of shape :math:`(\mathrm{shape}, 3, 3)` + """ + R = angles_to_matrix(*rand_angles(*shape, dtype=dtype, device=device)) + return R.detach().requires_grad_(requires_grad) diff --git a/pyproject.toml b/pyproject.toml index 4cb21b4..08cde16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,13 @@ classifiers = [ dependencies = [ "cached_path>=1.6.2", - "ase>=3.22.1", - "numpy<2.0.0", + "ase>=3.23.0", + "numpy>=1.26.4, <2.0.0", "scipy>=1.13.1", "torch==2.2.0", "dm-tree>=0.1.8", + "tqdm>=4.66.5", + "wandb>=0.17.7" ] [build-system] diff --git a/tests/conftest.py b/tests/conftest.py index 6217930..6b42349 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest from pathlib import Path +import pytest + @pytest.fixture(scope="module") def fixtures_path(request): diff --git a/tests/test_base.py b/tests/test_base.py index f48b7b0..1eb332f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,6 +1,7 @@ # type: ignore import pytest import torch + from orb_models.forcefield import base from orb_models.forcefield.base import refeaturize_atomgraphs diff --git a/tests/test_calculator.py b/tests/test_calculator.py index a94da0c..c904e10 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -1,8 +1,8 @@ +import ase import numpy as np +import pytest import torch import torch.nn as nn -import ase -import pytest from orb_models.forcefield import segment_ops from orb_models.forcefield.calculator import ORBCalculator diff --git a/tests/test_featurization_utilities.py b/tests/test_featurization_utilities.py index 7a33561..b0cc9ca 100644 --- a/tests/test_featurization_utilities.py +++ b/tests/test_featurization_utilities.py @@ -1,6 +1,7 @@ """Tests featurization utilities.""" import functools + import ase import ase.io import ase.neighborlist diff --git a/tests/test_segment_ops.py b/tests/test_segment_ops.py index 16ca66e..d62ce02 100644 --- a/tests/test_segment_ops.py +++ b/tests/test_segment_ops.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from orb_models.forcefield import segment_ops import pytest import torch +from orb_models.forcefield import segment_ops + @pytest.mark.parametrize( ("reduction, dtype"),