From ae29819e57d7ab12ecf4da69ffd6d4c169a437f5 Mon Sep 17 00:00:00 2001 From: Bryna Hazelton Date: Tue, 17 Sep 2024 16:24:45 -0700 Subject: [PATCH] Move subclass check into `__init_subclass__` --- src/pyuvdata/uvbeam/analytic_beam.py | 12 +++---- tests/uvbeam/test_analytic_beam.py | 49 ++++++++++++++-------------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/pyuvdata/uvbeam/analytic_beam.py b/src/pyuvdata/uvbeam/analytic_beam.py index 8c0299ec5..62c6652c2 100644 --- a/src/pyuvdata/uvbeam/analytic_beam.py +++ b/src/pyuvdata/uvbeam/analytic_beam.py @@ -68,6 +68,12 @@ class AnalyticBeam(ABC): def __init_subclass__(cls): """Initialize any imported subclass.""" + if cls.basis_vector_type not in cls._basis_vec_dict: + raise ValueError( + f"basis_vector_type for {cls.__name__} is {cls.basis_vector_type}, " + f"must be one of {list(cls._basis_vec_dict.keys())}" + ) + cls.__types__[cls.__name__] = cls @property @@ -86,12 +92,6 @@ def __post_init__(self, include_cross_pols): for the power beam. """ - if self.basis_vector_type not in self._basis_vec_dict: - raise ValueError( - f"basis_vector_type is {self.basis_vector_type}, must be one of " - f"{list(self._basis_vec_dict.keys())}" - ) - if self.feed_array is not None: for feed in self.feed_array: allowed_feeds = ["n", "e", "x", "y", "r", "l"] diff --git a/tests/uvbeam/test_analytic_beam.py b/tests/uvbeam/test_analytic_beam.py index 582b8629e..b06964129 100644 --- a/tests/uvbeam/test_analytic_beam.py +++ b/tests/uvbeam/test_analytic_beam.py @@ -316,33 +316,34 @@ def test_beamerrs(beam_kwargs, err_msg): def test_bad_basis_vector_type(): - class BadBeam(AnalyticBeam): - basis_vector_type = "healpix" - name = "bad beam" - - def _efield_eval( - self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray - ): - """Evaluate the efield at the given coordinates.""" - data_array = self._get_empty_data_array(az_array, za_array, freq_array) - data_array = data_array + 1.0 / np.sqrt(2.0) - return data_array - - def _power_eval( - self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray - ): - """Evaluate the efield at the given coordinates.""" - data_array = self._get_empty_data_array( - az_array, za_array, freq_array, beam_type="power" - ) - data_array = data_array + 1.0 - return data_array - with pytest.raises( ValueError, - match=re.escape("basis_vector_type is healpix, must be one of ['az_za']"), + match=re.escape( + "basis_vector_type for BadBeam is healpix, must be one of ['az_za']" + ), ): - BadBeam() + + class BadBeam(AnalyticBeam): + basis_vector_type = "healpix" + name = "bad beam" + + def _efield_eval( + self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray + ): + """Evaluate the efield at the given coordinates.""" + data_array = self._get_empty_data_array(az_array, za_array, freq_array) + data_array = data_array + 1.0 / np.sqrt(2.0) + return data_array + + def _power_eval( + self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray + ): + """Evaluate the efield at the given coordinates.""" + data_array = self._get_empty_data_array( + az_array, za_array, freq_array, beam_type="power" + ) + data_array = data_array + 1.0 + return data_array def test_to_uvbeam_errors():