Skip to content

Commit

Permalink
API(fields): rename weight functions to masks (#99)
Browse files Browse the repository at this point in the history
Use "mask" instead of "weight" as the field property that designates the
weight function. This makes it less ambiguous wrt. the weights from a
catalogue.

Closes: #98
  • Loading branch information
ntessore authored Jan 5, 2024
1 parent 8b8d188 commit 41cde36
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 54 deletions.
38 changes: 19 additions & 19 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def __init_subclass__(cls, *, spin: int | None = None) -> None:
break
cls.__ncol = (ncol - nopt, ncol)

def __init__(self, *columns: str, weight: str | None = None) -> None:
def __init__(self, *columns: str, mask: str | None = None) -> None:
"""Initialise the field."""
super().__init__()
self.__columns = self._init_columns(*columns) if columns else None
self.__weight = weight
self.__mask = mask
self._metadata: dict[str, Any] = {}
if (spin := self.__spin) is not None:
self._metadata["spin"] = spin
Expand Down Expand Up @@ -144,9 +144,9 @@ def spin(self) -> int:
return spin

@property
def weight(self) -> str | None:
"""Name of the weight function for this field."""
return self.__weight
def mask(self) -> str | None:
"""Name of the mask for this field."""
return self.__mask

@abstractmethod
async def __call__(
Expand Down Expand Up @@ -195,10 +195,10 @@ def __init__(
*columns: str,
overdensity: bool = True,
nbar: float | None = None,
weight: str | None = None,
mask: str | None = None,
) -> None:
"""Create a position field."""
super().__init__(*columns, weight=weight)
super().__init__(*columns, mask=mask)
self.__overdensity = overdensity
self.__nbar = nbar

Expand Down Expand Up @@ -541,7 +541,7 @@ class Spin2Field(ComplexField, spin=2):
Ellipticities = Spin2Field


def weights_for_fields(
def get_masks(
fields: Mapping[str, Field],
*,
comb: int | None = None,
Expand All @@ -550,12 +550,12 @@ def weights_for_fields(
append_eb: bool = False,
) -> Sequence[str] | Sequence[tuple[str, ...]]:
"""
Return the weights for a given set of fields.
Return the masks for a given set of fields.
If *comb* is given, produce combinations of weights for combinations
If *comb* is given, produce combinations of masks for combinations
of a number *comb* of fields.
The fields (not weights) can be filtered using the *include* and
The fields (not masks) can be filtered using the *include* and
*exclude* parameters. If *append_eb* is true, the filter is applied
to field names including the E/B-mode suffix when the spin weight is
non-zero.
Expand All @@ -575,21 +575,21 @@ def _all_str(seq: tuple[str | None, ...]) -> TypeGuard[tuple[str, ...]]:
return not any(item is None for item in seq)

if comb is None:
weights_no_comb: list[str] = []
masks_no_comb: list[str] = []
for key, field in fields.items():
if field.weight is None:
if field.mask is None:
continue
if not any(map(isgood, _key_eb(key))):
continue
weights_no_comb.append(field.weight)
return weights_no_comb
masks_no_comb.append(field.mask)
return masks_no_comb

weights_comb: list[tuple[str, ...]] = []
masks_comb: list[tuple[str, ...]] = []
for keys in combinations_with_replacement(fields, comb):
item = tuple(fields[key].weight for key in keys)
item = tuple(fields[key].mask for key in keys)
if not _all_str(item):
continue
if not any(map(isgood, product(*map(_key_eb, keys)))):
continue
weights_comb.append(item)
return weights_comb
masks_comb.append(item)
return masks_comb
22 changes: 11 additions & 11 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ def mixing_matrices(
if out is None:
out = TocDict()

# inverse mapping of weights to fields
weights: dict[str, dict[Any, Field]] = {}
# inverse mapping of masks to fields
masks: dict[str, dict[Any, Field]] = {}
for key, field in fields.items():
if field.weight is not None:
if field.weight not in weights:
weights[field.weight] = {}
weights[field.weight][key] = field
if field.mask is not None:
if field.mask not in masks:
masks[field.mask] = {}
masks[field.mask][key] = field

# keep track of combinations that have been done already
done = set()
Expand All @@ -236,21 +236,21 @@ def mixing_matrices(
progressbar = nullcontext()

# go through the toc dict of cls and compute mixing matrices
# which mixing matrix is computed depends on the `weights` mapping
# which mixing matrix is computed depends on the `masks` mapping
with progressbar as prog:
for (k1, k2, i1, i2), cl in cls.items():
# if the weights are not named then skip this cl
# if the masks are not named then skip this cl
try:
fields1 = weights[k1]
fields2 = weights[k2]
fields1 = masks[k1]
fields2 = masks[k2]
except KeyError:
continue

# deal with structured cl arrays
if cl.dtype.names is not None:
cl = cl["CL"]

# compute mixing matrices for all fields of this weight combination
# compute mixing matrices for all fields of this mask combination
for f1, f2 in product(fields1, fields2):
# check if this combination has been done already
if (f1, f2, i1, i2) in done or (f2, f1, i2, i1) in done:
Expand Down
42 changes: 21 additions & 21 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ async def __call__(self):
with pytest.raises(ValueError, match="accepts 2 to 3 columns"):
TestField("lon")

f = TestField("lon", "lat", weight="W")
f = TestField("lon", "lat", mask="W")

assert f.columns == ("lon", "lat", None)
assert f.weight == "W"
assert f.mask == "W"


def test_visibility(mapper, vmap):
Expand Down Expand Up @@ -328,41 +328,41 @@ def test_weights(mapper, catalog):
np.testing.assert_array_almost_equal(m, w / wbar)


def test_weights_for_fields():
def test_get_masks():
from unittest.mock import Mock

from heracles.fields import weights_for_fields
from heracles.fields import get_masks

fields = {
"A": Mock(weight="X", spin=0),
"B": Mock(weight="Y", spin=2),
"C": Mock(weight=None),
"A": Mock(mask="X", spin=0),
"B": Mock(mask="Y", spin=2),
"C": Mock(mask=None),
}

weights = weights_for_fields(fields)
masks = get_masks(fields)

assert weights == ["X", "Y"]
assert masks == ["X", "Y"]

weights = weights_for_fields(fields, comb=1)
masks = get_masks(fields, comb=1)

assert weights == [("X",), ("Y",)]
assert masks == [("X",), ("Y",)]

weights = weights_for_fields(fields, comb=2)
masks = get_masks(fields, comb=2)

assert weights == [("X", "X"), ("X", "Y"), ("Y", "Y")]
assert masks == [("X", "X"), ("X", "Y"), ("Y", "Y")]

weights = weights_for_fields(fields, comb=2, include=[("A",)])
masks = get_masks(fields, comb=2, include=[("A",)])

assert weights == [("X", "X"), ("X", "Y")]
assert masks == [("X", "X"), ("X", "Y")]

weights = weights_for_fields(fields, comb=2, exclude=[("A", "B")])
masks = get_masks(fields, comb=2, exclude=[("A", "B")])

assert weights == [("X", "X"), ("Y", "Y")]
assert masks == [("X", "X"), ("Y", "Y")]

weights = weights_for_fields(fields, comb=2, include=[("A", "B")], append_eb=True)
masks = get_masks(fields, comb=2, include=[("A", "B")], append_eb=True)

assert weights == []
assert masks == []

weights = weights_for_fields(fields, comb=2, include=[("A", "B_E")], append_eb=True)
masks = get_masks(fields, comb=2, include=[("A", "B_E")], append_eb=True)

assert weights == [("X", "Y")]
assert masks == [("X", "Y")]
6 changes: 3 additions & 3 deletions tests/test_twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_mixing_matrices(mock, mock_eb, rng):
# this only tests the function logic
# the mixing matrix computation itself is tested elsewhere

# field definition, requires weight function and spin weight
# field definition, requires mask and spin weight

# mixmat_eb returns three values
mock_eb.return_value = (Mock(), Mock(), Mock())
Expand All @@ -146,8 +146,8 @@ def test_mixing_matrices(mock, mock_eb, rng):

# create the mock field information
fields = {
"P": Mock(weight="V", spin=0),
"G": Mock(weight="W", spin=2),
"P": Mock(mask="V", spin=0),
"G": Mock(mask="W", spin=2),
}

# compute pos-pos
Expand Down

0 comments on commit 41cde36

Please sign in to comment.