Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/fix/act flags #947

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
69a337d
linting to debug eventually
iparask Aug 20, 2024
e2f259b
Fix for flags and downsampling. This will allow the ML/depth1 mapmake…
chervias Sep 6, 2024
7b3c00b
setting flags input
iparask Sep 9, 2024
ee19ad0
configurable parsing for flags
iparask Sep 9, 2024
6a452f6
Merge branch 'master' into gp/fix/act_flags
iparask Sep 9, 2024
bd0468d
reverting linting change
iparask Sep 9, 2024
7c44e53
adding missing call
iparask Sep 9, 2024
51b0e3b
moving glitch flags path to init
iparask Sep 23, 2024
815dabc
fix:syntax error
iparask Sep 24, 2024
1519f85
Merge branch 'master' into gp/fix/act_flags
iparask Sep 24, 2024
0a48177
glitch flags in MLmapmaker plus some utils changes
iparask Sep 24, 2024
5a23955
wip: comment addresing
iparask Sep 30, 2024
e0af0f3
wip: docstring in numpy style and unit tests
Oct 1, 2024
9ea8366
wip: fix ctime
iparask Oct 3, 2024
964481d
removing unused imports
iparask Nov 13, 2024
d4d6659
Merge branch 'master' into gp/fix/act_flags
iparask Nov 13, 2024
5f8121f
Merge branch 'master' into gp/fix/act_flags
iparask Nov 15, 2024
98449b7
Merge branch 'master' into gp/fix/act_flags
iparask Dec 12, 2024
09f2b5d
Merge branch 'gp/fix/act_flags' of https://github.com/simonsobs/sotod…
iparask Dec 13, 2024
ae0aabb
redistributing tiled map for multipass
iparask Dec 19, 2024
8097ac4
Updating with Sigurd's suggestions
iparask Jan 8, 2025
176108c
Merge branch 'master' into gp/fix/act_flags
iparask Jan 8, 2025
7a3473c
wip: removing if tiled check
iparask Jan 8, 2025
3e5b0c5
Merge branch 'gp/fix/act_flags' of https://github.com/simonsobs/sotod…
iparask Jan 8, 2025
f9a6a45
wip: fixing autoformat
iparask Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sotodlib/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import logging
import numpy as np

from typing import Union, Dict, Tuple, List
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to context.py are just noise; please revert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will. I put them when I was trying to understand what was happening.


from . import metadata
from .util import tag_substr
from .axisman import AxisManager, OffsetAxis, AxisInterface

logger = logging.getLogger(__name__)


class Context(odict):
# Sets of special handlers may be registered in this class variable, then
# requested by name in the context.yaml key "context_hooks".
Expand Down Expand Up @@ -322,7 +325,8 @@ def get_meta(self,
check=False,
ignore_missing=False,
on_missing=None,
det_info_scan=False):
det_info_scan=False
):
"""Load supporting metadata for an observation and return it in an
AxisManager.

Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/g3_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from spt3g import core
from so3g.spt3g import core
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an import bug.



class DataG3Module(object):
Expand Down
63 changes: 39 additions & 24 deletions sotodlib/mapmaking/ml_mapmaker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import numpy as np
from pixell import enmap, utils, tilemap, bunch
import h5py
import so3g
from typing import Optional
from pixell import bunch, enmap, tilemap
from pixell import utils as putils

from .. import coords
from .utilities import *
from .pointing_matrix import *
from .pointing_matrix import PmatCut
from .utilities import (MultiZipper, get_flags_from_path, recentering_to_quat_lonlat,
evaluate_recentering, TileMapZipper, MapZipper,
safe_invert_div, unarr, ArrayZipper)
from .noise_model import NmatUncorr


class MLMapmaker:
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False):
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False, glitch_flags:str = "flags.glitch_flags"):
"""Initialize a Maximum Likelihood Mapmaker.
Arguments:
* signals: List of Signal-objects representing the models that will be solved
Expand All @@ -26,6 +34,7 @@ def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False
self.data = []
self.dof = MultiZipper()
self.ready = False
self.glitch_flags_path = glitch_flags

def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None):
# Prepare our tod
Expand All @@ -36,7 +45,7 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# the noise model, if available
if signal_estimate is not None: tod -= signal_estimate
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# Allow the user to override the noise model on a per-obs level
if noise_model is None: noise_model = self.noise_model
# Build the noise model from the obs unless a fully
Expand All @@ -55,12 +64,12 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# The signal estimate might not be desloped, so
# adding it back can reintroduce a slope. Fix that here.
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# And apply it to the tod
tod = nmat.apply(tod)
# Add the observation to each of our signals
for signal in self.signals:
signal.add_obs(id, obs, nmat, tod)
signal.add_obs(id, obs, nmat, tod, glitch_flags=self.glitch_flags_path)
# Save what we need about this observation
self.data.append(bunch.Bunch(id=id, ndet=obs.dets.count, nsamp=len(ctime),
dets=obs.dets.vals, nmat=nmat))
Expand Down Expand Up @@ -119,7 +128,7 @@ def solve(self, maxiter=500, maxerr=1e-6, x0=None):
self.prepare()
rhs = self.dof.zip(*[signal.rhs for signal in self.signals])
if x0 is not None: x0 = self.dof.zip(*x0)
solver = utils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
solver = putils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
while solver.i < maxiter and solver.err > maxerr:
solver.step()
yield bunch.Bunch(i=solver.i, err=solver.err, x=self.dof.unzip(solver.x))
Expand All @@ -146,7 +155,7 @@ def transeval(self, id, obs, other, x, tod=None):

class Signal:
"""This class represents a thing we want to solve for, e.g. the sky, ground, cut samples, etc."""
def __init__(self, name, ofmt, output, ext):
def __init__(self, name, ofmt, output, ext, glitch_flags: str = "flags.glitch_flags"):
mhasself marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize a Signal. It probably doesn't make sense to construct a generic signal
directly, though. Use one of the subclasses.
Arguments:
Expand All @@ -161,7 +170,8 @@ def __init__(self, name, ofmt, output, ext):
self.ext = ext
self.dof = None
self.ready = False
def add_obs(self, id, obs, nmat, Nd): pass
self.glitch_flags = glitch_flags
def add_obs(self, id, obs, nmat, Nd, glitch_flags:Optional[str]): pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question.

Can it be replaced with a generic **kwargs, to imply "yes sometimes a subclass will accept and process other random args."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can

def prepare(self): self.ready = True
def forward (self, id, tod, x): pass
def backward(self, id, tod, x): pass
Expand All @@ -176,12 +186,12 @@ class SignalMap(Signal):
"""Signal describing a non-distributed sky map."""
def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", output=True,
ext="fits", dtype=np.float32, sys=None, recenter=None, tile_shape=(500,500), tiled=False,
interpol=None):
interpol=None, glitch_flags: str = "flags.glitch_flags"):
"""Signal describing a sky map in the coordinate system given by "sys", which defaults
to equatorial coordinates. If tiled==True, then this will be a distributed map with
the given tile_shape, otherwise it will be a plain enmap. interpol controls the
pointing matrix interpolation mode. See so3g's Projectionist docstring for details."""
Signal.__init__(self, name, ofmt, output, ext)
Signal.__init__(self, name, ofmt, output, ext, glitch_flags)
mhasself marked this conversation as resolved.
Show resolved Hide resolved
self.comm = comm
self.comps = comps
self.sys = sys
Expand All @@ -202,15 +212,16 @@ def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", out
self.div = enmap.zeros((ncomp,ncomp)+shape, wcs, dtype=dtype)
self.hits= enmap.zeros( shape, wcs, dtype=dtype)

def add_obs(self, id, obs, nmat, Nd, pmap=None):
def add_obs(self, id, obs, nmat, Nd, pmap=None, glitch_flags: Optional[str] = None):
"""Add and process an observation, building the pointing matrix
and our part of the RHS. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data.
"""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
ctime = obs.timestamps
pcut = PmatCut(obs.flags.glitch_flags) # could pass this in, but fast to construct
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags)) # could pass this in, but fast to construct
if pmap is None:
# Build the local geometry and pointing matrix for this observation
if self.recenter:
Expand Down Expand Up @@ -261,9 +272,9 @@ def prepare(self):
self.dof = TileMapZipper(self.rhs.geometry, dtype=self.dtype, comm=self.comm)
else:
if self.comm is not None:
self.rhs = utils.allreduce(self.rhs, self.comm)
self.div = utils.allreduce(self.div, self.comm)
self.hits = utils.allreduce(self.hits, self.comm)
self.rhs = putils.allreduce(self.rhs, self.comm)
self.div = putils.allreduce(self.div, self.comm)
self.hits = putils.allreduce(self.hits, self.comm)
self.dof = MapZipper(*self.rhs.geometry, dtype=self.dtype)
self.idiv = safe_invert_div(self.div)
self.ready = True
Expand Down Expand Up @@ -300,7 +311,7 @@ def from_work(self, map):
return tilemap.redistribute(map, self.comm, self.rhs.geometry.active)
else:
if self.comm is None: return map
else: return utils.allreduce(map, self.comm)
else: return putils.allreduce(map, self.comm)

def write(self, prefix, tag, m):
if not self.output: return
Expand Down Expand Up @@ -347,6 +358,7 @@ def transeval(self, id, obs, other, map, tod):
# Currently we don't support any actual translation, but could handle
# resolution changes in the future (probably not useful though)
self._checkcompat(other)
ctime = obs.timestamp
# Build the local geometry and pointing matrix for this observation
if self.recenter:
rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter,
Expand All @@ -361,9 +373,9 @@ def transeval(self, id, obs, other, map, tod):

class SignalCut(Signal):
def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
output=False, cut_type=None):
output=False, cut_type=None, glitch_flags:str ="flags.glitch_flags"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer these default to None, as they do in SignalMap. Then convert None -> flags.glitch_flags, to handle the default.

When you have cascading kwargs, e.g.:

  SignalMap(..., glitch_flags=None)
     -> Signal(glitch_flags=glitch_flags)
          ->  self.glitch_flags = glitch_flags

it tends to defeat the use of default value declarations for optional parameters. My practice has been to prefer this format for setting default args:

    class Signal:
        def __init__(self, ..., glitch_flags=None):
            if glitch_flags is None:
                glitch_flags = 'flags.glitch_flags'

Then subclasses, or whatever, can pass in their default value of None and it propagates cleanly all the way down to whatever base function is comfortable with setting the default value.

However... this isn't necessary relevant if we take glitch_flags out of the base class Signal. In that case you won't have cascading. I still think it's a good practice, in code that might become cascade, to use "None" to represent "the default value specified by some lower level". (Though this makes it harder to use "None" to represent "disable this feature" -- I would normally use False to signal that in cases where it is needed.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, I prefer having the actual default value over a check and set in the code. It makes it very easy know the default value, and it avoids any question as to what is happening. I agree the base class needs to have a default as None. Using kargs will resolve that and make it much cleaner for sure.

"""Signal for handling the ML solution for the values of the cut samples."""
Signal.__init__(self, name, ofmt, output, ext="hdf")
Signal.__init__(self, name, ofmt, output, ext="hdf", glitch_flags=glitch_flags)
self.comm = comm
self.data = {}
self.dtype = dtype
Expand All @@ -372,12 +384,14 @@ def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
self.rhs = []
self.div = []

def add_obs(self, id, obs, nmat, Nd):
def add_obs(self, id, obs, nmat, Nd, glitch_flags: Optional[str] = None):
"""Add and process an observation. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data."""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
pcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)

gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# Build our RHS
obs_rhs = np.zeros(pcut.njunk, self.dtype)
pcut.backward(Nd, obs_rhs)
Expand Down Expand Up @@ -441,15 +455,16 @@ def translate(self, other, junk):
so3g.translate_cuts(odata.pcut.cuts, sdata.pcut.cuts, sdata.pcut.model, sdata.pcut.params, junk[odata.i1:odata.i2], res[sdata.i1:sdata.i2])
return res

def transeval(self, id, obs, other, junk, tod):
def transeval(self, id, obs, other, junk, tod, glitch_flags: Optional[str] = None):
"""Translate data junk from SignalCut other to the current SignalCut,
and then evaluate it for the given observation, returning a tod.
This is used when building a signal-free tod for the noise model
in multipass mapmaking."""
self._checkcompat(other)
# We have to make a pointing matrix from scratch because add_obs
# won't have been called yet at this point
spcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
spcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# We do have one for other though, since that will be the output
# from the previous round of multiplass mapmaking.
odata = other.data[id]
Expand Down
70 changes: 59 additions & 11 deletions sotodlib/mapmaking/utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Union, Optional

import numpy as np
from pixell import enmap, utils, fft, tilemap, resample
import so3g
from pixell import enmap, fft, resample, tilemap, utils

from .. import coords, core, tod_ops

from .. import core
from .. import tod_ops
from .. import coords

def deslope_el(tod, el, srate, inplace=False):
if not inplace: tod = tod.copy()
Expand Down Expand Up @@ -136,7 +137,6 @@ def safe_invert_div(div, lim=1e-2, lim0=np.finfo(np.float32).tiny**0.5):
return idiv



def measure_cov(d, nmax=10000):
d = d[:,::max(1,d.shape[1]//nmax)]
n,m = d.shape
Expand Down Expand Up @@ -339,6 +339,7 @@ def evaluate_recentering(info, ctime, geom=None, site=None, weather="typical"):
"""Evaluate the quaternion that performs the coordinate recentering specified in
info, which can be obtained from parse_recentering."""
import ephem

# Get the coordinates of the from, to and up points. This was a bit involved...
def to_cel(lonlat, sys, ctime=None, site=None, weather=None):
# Convert lonlat from sys to celestial coorinates. Maybe polish and put elswhere
Expand Down Expand Up @@ -370,6 +371,7 @@ def recentering_to_quat_lonlat(p1, p2, pu):
"""Return the quaternion that represents the rotation that takes point p1
to p2, with the up direction pointing towards the point pu, all given as lonlat pairs"""
from so3g.proj import quat

# 1. First rotate our point to the north pole: Ry(-(90-dec1))Rz(-ra1)
# 2. Apply the same rotation to the up point.
# 3. We want the up point to be upwards, so rotate it to ra = 180°: Rz(pi-rau2)
Expand Down Expand Up @@ -439,8 +441,48 @@ def rangemat_sum(rangemat):
res[i] = np.sum(ra[:,1]-ra[:,0])
return res

def find_usable_detectors(obs, maxcut=0.1):
ncut = rangemat_sum(obs.flags.glitch_flags)
def flags_in_path(
aman: core.AxisManager, rpath: str, sep: str = "."
) -> bool:
"""
This function allows to pull data from an AxisManager based on a path.
mhasself marked this conversation as resolved.
Show resolved Hide resolved
Parameters:
mhasself marked this conversation as resolved.
Show resolved Hide resolved
- aman: An Axis Manager object
- path: a string with a recursive path to extract data. The path is separated via a sep.
For example 'flags.glitch_flags'
- sep: separator. Defaults to `.`
"""

rpath = rpath.split(sep=sep)
flags = aman.copy()
mhasself marked this conversation as resolved.
Show resolved Hide resolved
while rpath and flags is not None:
path = rpath.pop()
flags = flags[path]
mhasself marked this conversation as resolved.
Show resolved Hide resolved

return flags is not None


def get_flags_from_path(
aman: core.AxisManager, rpath: str, sep: str = "."
) -> Union[so3g.proj.RangesMatrix, Any]:
"""
This function allows to pull data from an AxisManager based on a path.
Parameters:
- aman: An Axis Manager object
- path: a string with a recursive path to extract data. The path is separated via a sep.
For example 'flags.glitch_flags'
- sep: separator. Defaults to `.`
"""

flags = aman.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary copy.

for path in rpath.split(sep=sep):
flags = flags[path]

return flags


def find_usable_detectors(obs, maxcut=0.1, glitch_flags: str = "flags.glitch_flags"):
ncut = rangemat_sum(get_flags_from_path(obs, glitch_flags))
good = ncut < obs.samps.count * maxcut
return obs.dets.vals[good]

Expand Down Expand Up @@ -499,7 +541,7 @@ def downsample_obs(obs, down):
if isinstance(val, core.AxisManager):
res.wrap(key, val)
else:
axdesc = [(k,v) for k,v in enumerate(axes) if v is not None]
axdesc = [(k, v) for k, v in enumerate(axes) if v is not None]
res.wrap(key, val, axdesc)
# The normal sample stuff
res.wrap("timestamps", obs.timestamps[::down], [(0, "samps")])
Expand All @@ -511,16 +553,22 @@ def downsample_obs(obs, down):

# The cuts
# obs.flags will contain all types of flags. We should query it for glitch_flags and source_flags
cut_keys = ["glitch_flags"]
cut_keys = []
if flags_in_path(obs, "glitch_flags"):
cut_keys.append("glitch_flags")
elif flags_in_path(obs, "flags.glitch_flags"):
cut_keys.append("flags.glitch_flags")

if "source_flags" in obs.flags:
if flags_in_path(obs, "source_flags"):
cut_keys.append("source_flags")
elif flags_in_path(obs, "flags.source_flags"):
cut_keys.append("flags.source_flags")

# We need to add a res.flags FlagManager to res
res = res.wrap('flags', core.FlagManager.for_tod(res))

for key in cut_keys:
res.flags.wrap(key, downsample_cut(getattr(obs.flags, key), down), [(0,"dets"),(1,"samps")])
res.flags.wrap(key, downsample_cut(get_flags_from_path(obs, key), down), [(0,"dets"),(1,"samps")])

# Not sure how to deal with flags. Some sort of or-binning operation? But it
# doesn't matter anyway
Expand Down