Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Oct 5, 2023
1 parent ade74a5 commit ee0b36e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 25 deletions.
17 changes: 9 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List
import os
import urllib
import zipfile
from typing import List

import numpy as np
import pytest
from ase import Atoms
Expand Down Expand Up @@ -89,14 +90,14 @@ def get_md22_stachyose(get_tmp_path):

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)

return file_path


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

Expand All @@ -105,4 +106,4 @@ def modify_xyz_file(file_path, target_string, replacement_string):
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path
return new_file_path
24 changes: 8 additions & 16 deletions tests/regression_tests/test_apax_training.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import os
import pathlib
import urllib
import zipfile
import uuid

import numpy as np
import pandas as pd
import pytest
import yaml

from apax.train.run import run
from apax.data.statistics import scale_method_list, shift_method_list

TEST_PATH = pathlib.Path(__file__).parent.resolve()

def load_config_and_run_training(
config_path, **config_kwargs
):

def load_config_and_run_training(config_path, **config_kwargs):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

Expand All @@ -25,19 +19,19 @@ def load_config_and_run_training(
config_dict[pydentic_model_key][h_param_key] = value

run(config_dict)


def load_csv(filename):
data = np.loadtxt(filename, delimiter=',', skiprows=1) # Skip the header row
data = np.loadtxt(filename, delimiter=",", skiprows=1) # Skip the header row

with open(filename, 'r') as file:
header = file.readline().strip().split(',')
with open(filename, "r") as file:
header = file.readline().strip().split(",")

data_dict = {header[i]: data[:, i].tolist() for i in range(len(header))}

return data_dict


@pytest.mark.slow
def test_regression_model_training(get_md22_stachyose, get_tmp_path):
config_path = TEST_PATH / "apax_config.yaml"
Expand All @@ -60,10 +54,8 @@ def test_regression_model_training(get_md22_stachyose, get_tmp_path):
"val_forces_mse": 0.017160819058234304,
"val_loss": 0.45499257304743396,
}
for key in comparison_metrics.keys():
print(np.array(current_metrics[key])[-1])

for key in comparison_metrics.keys():
assert (
abs((np.array(current_metrics[key])[-1] / comparison_metrics[key]) - 1) < 1e-3
)
)
2 changes: 1 addition & 1 deletion tests/unit_tests/data/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator

from apax.data.statistics import PerElementRegressionShift
from apax.data.input_pipeline import create_dict_dataset
from apax.data.statistics import PerElementRegressionShift


def test_energy_per_element():
Expand Down

0 comments on commit ee0b36e

Please sign in to comment.