Skip to content

Commit

Permalink
fixup! API: make bias work for HEALPix and spherical convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Jan 16, 2024
1 parent 851762b commit be89908
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 97 deletions.
10 changes: 9 additions & 1 deletion heracles/maps/_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _mapw(ipix, maps, values, weight):
maps[k][i] += w * values[k][j]


class Healpix(Mapper, kernel="healpix", dirty=True):
class Healpix(Mapper, kernel="healpix"):
"""
Mapper for HEALPix maps. HEALPix maps have a resolution parameter,
available as the *nside* property.
Expand Down Expand Up @@ -262,3 +262,11 @@ def kl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
raise ValueError(msg)

return pw

def bl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Return the biasing kernel for HEALPix.
"""
kl = self.kl(lmax, spin)
where = np.arange(lmax + 1) >= abs(spin)
return np.divide(1.0, kl, where=where, out=np.zeros(lmax + 1))
20 changes: 9 additions & 11 deletions heracles/maps/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Self
Expand Down Expand Up @@ -70,15 +72,13 @@ class Mapper(metaclass=ABCMeta):
"""

__kernel: str
__dirty: bool

def __init_subclass__(cls, /, kernel: str, dirty: bool = False, **kwargs) -> None:
def __init_subclass__(cls, /, kernel: str, **kwargs) -> None:
"""
Initialise mapper subclasses with a *kernel* parameter.
"""
super().__init_subclass__(**kwargs)
cls.__kernel = kernel
cls.__dirty = dirty
_KERNELS[kernel] = cls

@classmethod
Expand All @@ -103,14 +103,6 @@ def kernel(self) -> str:
"""
return self.__kernel

@property
def dirty(self) -> bool:
"""
Return whether this mapper leaves its convolution kernel
imprinted after a spherical harmonic transform.
"""
return self.__dirty

@property
def metadata(self) -> Mapping[str, Any]:
"""
Expand Down Expand Up @@ -167,3 +159,9 @@ def kl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Return the convolution kernel in harmonic space.
"""

def bl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Return the biasing kernel in harmonic space.
"""
return np.where(np.arange(lmax + 1) < abs(spin), 0.0, 1.0)
28 changes: 28 additions & 0 deletions heracles/maps/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING, Any

import coroutines
import numpy as np

from heracles.core import TocDict, multi_value_getter, toc_match

Expand Down Expand Up @@ -149,10 +150,34 @@ def map_catalogs(
return out


def deconvolved(
mapper: Mapper,
alm: NDArray[Any] | tuple[NDArray[Any], ...],
*,
inplace: bool = False,
) -> NDArray[Any] | tuple[NDArray[Any], ...]:
"""
Divide *alm* by the spherical convolution kernel of *mapper*.
"""
from healpy import Alm, almxfl

if isinstance(alm, tuple):
result = tuple(deconvolved(mapper, a, inplace=inplace) for a in alm)
return result if not inplace else alm

lmax = Alm.getlmax(alm.shape[-1])
spin = (alm.dtype.metadata or {}).get("spin", 0)
kl = mapper.kl(lmax=lmax, spin=spin)
where = np.arange(lmax + 1) >= abs(spin)
fl = np.divide(1, kl, where=where, out=np.ones(lmax + 1))
return almxfl(alm, fl, inplace=inplace)


def transform_maps(
maps: Mapping[tuple[Any, Any], NDArray],
*,
lmax: int | Mapping[Any, int] | None = None,
deconvolve: bool = True,
out: MutableMapping[tuple[Any, Any], NDArray] | None = None,
progress: bool = False,
**kwargs,
Expand Down Expand Up @@ -197,6 +222,9 @@ def transform_maps(

alms = mapper.transform(m, _lmax)

if deconvolve:
deconvolved(mapper, alms, inplace=True)

if isinstance(alms, tuple):
out[f"{k}_E", i] = alms[0]
out[f"{k}_B", i] = alms[1]
Expand Down
153 changes: 69 additions & 84 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,7 @@
import healpy as hp
import numpy as np

from .core import (
TocDict,
items_with_suffix,
multi_value_getter,
toc_match,
update_metadata,
)
from .core import TocDict, items_with_suffix, toc_match, update_metadata
from .maps import mapper_from_dict

if TYPE_CHECKING:
Expand All @@ -45,7 +39,6 @@
from numpy.typing import ArrayLike, NDArray

from .fields import Field
from .maps import Mapper
from .progress import Progress

# type alias for the keys of two-point data
Expand All @@ -54,6 +47,56 @@
logger = logging.getLogger(__name__)


def _debias_cl(
cl: NDArray[Any],
bias: float | None = None,
*,
inplace: bool = False,
) -> NDArray[Any]:
"""
Remove additive bias from angular power spectrum.
"""

md = cl.dtype.metadata or {}

if not inplace:
cl = cl.copy()
update_metadata(cl, **md)

lmax = len(cl) - 1

# spins of the spectrum
spin1, spin2 = md.get("spin_1", 0), md.get("spin_2", 0)

# use explicit bias, if given, or bias value from metadata
b = md.get("bias", 0.0) if bias is None else bias

# store the bias value used here in metadata
update_metadata(cl, bias=b)

# apply biasing kernel from mappers
try:
mapper = mapper_from_dict(items_with_suffix(md, "_1"))
except ValueError:
pass
else:
b = b * mapper.bl(lmax=lmax, spin=spin1)
try:
mapper = mapper_from_dict(items_with_suffix(md, "_2"))
except ValueError:
pass
else:
b = b * mapper.bl(lmax=lmax, spin=spin2)

# remove bias
if cl.dtype.names is None:
cl -= b
else:
cl["CL"] -= b

return cl


def angular_power_spectra(
alms,
alms2=None,
Expand Down Expand Up @@ -147,20 +190,16 @@ def angular_power_spectra(
md["bias"] = bias
if bcor is not None:
md["bcor"] = bcor
update_metadata(cl, **md)

# debias cl if asked to
if debias and bias is not None:
# minimum l for correction
_lmin = max(abs(md.get("spin_1", 0)), abs(md.get("spin_2", 0)))
cl[_lmin:] -= bias
_debias_cl(cl, bias, inplace=True)

# if bins are given, apply the binning
if bins is not None:
cl = bin2pt(cl, bins, "CL", weights=weights)

# write metadata for this spectrum
update_metadata(cl, **md)

# add cl to the set
cls[k1, k2, i1, i2] = cl

Expand Down Expand Up @@ -191,31 +230,7 @@ def debias_cls(cls, bias=None, *, inplace=False):
logger.info("debiasing %s x %s cl for bins %s, %s", *key)

cl = cls[key]
md = cl.dtype.metadata or {}

if not inplace:
cl = cl.copy()
update_metadata(cl, **md)

# minimum l for correction
lmin = max(abs(md.get("spin_1", 0)), abs(md.get("spin_2", 0)))

# get bias from explicit dict, if given, or metadata
if bias is None:
b = md.get("bias", 0.0)
else:
b = bias.get(key, 0.0)

# remove bias
if cl.dtype.names is None:
cl[lmin:] -= b
else:
cl["CL"][lmin:] -= b

# write noise bias to corrected cl
update_metadata(cl, bias=b)

# store debiased cl in output set
_debias_cl(cl, bias and bias.get(key), inplace=inplace)
out[key] = cl

logger.info(
Expand All @@ -229,9 +244,8 @@ def debias_cls(cls, bias=None, *, inplace=False):


def mixing_matrices(
cls: Mapping[TwoPointKey, NDArray[Any]],
fields: Mapping[Any, Field],
mapper: Mapper | Mapping[Any, Mapper] | None = None,
cls: Mapping[TwoPointKey, NDArray[Any]],
*,
l1max: int | None = None,
l2max: int | None = None,
Expand All @@ -245,9 +259,6 @@ def mixing_matrices(

from convolvecl import mixmat, mixmat_eb

# getter for mapper value or dict
mappergetter = multi_value_getter(mapper) if mapper is not None else None

# output dictionary if not provided
if out is None:
out = TocDict()
Expand Down Expand Up @@ -284,29 +295,10 @@ def mixing_matrices(
except KeyError:
continue

# metadata for weight spectrum
md = cl.dtype.metadata or {}

# deal with structured cl arrays
if cl.dtype.names is not None:
cl = cl["CL"]

# deconvolve the kernels of the first and second map
try:
_mapper = mapper_from_dict(items_with_suffix(md, "_1"))
except ValueError:
pass
else:
if _mapper.dirty:
cl = cl / _mapper.kl(len(cl) - 1)
try:
_mapper = mapper_from_dict(items_with_suffix(md, "_2"))
except ValueError:
pass
else:
if _mapper.dirty:
cl = cl / _mapper.kl(len(cl) - 1)

# compute mixing matrices for all fields of this mask combination
for f1, f2 in product(fields1, fields2):
# check if this combination has been done already
Expand All @@ -326,9 +318,6 @@ def mixing_matrices(
# get spins of fields
spin1, spin2 = fields1[f1].spin, fields2[f2].spin

# mixing matrices to be added
tba = {}

# if any spin is zero, then there is no E/B decomposition
if spin1 == 0 or spin2 == 0:
mm = mixmat(
Expand All @@ -338,9 +327,11 @@ def mixing_matrices(
l3max=l3max,
spin=(spin1, spin2),
)
if bins is not None:
mm = bin2pt(mm, bins, "MM", weights=weights)
name1 = f1 if spin1 == 0 else f"{f1}_E"
name2 = f2 if spin2 == 0 else f"{f2}_E"
tba[name1, name2, i1, i2] = mm
out[name1, name2, i1, i2] = mm
del mm
else:
# E/B decomposition for mixing matrix
Expand All @@ -351,24 +342,14 @@ def mixing_matrices(
l3max=l3max,
spin=(spin1, spin2),
)
tba[f"{f1}_E", f"{f2}_E", i1, i2] = mm_ee
tba[f"{f1}_B", f"{f2}_B", i1, i2] = mm_bb
tba[f"{f1}_E", f"{f2}_B", i1, i2] = mm_eb
del mm_ee, mm_bb, mm_eb

# post-process mixing matrices before adding them to out
for key, mm in tba.items():
if mappergetter is not None:
if (_mapper := mappergetter((f1, i1))).dirty:
mm *= _mapper.kl(len(mm) - 1, spin1)[:, None]
if (_mapper := mappergetter((f2, i2))).dirty:
mm *= _mapper.kl(len(mm) - 1, spin2)[:, None]
if bins is not None:
mm = bin2pt(mm, bins, "MM", weights=weights)
update_metadata(mm, spin_1=spin1, spin_2=spin2)
out[key] = mm
del mm
del tba
mm_ee = bin2pt(mm_ee, bins, "MM", weights=weights)
mm_bb = bin2pt(mm_bb, bins, "MM", weights=weights)
mm_eb = bin2pt(mm_eb, bins, "MM", weights=weights)
out[f"{f1}_E", f"{f2}_E", i1, i2] = mm_ee
out[f"{f1}_B", f"{f2}_B", i1, i2] = mm_bb
out[f"{f1}_E", f"{f2}_B", i1, i2] = mm_eb
del mm_ee, mm_bb, mm_eb

if prog is not None:
subtask.remove()
Expand Down Expand Up @@ -444,6 +425,10 @@ def norm(a, b):
# add weights
binned["W"] = wb

# copy metadata, if there is any
if arr.dtype.metadata is not None:
update_metadata(binned, **arr.dtype.metadata)

# all done
return binned

Expand Down
3 changes: 2 additions & 1 deletion tests/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def test_map_catalogs_match():


@unittest.mock.patch.dict("heracles.maps._mapper._KERNELS", clear=True)
def test_transform_maps(rng):
@unittest.mock.patch("heracles.maps._mapping.deconvolved")
def test_transform_maps(mock_deconvolved, rng):
import numpy as np

from heracles.maps import transform_maps
Expand Down

0 comments on commit be89908

Please sign in to comment.