From ee0b36eb3cfeed2e7feb63eacb63433b09608411 Mon Sep 17 00:00:00 2001 From: Tetracarbonylnickel Date: Thu, 5 Oct 2023 21:06:02 +0200 Subject: [PATCH] linting --- tests/conftest.py | 17 +++++++------- tests/regression_tests/test_apax_training.py | 24 +++++++------------- tests/unit_tests/data/test_statistics.py | 2 +- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4356acd7..0d5a08bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ -from typing import List import os import urllib import zipfile +from typing import List + import numpy as np import pytest from ase import Atoms @@ -89,14 +90,14 @@ def get_md22_stachyose(get_tmp_path): with zipfile.ZipFile(file_path, "r") as zip_ref: zip_ref.extractall(data_path) - + file_path = modify_xyz_file( - file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy" - ) - + file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy" + ) + return file_path - - + + def modify_xyz_file(file_path, target_string, replacement_string): new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix) @@ -105,4 +106,4 @@ def modify_xyz_file(file_path, target_string, replacement_string): # Replace all occurrences of the target string with the replacement string modified_line = line.replace(target_string, replacement_string) output_file.write(modified_line) - return new_file_path \ No newline at end of file + return new_file_path diff --git a/tests/regression_tests/test_apax_training.py b/tests/regression_tests/test_apax_training.py index 7091b600..339cb28a 100644 --- a/tests/regression_tests/test_apax_training.py +++ b/tests/regression_tests/test_apax_training.py @@ -1,22 +1,16 @@ -import os import pathlib -import urllib -import zipfile import uuid import numpy as np -import pandas as pd import pytest import yaml from apax.train.run import run -from apax.data.statistics import scale_method_list, shift_method_list TEST_PATH = pathlib.Path(__file__).parent.resolve() -def load_config_and_run_training( - config_path, **config_kwargs -): + +def load_config_and_run_training(config_path, **config_kwargs): with open(config_path.as_posix(), "r") as stream: config_dict = yaml.safe_load(stream) @@ -25,19 +19,19 @@ def load_config_and_run_training( config_dict[pydentic_model_key][h_param_key] = value run(config_dict) - + def load_csv(filename): - data = np.loadtxt(filename, delimiter=',', skiprows=1) # Skip the header row + data = np.loadtxt(filename, delimiter=",", skiprows=1) # Skip the header row - with open(filename, 'r') as file: - header = file.readline().strip().split(',') + with open(filename, "r") as file: + header = file.readline().strip().split(",") data_dict = {header[i]: data[:, i].tolist() for i in range(len(header))} return data_dict - + @pytest.mark.slow def test_regression_model_training(get_md22_stachyose, get_tmp_path): config_path = TEST_PATH / "apax_config.yaml" @@ -60,10 +54,8 @@ def test_regression_model_training(get_md22_stachyose, get_tmp_path): "val_forces_mse": 0.017160819058234304, "val_loss": 0.45499257304743396, } - for key in comparison_metrics.keys(): - print(np.array(current_metrics[key])[-1]) for key in comparison_metrics.keys(): assert ( abs((np.array(current_metrics[key])[-1] / comparison_metrics[key]) - 1) < 1e-3 - ) \ No newline at end of file + ) diff --git a/tests/unit_tests/data/test_statistics.py b/tests/unit_tests/data/test_statistics.py index 6b3d4243..e2e698ee 100644 --- a/tests/unit_tests/data/test_statistics.py +++ b/tests/unit_tests/data/test_statistics.py @@ -2,8 +2,8 @@ from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator -from apax.data.statistics import PerElementRegressionShift from apax.data.input_pipeline import create_dict_dataset +from apax.data.statistics import PerElementRegressionShift def test_energy_per_element():