diff --git a/heracles/healpy.py b/heracles/healpy.py index 9f3759c..3efa709 100644 --- a/heracles/healpy.py +++ b/heracles/healpy.py @@ -178,7 +178,7 @@ def transform(self, data: NDArray[Any]) -> NDArray[Any]: msg = f"spin-{spin} maps not yet supported" raise NotImplementedError(msg) - alms = hp.map2alm( + alm = hp.map2alm( data, lmax=self.__lmax, pol=pol, @@ -189,16 +189,16 @@ def transform(self, data: NDArray[Any]) -> NDArray[Any]: if pw is not None: fl = np.ones(self.__lmax + 1) fl[abs(spin) :] /= pw[abs(spin) :] - for alm in alms: - hp.almxfl(alm, fl, inplace=True) + for i in np.ndindex(*alm.shape[:-1]): + alm[i] = hp.almxfl(alm[i], fl) del fl if spin != 0: - alms = alms[1:].copy() + alm = alm[1:].copy() - update_metadata(alms, **md) + update_metadata(alm, **md) - return alms + return alm def resample(self, data: NDArray[Any]) -> NDArray[Any]: """ diff --git a/tests/test_healpy.py b/tests/test_healpy.py index 77a6095..117662a 100644 --- a/tests/test_healpy.py +++ b/tests/test_healpy.py @@ -11,8 +11,10 @@ else: HAVE_HEALPY = True +skipif_not_healpy = pytest.mark.skipif(not HAVE_HEALPY, reason="test requires healpy") -@pytest.mark.skipif(not HAVE_HEALPY, reason="test requires healpy") + +@skipif_not_healpy def test_healpix_maps(rng): from heracles.healpy import HealpixMapper from heracles.mapper import Mapper @@ -76,7 +78,7 @@ def test_healpix_maps(rng): npt.assert_array_equal(m, expected) -@pytest.mark.skipif(not HAVE_HEALPY, reason="test requires healpy") +@skipif_not_healpy @unittest.mock.patch("healpy.map2alm") def test_healpix_transform(mock_map2alm, rng): from heracles.core import update_metadata @@ -112,3 +114,50 @@ def test_healpix_transform(mock_map2alm, rng): assert alms.dtype.metadata["spin"] == 2 assert alms.dtype.metadata["b"] == 2 assert alms.dtype.metadata["nside"] == nside + + +@skipif_not_healpy +@unittest.mock.patch("healpy.map2alm") +def test_healpix_deconvolve(mock_map2alm): + from heracles.core import update_metadata + from heracles.healpy import HealpixMapper + + nside = 32 + npix = 12 * nside**2 + + lmax = 48 + nlm = (lmax + 1) * (lmax + 2) // 2 + + pw0, pw2 = hp.pixwin(nside, lmax=lmax, pol=True) + pw2[:2] = 1.0 + + mapper = HealpixMapper(nside, lmax, deconvolve=True) + + # single scalar map + data = np.zeros(npix) + update_metadata(data, spin=0) + + mock_map2alm.return_value = np.ones(nlm, dtype=complex) + + alm = mapper.transform(data) + + assert alm.shape == (nlm,) + stop = 0 + for m in range(lmax + 1): + start, stop = stop, stop + lmax - m + 1 + npt.assert_array_equal(alm[start:stop], 1.0 / pw0[m:]) + + # polarisation map + data = np.zeros((2, npix)) + update_metadata(data, spin=2) + + mock_map2alm.return_value = np.ones((3, nlm), dtype=complex) + + alm = mapper.transform(data) + + assert alm.shape == (2, nlm) + stop = 0 + for m in range(lmax + 1): + start, stop = stop, stop + lmax - m + 1 + npt.assert_array_equal(alm[0, start:stop], 1.0 / pw2[m:]) + npt.assert_array_equal(alm[1, start:stop], 1.0 / pw2[m:])