Skip to content

Commit

Permalink
Merge pull request #114 from apax-hub/dev
Browse files Browse the repository at this point in the history
Accumulated changes from `dev` since Jan. 25.
  • Loading branch information
M-R-Schaefer authored Mar 22, 2023
2 parents 1be0838 + 7dd478c commit 16e2217
Show file tree
Hide file tree
Showing 114 changed files with 3,488 additions and 2,533 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: documentation

on:
push:
schedule:
- cron: '14 3 * * 1' # at 03:14 on Monday.

jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: Run Poetry Image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: 1.2.2

- name: Install Sphinx Dependencies
run: |
poetry --version
poetry install
- name: Build documentation
run: |
cd docs
poetry run sphinx-build -b html source build
6 changes: 3 additions & 3 deletions .github/workflows/linting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Black Check
uses: psf/black@stable
with:
src: "./gmnn_jax"
src: "./apax"
version: "22.10.0"

isort:
Expand All @@ -29,7 +29,7 @@ jobs:
- name: run isort
run: |
isort --check-only --quiet gmnn_jax
isort --check-only --quiet apax
flake8:
runs-on: ubuntu-latest
Expand All @@ -45,4 +45,4 @@ jobs:
- name: run flake8
run: |
flake8 gmnn_jax --count --show-source --statistics
flake8 apax --count --show-source --statistics
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 GM-NN
Copyright (c) 2022 Moritz Schäfer and Nico Segreto

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
93 changes: 65 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,10 @@
## Roadmap

- [x] basic loading of fixed size ASE structures into `tf.data.Dataset`
- [x] basic linear regressor atomic number -> energy
- [x] per-example model + `vmap utiliation`
- [x] loading model parameters from TF GMNN
- [x] basic training loop
- [x] basic metrics
- [x] hooks / tensorboard
- [x] model checkpoints
- [x] restart
- [ ] advanced training loop
- [ ] MLIP metrics
- [x] async checkpoints
- [x] jit compiled metrics
- [x] dataset statistics
- [x] precomputing neighborlists with `jax_md`
- [ ] tests
- [ ] documentation
- [ ] generalize to differently sized molecules
- [x] Optimizer with different lr for different parameter groups
- [x] GMNN energy model with `jax_md`
- [x] force model
- [x] running MD with GMNN
# `apax`: Atomstic learned Potentials in JAX!
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

`apax` is a high-performance, extendable package for training of and inference with atomistic neural networks.
It implements the Gaussian Moment Neural Network model [2, 3].
It is based on [JAX](https://jax.readthedocs.io/en/latest/) and uses [JaxMD](https://github.com/jax-md/jax-md) as a molecular dynamics engine.


## Installation
Expand All @@ -32,16 +15,16 @@ You can install [Poetry](https://python-poetry.org/) via
curl -sSL https://install.python-poetry.org | python3 -
```

Now you can install GMNN in your project by running
Now you can install apax in your project by running

```bash
poetry add git+https://github.com/GM-NN/gmnn-jax.git
poetry add git+https://github.com/apax-hub/apax.git
```

As a developer, you can clone the repository and install it via

```bash
git clone https://github.com/GM-NN/gmnn-jax.git <dest_dir>
git clone https://github.com/apax-hub/apax.git <dest_dir>
cd <dest_dir>
poetry install
```
Expand All @@ -57,4 +40,58 @@ pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

See the [Jax installation instructions](https://github.com/google/jax#installation) for more details.
See the [Jax installation instructions](https://github.com/google/jax#installation) for more details.



## Usage

### Your first apax Model

In order to train a model, you need to run

```python
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
apax template train # use --full for a template with all input options
```

Please refer to the documentation LINK for a detailed explanation of all parameters.
The documentation can convenienty be accessed by runnning `apax docs`.

## Molecular Dynamics

There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
apax md config.yaml md_config.yaml
```

A template command for MD input files is provided as well.

The second way is to use the ASE calculator provided in `apax.md`.


## Authors
- Moritz René Schäfer
- Nico Segreto

Under the supervion of Johannes Kästner

## References
* [1] DOI PLACEHOLDER
* [2] V. Zaverkin and J. Kästner, [“Gaussian Moments as Physically Inspired Molecular Descriptors for Accurate and Scalable Machine Learning Potentials,”](https://doi.org/10.1021/acs.jctc.0c00347) J. Chem. Theory Comput. **16**, 5410–5421 (2020).
* [3] V. Zaverkin, D. Holzmüller, I. Steinwart, and J. Kästner, [“Fast and Sample-Efficient Interatomic Neural Network Potentials for Molecules and Materials Based on Gaussian Moments,”](https://pubs.acs.org/doi/10.1021/acs.jctc.1c00527) J. Chem. Theory Comput. **17**, 6658–6670 (2021).


## Contributing

We are happy to receive your issues and pull requests!

Do not hesitate to contact any of the authors above if you have any further questions.
5 changes: 4 additions & 1 deletion gmnn_jax/__init__.py → apax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import warnings

import tensorflow as tf
from jax.config import config as jax_config

tf.config.set_visible_devices([], "GPU")

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
warnings.filterwarnings(action="ignore", category=FutureWarning, module=r"jax.*scatter")
tf.config.experimental.set_visible_devices([], "GPU")
jax_config.update("jax_enable_x64", True)
File renamed without changes.
67 changes: 28 additions & 39 deletions gmnn_jax/cli/gmnn_app.py → apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
from pydantic import ValidationError
from rich.console import Console

from gmnn_jax.cli import templates
from apax.cli import templates

console = Console(highlight=False)
app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]})

app = typer.Typer(
context_settings={"help_option_names": ["-h", "--help"]},
pretty_exceptions_show_locals=False,
)
validate_app = typer.Typer(
pretty_exceptions_show_locals=False,
context_settings={"help_option_names": ["-h", "--help"]},
help="Validate training or MD config files.",
)
template_app = typer.Typer(
pretty_exceptions_show_locals=False,
context_settings={"help_option_names": ["-h", "--help"]},
help="Create configuration file templates.",
)
Expand All @@ -32,9 +38,9 @@ def train(
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
):
"""
Starts the training of a GMNN model with parameters provided by a configuration file.
Starts the training of a model with parameters provided by a configuration file.
"""
from gmnn_jax.train.run import run
from apax.train.run import run

run(train_config_path, log_file, log_level)

Expand All @@ -52,7 +58,7 @@ def md(
Starts performing a molecular dynamics simulation (currently only NHC thermostat)
with paramters provided by a configuration file.
"""
from gmnn_jax.md import run_md
from apax.md import run_md

run_md(train_config_path, md_config_path, log_file, log_level)

Expand All @@ -74,7 +80,7 @@ def eval(
Starts performing the evaluation of the test dataset
with parameters provided by a configuration file.
"""
from gmnn_jax.train.eval import eval_model
from apax.train.eval import eval_model

eval_model(train_config_path, n_data)

Expand All @@ -84,8 +90,8 @@ def docs():
"""
Opens the documentation website in your browser.
"""
console.print("Opening gmnn-jax's docs at https://github.com/GM-NN/gmnn-jax")
typer.launch("https://github.com/GM-NN/gmnn-jax")
console.print("Opening apax's docs at https://github.com/apax-hub/apax")
typer.launch("https://github.com/apax-hub/apax")


@validate_app.command("train")
Expand All @@ -101,7 +107,7 @@ def validate_train_config(
----------
config_path: Path to the training configuration file.
"""
from gmnn_jax.config import Config
from apax.config import Config

with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
Expand Down Expand Up @@ -130,7 +136,7 @@ def validate_md_config(
----------
config_path: Path to the molecular dynamics configuration file.
"""
from gmnn_jax.config import MDConfig
from apax.config import MDConfig

with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
Expand Down Expand Up @@ -165,12 +171,12 @@ def visualize_model(
----------
config_path: Path to the training configuration file.
"""
from jax_md.partition import space
import jax
from jax_md import space

from gmnn_jax.config import Config
from gmnn_jax.model import get_training_model
from gmnn_jax.utils.data import make_minimal_input
from gmnn_jax.visualize import model_tabular
from apax.config import Config
from apax.model.builder import ModelBuilder
from apax.utils.data import make_minimal_input

with open(config_path, "r") as stream:
user_config = yaml.safe_load(stream)
Expand All @@ -182,16 +188,12 @@ def visualize_model(
console.print("Configuration Invalid!", style="red3")
raise typer.Exit(code=1)

displacement_fn, _ = space.free()
R, Z, idx = make_minimal_input()

gmnn = get_training_model(
n_atoms=2,
n_species=10,
displacement_fn=displacement_fn,
**config.model.get_dict(),
R, Z, idx, box = make_minimal_input()
builder = ModelBuilder(config.model.get_dict(), n_species=10)
model = builder.build_energy_model(
displacement_fn=space.free()[0],
)
model_tabular(gmnn, R, Z, idx)
print(model.tabulate(jax.random.PRNGKey(0), R, Z, idx, box))


@template_app.command("train")
Expand Down Expand Up @@ -235,23 +237,10 @@ def template_md_config():
config.write(template_content)


logo = """
[bold white] /###### /## /## /## /## /## /##[bold turquoise2] /##### /###### /## /##
[bold white] /##__ ##| ### /###| ### | ##| ### | ##[bold turquoise2] |__ ## /##__ ##| ## / ##
[bold white]| ## \__/| #### /####| ####| ##| ####| ##[bold turquoise2] | ##| ## \ ##| ##/ ##/
[bold white]| ## /####| ## ##/## ##| ## ## ##| ## ## ##[bold turquoise2] /###### | ##| ######## \ ####/
[bold white]| ##|_ ##| ## ###| ##| ## ####| ## ####[bold turquoise2]|______/ /## | ##| ##__ ## >## ##
[bold white]| ## \ ##| ##\ # | ##| ##\ ###| ##\ ###[bold turquoise2] | ## | ##| ## | ## /##/\ ##
[bold white]| ######/| ## \/ | ##| ## \ ##| ## \ ##[bold turquoise2] | ######/| ## | ##| ## \ ##
[bold white] \______/ |__/ |__/|__/ \__/|__/ \__/[bold turquoise2] \______/ |__/ |__/|__/ |__/
""" # noqa: E501, W605, W291, E261, E303


def version_callback(value: bool) -> None:
"""Get the installed gmnn-jax version."""
"""Get the installed apax version."""
if value:
console.print(logo)
console.print(f"gmnn-jax {importlib.metadata.version('gmnn-jax')}")
console.print(f"apax {importlib.metadata.version('apax')}")
raise typer.Exit()


Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ seed: 1

data:
model_path: models/
model_name: gmnn
model_name: apax


# Use either data_path for a single dataset file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ n_epochs: <NUMBER OF EPOCHS>

data:
model_path: models
model_name: gmnn
model_name: apax
data_path: <PATH>

n_train: 100
Expand Down
4 changes: 2 additions & 2 deletions gmnn_jax/config/__init__.py → apax/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gmnn_jax.config.md_config import MDConfig
from gmnn_jax.config.train_config import (
from apax.config.md_config import MDConfig
from apax.config.train_config import (
CallbackConfig,
Config,
DataConfig,
Expand Down
File renamed without changes.
Loading

0 comments on commit 16e2217

Please sign in to comment.