diff --git a/src/pyuvdata/uvbeam/analytic_beam.py b/src/pyuvdata/uvbeam/analytic_beam.py index 4a426eb58..e41bbb874 100644 --- a/src/pyuvdata/uvbeam/analytic_beam.py +++ b/src/pyuvdata/uvbeam/analytic_beam.py @@ -5,12 +5,15 @@ from __future__ import annotations +import dataclasses +import importlib from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Literal import numpy as np import numpy.typing as npt +import yaml from astropy.constants import c as speed_of_light from scipy.special import j1 @@ -18,7 +21,8 @@ from ..docstrings import combine_docstrings from .uvbeam import UVBeam, _convert_feeds_to_pols -__all__ = ["AnalyticBeam", "AiryBeam", "GaussianBeam", "ShortDipoleBeam", "UniformBeam"] +analytic_beam_classes = ["AiryBeam", "GaussianBeam", "ShortDipoleBeam", "UniformBeam"] +__all__ = ["AnalyticBeam"] + analytic_beam_classes @dataclass @@ -67,7 +71,7 @@ def basis_vector_type(self): def __init__( self, *, - feed_array: npt.NDArray[np.str] | None = None, + feed_array: npt.NDArray[str] | None = None, include_cross_pols: bool = True, x_orientation: Literal["east", "north"] = "east", ): @@ -98,9 +102,9 @@ def __init__( def _check_eval_inputs( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], ): """Check the inputs for the eval methods.""" if az_array.ndim > 1 or za_array.ndim > 1 or freq_array.ndim > 1: @@ -113,7 +117,7 @@ def _check_eval_inputs( def _get_empty_data_array( self, npts: int, nfreqs: int, beam_type: str = "efield" - ) -> npt.NDArray[np.float]: + ) -> npt.NDArray[float]: """Get the empty data to fill in the eval methods.""" if beam_type == "efield": return np.zeros((self.Naxes_vec, self.Nfeeds, nfreqs, npts), dtype=complex) @@ -129,10 +133,10 @@ def _get_empty_data_array( def _efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """ Evaluate the efield at the given coordinates. @@ -161,10 +165,10 @@ def _efield_eval( def efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """ Evaluate the efield at the given coordinates. @@ -198,10 +202,10 @@ def efield_eval( def _power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """ Evaluate the power at the given coordinates. @@ -230,10 +234,10 @@ def _power_eval( def power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """ Evaluate the power at the given coordinates. @@ -274,15 +278,15 @@ def power_eval( @combine_docstrings(UVBeam.new) def to_uvbeam( self, - freq_array: npt.NDArray[np.float], + freq_array: npt.NDArray[float], beam_type: Literal["efield", "power"] = "efield", pixel_coordinate_system: ( Literal["az_za", "orthoslant_zenith", "healpix"] | None ) = None, - axis1_array: npt.NDArray[np.float] | None = None, - axis2_array: npt.NDArray[np.float] | None = None, + axis1_array: npt.NDArray[float] | None = None, + axis2_array: npt.NDArray[float] | None = None, nside: int | None = None, - healpix_pixel_array: npt.NDArray[np.int] | None = None, + healpix_pixel_array: npt.NDArray[int] | None = None, ordering: Literal["ring", "nested"] | None = None, ): """Generate a UVBeam object from an AnalyticBeam object. @@ -401,7 +405,78 @@ def to_uvbeam( return uvb -def diameter_to_sigma(diameter: float, freq_array: npt.NDArray[np.float]) -> float: +def analytic_beam_constructor(loader, node): + """ + Define a yaml constructor for analytic beams. + + The yaml must specify a "class" field with an importable class and any + required inputs to that class's constructor. + + Parameters + ---------- + loader: yaml.Loader + An instance of a yaml Loader object. + node: yaml.Node + A yaml node object. + + Returns + ------- + beam + An instance of an AnalyticBeam subclass. + + """ + values = loader.construct_mapping(node) + if "class" not in values: + raise ValueError("yaml entries for AnalyticBeam must specify a class") + class_parts = (values.pop("class")).split(".") + class_name = class_parts[-1] + if len(class_parts) == 1: + # no module specified, assume pyuvdata + module = importlib.import_module("pyuvdata") + else: + module = (".").join(class_parts[:-1]) + module = importlib.import_module(module) + beam_class = getattr(module, class_name) + + beam = beam_class(**values) + + return beam + + +yaml.add_constructor("!AnalyticBeam", analytic_beam_constructor, Loader=yaml.SafeLoader) + + +def analytic_beam_representer(dumper, beam): + """ + Define a yaml representer for analytic beams. + + Parameters + ---------- + dumper: yaml.Dumper + An instance of a yaml Loader object. + beam: AnalyticBeam subclass + An analytic beam object. + + Returns + ------- + str + The yaml representation of the analytic beam. + + """ + mapping = { + "class": beam.__module__ + "." + beam.__class__.__name__, + **dataclasses.asdict(beam), + } + + return dumper.represent_mapping("!AnalyticBeam", mapping) + + +yaml.add_multi_representer( + AnalyticBeam, analytic_beam_representer, Dumper=yaml.SafeDumper +) + + +def diameter_to_sigma(diameter: float, freq_array: npt.NDArray[float]) -> float: """ Find the sigma that gives a beam width similar to an Airy disk. @@ -480,7 +555,7 @@ def __init__( diameter: float | None = None, spectral_index: float = 0.0, reference_frequency: float = None, - feed_array: npt.NDArray[np.str] | None = None, + feed_array: npt.NDArray[str] | None = None, include_cross_pols: bool = True, ): if (diameter is None and sigma is None) or ( @@ -510,7 +585,7 @@ def __init__( super().__init__(feed_array=feed_array, include_cross_pols=include_cross_pols) - def get_sigmas(self, freq_array: npt.NDArray[np.float]) -> npt.NDArray[np.float]: + def get_sigmas(self, freq_array: npt.NDArray[float]) -> npt.NDArray[float]: """ Get the sigmas for the gaussian beam using the diameter (if defined). @@ -538,10 +613,10 @@ def get_sigmas(self, freq_array: npt.NDArray[np.float]) -> npt.NDArray[np.float] def _efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the efield at the given coordinates.""" sigmas = self.get_sigmas(freq_array) @@ -562,10 +637,10 @@ def _efield_eval( def _power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the power at the given coordinates.""" sigmas = self.get_sigmas(freq_array) @@ -608,7 +683,7 @@ def __init__( self, diameter: float, *, - feed_array: npt.NDArray[np.str] | None = None, + feed_array: npt.NDArray[str] | None = None, include_cross_pols: bool = True, ): super().__init__(feed_array=feed_array, include_cross_pols=include_cross_pols) @@ -618,10 +693,10 @@ def __init__( def _efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the efield at the given coordinates.""" data_array = self._get_empty_data_array(az_array.size, freq_array.size) @@ -647,10 +722,10 @@ def _efield_eval( def _power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the power at the given coordinates.""" data_array = self._get_empty_data_array( az_array.size, freq_array.size, beam_type="power" @@ -711,10 +786,10 @@ def __init__( def _efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the efield at the given coordinates.""" data_array = self._get_empty_data_array(az_array.size, freq_array.size) @@ -733,10 +808,10 @@ def _efield_eval( def _power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the power at the given coordinates.""" data_array = self._get_empty_data_array( az_array.size, freq_array.size, beam_type="power" @@ -778,7 +853,7 @@ class UniformBeam(AnalyticBeam): def __init__( self, *, - feed_array: npt.NDArray[np.str] | None = None, + feed_array: npt.NDArray[str] | None = None, include_cross_pols: bool = True, ): super().__init__(feed_array=feed_array, include_cross_pols=include_cross_pols) @@ -786,10 +861,10 @@ def __init__( def _efield_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the efield at the given coordinates.""" data_array = self._get_empty_data_array(az_array.size, freq_array.size) @@ -804,10 +879,10 @@ def _efield_eval( def _power_eval( self, *, - az_array: npt.NDArray[np.float], - za_array: npt.NDArray[np.float], - freq_array: npt.NDArray[np.float], - ) -> npt.NDArray[np.float]: + az_array: npt.NDArray[float], + za_array: npt.NDArray[float], + freq_array: npt.NDArray[float], + ) -> npt.NDArray[float]: """Evaluate the power at the given coordinates.""" data_array = self._get_empty_data_array( az_array.size, freq_array.size, beam_type="power" diff --git a/src/pyuvdata/uvbeam/uvbeam.py b/src/pyuvdata/uvbeam/uvbeam.py index 93f533288..8495f1dbe 100644 --- a/src/pyuvdata/uvbeam/uvbeam.py +++ b/src/pyuvdata/uvbeam/uvbeam.py @@ -4,6 +4,7 @@ """Primary container for radio telescope antenna beams.""" import copy +import importlib import os import warnings @@ -4471,3 +4472,92 @@ def write_beamfits(self, filename, **kwargs): beamfits_obj = self._convert_to_filetype("beamfits") beamfits_obj.write_beamfits(filename, **kwargs) del beamfits_obj + + +def uvbeam_constructor(loader, node): + """ + Define a yaml constructor for UVBeam objects. + + The yaml must specify a "filename" field pointing to the UVBeam readable file + and any desired arguments to the UVBeam.from_file method. + + Parameters + ---------- + loader: yaml.Loader + An instance of a yaml Loader object. + node: yaml.Node + A yaml node object. + + Returns + ------- + UVBeam + An instance of a UVBeam. + + """ + values = loader.construct_mapping(node) + if "filename" not in values: + raise ValueError("yaml entries for UVBeam must specify a filename.") + if "path_variable" in values: + path_parts = (values.pop("path_variable")).split(".") + var_name = path_parts[-1] + if len(path_parts) == 1: + # no module specified, assume pyuvdata + module = importlib.import_module("pyuvdata") + else: + module = (".").join(path_parts[:-1]) + module = importlib.import_module(module) + path_var = getattr(module, var_name) + values["filename"] = os.path.join(path_var, values["filename"]) + + beam = UVBeam.from_file(**values) + + return beam + + +yaml.add_constructor("!UVBeam", uvbeam_constructor, Loader=yaml.SafeLoader) + + +def uvbeam_beam_representer(dumper, beam): + """ + Define a yaml representer for UVbeams. + + Note: since all the possible selects cannot be extracted from the object, + the object generated from this yaml may not be an exact match for the object + in memory. Also note that the filename parameter must not be None and must + point to an existing file. It's likely that the user will need to update + the filename parameter to include the full path. + + Parameters + ---------- + dumper: yaml.Dumper + An instance of a yaml Loader object. + beam: UVBeam + A UVbeam object, which must have a filename defined on it. + + Returns + ------- + str + The yaml representation of the UVbeam. + + """ + print(beam.filename) + print(isinstance(beam.filename, str)) + if beam.filename is None: + raise ValueError( + "beam must have a filename defined to be able to represent it in a yaml." + ) + elif not isinstance(beam.filename, str): + raise ValueError( + "beam.filename must be a string to be able to represent it in a yaml." + ) + elif not os.path.exists(beam.filename): + raise ValueError( + "beam.filename must be an existing file to be able to represent it " + "in a yaml." + ) + mapping = {"filename": beam.filename} + + return dumper.represent_mapping("!UVBeam", mapping) + + +yaml.add_multi_representer(UVBeam, uvbeam_beam_representer, Dumper=yaml.SafeDumper) diff --git a/tests/uvbeam/test_analytic_beam.py b/tests/uvbeam/test_analytic_beam.py index 019e9feb5..f48068a84 100644 --- a/tests/uvbeam/test_analytic_beam.py +++ b/tests/uvbeam/test_analytic_beam.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import yaml from astropy.constants import c as speed_of_light from scipy.special import j1 @@ -385,3 +386,52 @@ def test_to_uvbeam_errors(): axis2_array=np.deg2rad(np.linspace(0, 90, 10)), pixel_coordinate_system="foo", ) + + +@pytest.mark.parametrize( + ["input_yaml", "beam"], + [ + [ + """ + beam: !AnalyticBeam + class: pyuvdata.UniformBeam + """, + UniformBeam(), + ], + [ + """ + beam: !AnalyticBeam + class: AiryBeam + diameter: 10 + """, + AiryBeam(diameter=10), + ], + [ + """ + beam: !AnalyticBeam + class: pyuvdata.uvbeam.analytic_beam.ShortDipoleBeam + """, + ShortDipoleBeam(), + ], + [ + """ + beam: !AnalyticBeam + class: GaussianBeam + reference_frequency: 120000000. + spectral_index: -1.5 + sigma: 0.26 + """, + GaussianBeam(sigma=0.26, spectral_index=-1.5, reference_frequency=120e6), + ], + ], +) +def test_yaml_constructor(input_yaml, beam): + beam_from_yaml = yaml.safe_load(input_yaml)["beam"] + + assert beam_from_yaml == beam + + output_yaml = yaml.safe_dump({"beam": beam}) + + new_beam_from_yaml = yaml.safe_load(output_yaml)["beam"] + + assert new_beam_from_yaml == beam_from_yaml diff --git a/tests/uvbeam/test_uvbeam.py b/tests/uvbeam/test_uvbeam.py index d44670111..49e6f1b2d 100644 --- a/tests/uvbeam/test_uvbeam.py +++ b/tests/uvbeam/test_uvbeam.py @@ -11,6 +11,7 @@ import numpy as np import pytest +import yaml from astropy import units from astropy.io import fits @@ -2967,3 +2968,63 @@ def test_from_file(filename): assert uvb.check() assert uvb2.check() assert uvb == uvb2 + + +@pytest.mark.parametrize("filename", [cst_yaml_file, mwa_beam_file, casa_beamfits]) +def test_yaml_constructor(filename): + input_yaml = f""" + beam: !UVBeam + filename: {filename} + run_check: False + """ + + beam_from_yaml = yaml.safe_load(input_yaml)["beam"] + + # don't run checks because of casa_beamfits, we'll do that later + uvb = UVBeam.from_file(filename, run_check=False) + # hera casa beam is missing some parameters but we just want to check + # that reading is going okay + if filename == casa_beamfits: + # fill in missing parameters + for _uvb in [uvb, beam_from_yaml]: + _uvb.data_normalization = "peak" + _uvb.feed_name = "casa_ideal" + _uvb.feed_version = "v0" + _uvb.model_name = "casa_airy" + _uvb.model_version = "v0" + + # this file is actually in an orthoslant projection RA/DEC at zenith at a + # particular time. + # For now pretend it's in a zenith orthoslant projection + _uvb.pixel_coordinate_system = "orthoslant_zenith" + # double check the files are valid + assert uvb.check() + assert beam_from_yaml.check() + assert uvb == beam_from_yaml + + if isinstance(uvb.filename, list): + err_msg = "beam.filename must be a string to be able to represent it in a yaml." + elif not os.path.exists(uvb.filename): + err_msg = ( + "beam.filename must be an existing file to be able to represent it " + "in a yaml." + ) + with pytest.raises(ValueError, match=err_msg): + output_yaml = yaml.safe_dump({"beam": uvb}) + + uvb.filename = filename + output_yaml = yaml.safe_dump({"beam": uvb}) + try: + new_beam_from_yaml = yaml.safe_load(output_yaml)["beam"] + + assert new_beam_from_yaml == beam_from_yaml + except ValueError: + # get here for the ill-defined casa beam. just test the null filename case + + uvb.filename = None + with pytest.raises( + ValueError, + match="beam must have a filename defined to be able to represent it " + "in a yaml.", + ): + output_yaml = yaml.safe_dump({"beam": uvb})