diff --git a/heracles/_mapping.py b/heracles/_mapping.py index b3797176..349e1dd3 100644 --- a/heracles/_mapping.py +++ b/heracles/_mapping.py @@ -177,7 +177,7 @@ def transform( msg = f"unknown field name: {k}" raise ValueError(msg) - out[k, i] = field.mapper_or_error.transform(m) + out[k, i] = field.mapper.transform(m) if progress: subtask.remove() diff --git a/heracles/catalog/base.py b/heracles/catalog/base.py index 635fa2d4..74b69ef8 100644 --- a/heracles/catalog/base.py +++ b/heracles/catalog/base.py @@ -440,7 +440,7 @@ def where(self, selection, visibility=None): @property def page_size(self): - """number of rows per page (default: 100_000)""" + """number of rows per page (default: 1_000_000)""" return self._page_size @page_size.setter diff --git a/heracles/fields.py b/heracles/fields.py index 83141730..1e3dfb86 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -59,62 +59,25 @@ class Field(metaclass=ABCMeta): # every field subclass has a static spin weight attribute, which can be # overwritten by the class (or even an individual instance) - __spin: int | None = None + _spin: int | None = None # definition of required and optional columns __ncol: tuple[int, int] - def __init_subclass__(cls, *, spin: int | None = None) -> None: - """Initialise spin weight of field subclasses.""" - super().__init_subclass__() - if spin is not None: - cls.__spin = spin - uses = cls.uses - if uses is None: - uses = () - elif isinstance(uses, str): - uses = (uses,) - ncol = len(uses) - nopt = 0 - for u in uses[::-1]: - if u.startswith("[") and u.endswith("]"): - nopt += 1 - else: - break - cls.__ncol = (ncol - nopt, ncol) - def __init__( self, mapper: Mapper | None, *columns: str, + weight: str | None = None, mask: str | None = None, ) -> None: """Initialise the field.""" super().__init__() self.__mapper = mapper - self.__columns = self._init_columns(*columns) if columns else None + self.__columns = columns if columns else None + self.__weight = weight if weight else None self.__mask = mask - - @classmethod - def _init_columns(cls, *columns: str) -> Columns: - """Initialise the given set of columns for a specific field - subclass.""" - nmin, nmax = cls.__ncol - if not nmin <= len(columns) <= nmax: - uses = cls.uses - if uses is None: - uses = () - if isinstance(uses, str): - uses = (uses,) - count = f"{nmin}" - if nmax != nmin: - count += f" to {nmax}" - msg = f"field of type '{cls.__name__}' accepts {count} columns" - if uses: - msg += " (" + ", ".join(uses) + ")" - msg += f", received {len(columns)}" - raise ValueError(msg) - return columns + (None,) * (nmax - len(columns)) + self._spin = 0 @property def mapper(self) -> Mapper | None: @@ -122,37 +85,18 @@ def mapper(self) -> Mapper | None: return self.__mapper @property - def mapper_or_error(self) -> Mapper: - """Return the mapper used by this field, or raise a :class:`ValueError` - if not set.""" - if self.__mapper is None: - msg = "no mapper for field" - raise ValueError(msg) - return self.__mapper + def weight(self) -> str | None: + """Return the mapper used by this field.""" + return self.__weight @property def columns(self) -> Columns | None: """Return the catalogue columns used by this field.""" return self.__columns - @property - def columns_or_error(self) -> Columns: - """Return the catalogue columns used by this field, or raise a - :class:`ValueError` if not set.""" - if self.__columns is None: - msg = "no columns for field" - raise ValueError(msg) - return self.__columns - @property def spin(self) -> int: - """Spin weight of field.""" - spin = self.__spin - if spin is None: - clsname = self.__class__.__name__ - msg = f"field of type '{clsname}' has undefined spin weight" - raise ValueError(msg) - return spin + return self._spin @property def mask(self) -> str | None: @@ -169,6 +113,17 @@ async def __call__( """Implementation for mapping a catalogue.""" ... + def CheckColumns(self, *expected): + if self.columns is None: + msg = "No columns defined!" + raise ValueError(msg) + if len(expected) != len(self.columns): + error = "Column error. Expected " + str(len(expected)) + " columns" + error += ( + " with a format " + str(expected) + ". Received " + str(self.columns) + ) + raise ValueError(error) + async def _pages( catalog: Catalog, @@ -190,7 +145,7 @@ async def _pages( await coroutines.sleep() -class Positions(Field, spin=0): +class Positions(Field): """Field of positions in a catalogue. Can produce both overdensity maps and number count maps, depending @@ -198,18 +153,17 @@ class Positions(Field, spin=0): """ - uses = "longitude", "latitude" - def __init__( self, mapper: Mapper | None, *columns: str, + weight: str | None, overdensity: bool = True, nbar: float | None = None, mask: str | None = None, ) -> None: """Create a position field.""" - super().__init__(mapper, *columns, mask=mask) + super().__init__(mapper, *columns, weight=weight, mask=mask) self.__overdensity = overdensity self.__nbar = nbar @@ -242,10 +196,13 @@ async def __call__( raise ValueError(msg) # get mapper - mapper = self.mapper_or_error + mapper = self.mapper # get catalogue column definition - col = self.columns_or_error + col = self.columns + self.CheckColumns("longitude", "latitude") + + # if(len(col)!=2): # position map pos = mapper.create(spin=self.spin) @@ -259,7 +216,7 @@ async def __call__( lon, lat = page.get(*col) w = np.ones(page.size) - mapper.map_values(lon, lat, pos, w) + self.mapper.map_values(lon, lat, pos, w) ngal += page.size @@ -307,11 +264,9 @@ async def __call__( return pos -class ScalarField(Field, spin=0): +class ScalarField(Field): """Field of real scalar values in a catalogue.""" - uses = "longitude", "latitude", "value", "[weight]" - async def __call__( self, catalog: Catalog, @@ -321,11 +276,13 @@ async def __call__( """Map real values from catalogue to HEALPix map.""" # get mapper - mapper = self.mapper_or_error + mapper = self.mapper # get the column definition of the catalogue - *col, wcol = self.columns_or_error + col = self.columns + self.CheckColumns("longitude", "latitude", "value") + wcol = self.weight # scalar field map val = mapper.create(spin=self.spin) @@ -373,7 +330,7 @@ async def __call__( return val -class ComplexField(Field, spin=0): +class ComplexField(Field): """Field of complex values in a catalogue. The :class:`ComplexField` class has zero spin weight, while @@ -381,8 +338,6 @@ class ComplexField(Field, spin=0): """ - uses = "longitude", "latitude", "real", "imag", "[weight]" - async def __call__( self, catalog: Catalog, @@ -392,10 +347,14 @@ async def __call__( """Map complex values from catalogue to HEALPix map.""" # get mapper - mapper = self.mapper_or_error + mapper = self.mapper # get the column definition of the catalogue - *col, wcol = self.columns_or_error + col = self.columns + + self.CheckColumns("longitude", "latitude", "real", "imag") + + wcol = self.weight # complex map with real and imaginary part val = mapper.create(2, spin=self.spin) @@ -443,7 +402,7 @@ async def __call__( return val -class Visibility(Field, spin=0): +class Visibility(Field): """Copy visibility map from catalogue at given resolution.""" async def __call__( @@ -455,7 +414,7 @@ async def __call__( """Create a visibility map from the given catalogue.""" # get mapper - mapper = self.mapper_or_error + mapper = self.mapper # make sure that catalogue has a visibility visibility = catalog.visibility @@ -479,11 +438,9 @@ async def __call__( return out -class Weights(Field, spin=0): +class Weights(Field): """Field of weight values from a catalogue.""" - uses = "longitude", "latitude", "[weight]" - async def __call__( self, catalog: Catalog, @@ -493,10 +450,12 @@ async def __call__( """Map catalogue weights.""" # get mapper - mapper = self.mapper_or_error + mapper = self.mapper # get the columns for this field - *col, wcol = self.columns_or_error + col = self.columns + self.CheckColumns("longitude", "latitude") + wcol = self.weight # weight map wht = mapper.create(spin=self.spin) @@ -543,9 +502,20 @@ async def __call__( return wht -class Spin2Field(ComplexField, spin=2): +class Spin2Field(ComplexField): """Spin-2 complex field.""" + def __init__( + self, + mapper: Mapper | None, + *columns: str, + weight: str | None, + mask: str | None = None, + ) -> None: + """Initialise the field.""" + super().__init__(mapper, *columns, weight=weight, mask=mask) + self._spin = 2 + Shears = Spin2Field Ellipticities = Spin2Field diff --git a/pyproject.toml b/pyproject.toml index 3d7fca9f..8a6a6d0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "coroutines", "fitsio", "healpy", + "matplotlib", "numba", "numpy", ] diff --git a/tests/test_fields.py b/tests/test_fields.py index fab4ccc3..2e1e0d79 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -78,52 +78,39 @@ def catalog(page): def test_field_abc(): - from unittest.mock import Mock - - from heracles.fields import Columns, Field + from heracles.fields import Field with pytest.raises(TypeError): Field() class SpinLessField(Field): - def _init_columns(self, *columns: str) -> Columns: - return columns - async def __call__(self): pass - f = SpinLessField(None) - - with pytest.raises(ValueError, match="undefined spin weight"): - f.spin - - class TestField(Field, spin=0): - uses = "lon", "lat", "[weight]" + f = SpinLessField(None, weight=None) + assert f.spin == 0 + class TestField(Field): async def __call__(self): pass - f = TestField(None) + f = TestField(None, weight=None) assert f.mapper is None assert f.columns is None assert f.spin == 0 - with pytest.raises(ValueError): - f.mapper_or_error + with pytest.raises(ValueError, match="No columns defined"): + f.CheckColumns(None) with pytest.raises(ValueError): - f.columns_or_error - - mapper = Mock() - - with pytest.raises(ValueError, match="accepts 2 to 3 columns"): - TestField(mapper, "lon") + f = TestField(mapper, "lon", weight=None) + f.CheckColumns("lon", "lat") - f = TestField(mapper, "lon", "lat", mask="W") + f = TestField(mapper, "lon", "lat", weight=None, mask="W") assert f.mapper is mapper - assert f.columns == ("lon", "lat", None) + assert f.columns == ("lon", "lat") assert f.mask == "W" @@ -143,7 +130,7 @@ def test_visibility(nside, vmap): mapper_out = HealpixMapper(nside_out) - f = Visibility(mapper_out) + f = Visibility(mapper_out, weight=None) with pytest.warns(UserWarning) if nside != nside_out else nullcontext(): result = coroutines.run(f(catalog)) @@ -165,7 +152,7 @@ def test_visibility(nside, vmap): # test missing visibility map catalog = Mock() catalog.visibility = None - f = Visibility(mapper) + f = Visibility(mapper, weight=None) with pytest.raises(ValueError, match="no visibility"): coroutines.run(f(catalog)) @@ -179,7 +166,7 @@ def test_positions(mapper, catalog, vmap): # normal mode: compute overdensity maps with metadata - f = Positions(mapper, "ra", "dec") + f = Positions(mapper, "ra", "dec", weight=None) # test some default settings assert f.spin == 0 @@ -211,7 +198,7 @@ def test_positions(mapper, catalog, vmap): # compute number count map - f = Positions(mapper, "ra", "dec", overdensity=False) + f = Positions(mapper, "ra", "dec", weight=None, overdensity=False) m = coroutines.run(f(catalog)) assert m.shape == (npix,) @@ -234,7 +221,7 @@ def test_positions(mapper, catalog, vmap): catalog.fsky = vmap.mean() nbar /= catalog.fsky - f = Positions(mapper, "ra", "dec") + f = Positions(mapper, "ra", "dec", weight=None) m = coroutines.run(f(catalog)) assert m.shape == (12 * mapper.nside**2,) @@ -252,7 +239,7 @@ def test_positions(mapper, catalog, vmap): # compute number count map with visibility map - f = Positions(mapper, "ra", "dec", overdensity=False) + f = Positions(mapper, "ra", "dec", weight=None, overdensity=False) m = coroutines.run(f(catalog)) assert m.shape == (12 * mapper.nside**2,) @@ -270,7 +257,7 @@ def test_positions(mapper, catalog, vmap): # compute overdensity maps with given (incorrect) nbar - f = Positions(mapper, "ra", "dec", nbar=2 * nbar) + f = Positions(mapper, "ra", "dec", weight=None, nbar=2 * nbar) with pytest.warns(UserWarning, match="mean density"): m = coroutines.run(f(catalog)) @@ -283,7 +270,7 @@ def test_scalar_field(mapper, catalog): npix = 12 * mapper.nside**2 - f = ScalarField(mapper, "ra", "dec", "g1", "w") + f = ScalarField(mapper, "ra", "dec", "g1", weight="w") m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] @@ -313,7 +300,7 @@ def test_complex_field(mapper, catalog): npix = 12 * mapper.nside**2 - f = Spin2Field(mapper, "ra", "dec", "g1", "g2", "w") + f = Spin2Field(mapper, "ra", "dec", "g1", "g2", weight="w") m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] @@ -325,7 +312,7 @@ def test_complex_field(mapper, catalog): bias = (4 * np.pi / npix / npix) * v2 / 2 assert m.shape == (2, npix) - assert m.dtype.metadata == { + testdata = { "catalog": catalog.label, "spin": 2, "wbar": pytest.approx(wbar), @@ -336,6 +323,20 @@ def test_complex_field(mapper, catalog): "deconv": mapper.deconvolve, "bias": pytest.approx(bias / wbar**2), } + print(testdata) + print(m.dtype.metadata) + assert m.dtype.metadata == { + "catalog": catalog.label, + "spin": 2, + "wbar": pytest.approx(wbar), + "geometry": "healpix", + "kernel": "healpix", + "nside": mapper.nside, + "lmax": mapper.lmax, + "deconv": mapper.deconvolve, + "bias": pytest.approx(bias / wbar**2, abs=1e-6), + } + np.testing.assert_array_almost_equal(m, 0) @@ -344,7 +345,7 @@ def test_weights(mapper, catalog): npix = 12 * mapper.nside**2 - f = Weights(mapper, "ra", "dec", "w") + f = Weights(mapper, "ra", "dec", weight="w") m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 71e20fac..bbfea623 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -63,5 +63,5 @@ def test_transform(rng): assert len(alms) == 2 assert alms.keys() == {("X", 0), ("Y", 1)} - assert alms["X", 0] is x.mapper_or_error.transform.return_value - assert alms["Y", 1] is y.mapper_or_error.transform.return_value + assert alms["X", 0] is x.mapper.transform.return_value + assert alms["Y", 1] is y.mapper.transform.return_value