From 84e4be0df96a7419479138423ed4fc1d06062282 Mon Sep 17 00:00:00 2001
From: Nicolas Tessore <n.tessore@ucl.ac.uk>
Date: Thu, 30 Nov 2023 21:54:37 +0000
Subject: [PATCH] refactor read_mask() into read_vmap()

---
 heracles/io.py   |  35 ++++--------
 tests/test_io.py | 145 ++++++++++++++---------------------------------
 2 files changed, 55 insertions(+), 125 deletions(-)

diff --git a/heracles/io.py b/heracles/io.py
index 9eff0d0..c661ff6 100644
--- a/heracles/io.py
+++ b/heracles/io.py
@@ -24,6 +24,7 @@
 from functools import partial
 from pathlib import Path
 from types import MappingProxyType
+from warnings import warn
 from weakref import WeakValueDictionary
 
 import fitsio
@@ -196,33 +197,19 @@ def _read_twopoint(fits, ext):
     return arr
 
 
-def read_mask(mask_name, nside=None, field=0, extra_mask_name=None):
+def read_vmap(filename, nside=None, field=0):
     """read visibility map from a HEALPix map file"""
-    mask = hp.read_map(mask_name, field=field)
+    vmap = hp.read_map(filename, field=field, dtype=float)
 
     # set unseen pixels to zero
-    unseen = np.where(mask == hp.UNSEEN)
-    mask[unseen] = 0
-
-    nside_mask = hp.get_nside(mask)
-
-    if nside is not None:
-        # mask is provided at a different resolution
-        if nside_mask < nside:
-            print("WARNING: Nside of mask < Nside of requested maps")
-        if nside_mask != nside:
-            mask = hp.ud_grade(mask, nside)
-            nside_mask = nside
-
-    # apply extra mask if given
-    if extra_mask_name is not None:
-        extra_mask = hp.read_map(extra_mask_name)
-        nside_extra = hp.get_nside(extra_mask)
-        if nside_extra != nside_mask:
-            extra_mask = hp.ud_grade(extra_mask, nside_mask)
-        mask *= extra_mask
-
-    return mask
+    vmap[vmap == hp.UNSEEN] = 0
+
+    if nside is not None and nside != hp.get_nside(vmap):
+        # vmap is provided at a different resolution
+        warn(f"{filename}: changing NSIDE to {nside}")
+        vmap = hp.ud_grade(vmap, nside)
+
+    return vmap
 
 
 def write_maps(
diff --git a/tests/test_io.py b/tests/test_io.py
index 3e128cf..968d32b 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -1,7 +1,5 @@
 import pytest
 
-NFIELDS_TEST = 4
-
 
 @pytest.fixture
 def zbins():
@@ -79,26 +77,28 @@ def datadir(tmp_path_factory):
 
 
 @pytest.fixture(scope="session")
-def mock_mask_fields(nside, rng):
+def mock_vmap_fields(nside, rng):
     import healpy as hp
     import numpy as np
 
+    nfields = 4
+
     npix = hp.nside2npix(nside)
-    maps = rng.random(npix * NFIELDS_TEST).reshape((npix, NFIELDS_TEST))
+    maps = rng.random(npix * nfields).reshape((npix, nfields))
     pixels = np.unique(rng.integers(0, npix, size=npix // 3))
-    maskpix = np.delete(np.arange(0, npix), pixels)
-    for i in range(NFIELDS_TEST):
-        maps[:, i][maskpix] = 0
+    vmappix = np.delete(np.arange(0, npix), pixels)
+    for i in range(nfields):
+        maps[:, i][vmappix] = 0
     return [maps, pixels]
 
 
 @pytest.fixture(scope="session")
-def mock_writemask_partial(mock_mask_fields, nside, datadir):
+def mock_vmap_partial(mock_vmap_fields, nside, datadir):
     import fitsio
 
-    filename = str(datadir / "mask_partial.fits")
+    filename = str(datadir / "vmap_partial.fits")
 
-    maps, pixels = mock_mask_fields
+    maps, pixels = mock_vmap_fields
 
     fits = fitsio.FITS(filename, "rw")
     fits.write(data=None)
@@ -124,12 +124,12 @@ def mock_writemask_partial(mock_mask_fields, nside, datadir):
 
 
 @pytest.fixture(scope="session")
-def mock_writemask_full(mock_mask_fields, nside, datadir):
+def mock_vmap(mock_vmap_fields, nside, datadir):
     import fitsio
 
-    filename = str(datadir / "mask_full.fits")
+    filename = str(datadir / "vmap.fits")
 
-    maps, _ = mock_mask_fields
+    maps, _ = mock_vmap_fields
 
     fits = fitsio.FITS(filename, "rw")
     fits.write(data=None)
@@ -148,44 +148,6 @@ def mock_writemask_full(mock_mask_fields, nside, datadir):
     return filename
 
 
-@pytest.fixture(scope="session")
-def mock_mask_extra(nside, rng):
-    import healpy as hp
-    import numpy as np
-
-    npix = hp.nside2npix(nside)
-    maps = rng.random(npix)
-    pixels = np.unique(rng.integers(0, npix, size=npix // 3))
-    maskpix = np.delete(np.arange(0, npix), pixels)
-    maps[maskpix] = 0
-    return [maps, pixels]
-
-
-@pytest.fixture(scope="session")
-def mock_writemask_extra(mock_mask_extra, nside, datadir):
-    import fitsio
-
-    filename = str(datadir / "mask_extra.fits")
-
-    maps, _ = mock_mask_extra
-
-    fits = fitsio.FITS(filename, "rw")
-    fits.write(data=None)
-    fits.write_table(
-        [maps],
-        names=["WEIGHT"],
-        header={
-            "NSIDE": nside,
-            "ORDERING": "RING",
-            "INDXSCHM": "IMPLICIT",
-            "OBJECT": "FULLSKY",
-        },
-    )
-    fits.close()
-
-    return filename
-
-
 def test_write_read_maps(rng, tmp_path):
     import healpy as hp
     import numpy as np
@@ -318,71 +280,52 @@ def test_write_read_cov(mock_cls, tmp_path):
         np.testing.assert_array_equal(cov_[key], cov[key])
 
 
-def test_read_mask_partial(mock_mask_fields, mock_writemask_partial, nside):
+def test_read_vmap_partial(mock_vmap_fields, mock_vmap_partial, nside):
     import healpy as hp
 
-    from heracles.io import read_mask
+    from heracles.io import read_vmap
 
-    maps = mock_mask_fields[0]
+    maps = mock_vmap_fields[0]
 
-    mask = read_mask(mock_writemask_partial, nside=nside)
-    assert (mask == maps[:, 0]).all()
+    vmap = read_vmap(mock_vmap_partial, nside=nside)
+    assert (vmap == maps[:, 0]).all()
 
-    ibin = 2
-    mask = read_mask(mock_writemask_partial, nside=nside, field=ibin)
-    assert (mask == maps[:, ibin]).all()
+    field = 2
+    vmap = read_vmap(mock_vmap_partial, nside=nside, field=field)
+    assert (vmap == maps[:, field]).all()
 
-    ibin = 3
-    mask = read_mask(mock_writemask_partial, nside=nside // 2, field=ibin)
-    maskud = hp.pixelfunc.ud_grade(maps[:, ibin], nside // 2)
-    assert (mask == maskud).all()
+    field = 3
+    with pytest.warns():
+        vmap = read_vmap(mock_vmap_partial, nside=nside // 2, field=field)
+    vmapud = hp.pixelfunc.ud_grade(maps[:, field], nside // 2)
+    assert (vmap == vmapud).all()
 
 
-def test_read_mask_full(mock_mask_fields, mock_writemask_full, nside):
+def test_read_vmap(mock_vmap_fields, mock_vmap, nside):
     import healpy as hp
 
-    from heracles.io import read_mask
-
-    maps = mock_mask_fields[0]
+    from heracles.io import read_vmap
 
-    mask = read_mask(mock_writemask_full, nside=nside)
-    assert (mask == maps[:, 0]).all()
+    maps = mock_vmap_fields[0]
 
-    ibin = 2
-    mask = read_mask(mock_writemask_full, nside=nside, field=ibin)
-    assert (mask == maps[:, ibin]).all()
+    vmap = read_vmap(mock_vmap, nside=nside)
+    assert (vmap == maps[:, 0]).all()
 
-    ibin = 3
-    mask = read_mask(mock_writemask_full, nside=nside // 2, field=ibin)
-    maskud = hp.pixelfunc.ud_grade(maps[:, ibin], nside // 2)
-    assert (mask == maskud).all()
+    fields = 2
+    vmap = read_vmap(mock_vmap, nside=nside, field=fields)
+    assert (vmap == maps[:, fields]).all()
 
-    ibin = 3
-    mask = read_mask(mock_writemask_full, nside=nside * 2, field=ibin)
-    maskud = hp.pixelfunc.ud_grade(maps[:, ibin], nside * 2)
-    assert (mask == maskud).all()
+    fields = 3
+    with pytest.warns():
+        vmap = read_vmap(mock_vmap, nside=nside // 2, field=fields)
+    vmapud = hp.pixelfunc.ud_grade(maps[:, fields], nside // 2)
+    assert (vmap == vmapud).all()
 
-
-def test_read_mask_extra(
-    mock_mask_fields,
-    mock_mask_extra,
-    mock_writemask_full,
-    nside,
-    mock_writemask_extra,
-):
-    from heracles.io import read_mask
-
-    maps = mock_mask_fields[0]
-    maps_extra = mock_mask_extra[0]
-
-    ibin = 2
-    mask = read_mask(
-        mock_writemask_full,
-        nside=nside,
-        field=ibin,
-        extra_mask_name=mock_writemask_extra,
-    )
-    assert (mask == maps[:, ibin] * maps_extra[:]).all()
+    fields = 3
+    with pytest.warns():
+        vmap = read_vmap(mock_vmap, nside=nside * 2, field=fields)
+    vmapud = hp.pixelfunc.ud_grade(maps[:, fields], nside * 2)
+    assert (vmap == vmapud).all()
 
 
 def test_tocfits(tmp_path):