diff --git a/tests/regression_tests/apax_config.yaml b/tests/regression_tests/apax_config.yaml new file mode 100644 index 00000000..3ed1f789 --- /dev/null +++ b/tests/regression_tests/apax_config.yaml @@ -0,0 +1,89 @@ +n_epochs: 1000 +seed: 0 + +data: + directory: models/ + experiment: apax + + data_path: + + n_train: 1000 + n_valid: 100 + + batch_size: 8 + valid_batch_size: 100 + + shift_method: "per_element_regression_shift" + shift_options: {"energy_regularisation": 1.0} + shuffle_buffer_size: 1000 + + pos_unit: Ang + energy_unit: eV + +model: + n_basis: 7 + n_radial: 5 + nn: [512, 512] + + r_max: 6.5 + r_min: 0.5 + + calc_stress: false + use_zbl: true + + b_init: normal + descriptor_dtype: fp32 + readout_dtype: fp32 + scale_shift_dtype: fp32 + +metrics: + - name: energy + reductions: + - mae + - name: forces + reductions: + - mae + - mse + # - name: stress + # reductions: + # - mae + # - mse + +loss: + - loss_type: structures + name: energy + weight: 1.0 + - loss_type: structures + name: forces + weight: 8.0 + - loss_type: cosine_sim + name: forces + weight: 0.1 + # - loss_type: structures + # name: stress + # weight: 1.0 + +optimizer: + opt_name: adam + opt_kwargs: {} + emb_lr: 0.003 + nn_lr: 0.002 + scale_lr: 0.0005 + shift_lr: 0.025 + zbl_lr: 0.001 + transition_begin: 0 + +callbacks: +- name: csv + +checkpoints: + ckpt_interval: 1 + # The options below are used for transfer learning + # base_model_checkpoint: null + # reset_layers: [] + +progress_bar: + disable_epoch_pbar: false + disable_nl_pbar: false + +maximize_l2_cache: false diff --git a/tests/regression_tests/test_apax_training.py b/tests/regression_tests/test_apax_training.py new file mode 100644 index 00000000..e1bde72c --- /dev/null +++ b/tests/regression_tests/test_apax_training.py @@ -0,0 +1,50 @@ +import pytest +import pathlib +import yaml +from apax.train.run import run +import urllib +import zipfile +import os + +TEST_PATH = pathlib.Path(__file__).parent.resolve() +MD22_STACHYOSE_URL = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip" + +@pytest.mark.slow +def test_regression_model_training(get_tmp_path): + temp_path = get_tmp_path + + #load dataset and safe it in temp_dir + data_path = temp_path / "data" + os.makedirs(data_path, exist_ok=True) + urllib.request.urlretrieve(MD22_STACHYOSE_URL, data_path / "md22_stachyose.zip") + + with zipfile.ZipFile(data_path / "md22_stachyose.zip", "r") as zip_ref: + zip_ref.extractall(data_path) + + input_file_path = (data_path / "md22_stachyose.xyz").as_posix() + output_file_path = (data_path / "md22_stachyose_mod.xyz").as_posix() + target_string = 'Energy' + replacement_string = 'energy' + + replace_string_in_file(input_file_path, output_file_path, target_string, replacement_string) + + #read and adjust config + confg_path = TEST_PATH / "apax_config.yaml" + with open(confg_path.as_posix(), "r") as stream: + config_dict = yaml.safe_load(stream) + + config_dict["data"]["directory"] = temp_path.as_posix() + config_dict["data"]["data_path"] = output_file_path + config_dict["data"]["energy_unit"] = "kcal/mol" + + run(config_dict) + + + assert False + +def replace_string_in_file(input_file_path, output_file_path, target_string, replacement_string): + with open(input_file_path, 'r') as input_file, open(output_file_path, 'w') as output_file: + for line in input_file: + # Replace all occurrences of the target string with the replacement string + modified_line = line.replace(target_string, replacement_string) + output_file.write(modified_line) \ No newline at end of file