From b2ff7c87f094776a87f8f03c262da1b38f114957 Mon Sep 17 00:00:00 2001 From: Tobias Preis Date: Wed, 29 Oct 2025 14:25:48 +0100 Subject: [PATCH 1/4] set units using config --- astromodels/core/units.py | 14 +++++--------- astromodels/tests/test_units.py | 17 +++++++++++++++++ astromodels/utils/config_structure.py | 9 +++++++++ 3 files changed, 31 insertions(+), 9 deletions(-) create mode 100644 astromodels/tests/test_units.py diff --git a/astromodels/core/units.py b/astromodels/core/units.py index 4bd5d99b..167dc4e5 100644 --- a/astromodels/core/units.py +++ b/astromodels/core/units.py @@ -6,17 +6,13 @@ import astropy.units as u +from astromodels.utils.configuration import astromodels_config from astromodels.utils.pretty_list import dict_to_list # This module keeps the configuration of the units used in astromodels # Pre-defined values -_ENERGY = u.keV -_TIME = u.s -_ANGLE = u.deg -_AREA = u.cm**2 - class UnknownUnit(Exception): pass @@ -63,13 +59,13 @@ def __init__( ): if energy_unit is None: - energy_unit = _ENERGY + energy_unit = u.Unit(astromodels_config.units.energy) if time_unit is None: - time_unit = _TIME + time_unit = u.Unit(astromodels_config.units.time) if angle_unit is None: - angle_unit = _ANGLE + angle_unit = u.Unit(astromodels_config.units.angle) if area_unit is None: - area_unit = _AREA + area_unit = u.Unit(astromodels_config.units.area) self._units = collections.OrderedDict() diff --git a/astromodels/tests/test_units.py b/astromodels/tests/test_units.py new file mode 100644 index 00000000..92904f0c --- /dev/null +++ b/astromodels/tests/test_units.py @@ -0,0 +1,17 @@ +import astropy.units as u +from omegaconf import OmegaConf + +from astromodels.core.units import get_units +from astromodels.utils.configuration import astromodels_config +from astromodels.utils.file_utils import get_path_of_user_config + + +def test_config_unit_same(): + for user_config_file in get_path_of_user_config().glob("*.yml"): + user_conf = OmegaConf.load(user_config_file) + if "units" in user_conf.keys(): + for k in user_conf["units"].keys(): + assert u.Unit(user_conf["units"][k]) == getattr(get_units(), k) + assert astromodels_config["units"][k] == user_conf["units"][k] + for k, v in astromodels_config["units"].items(): + assert u.Unit(v) == getattr(get_units(), k) diff --git a/astromodels/utils/config_structure.py b/astromodels/utils/config_structure.py index 09a4a895..4bfb1c20 100644 --- a/astromodels/utils/config_structure.py +++ b/astromodels/utils/config_structure.py @@ -27,6 +27,14 @@ class Logging: message_style: str = "bold grey78" +@dataclass +class Units: + energy: str = "keV" + time: str = "s" + angle: str = "deg" + area: str = "cm2" + + class AbsTables(Enum): WILM = "WILM" ASPL = "ASPL" @@ -60,3 +68,4 @@ class Config: logging: Logging = field(default_factory=Logging) absorption_models: AbsorptionModels = field(default_factory=AbsorptionModels) modeling: Modeling = field(default_factory=Modeling) + units: Units = field(default_factory=Units) From f89d5a6c988d558dd9f38fa19f7b9a57cbe6a4f1 Mon Sep 17 00:00:00 2001 From: Tobias Preis Date: Mon, 10 Nov 2025 07:39:42 +0100 Subject: [PATCH 2/4] unit setter --- astromodels/core/units.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/astromodels/core/units.py b/astromodels/core/units.py index 167dc4e5..711f6b81 100644 --- a/astromodels/core/units.py +++ b/astromodels/core/units.py @@ -11,8 +11,6 @@ # This module keeps the configuration of the units used in astromodels -# Pre-defined values - class UnknownUnit(Exception): pass @@ -185,5 +183,12 @@ def __call__(self, *args, **kwds): # Create the factory to be used in the program +def set_units(key: str, value: u.Unit): + """ + Update the units used + """ + if getattr(get_units(), key) != value: + setattr(get_units(), key, value) + get_units = _AstromodelsUnitsFactory() From d934b6b3714f3ddc52f6dda218071b6ff90ce3d3 Mon Sep 17 00:00:00 2001 From: Tobias Preis Date: Mon, 10 Nov 2025 07:39:46 +0100 Subject: [PATCH 3/4] test units --- astromodels/tests/test_units.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/astromodels/tests/test_units.py b/astromodels/tests/test_units.py index 92904f0c..9daedb91 100644 --- a/astromodels/tests/test_units.py +++ b/astromodels/tests/test_units.py @@ -1,9 +1,12 @@ +from copy import deepcopy import astropy.units as u from omegaconf import OmegaConf -from astromodels.core.units import get_units +from astromodels.core.units import get_units, set_units from astromodels.utils.configuration import astromodels_config from astromodels.utils.file_utils import get_path_of_user_config +from astromodels.functions.functions_1D.powerlaws import Powerlaw +from astromodels.sources.point_source import PointSource def test_config_unit_same(): @@ -15,3 +18,22 @@ def test_config_unit_same(): assert astromodels_config["units"][k] == user_conf["units"][k] for k, v in astromodels_config["units"].items(): assert u.Unit(v) == getattr(get_units(), k) + + +def test_set_units(): + mapping = {"energy": u.TeV, "area": u.m**2, "time": u.h, "angle": u.rad} + ref = deepcopy(get_units()) + not_changed = [] + for k, v in mapping.items(): + if getattr(ref, k) != v: + set_units(k, v) + else: + not_changed.append(k) + for k, v in mapping.items(): + if k not in not_changed: + assert getattr(get_units(), k) == v + assert getattr(ref, k) != getattr(get_units(), k) + + # reset it to the original + for k, v in mapping.items(): + set_units(k, getattr(ref, k)) From e024393d79ec0762300f8997d8dec12f6f19381b Mon Sep 17 00:00:00 2001 From: Tobias Preis Date: Fri, 21 Nov 2025 15:06:33 +0100 Subject: [PATCH 4/4] fix flake8 issues --- astromodels/tests/test_units.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/astromodels/tests/test_units.py b/astromodels/tests/test_units.py index 9daedb91..9cb10ef2 100644 --- a/astromodels/tests/test_units.py +++ b/astromodels/tests/test_units.py @@ -5,8 +5,6 @@ from astromodels.core.units import get_units, set_units from astromodels.utils.configuration import astromodels_config from astromodels.utils.file_utils import get_path_of_user_config -from astromodels.functions.functions_1D.powerlaws import Powerlaw -from astromodels.sources.point_source import PointSource def test_config_unit_same():