Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-151: Refactor Fields Module #152

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion heracles/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def transform(
msg = f"unknown field name: {k}"
raise ValueError(msg)

out[k, i] = field.mapper_or_error.transform(m)
out[k, i] = field.mapper.transform(m)

if progress:
subtask.remove()
Expand Down
2 changes: 1 addition & 1 deletion heracles/catalog/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def where(self, selection, visibility=None):

@property
def page_size(self):
"""number of rows per page (default: 100_000)"""
"""number of rows per page (default: 1_000_000)"""
return self._page_size

@page_size.setter
Expand Down
150 changes: 60 additions & 90 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,100 +59,44 @@ class Field(metaclass=ABCMeta):

# every field subclass has a static spin weight attribute, which can be
# overwritten by the class (or even an individual instance)
__spin: int | None = None
_spin: int | None = None

# definition of required and optional columns
__ncol: tuple[int, int]

def __init_subclass__(cls, *, spin: int | None = None) -> None:
"""Initialise spin weight of field subclasses."""
super().__init_subclass__()
if spin is not None:
cls.__spin = spin
uses = cls.uses
if uses is None:
uses = ()
elif isinstance(uses, str):
uses = (uses,)
ncol = len(uses)
nopt = 0
for u in uses[::-1]:
if u.startswith("[") and u.endswith("]"):
nopt += 1
else:
break
cls.__ncol = (ncol - nopt, ncol)

def __init__(
self,
mapper: Mapper | None,
*columns: str,
weight: str | None = None,
mask: str | None = None,
) -> None:
"""Initialise the field."""
super().__init__()
self.__mapper = mapper
self.__columns = self._init_columns(*columns) if columns else None
self.__columns = columns if columns else None
self.__weight = weight if weight else None
self.__mask = mask

@classmethod
def _init_columns(cls, *columns: str) -> Columns:
"""Initialise the given set of columns for a specific field
subclass."""
nmin, nmax = cls.__ncol
if not nmin <= len(columns) <= nmax:
uses = cls.uses
if uses is None:
uses = ()
if isinstance(uses, str):
uses = (uses,)
count = f"{nmin}"
if nmax != nmin:
count += f" to {nmax}"
msg = f"field of type '{cls.__name__}' accepts {count} columns"
if uses:
msg += " (" + ", ".join(uses) + ")"
msg += f", received {len(columns)}"
raise ValueError(msg)
return columns + (None,) * (nmax - len(columns))
self._spin = 0

@property
def mapper(self) -> Mapper | None:
"""Return the mapper used by this field."""
return self.__mapper

@property
def mapper_or_error(self) -> Mapper:
"""Return the mapper used by this field, or raise a :class:`ValueError`
if not set."""
if self.__mapper is None:
msg = "no mapper for field"
raise ValueError(msg)
return self.__mapper
def weight(self) -> str | None:
"""Return the mapper used by this field."""
return self.__weight

@property
def columns(self) -> Columns | None:
"""Return the catalogue columns used by this field."""
return self.__columns

@property
def columns_or_error(self) -> Columns:
"""Return the catalogue columns used by this field, or raise a
:class:`ValueError` if not set."""
if self.__columns is None:
msg = "no columns for field"
raise ValueError(msg)
return self.__columns

@property
def spin(self) -> int:
"""Spin weight of field."""
spin = self.__spin
if spin is None:
clsname = self.__class__.__name__
msg = f"field of type '{clsname}' has undefined spin weight"
raise ValueError(msg)
return spin
return self._spin

@property
def mask(self) -> str | None:
Expand All @@ -169,6 +113,17 @@ async def __call__(
"""Implementation for mapping a catalogue."""
...

def CheckColumns(self, *expected):
if self.columns is None:
msg = "No columns defined!"
raise ValueError(msg)
if len(expected) != len(self.columns):
error = "Column error. Expected " + str(len(expected)) + " columns"
error += (
" with a format " + str(expected) + ". Received " + str(self.columns)
)
raise ValueError(error)


async def _pages(
catalog: Catalog,
Expand All @@ -190,26 +145,25 @@ async def _pages(
await coroutines.sleep()


class Positions(Field, spin=0):
class Positions(Field):
"""Field of positions in a catalogue.

Can produce both overdensity maps and number count maps, depending
on the ``overdensity`` property.

"""

uses = "longitude", "latitude"

def __init__(
self,
mapper: Mapper | None,
*columns: str,
weight: str | None,
overdensity: bool = True,
nbar: float | None = None,
mask: str | None = None,
) -> None:
"""Create a position field."""
super().__init__(mapper, *columns, mask=mask)
super().__init__(mapper, *columns, weight=weight, mask=mask)
self.__overdensity = overdensity
self.__nbar = nbar

Expand Down Expand Up @@ -242,10 +196,13 @@ async def __call__(
raise ValueError(msg)

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get catalogue column definition
col = self.columns_or_error
col = self.columns
self.CheckColumns("longitude", "latitude")

# if(len(col)!=2):

# position map
pos = mapper.create(spin=self.spin)
Expand All @@ -259,7 +216,7 @@ async def __call__(
lon, lat = page.get(*col)
w = np.ones(page.size)

mapper.map_values(lon, lat, pos, w)
self.mapper.map_values(lon, lat, pos, w)

ngal += page.size

Expand Down Expand Up @@ -307,11 +264,9 @@ async def __call__(
return pos


class ScalarField(Field, spin=0):
class ScalarField(Field):
"""Field of real scalar values in a catalogue."""

uses = "longitude", "latitude", "value", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -321,11 +276,13 @@ async def __call__(
"""Map real values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the column definition of the catalogue
*col, wcol = self.columns_or_error
col = self.columns
self.CheckColumns("longitude", "latitude", "value")

wcol = self.weight
# scalar field map
val = mapper.create(spin=self.spin)

Expand Down Expand Up @@ -373,16 +330,14 @@ async def __call__(
return val


class ComplexField(Field, spin=0):
class ComplexField(Field):
"""Field of complex values in a catalogue.

The :class:`ComplexField` class has zero spin weight, while
subclasses such as :class:`Spin2Field` have non-zero spin weight.

"""

uses = "longitude", "latitude", "real", "imag", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -392,10 +347,14 @@ async def __call__(
"""Map complex values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the column definition of the catalogue
*col, wcol = self.columns_or_error
col = self.columns

self.CheckColumns("longitude", "latitude", "real", "imag")

wcol = self.weight

# complex map with real and imaginary part
val = mapper.create(2, spin=self.spin)
Expand Down Expand Up @@ -443,7 +402,7 @@ async def __call__(
return val


class Visibility(Field, spin=0):
class Visibility(Field):
"""Copy visibility map from catalogue at given resolution."""

async def __call__(
Expand All @@ -455,7 +414,7 @@ async def __call__(
"""Create a visibility map from the given catalogue."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# make sure that catalogue has a visibility
visibility = catalog.visibility
Expand All @@ -479,11 +438,9 @@ async def __call__(
return out


class Weights(Field, spin=0):
class Weights(Field):
"""Field of weight values from a catalogue."""

uses = "longitude", "latitude", "[weight]"

async def __call__(
self,
catalog: Catalog,
Expand All @@ -493,10 +450,12 @@ async def __call__(
"""Map catalogue weights."""

# get mapper
mapper = self.mapper_or_error
mapper = self.mapper

# get the columns for this field
*col, wcol = self.columns_or_error
col = self.columns
self.CheckColumns("longitude", "latitude")
wcol = self.weight

# weight map
wht = mapper.create(spin=self.spin)
Expand Down Expand Up @@ -543,9 +502,20 @@ async def __call__(
return wht


class Spin2Field(ComplexField, spin=2):
class Spin2Field(ComplexField):
"""Spin-2 complex field."""

def __init__(
self,
mapper: Mapper | None,
*columns: str,
weight: str | None,
mask: str | None = None,
) -> None:
"""Initialise the field."""
super().__init__(mapper, *columns, weight=weight, mask=mask)
self._spin = 2


Shears = Spin2Field
Ellipticities = Spin2Field
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"coroutines",
"fitsio",
"healpy",
"matplotlib",
"numba",
"numpy",
]
Expand Down
Loading
Loading