Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions astromodels/core/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@

import astropy.units as u

from astromodels.utils.configuration import astromodels_config
from astromodels.utils.pretty_list import dict_to_list

# This module keeps the configuration of the units used in astromodels

# Pre-defined values

_ENERGY = u.keV
_TIME = u.s
_ANGLE = u.deg
_AREA = u.cm**2


class UnknownUnit(Exception):
pass
Expand Down Expand Up @@ -63,13 +57,13 @@ def __init__(
):

if energy_unit is None:
energy_unit = _ENERGY
energy_unit = u.Unit(astromodels_config.units.energy)
if time_unit is None:
time_unit = _TIME
time_unit = u.Unit(astromodels_config.units.time)
if angle_unit is None:
angle_unit = _ANGLE
angle_unit = u.Unit(astromodels_config.units.angle)
if area_unit is None:
area_unit = _AREA
area_unit = u.Unit(astromodels_config.units.area)

self._units = collections.OrderedDict()

Expand Down Expand Up @@ -189,5 +183,12 @@ def __call__(self, *args, **kwds):


# Create the factory to be used in the program
def set_units(key: str, value: u.Unit):
"""
Update the units used
"""
if getattr(get_units(), key) != value:
setattr(get_units(), key, value)


get_units = _AstromodelsUnitsFactory()
37 changes: 37 additions & 0 deletions astromodels/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from copy import deepcopy
import astropy.units as u
from omegaconf import OmegaConf

from astromodels.core.units import get_units, set_units
from astromodels.utils.configuration import astromodels_config
from astromodels.utils.file_utils import get_path_of_user_config


def test_config_unit_same():
for user_config_file in get_path_of_user_config().glob("*.yml"):
user_conf = OmegaConf.load(user_config_file)
if "units" in user_conf.keys():
for k in user_conf["units"].keys():
assert u.Unit(user_conf["units"][k]) == getattr(get_units(), k)
assert astromodels_config["units"][k] == user_conf["units"][k]
for k, v in astromodels_config["units"].items():
assert u.Unit(v) == getattr(get_units(), k)


def test_set_units():
mapping = {"energy": u.TeV, "area": u.m**2, "time": u.h, "angle": u.rad}
ref = deepcopy(get_units())
not_changed = []
for k, v in mapping.items():
if getattr(ref, k) != v:
set_units(k, v)
else:
not_changed.append(k)
for k, v in mapping.items():
if k not in not_changed:
assert getattr(get_units(), k) == v
assert getattr(ref, k) != getattr(get_units(), k)

# reset it to the original
for k, v in mapping.items():
set_units(k, getattr(ref, k))
9 changes: 9 additions & 0 deletions astromodels/utils/config_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Logging:
message_style: str = "bold grey78"


@dataclass
class Units:
energy: str = "keV"
time: str = "s"
angle: str = "deg"
area: str = "cm2"


class AbsTables(Enum):
WILM = "WILM"
ASPL = "ASPL"
Expand Down Expand Up @@ -60,3 +68,4 @@ class Config:
logging: Logging = field(default_factory=Logging)
absorption_models: AbsorptionModels = field(default_factory=AbsorptionModels)
modeling: Modeling = field(default_factory=Modeling)
units: Units = field(default_factory=Units)
Loading