diff --git a/astromodels/core/units.py b/astromodels/core/units.py index 4bd5d99b..711f6b81 100644 --- a/astromodels/core/units.py +++ b/astromodels/core/units.py @@ -6,17 +6,11 @@ import astropy.units as u +from astromodels.utils.configuration import astromodels_config from astromodels.utils.pretty_list import dict_to_list # This module keeps the configuration of the units used in astromodels -# Pre-defined values - -_ENERGY = u.keV -_TIME = u.s -_ANGLE = u.deg -_AREA = u.cm**2 - class UnknownUnit(Exception): pass @@ -63,13 +57,13 @@ def __init__( ): if energy_unit is None: - energy_unit = _ENERGY + energy_unit = u.Unit(astromodels_config.units.energy) if time_unit is None: - time_unit = _TIME + time_unit = u.Unit(astromodels_config.units.time) if angle_unit is None: - angle_unit = _ANGLE + angle_unit = u.Unit(astromodels_config.units.angle) if area_unit is None: - area_unit = _AREA + area_unit = u.Unit(astromodels_config.units.area) self._units = collections.OrderedDict() @@ -189,5 +183,12 @@ def __call__(self, *args, **kwds): # Create the factory to be used in the program +def set_units(key: str, value: u.Unit): + """ + Update the units used + """ + if getattr(get_units(), key) != value: + setattr(get_units(), key, value) + get_units = _AstromodelsUnitsFactory() diff --git a/astromodels/tests/test_units.py b/astromodels/tests/test_units.py new file mode 100644 index 00000000..9cb10ef2 --- /dev/null +++ b/astromodels/tests/test_units.py @@ -0,0 +1,37 @@ +from copy import deepcopy +import astropy.units as u +from omegaconf import OmegaConf + +from astromodels.core.units import get_units, set_units +from astromodels.utils.configuration import astromodels_config +from astromodels.utils.file_utils import get_path_of_user_config + + +def test_config_unit_same(): + for user_config_file in get_path_of_user_config().glob("*.yml"): + user_conf = OmegaConf.load(user_config_file) + if "units" in user_conf.keys(): + for k in user_conf["units"].keys(): + assert u.Unit(user_conf["units"][k]) == getattr(get_units(), k) + assert astromodels_config["units"][k] == user_conf["units"][k] + for k, v in astromodels_config["units"].items(): + assert u.Unit(v) == getattr(get_units(), k) + + +def test_set_units(): + mapping = {"energy": u.TeV, "area": u.m**2, "time": u.h, "angle": u.rad} + ref = deepcopy(get_units()) + not_changed = [] + for k, v in mapping.items(): + if getattr(ref, k) != v: + set_units(k, v) + else: + not_changed.append(k) + for k, v in mapping.items(): + if k not in not_changed: + assert getattr(get_units(), k) == v + assert getattr(ref, k) != getattr(get_units(), k) + + # reset it to the original + for k, v in mapping.items(): + set_units(k, getattr(ref, k)) diff --git a/astromodels/utils/config_structure.py b/astromodels/utils/config_structure.py index 09a4a895..4bfb1c20 100644 --- a/astromodels/utils/config_structure.py +++ b/astromodels/utils/config_structure.py @@ -27,6 +27,14 @@ class Logging: message_style: str = "bold grey78" +@dataclass +class Units: + energy: str = "keV" + time: str = "s" + angle: str = "deg" + area: str = "cm2" + + class AbsTables(Enum): WILM = "WILM" ASPL = "ASPL" @@ -60,3 +68,4 @@ class Config: logging: Logging = field(default_factory=Logging) absorption_models: AbsorptionModels = field(default_factory=AbsorptionModels) modeling: Modeling = field(default_factory=Modeling) + units: Units = field(default_factory=Units)