Skip to content

Commit

Permalink
add test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Jul 12, 2024
1 parent ae54798 commit f4add86
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 166 deletions.
5 changes: 1 addition & 4 deletions src/pyuvdata/utils/antenna.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def _select_antenna_helper(
)

for ant in antenna_nums:
if ant in obj_ant_array:
ant_inds = np.append(ant_inds, np.where(obj_ant_array == ant)[0])
else:
raise ValueError(f"Antenna number {ant} is not present in the array")
ant_inds = np.append(ant_inds, np.where(obj_ant_array == ant)[0])

ant_inds = sorted(set(ant_inds))
else:
Expand Down
10 changes: 5 additions & 5 deletions src/pyuvdata/utils/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _extract_bls_pol(
for bl_ind in bls:
if not (bl_ind in baseline_array):
raise ValueError(
f"Baseline number {bl_ind} is not present in the " "baseline_array"
f"Baseline number {bl_ind} is not present in the baseline_array"
)
bls = list(zip(*baseline_to_antnums(bls, Nants_telescope=nants_telescope)))
elif isinstance(bls, tuple) and (len(bls) == 2 or len(bls) == 3):
Expand All @@ -416,17 +416,17 @@ def _extract_bls_pol(
if any(len(item) == 3 for item in bls):
if polarizations is not None:
raise ValueError(
"Cannot provide any length-3 tuples and also specify " "polarizations."
"Cannot provide any length-3 tuples and also specify polarizations."
)

bls_2 = copy.deepcopy(bls)
for bl_i, bl in enumerate(bls):
if len(bl) == 2:
continue
if len(bl) != 3:
raise ValueError("If some bls are 3-tuples, all bls must be 3-tuples.")

if not isinstance(bl[2], str):
raise ValueError(
"The third element in a bl tuple must be a " "polarization string"
"The third element in a bl tuple must be a polarization string"
)

bl_pols = set()
Expand Down
4 changes: 3 additions & 1 deletion src/pyuvdata/uvdata/uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10832,9 +10832,11 @@ def read(

# MWA corr fits can only handle length-two bls tuples, anything
# else needs to be handled via select.
bls_use = copy.deepcopy(bls)
if bls is not None:
if not all(len(item) == 2 for item in bls):
select_bls = bls
bls_use = None

# these aren't supported by partial read, so do it in select
select_ant_str = ant_str
Expand Down Expand Up @@ -10944,7 +10946,7 @@ def read(
filename,
antenna_nums=antenna_nums,
antenna_names=antenna_names,
bls=bls,
bls=bls_use,
frequencies=frequencies,
freq_chans=freq_chans,
times=times,
Expand Down
10 changes: 8 additions & 2 deletions tests/uvdata/test_mwa_corr_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,11 @@ def test_default_corrections(tmp_path):
[{"antenna_nums": [18, 31, 66, 95]}, ""],
[{"antenna_names": [f"Tile{ant:03d}" for ant in [18, 31, 66, 95]]}, ""],
[{"bls": [(48, 34), (96, 11), (22, 87)]}, ""],
[
{"bls": [(48, 34, "xx"), (96, 11, "xx"), (22, 87, "xx")]},
"a select on read keyword is set that is not supported by "
"read_mwa_corr_fits. This select will be done after reading the file.",
],
[
{"ant_str": "48_34,96_11,22_87"},
"a select on read keyword is set that is not supported by "
Expand All @@ -1181,7 +1186,7 @@ def test_partial_read_bl_axis(tmp_path, mwax, select_kwargs, warn_msg):
if warn_msg != "":
warn_msg_list.append(warn_msg)

if mwax and "bls" not in select_kwargs.keys():
if mwax and ("bls" not in select_kwargs.keys() or warn_msg != ""):
# The bls selection has no autos
warn_msg_list.append("Fixing auto-correlations to be be real-only")

Expand Down Expand Up @@ -1322,8 +1327,9 @@ def test_partial_read_freq_axis(tmp_path, mwax, select_kwargs, read_kwargs, nspw
"select_kwargs",
[
{"polarizations": ["xx"]},
{"polarizations": ["xx", "yy"]},
{"polarizations": np.atleast_3d(["xx", "yy"])},
{"polarizations": ["xx", "xy"]},
{"polarizations": [-7, -8]},
],
)
@pytest.mark.parametrize("mwax", [False, True])
Expand Down
224 changes: 70 additions & 154 deletions tests/uvdata/test_uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,120 +1578,53 @@ def sort_bl(p):


@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
def test_select_bls(casa_uvfits):
@pytest.mark.parametrize(
"sel_type", ["antpair", "blnum", "antpairpol", "antpair_npint", "single"]
)
def test_select_bls(casa_uvfits, sel_type):
uv_object = casa_uvfits
old_history = uv_object.history
first_ants = [7, 3, 8, 3, 22, 28, 9]
second_ants = [1, 21, 9, 2, 3, 4, 23]
pols = ["RR", "RR", "RR", "RR", "RR", "RR", "RR"]
new_unique_ants = np.unique(first_ants + second_ants)
ant_pairs_to_keep = list(zip(first_ants, second_ants))
sorted_pairs_to_keep = [sort_bl(p) for p in ant_pairs_to_keep]

blts_select = [
sort_bl((a1, a2)) in sorted_pairs_to_keep
for (a1, a2) in zip(uv_object.ant_1_array, uv_object.ant_2_array)
]
Nblts_selected = np.sum(blts_select)

uv_object2 = uv_object.copy()
uv_object2.select(bls=ant_pairs_to_keep)
sorted_pairs_object2 = [
sort_bl(p) for p in zip(uv_object2.ant_1_array, uv_object2.ant_2_array)
]

assert len(new_unique_ants) == uv_object2.Nants_data
assert Nblts_selected == uv_object2.Nblts
for ant in new_unique_ants:
assert ant in uv_object2.ant_1_array or ant in uv_object2.ant_2_array
for ant in np.unique(
uv_object2.ant_1_array.tolist() + uv_object2.ant_2_array.tolist()
):
assert ant in new_unique_ants
for pair in sorted_pairs_to_keep:
assert pair in sorted_pairs_object2
for pair in sorted_pairs_object2:
assert pair in sorted_pairs_to_keep

assert utils.history._check_histories(
old_history + " Downselected to specific antenna pairs using pyuvdata.",
uv_object2.history,
)

# check using baseline number parameter
uv_object3 = uv_object.copy()
bls_nums_to_keep = [
uv_object.antnums_to_baseline(ant1, ant2) for ant1, ant2 in sorted_pairs_to_keep
]

uv_object3.select(bls=bls_nums_to_keep)
sorted_pairs_object3 = [
sort_bl(p) for p in zip(uv_object3.ant_1_array, uv_object3.ant_2_array)
]

assert len(new_unique_ants) == uv_object3.Nants_data
assert Nblts_selected == uv_object3.Nblts
for ant in new_unique_ants:
assert ant in uv_object3.ant_1_array or ant in uv_object3.ant_2_array
for ant in np.unique(
uv_object3.ant_1_array.tolist() + uv_object3.ant_2_array.tolist()
):
assert ant in new_unique_ants
for pair in sorted_pairs_to_keep:
assert pair in sorted_pairs_object3
for pair in sorted_pairs_object3:
assert pair in sorted_pairs_to_keep

assert utils.history._check_histories(
old_history + " Downselected to specific antenna pairs using pyuvdata.",
uv_object3.history,
)

# check select with polarizations
first_ants = [7, 3, 8, 3, 22, 28, 9]
second_ants = [1, 21, 9, 2, 3, 4, 23]
pols = ["RR", "RR", "RR", "RR", "RR", "RR", "RR"]
new_unique_ants = np.unique(first_ants + second_ants)
bls_to_keep = list(zip(first_ants, second_ants, pols))
sorted_bls_to_keep = [sort_bl(p) for p in bls_to_keep]

blts_select = [
sort_bl((a1, a2, "RR")) in sorted_bls_to_keep
sort_bl((a1, a2)) in sorted_pairs_to_keep
for (a1, a2) in zip(uv_object.ant_1_array, uv_object.ant_2_array)
]
Nblts_selected = np.sum(blts_select)
sel_str = "antenna pairs"

uv_object2 = uv_object.copy()
uv_object2.select(bls=bls_to_keep)
sorted_pairs_object2 = [
sort_bl(p) + ("RR",)
for p in zip(uv_object2.ant_1_array, uv_object2.ant_2_array)
]

assert len(new_unique_ants) == uv_object2.Nants_data
assert Nblts_selected == uv_object2.Nblts
for ant in new_unique_ants:
assert ant in uv_object2.ant_1_array or ant in uv_object2.ant_2_array
for ant in np.unique(
uv_object2.ant_1_array.tolist() + uv_object2.ant_2_array.tolist()
):
assert ant in new_unique_ants
for bl in sorted_bls_to_keep:
assert bl in sorted_pairs_object2
for bl in sorted_pairs_object2:
assert bl in sorted_bls_to_keep

assert utils.history._check_histories(
old_history
+ " Downselected to specific antenna pairs, polarizations using pyuvdata.",
uv_object2.history,
)

# check that you can use numpy integers with out errors:
first_ants = list(map(np.int32, [7, 3, 8, 3, 22, 28, 9]))
second_ants = list(map(np.int32, [1, 21, 9, 2, 3, 4, 23]))
ant_pairs_to_keep = list(zip(first_ants, second_ants))
if sel_type == "antpair":
bls_select = ant_pairs_to_keep
elif sel_type == "antpair_npint":
bls_select = list(
zip(list(map(np.int32, first_ants)), list(map(np.int32, second_ants)))
)
elif sel_type == "blnum":
bls_select = bls_nums_to_keep
elif sel_type == "antpairpol":
bls_select = bls_to_keep
sel_str = "antenna pairs, polarizations"
elif sel_type == "single":
bls_select = (1, 7)
new_unique_ants = [1, 7]
sorted_pairs_to_keep = [(1, 7)]
blts_select = [
sort_bl((a1, a2)) in sorted_pairs_to_keep
for (a1, a2) in zip(uv_object.ant_1_array, uv_object.ant_2_array)
]
Nblts_selected = np.sum(blts_select)

uv_object2 = uv_object.select(bls=ant_pairs_to_keep, inplace=False)
uv_object2 = uv_object.copy()
uv_object2.select(bls=bls_select)
sorted_pairs_object2 = [
sort_bl(p) for p in zip(uv_object2.ant_1_array, uv_object2.ant_2_array)
]
Expand All @@ -1709,70 +1642,53 @@ def test_select_bls(casa_uvfits):
for pair in sorted_pairs_object2:
assert pair in sorted_pairs_to_keep

if sel_type == "antpairpol":
assert uv_object2.Npols == 1
if sel_type == "2_3_tuple":
assert uv_object2.Npols == 1

assert utils.history._check_histories(
old_history + " Downselected to specific antenna pairs using pyuvdata.",
old_history + f" Downselected to specific {sel_str} using pyuvdata.",
uv_object2.history,
)

# check that you can specify a single pair without errors
uv_object2.select(bls=(1, 7))
sorted_pairs_object2 = [
sort_bl(p) for p in zip(uv_object2.ant_1_array, uv_object2.ant_2_array)
]
assert list(set(sorted_pairs_object2)) == [(1, 7)]

# check for errors associated with antenna pairs not included in data and bad inputs
with pytest.raises(
ValueError, match="bls must be a list of tuples of antenna numbers"
):
uv_object.select(bls=list(zip(first_ants, second_ants)) + [1, 7])

with pytest.raises(
ValueError, match="bls must be a list of tuples of antenna numbers"
):
uv_object.select(
bls=[
(
uv_object.telescope.antenna_names[0],
uv_object.telescope.antenna_names[1],
)
]
)

with pytest.raises(
ValueError, match=re.escape("Antenna pair (5, 1) does not have any data")
):
uv_object.select(bls=(5, 1))

with pytest.raises(
ValueError, match=re.escape("Antenna pair (1, 5) does not have any data")
):
uv_object.select(bls=(1, 5))

with pytest.raises(
ValueError, match=re.escape("Antenna pair (27, 27) does not have any data")
):
uv_object.select(bls=(27, 27))

with pytest.raises(
ValueError,
match="Cannot provide any length-3 tuples and also specify polarizations.",
):
uv_object.select(bls=(7, 1, "RR"), polarizations="RR")

with pytest.raises(
ValueError,
match="The third element in a bl tuple must be a polarization string",
):
uv_object.select(bls=(7, 1, 7))

with pytest.raises(
ValueError, match="bls must be a list of tuples of antenna numbers"
):
uv_object.select(bls=[])
@pytest.mark.parametrize(
["sel_kwargs", "err_msg"],
[
[
{"bls": list(zip([7, 3, 8], [1, 21, 9])) + [1, 7]},
"bls must be a list of tuples of antenna numbers",
],
[{"bls": ("foo", "bar")}, "bls must be a list of tuples of antenna numbers"],
[{"bls": (5, 1)}, re.escape("Antenna pair (5, 1) does not have any data")],
[
{"bls": (5, 1, "RR")},
re.escape("Antenna pair (5, 1, 'RR') does not have any data"),
],
[{"bls": (1, 5)}, re.escape("Antenna pair (1, 5) does not have any data")],
[{"bls": (27, 27)}, re.escape("Antenna pair (27, 27) does not have any data")],
[
{"bls": (7, 1, "RR"), "polarizations": "RR"},
"Cannot provide any length-3 tuples and also specify polarizations.",
],
[
{"bls": (7, 1, 7)},
"The third element in a bl tuple must be a polarization string",
],
[
{"bls": [(7, 1, "RR"), (1, 5)]},
"If some bls are 3-tuples, all bls must be 3-tuples.",
],
[{"bls": []}, "bls must be a list of tuples of antenna numbers"],
[{"bls": [100]}, "Baseline number 100 is not present in the baseline_array"],
],
)
def test_select_bls_errors(casa_uvfits, sel_kwargs, err_msg):
uv_object = casa_uvfits

with pytest.raises(ValueError, match="Baseline number 100 is not present in the"):
uv_object.select(bls=[100])
with pytest.raises(ValueError, match=err_msg):
uv_object.select(**sel_kwargs)


@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values")
Expand Down

0 comments on commit f4add86

Please sign in to comment.