Skip to content

Commit

Permalink
regression test overhaul
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Oct 5, 2023
1 parent 5519d7c commit ade74a5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 57 deletions.
34 changes: 33 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import os
import urllib
import zipfile
import numpy as np
import pytest
from ase import Atoms
Expand Down Expand Up @@ -74,3 +76,33 @@ def example_atoms(num_data: int, pbc: bool, calc_results: List[str]) -> Atoms:
def get_tmp_path(tmp_path_factory):
test_path = tmp_path_factory.mktemp("apax_tests")
return test_path


@pytest.fixture(scope="session")
def get_md22_stachyose(get_tmp_path):
url = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"
data_path = get_tmp_path / "data"
file_path = data_path / "md22_stachyose.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)

return file_path


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path
16 changes: 8 additions & 8 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
n_epochs: 1000
n_epochs: 200
seed: 0

data:
Expand All @@ -10,7 +10,7 @@ data:
n_train: 1000
n_valid: 100

batch_size: 8
batch_size: 32
valid_batch_size: 100

shift_method: "per_element_regression_shift"
Expand Down Expand Up @@ -66,10 +66,10 @@ loss:
optimizer:
opt_name: adam
opt_kwargs: {}
emb_lr: 0.003
nn_lr: 0.002
scale_lr: 0.0005
shift_lr: 0.025
emb_lr: 0.02
nn_lr: 0.03
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0

Expand All @@ -83,7 +83,7 @@ checkpoints:
# reset_layers: []

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false
disable_epoch_pbar: true
disable_nl_pbar: true

maximize_l2_cache: false
82 changes: 34 additions & 48 deletions tests/regression_tests/test_apax_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,82 +2,68 @@
import pathlib
import urllib
import zipfile
import uuid

import numpy as np
import pandas as pd
import pytest
import yaml

from apax.train.run import run
from apax.data.statistics import scale_method_list, shift_method_list

TEST_PATH = pathlib.Path(__file__).parent.resolve()
MD22_STACHYOSE_URL = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"


def download_and_extract_data(data_path, filename, url, file_format):
file_path = data_path / filename

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

return file_path.with_suffix(f".{file_format}")


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path


def load_config_and_run_training(
config_path, file_path, working_dir, energy_unit="eV", pos_unit="Ang"
config_path, **config_kwargs
):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

config_dict["data"]["directory"] = working_dir.as_posix()
config_dict["data"]["data_path"] = file_path.as_posix()
config_dict["data"]["energy_unit"] = energy_unit
config_dict["data"]["pos_unit"] = pos_unit
for pydentic_model_key, config_mods in config_kwargs.items():
for h_param_key, value in config_mods.items():
config_dict[pydentic_model_key][h_param_key] = value

run(config_dict)


def load_csv(filename):
data = np.loadtxt(filename, delimiter=',', skiprows=1) # Skip the header row

with open(filename, 'r') as file:
header = file.readline().strip().split(',')

return
data_dict = {header[i]: data[:, i].tolist() for i in range(len(header))}

return data_dict


@pytest.mark.slow
def test_regression_model_training(get_tmp_path):
def test_regression_model_training(get_md22_stachyose, get_tmp_path):
config_path = TEST_PATH / "apax_config.yaml"
working_dir = get_tmp_path
data_path = working_dir / "data"
filename = "md22_stachyose.zip"
working_dir = get_tmp_path / str(uuid.uuid4())
file_path = get_md22_stachyose

file_path = download_and_extract_data(data_path, filename, MD22_STACHYOSE_URL, "xyz")

file_path = modify_xyz_file(
file_path, target_string="Energy", replacement_string="energy"
)
data_config_mods = {
"directory": working_dir.as_posix(),
"data_path": file_path.as_posix(),
"energy_unit": "kcal/mol",
}

load_config_and_run_training(config_path, file_path, working_dir, "kcal/mol")
load_config_and_run_training(config_path, data=data_config_mods)

current_metrics = pd.read_csv(working_dir / "test/log.csv")
current_metrics = load_csv(working_dir / "test/log.csv")

comparison_metrics = {
"val_energy_mae": 0.2048215700433502,
"val_forces_mae": 0.054957914591049,
"val_forces_mse": 0.0056583952479869,
"val_loss": 0.1395589689994847,
"val_energy_mae": 0.24696787788040334,
"val_forces_mae": 0.09672525137916232,
"val_forces_mse": 0.017160819058234304,
"val_loss": 0.45499257304743396,
}
for key in comparison_metrics.keys():
print(np.array(current_metrics[key])[-1])

for key in comparison_metrics.keys():
assert (
abs((np.array(current_metrics[key])[-1] / comparison_metrics[key]) - 1) < 1e-4
)
abs((np.array(current_metrics[key])[-1] / comparison_metrics[key]) - 1) < 1e-3
)

0 comments on commit ade74a5

Please sign in to comment.