Skip to content

Commit

Permalink
Merge pull request #445 from MannLabs/fix_bug_in_multistep_config
Browse files Browse the repository at this point in the history
fix bug in multistep config output dir
  • Loading branch information
GeorgWa authored Jan 24, 2025
2 parents 5f89002 + 1bbbaad commit a80b6c8
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 21 deletions.
35 changes: 19 additions & 16 deletions alphadia/search_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
self._config = self._init_config(
config, cli_config, extra_config, output_folder
)
self._config.to_yaml(os.path.join(output_folder, "frozen_config.yaml"))

logger.setLevel(logging.getLevelName(self._config["general"]["log_level"]))

self.raw_path_list = self._config[ConfigKeys.RAW_PATHS]
Expand All @@ -86,27 +88,18 @@ def _init_config(
) -> Config:
"""Initialize the config with default values and update with user defined values."""

default_config_path = os.path.join(
os.path.dirname(__file__), "constants", "default.yaml"
)
logger.info(f"loading config from {default_config_path}")
config = Config()
config.from_yaml(default_config_path)
config = SearchStep._load_default_config()

config_updates = []

if user_config:
logger.info("loading additional config provided via CLI")
# load update config from dict
if isinstance(user_config, dict):
user_config_update = Config(user_config, name=USER_DEFINED)
config_updates.append(user_config_update)
elif isinstance(user_config, Config):
if isinstance(user_config, Config):
config_updates.append(user_config)
else:
raise ValueError(
"'config' parameter must be of type 'dict' or 'Config'"
)
user_config_update = Config(user_config, name=USER_DEFINED)
config_updates.append(user_config_update)

if cli_config:
logger.info("loading additional config provided via CLI parameters")
Expand All @@ -117,16 +110,26 @@ def _init_config(
if extra_config:
extra_config_update = Config(extra_config, name=MULTISTEP_SEARCH)
# need to overwrite user-defined output folder here to have correct value in config dump
extra_config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder
extra_config_update[ConfigKeys.OUTPUT_DIRECTORY] = output_folder
config_updates.append(extra_config_update)

config.update(config_updates, do_print=True)
if config_updates:
config.update(config_updates, do_print=True)

if config.get(ConfigKeys.OUTPUT_DIRECTORY, None) is None:
config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder

config.to_yaml(os.path.join(output_folder, "frozen_config.yaml"))
return config

@staticmethod
def _load_default_config():
"""Load default config from file."""
default_config_path = os.path.join(
os.path.dirname(__file__), "constants", "default.yaml"
)
logger.info(f"loading config from {default_config_path}")
config = Config()
config.from_yaml(default_config_path)
return config

@property
Expand Down
3 changes: 3 additions & 0 deletions alphadia/workflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __setitem__(self, key, item):
def __delitem__(self, key):
raise NotImplementedError("Use update() to update the config.")

def copy(self):
raise NotImplementedError("Use deepcopy() to copy the config.")

def update(self, configs: list["Config"], do_print: bool = False):
"""
Updates the config with one or more other config objects.
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ version = {attr = "alphadia.__version__"}
[project.scripts]
alphadia = "alphadia.cli:run"

[tool.ruff]
extend-exclude = ["tests"]


[tool.ruff.lint]
select = [
Expand Down
127 changes: 125 additions & 2 deletions tests/unit_tests/test_search_step.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import tempfile
from unittest import skip
from copy import deepcopy
from unittest.mock import MagicMock, patch

import pytest
from alphabase.constants import _const
from alphabase.constants.modification import MOD_DF

from alphadia import search_step
from alphadia.search_step import SearchStep
from alphadia.test_data_downloader import DataShareDownloader
from alphadia.workflow.config import Config


@pytest.mark.slow()
Expand Down Expand Up @@ -79,7 +82,6 @@ def test_library_loading():
assert len(step.spectral_library.fragment_df) > 0


@skip # TODO activate again (this test works after making custom_modifications a list in the next PR)
def test_custom_modifications():
temp_directory = tempfile.gettempdir()

Expand All @@ -94,3 +96,124 @@ def test_custom_modifications():

step = search_step.SearchStep(temp_directory, config=config) # noqa F841
assert "ThisModDoesNotExists@K" in MOD_DF["mod_name"].values


@patch("alphadia.search_step.SearchStep._load_default_config")
def test_initializes_with_default_config(mock_load_default_config):
"""Test that the config is initialized with default values."""
config = Config(
{"key1": "value1", "key2": "value2"}, "default"
) # not using a mock here as working with the real object is much simpler
mock_load_default_config.return_value = deepcopy(
config
) # copy required here as we want to compare changes to a mutable object below

# when
result = SearchStep._init_config(None, None, None, "/output")

mock_load_default_config.assert_called_once()
assert result == config | {"output_directory": "/output"}


@patch("alphadia.search_step.SearchStep._load_default_config")
def test_updates_with_user_config_object(mock_load_default_config):
"""Test that the config is updated with user config object."""
config = Config({"key1": "value1", "key2": "value2"})
mock_load_default_config.return_value = deepcopy(config)

user_config = Config({"key2": "value2b"})
# when
result = SearchStep._init_config(user_config, None, None, "/output")

assert result == {
"key1": "value1",
"key2": "value2b",
"output_directory": "/output",
}


@patch("alphadia.search_step.SearchStep._load_default_config")
def test_updates_with_user_and_cli_and_extra_config_dicts(
mock_load_default_config,
):
"""Test that the config is updated with user, cli and extra config dicts."""
config = Config(
{
"key1": "value1",
"key2": "value2",
"key3": "value3",
"key4": "value4",
"output_directory": None,
}
)
mock_load_default_config.return_value = deepcopy(config)

user_config = {"key2": "value2b"}
cli_config = {"key3": "value3b"}
extra_config = {"key4": "value4b"}
# when
result = SearchStep._init_config(user_config, cli_config, extra_config, "/output")

mock_load_default_config.assert_called_once()

assert result == {
"key1": "value1",
"key2": "value2b",
"key3": "value3b",
"key4": "value4b",
"output_directory": "/output",
}


@patch("alphadia.search_step.SearchStep._load_default_config")
def test_updates_with_cli_config_no_overwrite_output_path(
mock_load_default_config,
):
"""Test that the output directory is not overwritten if provided by config."""
config = Config({"key1": "value1", "output_directory": None})
mock_load_default_config.return_value = deepcopy(config)

user_config = {"key1": "value1b", "output_directory": "/output"}
# when
result = SearchStep._init_config(user_config, None, None, "/another_output")

mock_load_default_config.assert_called_once()

assert result == {"key1": "value1b", "output_directory": "/output"}


@patch("alphadia.search_step.SearchStep._load_default_config")
def test_updates_with_extra_config_overwrite_output_path(
mock_load_default_config,
):
"""Test that the output directory is overwritten by extra_config."""
config = Config({"key1": "value1", "output_directory": "/default_output"})
mock_load_default_config.return_value = deepcopy(config)

extra_config = {"key1": "value1b"}
# when
result = SearchStep._init_config(None, None, extra_config, "/extra_output")

mock_load_default_config.assert_called_once()

assert result == {"key1": "value1b", "output_directory": "/extra_output"}


@pytest.mark.parametrize(
("config1", "config2", "config3"),
[
("not_dict_nor_config_object", None, None),
(None, "not_dict_nor_config_object", None),
(None, None, "not_dict_nor_config_object"),
],
)
@patch("alphadia.search_step.SearchStep._load_default_config")
def test_raises_value_error_for_invalid_config(
mock_load_default_config, config1, config2, config3
):
"""Test that a TypeError is raised if the config is not a dict or Config object."""
mock_load_default_config.return_value = MagicMock(spec=Config)

with pytest.raises(TypeError, match="'str' object is not a mapping"):
# when
SearchStep._init_config(config1, config2, config3, "/output")

0 comments on commit a80b6c8

Please sign in to comment.