Skip to content

Commit

Permalink
Set tolerances in tests as much as possible
Browse files Browse the repository at this point in the history
Also use `np.testing.assert_allclose` as much as possible
  • Loading branch information
bhazelton committed Oct 24, 2024
1 parent ac85512 commit 532938e
Show file tree
Hide file tree
Showing 26 changed files with 1,513 additions and 418 deletions.
10 changes: 5 additions & 5 deletions tests/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_airy_beam_values(az_za_deg_grid):
for feed in range(2):
expected_data[pol, feed, :, :] = airy_values / np.sqrt(2.0)

np.testing.assert_allclose(beam_vals, expected_data)
np.testing.assert_allclose(beam_vals, expected_data, atol=1e-15, rtol=0)

assert beam.__repr__() == f"AiryBeam(diameter={diameter_m})"

Expand Down Expand Up @@ -96,7 +96,7 @@ def test_achromatic_gaussian_beam(az_za_deg_grid, sigma_type):
for feed in range(2):
expected_data[pol, feed, :, :] = gaussian_vals / np.sqrt(2.0)

np.testing.assert_allclose(beam_vals, expected_data)
np.testing.assert_allclose(beam_vals, expected_data, atol=1e-15, rtol=0)

assert (
beam.__repr__() == f"GaussianBeam(sigma={sigma_use.__repr__()}, "
Expand Down Expand Up @@ -197,7 +197,7 @@ def test_short_dipole_beam(az_za_deg_grid):
expected_data[1, 0] = np.cos(za_vals) * np.cos(az_vals)
expected_data[1, 1] = np.cos(za_vals) * np.sin(az_vals)

np.testing.assert_allclose(efield_vals, expected_data)
np.testing.assert_allclose(efield_vals, expected_data, atol=1e-15, rtol=0)

power_vals = beam.power_eval(az_array=az_vals, za_array=za_vals, freq_array=freqs)
expected_data = np.zeros((1, 4, n_freqs, nsrcs), dtype=float)
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_uniform_beam(az_za_deg_grid, feed_array, x_orientation):
beam_vals = beam.efield_eval(az_array=az_vals, za_array=za_vals, freq_array=freqs)

expected_data = np.ones((2, 2, n_freqs, nsrcs), dtype=float) / np.sqrt(2.0)
np.testing.assert_allclose(beam_vals, expected_data)
np.testing.assert_allclose(beam_vals, expected_data, atol=1e-15, rtol=0)

assert beam.__repr__() == "UniformBeam()"

Expand Down Expand Up @@ -532,7 +532,7 @@ def test_yaml_constructor_new(az_za_deg_grid):
az_array=az_vals, za_array=za_vals, freq_array=freqs
)

np.testing.assert_allclose(from_efield_eval, from_power_eval)
np.testing.assert_allclose(from_efield_eval, from_power_eval, atol=1e-15, rtol=0)


def test_yaml_constructor_errors():
Expand Down
10 changes: 6 additions & 4 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError:
hasmoon = False

from pyuvdata import parameter as uvp
from pyuvdata import parameter as uvp, utils
from pyuvdata.parameter import allowed_location_types
from pyuvdata.uvbase import UVBase

Expand Down Expand Up @@ -585,10 +585,12 @@ def test_location_xyz_latlonalt_match(frame, selenoid):
height=ref_latlonalt_moon[2] * units.m,
),
)
np.testing.assert_allclose(latlonalt_val, param1.lat_lon_alt())
np.testing.assert_allclose(
latlonalt_val, param1.lat_lon_alt(), rtol=0, atol=utils.RADIAN_TOL
)

param2 = uvp.LocationParameter(name="p2", value=loc_detic)
np.testing.assert_allclose(xyz_val, param2.xyz())
np.testing.assert_allclose(xyz_val, param2.xyz(), rtol=0, atol=1e-3)

param5 = uvp.LocationParameter(name="p2", value=wrong_obj)
param5.set_lat_lon_alt(latlonalt_val, ellipsoid=selenoid)
Expand All @@ -605,7 +607,7 @@ def test_location_xyz_latlonalt_match(frame, selenoid):
)
param3.set_lat_lon_alt_degrees(latlonalt_deg_val)

np.testing.assert_allclose(xyz_val, param3.xyz())
np.testing.assert_allclose(xyz_val, param3.xyz(), rtol=0, atol=1e-3)


def test_location_acceptability():
Expand Down
17 changes: 12 additions & 5 deletions tests/test_telescopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ def test_old_attr_names():
"Telescope.location instead (which contains an astropy "
"EarthLocation object). This will become an error in version 3.2.",
):
np.testing.assert_allclose(mwa_tel.telescope_location, mwa_tel._location.xyz())
np.testing.assert_allclose(
mwa_tel.telescope_location,
mwa_tel._location.xyz(),
rtol=mwa_tel._location.tols[0],
atol=mwa_tel._location.tols[1],
)

with check_warnings(
DeprecationWarning,
Expand Down Expand Up @@ -290,6 +295,8 @@ def test_old_known_tel_dict_keys():
np.testing.assert_allclose(
KNOWN_TELESCOPES["HERA"]["center_xyz"],
Quantity(hera_tel.location.geocentric).to_value("m"),
rtol=hera_tel._location.tols[0],
atol=hera_tel._location.tols[1],
)
with check_warnings(DeprecationWarning, match=warn_msg):
assert KNOWN_TELESCOPES["HERA"]["citation"] == hera_tel.citation
Expand Down Expand Up @@ -318,7 +325,7 @@ def test_hera_loc():

telescope_obj = Telescope.from_known_telescopes("HERA")

assert np.allclose(
np.testing.assert_allclose(
telescope_obj._location.xyz(),
hera_data.telescope._location.xyz(),
rtol=hera_data.telescope._location.tols[0],
Expand All @@ -342,7 +349,7 @@ def test_alternate_antenna_inputs():
antenna_positions=antpos_array, antenna_numbers=antnum, antenna_names=antname
)

assert np.allclose(pos, pos2)
np.testing.assert_allclose(pos, pos2, rtol=0, atol=1e-3)
assert np.all(names == names2)
assert np.all(nums == nums2)

Expand All @@ -352,7 +359,7 @@ def test_alternate_antenna_inputs():
"002": np.array([0, 0, 2]),
}
pos, names, nums = get_antenna_params(antenna_positions=antpos_dict)
assert np.allclose(pos, pos2)
np.testing.assert_allclose(pos, pos2, rtol=0, atol=1e-3)
assert np.all(names == names2)
assert np.all(nums == nums2)

Expand Down Expand Up @@ -480,7 +487,7 @@ def test_get_enu_antpos():
# no center, no pick data ants
antpos = tel.get_enu_antpos()
assert antpos.shape == (tel.Nants, 3)
assert np.isclose(antpos[0, 0], 19.340211050751535)
assert np.isclose(antpos[0, 0], 19.340211050751535, rtol=0, atol=1e-3)


def test_ignore_param_updates_error():
Expand Down
45 changes: 37 additions & 8 deletions tests/test_uvbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astropy.coordinates import Distance, EarthLocation, Latitude, Longitude, SkyCoord
from astropy.time import Time

from pyuvdata import Telescope, parameter as uvp
from pyuvdata import Telescope, parameter as uvp, utils
from pyuvdata.testing import check_warnings
from pyuvdata.uvbase import UVBase, _warning, old_telescope_metadata_attrs

Expand Down Expand Up @@ -533,11 +533,17 @@ def test_getattr_old_telescope():
):
param_val = getattr(test_obj, param)
tel_param_val = getattr(test_obj.telescope, tel_param)
tel_param_obj = getattr(test_obj.telescope, "_" + tel_param)
if not isinstance(param_val, np.ndarray):
assert param_val == tel_param_val
else:
if not isinstance(param_val.flat[0], str):
np.testing.assert_allclose(param_val, tel_param_val)
np.testing.assert_allclose(
param_val,
tel_param_val,
rtol=tel_param_obj.tols[0],
atol=tel_param_obj.tols[1],
)
else:
assert param_val.tolist() == tel_param_val.tolist()
elif param == "telescope_location":
Expand All @@ -550,7 +556,12 @@ def test_getattr_old_telescope():
"object. This will become an error in version 3.2.",
):
param_val = getattr(test_obj, param)
np.testing.assert_allclose(param_val, test_obj.telescope._location.xyz())
np.testing.assert_allclose(
param_val,
test_obj.telescope._location.xyz(),
rtol=tel_param_obj.tols[0],
atol=tel_param_obj.tols[1],
)
elif param == "telescope_location_lat_lon_alt":
with check_warnings(
DeprecationWarning,
Expand All @@ -562,7 +573,10 @@ def test_getattr_old_telescope():
):
param_val = getattr(test_obj, param)
np.testing.assert_allclose(
param_val, test_obj.telescope._location.lat_lon_alt()
param_val,
test_obj.telescope._location.lat_lon_alt(),
rtol=0,
atol=utils.RADIAN_TOL,
)
elif param == "telescope_location_lat_lon_alt_degrees":
with check_warnings(
Expand All @@ -575,7 +589,10 @@ def test_getattr_old_telescope():
):
param_val = getattr(test_obj, param)
np.testing.assert_allclose(
param_val, test_obj.telescope._location.lat_lon_alt_degrees()
param_val,
test_obj.telescope._location.lat_lon_alt_degrees(),
rtol=0,
atol=np.rad2deg(utils.RADIAN_TOL),
)


Expand All @@ -587,6 +604,7 @@ def test_setattr_old_telescope():
for param, tel_param in old_telescope_metadata_attrs.items():
if tel_param is not None:
tel_val = getattr(new_telescope, tel_param)
tel_param_obj = getattr(new_telescope, "_" + tel_param)
with check_warnings(
DeprecationWarning,
match=f"The UVData.{param} attribute now just points to the "
Expand All @@ -609,7 +627,12 @@ def test_setattr_old_telescope():
assert param_val == tel_val
else:
if not isinstance(param_val.flat[0], str):
np.testing.assert_allclose(param_val, tel_val)
np.testing.assert_allclose(
param_val,
tel_val,
rtol=tel_param_obj.tols[0],
atol=tel_param_obj.tols[1],
)
else:
assert param_val.tolist() == tel_val.tolist()
elif param == "telescope_location":
Expand All @@ -636,7 +659,10 @@ def test_setattr_old_telescope():
):
test_obj.telescope_location_lat_lon_alt = tel_val
np.testing.assert_allclose(
new_telescope._location.xyz(), test_obj.telescope._location.xyz()
new_telescope._location.xyz(),
test_obj.telescope._location.xyz(),
rtol=tel_param_obj.tols[0],
atol=tel_param_obj.tols[1],
)
elif param == "telescope_location_lat_lon_alt_degrees":
tel_val = new_telescope._location.lat_lon_alt_degrees()
Expand All @@ -650,7 +676,10 @@ def test_setattr_old_telescope():
):
test_obj.telescope_location_lat_lon_alt_degrees = tel_val
np.testing.assert_allclose(
new_telescope._location.xyz(), test_obj.telescope._location.xyz()
new_telescope._location.xyz(),
test_obj.telescope._location.xyz(),
rtol=tel_param_obj.tols[0],
atol=tel_param_obj.tols[1],
)


Expand Down
44 changes: 23 additions & 21 deletions tests/utils/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,9 @@ def test_lla_xyz_lla_roundtrip():
lons *= np.pi / 180.0
xyz = utils.XYZ_from_LatLonAlt(lats, lons, alts)
lats_new, lons_new, alts_new = utils.LatLonAlt_from_XYZ(xyz)
np.testing.assert_allclose(lats_new, lats)
np.testing.assert_allclose(lons_new, lons)
np.testing.assert_allclose(alts_new, alts)
np.testing.assert_allclose(lats_new, lats, rtol=0, atol=utils.RADIAN_TOL)
np.testing.assert_allclose(lons_new, lons, rtol=0, atol=utils.RADIAN_TOL)
np.testing.assert_allclose(alts_new, alts, rtol=0, atol=1e-3)


def test_xyz_from_latlonalt(enu_ecef_info):
Expand All @@ -439,7 +439,9 @@ def test_enu_from_ecef(enu_ecef_info):
enu = utils.ENU_from_ECEF(
xyz, latitude=center_lat, longitude=center_lon, altitude=center_alt
)
np.testing.assert_allclose(np.stack((east, north, up), axis=1), enu, atol=1e-3)
np.testing.assert_allclose(
np.stack((east, north, up), axis=1), enu, rtol=0, atol=1e-3
)

enu2 = utils.ENU_from_ECEF(
xyz,
Expand All @@ -449,7 +451,7 @@ def test_enu_from_ecef(enu_ecef_info):
height=center_alt * units.m,
),
)
np.testing.assert_allclose(enu, enu2)
np.testing.assert_allclose(enu, enu2, rtol=0, atol=1e-3)


@pytest.mark.skipif(not hasmoon, reason="lunarsky not installed")
Expand All @@ -470,7 +472,9 @@ def test_enu_from_mcmf(enu_mcmf_info, selenoid):
ellipsoid=selenoid,
)

np.testing.assert_allclose(np.stack((east, north, up), axis=1), enu, atol=1e-3)
np.testing.assert_allclose(
np.stack((east, north, up), axis=1), enu, rtol=0, atol=1e-3
)

enu2 = utils.ENU_from_ECEF(
xyz,
Expand All @@ -481,7 +485,7 @@ def test_enu_from_mcmf(enu_mcmf_info, selenoid):
ellipsoid=selenoid,
),
)
np.testing.assert_allclose(enu, enu2, atol=1e-3)
np.testing.assert_allclose(enu, enu2, rtol=0, atol=1e-3)


def test_invalid_frame():
Expand Down Expand Up @@ -612,10 +616,10 @@ def test_ecef_from_enu_roundtrip(enu_ecef_info, enu_mcmf_info, frame, selenoid):
frame=frame,
ellipsoid=selenoid,
)
np.testing.assert_allclose(xyz, xyz_from_enu, atol=1e-3)
np.testing.assert_allclose(xyz, xyz_from_enu, rtol=0, atol=1e-3)

xyz_from_enu2 = utils.ECEF_from_ENU(enu, center_loc=loc_obj)
np.testing.assert_allclose(xyz_from_enu, xyz_from_enu2, atol=1e-3)
np.testing.assert_allclose(xyz_from_enu, xyz_from_enu2, rtol=0, atol=1e-3)

if selenoid == "SPHERE":
enu = utils.ENU_from_ECEF(
Expand All @@ -633,7 +637,7 @@ def test_ecef_from_enu_roundtrip(enu_ecef_info, enu_mcmf_info, frame, selenoid):
altitude=center_alt,
frame=frame,
)
np.testing.assert_allclose(xyz, xyz_from_enu, atol=1e-3)
np.testing.assert_allclose(xyz, xyz_from_enu, rtol=0, atol=1e-3)


@pytest.mark.parametrize("shape_type", ["transpose", "Nblts,2", "Nblts,1"])
Expand Down Expand Up @@ -673,7 +677,7 @@ def test_ecef_from_enu_single(enu_ecef_info):
)

np.testing.assert_allclose(
np.array((east[0], north[0], up[0])), enu_single, atol=1e-3
np.array((east[0], north[0], up[0])), enu_single, rtol=0, atol=1e-3
)


Expand All @@ -692,13 +696,13 @@ def test_ecef_from_enu_single_roundtrip(enu_ecef_info):
xyz[0, :], latitude=center_lat, longitude=center_lon, altitude=center_alt
)
np.testing.assert_allclose(
np.array((east[0], north[0], up[0])), enu[0, :], atol=1e-3
np.array((east[0], north[0], up[0])), enu[0, :], rtol=0, atol=1e-3
)

xyz_from_enu = utils.ECEF_from_ENU(
enu_single, latitude=center_lat, longitude=center_lon, altitude=center_alt
)
np.testing.assert_allclose(xyz[0, :], xyz_from_enu, atol=1e-3)
np.testing.assert_allclose(xyz[0, :], xyz_from_enu, rtol=0, atol=1e-3)


def test_mwa_ecef_conversion():
Expand Down Expand Up @@ -743,11 +747,11 @@ def test_mwa_ecef_conversion():

enu = utils.ENU_from_ECEF(ecef_xyz, latitude=lat, longitude=lon, altitude=alt)

np.testing.assert_allclose(enu, enh)
np.testing.assert_allclose(enu, enh, rtol=0, atol=1e-3)

# test other direction of ECEF rotation
rot_xyz = utils.rotECEF_from_ECEF(new_xyz, lon)
np.testing.assert_allclose(rot_xyz.T, xyz)
np.testing.assert_allclose(rot_xyz.T, xyz, rtol=0, atol=1e-3)


def test_hpx_latlon_az_za():
Expand All @@ -768,10 +772,8 @@ def test_hpx_latlon_az_za():
za_mesh, az_mesh
)

print(np.min(calc_lat), np.max(calc_lat))
print(np.min(lat_mesh), np.max(lat_mesh))
np.testing.assert_allclose(calc_lat, lat_mesh)
np.testing.assert_allclose(calc_lon, lon_mesh)
np.testing.assert_allclose(calc_lat, lat_mesh, rtol=0, atol=utils.RADIAN_TOL)
np.testing.assert_allclose(calc_lon, lon_mesh, rtol=0, atol=utils.RADIAN_TOL)

with pytest.raises(
ValueError, match="shapes of hpx_lat and hpx_lon values must match."
Expand All @@ -782,8 +784,8 @@ def test_hpx_latlon_az_za():
lat_mesh, lon_mesh
)

np.testing.assert_allclose(calc_za, za_mesh)
np.testing.assert_allclose(calc_az, az_mesh)
np.testing.assert_allclose(calc_za, za_mesh, rtol=0, atol=utils.RADIAN_TOL)
np.testing.assert_allclose(calc_az, az_mesh, rtol=0, atol=utils.RADIAN_TOL)


@pytest.mark.parametrize("err_state", ["err", "warn", "none"])
Expand Down
Loading

0 comments on commit 532938e

Please sign in to comment.