Skip to content

Commit

Permalink
first part of regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Oct 3, 2023
1 parent 874ffe2 commit 4149a5e
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
89 changes: 89 additions & 0 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
n_epochs: 1000
seed: 0

data:
directory: models/
experiment: apax

data_path: <PATH>

n_train: 1000
n_valid: 100

batch_size: 8
valid_batch_size: 100

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}
shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV

model:
n_basis: 7
n_radial: 5
nn: [512, 512]

r_max: 6.5
r_min: 0.5

calc_stress: false
use_zbl: true

b_init: normal
descriptor_dtype: fp32
readout_dtype: fp32
scale_shift_dtype: fp32

metrics:
- name: energy
reductions:
- mae
- name: forces
reductions:
- mae
- mse
# - name: stress
# reductions:
# - mae
# - mse

loss:
- loss_type: structures
name: energy
weight: 1.0
- loss_type: structures
name: forces
weight: 8.0
- loss_type: cosine_sim
name: forces
weight: 0.1
# - loss_type: structures
# name: stress
# weight: 1.0

optimizer:
opt_name: adam
opt_kwargs: {}
emb_lr: 0.003
nn_lr: 0.002
scale_lr: 0.0005
shift_lr: 0.025
zbl_lr: 0.001
transition_begin: 0

callbacks:
- name: csv

checkpoints:
ckpt_interval: 1
# The options below are used for transfer learning
# base_model_checkpoint: null
# reset_layers: []

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false

maximize_l2_cache: false
50 changes: 50 additions & 0 deletions tests/regression_tests/test_apax_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import pathlib
import yaml
from apax.train.run import run
import urllib
import zipfile
import os

TEST_PATH = pathlib.Path(__file__).parent.resolve()
MD22_STACHYOSE_URL = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"

@pytest.mark.slow
def test_regression_model_training(get_tmp_path):
temp_path = get_tmp_path

#load dataset and safe it in temp_dir
data_path = temp_path / "data"
os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(MD22_STACHYOSE_URL, data_path / "md22_stachyose.zip")

with zipfile.ZipFile(data_path / "md22_stachyose.zip", "r") as zip_ref:
zip_ref.extractall(data_path)

input_file_path = (data_path / "md22_stachyose.xyz").as_posix()
output_file_path = (data_path / "md22_stachyose_mod.xyz").as_posix()
target_string = 'Energy'
replacement_string = 'energy'

replace_string_in_file(input_file_path, output_file_path, target_string, replacement_string)

#read and adjust config
confg_path = TEST_PATH / "apax_config.yaml"
with open(confg_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

config_dict["data"]["directory"] = temp_path.as_posix()
config_dict["data"]["data_path"] = output_file_path
config_dict["data"]["energy_unit"] = "kcal/mol"

run(config_dict)


assert False

def replace_string_in_file(input_file_path, output_file_path, target_string, replacement_string):
with open(input_file_path, 'r') as input_file, open(output_file_path, 'w') as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)

0 comments on commit 4149a5e

Please sign in to comment.