Skip to content

Commit

Permalink
Better approach for testing the AnalyticBeam plugin code
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Sep 18, 2024
1 parent abc4714 commit 867af69
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 14 deletions.
51 changes: 51 additions & 0 deletions src/pyuvdata/data/test_analytic_beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 Radio Astronomy Software Group
# Licensed under the 2-clause BSD License
"""Define an AnalyticBeam subclass for testing the AnalyticBeam plugin code."""

from dataclasses import InitVar, dataclass, field
from typing import Literal

import numpy.typing as npt

from ..uvbeam.analytic_beam import AnalyticBeam


@dataclass(kw_only=True)
class AnalyticTest(AnalyticBeam):
"""A test class to support testing the AnalyticBeam plugin code."""

radius: float
feed_array: npt.NDArray[str] | None = field(default=None, repr=False, compare=False)
x_orientation: Literal["east", "north"] = field(
default="east", repr=False, compare=False
)

include_cross_pols: InitVar[bool] = True

basis_vector_type = "az_za"

def _efield_eval(
self,
*,
az_array: npt.NDArray[float],
za_array: npt.NDArray[float],
freq_array: npt.NDArray[float],
) -> npt.NDArray[float]:
"""Evaluate the efield at the given coordinates."""
data_array = self._get_empty_data_array(az_array.size, freq_array.size)

Check warning on line 35 in src/pyuvdata/data/test_analytic_beam.py

View check run for this annotation

Codecov / codecov/patch

src/pyuvdata/data/test_analytic_beam.py#L35

Added line #L35 was not covered by tests

return data_array

Check warning on line 37 in src/pyuvdata/data/test_analytic_beam.py

View check run for this annotation

Codecov / codecov/patch

src/pyuvdata/data/test_analytic_beam.py#L37

Added line #L37 was not covered by tests

def _power_eval(
self,
*,
az_array: npt.NDArray[float],
za_array: npt.NDArray[float],
freq_array: npt.NDArray[float],
) -> npt.NDArray[float]:
"""Evaluate the power at the given coordinates."""
data_array = self._get_empty_data_array(

Check warning on line 47 in src/pyuvdata/data/test_analytic_beam.py

View check run for this annotation

Codecov / codecov/patch

src/pyuvdata/data/test_analytic_beam.py#L47

Added line #L47 was not covered by tests
az_array.size, freq_array.size, beam_type="power"
)

return data_array

Check warning on line 51 in src/pyuvdata/data/test_analytic_beam.py

View check run for this annotation

Codecov / codecov/patch

src/pyuvdata/data/test_analytic_beam.py#L51

Added line #L51 was not covered by tests
10 changes: 3 additions & 7 deletions src/pyuvdata/uvbeam/analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,9 @@ def _analytic_beam_constructor(loader, node):

if class_name in AnalyticBeam.__types__:
beam_class = AnalyticBeam.__types__[class_name]
else:
if len(class_parts) == 1:
# no module specified, assume pyuvdata
module = importlib.import_module("pyuvdata")
else:
module = (".").join(class_parts[:-1])
module = importlib.import_module(module)
elif len(class_parts) > 1:
module = (".").join(class_parts[:-1])
module = importlib.import_module(module)
beam_class = getattr(module, class_name)

if class_name not in AnalyticBeam.__types__:
Expand Down
40 changes: 33 additions & 7 deletions tests/uvbeam/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,7 @@ def test_to_uvbeam_errors():
],
],
)
# the ordering of the following parameterize is critical, don't change it.
# once the AnalyticBeam.__types__ is cleared you can't get it back.
@pytest.mark.parametrize("clear_plugin_list", [False, True])
def test_yaml_constructor(input_yaml, beam, clear_plugin_list):
if clear_plugin_list:
AnalyticBeam.__types__ = {}

def test_yaml_constructor(input_yaml, beam):
beam_from_yaml = yaml.safe_load(input_yaml)["beam"]

assert beam_from_yaml == beam
Expand All @@ -442,6 +436,21 @@ def test_yaml_constructor(input_yaml, beam, clear_plugin_list):
assert new_beam_from_yaml == beam_from_yaml


def test_yaml_constructor_new():
input_yaml = """
beam: !AnalyticBeam
class: pyuvdata.data.test_analytic_beam.AnalyticTest
radius: 10
"""
beam_from_yaml = yaml.safe_load(input_yaml)["beam"]

from pyuvdata.data.test_analytic_beam import AnalyticTest

beam = AnalyticTest(radius=10)

assert beam_from_yaml == beam


def test_yaml_constructor_errors():
input_yaml = """
beam: !AnalyticBeam
Expand All @@ -452,3 +461,20 @@ def test_yaml_constructor_errors():
ValueError, match="yaml entries for AnalyticBeam must specify a class"
):
yaml.safe_load(input_yaml)["beam"]

input_yaml = """
beam: !AnalyticBeam
class: FakeBeam
diameter: 10
"""

with pytest.raises(
NameError,
match=re.escape(
"FakeBeam is not a known AnalyticBeam. Available options are: "
f"{list(AnalyticBeam.__types__.keys())}. If it is a custom beam, "
"either ensure the module is imported, or specify the beam with "
"dot-pathed modules included (i.e. `my_module.MyAnalyticBeam`)"
),
):
yaml.safe_load(input_yaml)["beam"]

0 comments on commit 867af69

Please sign in to comment.