diff --git a/alphadia/search_step.py b/alphadia/search_step.py index 4cbcfeb63..329f67821 100644 --- a/alphadia/search_step.py +++ b/alphadia/search_step.py @@ -62,6 +62,8 @@ def __init__( self._config = self._init_config( config, cli_config, extra_config, output_folder ) + self._config.to_yaml(os.path.join(output_folder, "frozen_config.yaml")) + logger.setLevel(logging.getLevelName(self._config["general"]["log_level"])) self.raw_path_list = self._config[ConfigKeys.RAW_PATHS] @@ -86,27 +88,18 @@ def _init_config( ) -> Config: """Initialize the config with default values and update with user defined values.""" - default_config_path = os.path.join( - os.path.dirname(__file__), "constants", "default.yaml" - ) - logger.info(f"loading config from {default_config_path}") - config = Config() - config.from_yaml(default_config_path) + config = SearchStep._load_default_config() config_updates = [] if user_config: logger.info("loading additional config provided via CLI") # load update config from dict - if isinstance(user_config, dict): - user_config_update = Config(user_config, name=USER_DEFINED) - config_updates.append(user_config_update) - elif isinstance(user_config, Config): + if isinstance(user_config, Config): config_updates.append(user_config) else: - raise ValueError( - "'config' parameter must be of type 'dict' or 'Config'" - ) + user_config_update = Config(user_config, name=USER_DEFINED) + config_updates.append(user_config_update) if cli_config: logger.info("loading additional config provided via CLI parameters") @@ -117,16 +110,26 @@ def _init_config( if extra_config: extra_config_update = Config(extra_config, name=MULTISTEP_SEARCH) # need to overwrite user-defined output folder here to have correct value in config dump - extra_config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder + extra_config_update[ConfigKeys.OUTPUT_DIRECTORY] = output_folder config_updates.append(extra_config_update) - config.update(config_updates, do_print=True) + if config_updates: + config.update(config_updates, do_print=True) if config.get(ConfigKeys.OUTPUT_DIRECTORY, None) is None: config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder - config.to_yaml(os.path.join(output_folder, "frozen_config.yaml")) + return config + @staticmethod + def _load_default_config(): + """Load default config from file.""" + default_config_path = os.path.join( + os.path.dirname(__file__), "constants", "default.yaml" + ) + logger.info(f"loading config from {default_config_path}") + config = Config() + config.from_yaml(default_config_path) return config @property diff --git a/alphadia/workflow/config.py b/alphadia/workflow/config.py index c1ce539e0..d345794b7 100644 --- a/alphadia/workflow/config.py +++ b/alphadia/workflow/config.py @@ -59,6 +59,9 @@ def __setitem__(self, key, item): def __delitem__(self, key): raise NotImplementedError("Use update() to update the config.") + def copy(self): + raise NotImplementedError("Use deepcopy() to copy the config.") + def update(self, configs: list["Config"], do_print: bool = False): """ Updates the config with one or more other config objects. diff --git a/pyproject.toml b/pyproject.toml index 4ae4465b5..40c14e569 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,9 +63,6 @@ version = {attr = "alphadia.__version__"} [project.scripts] alphadia = "alphadia.cli:run" -[tool.ruff] -extend-exclude = ["tests"] - [tool.ruff.lint] select = [ diff --git a/tests/unit_tests/test_search_step.py b/tests/unit_tests/test_search_step.py index 61ee0b475..5fe4d8a74 100644 --- a/tests/unit_tests/test_search_step.py +++ b/tests/unit_tests/test_search_step.py @@ -1,13 +1,16 @@ import os import tempfile -from unittest import skip +from copy import deepcopy +from unittest.mock import MagicMock, patch import pytest from alphabase.constants import _const from alphabase.constants.modification import MOD_DF from alphadia import search_step +from alphadia.search_step import SearchStep from alphadia.test_data_downloader import DataShareDownloader +from alphadia.workflow.config import Config @pytest.mark.slow() @@ -79,7 +82,6 @@ def test_library_loading(): assert len(step.spectral_library.fragment_df) > 0 -@skip # TODO activate again (this test works after making custom_modifications a list in the next PR) def test_custom_modifications(): temp_directory = tempfile.gettempdir() @@ -94,3 +96,124 @@ def test_custom_modifications(): step = search_step.SearchStep(temp_directory, config=config) # noqa F841 assert "ThisModDoesNotExists@K" in MOD_DF["mod_name"].values + + +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_initializes_with_default_config(mock_load_default_config): + """Test that the config is initialized with default values.""" + config = Config( + {"key1": "value1", "key2": "value2"}, "default" + ) # not using a mock here as working with the real object is much simpler + mock_load_default_config.return_value = deepcopy( + config + ) # copy required here as we want to compare changes to a mutable object below + + # when + result = SearchStep._init_config(None, None, None, "/output") + + mock_load_default_config.assert_called_once() + assert result == config | {"output_directory": "/output"} + + +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_updates_with_user_config_object(mock_load_default_config): + """Test that the config is updated with user config object.""" + config = Config({"key1": "value1", "key2": "value2"}) + mock_load_default_config.return_value = deepcopy(config) + + user_config = Config({"key2": "value2b"}) + # when + result = SearchStep._init_config(user_config, None, None, "/output") + + assert result == { + "key1": "value1", + "key2": "value2b", + "output_directory": "/output", + } + + +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_updates_with_user_and_cli_and_extra_config_dicts( + mock_load_default_config, +): + """Test that the config is updated with user, cli and extra config dicts.""" + config = Config( + { + "key1": "value1", + "key2": "value2", + "key3": "value3", + "key4": "value4", + "output_directory": None, + } + ) + mock_load_default_config.return_value = deepcopy(config) + + user_config = {"key2": "value2b"} + cli_config = {"key3": "value3b"} + extra_config = {"key4": "value4b"} + # when + result = SearchStep._init_config(user_config, cli_config, extra_config, "/output") + + mock_load_default_config.assert_called_once() + + assert result == { + "key1": "value1", + "key2": "value2b", + "key3": "value3b", + "key4": "value4b", + "output_directory": "/output", + } + + +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_updates_with_cli_config_no_overwrite_output_path( + mock_load_default_config, +): + """Test that the output directory is not overwritten if provided by config.""" + config = Config({"key1": "value1", "output_directory": None}) + mock_load_default_config.return_value = deepcopy(config) + + user_config = {"key1": "value1b", "output_directory": "/output"} + # when + result = SearchStep._init_config(user_config, None, None, "/another_output") + + mock_load_default_config.assert_called_once() + + assert result == {"key1": "value1b", "output_directory": "/output"} + + +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_updates_with_extra_config_overwrite_output_path( + mock_load_default_config, +): + """Test that the output directory is overwritten by extra_config.""" + config = Config({"key1": "value1", "output_directory": "/default_output"}) + mock_load_default_config.return_value = deepcopy(config) + + extra_config = {"key1": "value1b"} + # when + result = SearchStep._init_config(None, None, extra_config, "/extra_output") + + mock_load_default_config.assert_called_once() + + assert result == {"key1": "value1b", "output_directory": "/extra_output"} + + +@pytest.mark.parametrize( + ("config1", "config2", "config3"), + [ + ("not_dict_nor_config_object", None, None), + (None, "not_dict_nor_config_object", None), + (None, None, "not_dict_nor_config_object"), + ], +) +@patch("alphadia.search_step.SearchStep._load_default_config") +def test_raises_value_error_for_invalid_config( + mock_load_default_config, config1, config2, config3 +): + """Test that a TypeError is raised if the config is not a dict or Config object.""" + mock_load_default_config.return_value = MagicMock(spec=Config) + + with pytest.raises(TypeError, match="'str' object is not a mapping"): + # when + SearchStep._init_config(config1, config2, config3, "/output")