Skip to content

Commit

Permalink
Move subclass check into __init_subclass__
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Sep 17, 2024
1 parent 1327ad0 commit abc4714
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
12 changes: 6 additions & 6 deletions src/pyuvdata/uvbeam/analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class AnalyticBeam(ABC):

def __init_subclass__(cls):
"""Initialize any imported subclass."""
if cls.basis_vector_type not in cls._basis_vec_dict:
raise ValueError(
f"basis_vector_type for {cls.__name__} is {cls.basis_vector_type}, "
f"must be one of {list(cls._basis_vec_dict.keys())}"
)

cls.__types__[cls.__name__] = cls

@property
Expand All @@ -86,12 +92,6 @@ def __post_init__(self, include_cross_pols):
for the power beam.
"""
if self.basis_vector_type not in self._basis_vec_dict:
raise ValueError(
f"basis_vector_type is {self.basis_vector_type}, must be one of "
f"{list(self._basis_vec_dict.keys())}"
)

if self.feed_array is not None:
for feed in self.feed_array:
allowed_feeds = ["n", "e", "x", "y", "r", "l"]
Expand Down
49 changes: 25 additions & 24 deletions tests/uvbeam/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,33 +316,34 @@ def test_beamerrs(beam_kwargs, err_msg):


def test_bad_basis_vector_type():
class BadBeam(AnalyticBeam):
basis_vector_type = "healpix"
name = "bad beam"

def _efield_eval(
self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray
):
"""Evaluate the efield at the given coordinates."""
data_array = self._get_empty_data_array(az_array, za_array, freq_array)
data_array = data_array + 1.0 / np.sqrt(2.0)
return data_array

def _power_eval(
self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray
):
"""Evaluate the efield at the given coordinates."""
data_array = self._get_empty_data_array(
az_array, za_array, freq_array, beam_type="power"
)
data_array = data_array + 1.0
return data_array

with pytest.raises(
ValueError,
match=re.escape("basis_vector_type is healpix, must be one of ['az_za']"),
match=re.escape(
"basis_vector_type for BadBeam is healpix, must be one of ['az_za']"
),
):
BadBeam()

class BadBeam(AnalyticBeam):
basis_vector_type = "healpix"
name = "bad beam"

def _efield_eval(
self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray
):
"""Evaluate the efield at the given coordinates."""
data_array = self._get_empty_data_array(az_array, za_array, freq_array)
data_array = data_array + 1.0 / np.sqrt(2.0)
return data_array

def _power_eval(
self, az_array: np.ndarray, za_array: np.ndarray, freq_array: np.ndarray
):
"""Evaluate the efield at the given coordinates."""
data_array = self._get_empty_data_array(
az_array, za_array, freq_array, beam_type="power"
)
data_array = data_array + 1.0
return data_array


def test_to_uvbeam_errors():
Expand Down

0 comments on commit abc4714

Please sign in to comment.