Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API(fields): turn spin weight into a class property #93

Merged
merged 2 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,13 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n",
"</pre>\n"
],
"text/plain": [
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n",
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n"
]
Expand Down
67 changes: 32 additions & 35 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -54,11 +53,17 @@ class Field(metaclass=ABCMeta):

"""

def __init__(
self,
*columns: str,
spin: int = 0,
) -> None:
# every field subclass has a static spin weight attribute, which can be
# overwritten by the class (or even an individual instance)
__spin: int | None = None

def __init_subclass__(cls, *, spin: int | None = None) -> None:
"""Initialise spin weight of field subclasses."""
super().__init_subclass__()
if spin is not None:
cls.__spin = spin

def __init__(self, *columns: str) -> None:
"""Initialise the field."""
super().__init__()
self.__columns: Columns | None
Expand All @@ -70,9 +75,9 @@ def __init__(
raise TypeError(msg) from None
else:
self.__columns = None
self._metadata: dict[str, Any] = {
"spin": spin,
}
self._metadata: dict[str, Any] = {}
if (spin := self.__spin) is not None:
self._metadata["spin"] = spin

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -103,7 +108,12 @@ def metadata(self) -> Mapping[str, Any]:
@property
def spin(self) -> int:
"""Spin weight of field."""
return self._metadata["spin"]
spin = self.__spin
if spin is None:
clsname = self.__class__.__name__
msg = f"field of type '{clsname}' has undefined spin weight"
raise ValueError(msg)
return spin

@abstractmethod
async def __call__(
Expand Down Expand Up @@ -137,7 +147,7 @@ async def _pages(
await coroutines.sleep()


class Positions(Field):
class Positions(Field, spin=0):
"""Field of positions in a catalogue.

Can produce both overdensity maps and number count maps, depending
Expand All @@ -152,7 +162,7 @@ def __init__(
nbar: float | None = None,
) -> None:
"""Create a position field."""
super().__init__(*columns, spin=0)
super().__init__(*columns)
self.__overdensity = overdensity
self.__nbar = nbar

Expand Down Expand Up @@ -256,13 +266,9 @@ async def __call__(
return pos


class ScalarField(Field):
class ScalarField(Field, spin=0):
"""Field of real scalar values in a catalogue."""

def __init__(self, *columns: str) -> None:
"""Create a scalar field."""
super().__init__(*columns, spin=0)

@staticmethod
def _init_columns(
lon: str,
Expand Down Expand Up @@ -340,18 +346,14 @@ async def __call__(
return val


class ComplexField(Field):
class ComplexField(Field, spin=0):
"""Field of complex values in a catalogue.

Complex fields can have non-zero spin weight, set using the
``spin=`` parameter.
The :class:`ComplexField` class has zero spin weight, while
subclasses such as :class:`Spin2Field` have non-zero spin weight.

"""

def __init__(self, *columns: str, spin: int = 0) -> None:
"""Create a complex field."""
super().__init__(*columns, spin=spin)

@staticmethod
def _init_columns(
lon: str,
Expand Down Expand Up @@ -430,13 +432,9 @@ async def __call__(
return val


class Visibility(Field):
class Visibility(Field, spin=0):
"""Copy visibility map from catalogue at given resolution."""

def __init__(self) -> None:
"""Create a visibility field."""
super().__init__(spin=0)

@staticmethod
def _init_columns() -> Columns:
return ()
Expand Down Expand Up @@ -474,13 +472,9 @@ async def __call__(
return vmap


class Weights(Field):
class Weights(Field, spin=0):
"""Field of weight values from a catalogue."""

def __init__(self, *columns: str) -> None:
"""Create a weight field."""
super().__init__(*columns, spin=0)

@staticmethod
def _init_columns(lon: str, lat: str, weight: str | None = None) -> Columns:
return lon, lat, weight
Expand Down Expand Up @@ -528,6 +522,9 @@ async def __call__(
return wht


Spin2Field = partial(ComplexField, spin=2)
class Spin2Field(ComplexField, spin=2):
"""Spin-2 complex field."""


Shears = Spin2Field
Ellipticities = Spin2Field
18 changes: 15 additions & 3 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,19 @@ def test_field_abc():
with pytest.raises(TypeError):
Field()

class TestField(Field):
class SpinLessField(Field):
def _init_columns(self, *columns: str) -> Columns:
return columns

async def __call__(self):
pass

f = SpinLessField()

with pytest.raises(ValueError, match="undefined spin weight"):
f.spin

class TestField(Field, spin=0):
@staticmethod
def _init_columns(lon, lat, weight=None) -> Columns:
return lon, lat, weight
Expand Down Expand Up @@ -269,11 +281,11 @@ def test_scalar_field(mapper, catalog):


def test_complex_field(mapper, catalog):
from heracles.fields import ComplexField
from heracles.fields import Spin2Field

npix = 12 * mapper.nside**2

f = ComplexField("ra", "dec", "g1", "g2", "w", spin=2)
f = Spin2Field("ra", "dec", "g1", "g2", "w")
m = coroutines.run(f(catalog, mapper))

w = next(iter(catalog))["w"]
Expand Down
Loading