diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58aaff6f9..3e5c531f5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,8 +24,6 @@ jobs: fail-fast: false matrix: include: - - python: "3.8" - pyshort: "38" - python: "3.9" pyshort: "39" - python: "3.11" @@ -62,13 +60,13 @@ jobs: export OMP_NUM_THREADS=2 export OPENBLAS_NUM_THREADS=2 export MPI_DISABLE=1 - python3 setup.py test + python3 -m unittest discover - name: Run MPI Tests run: | export OMP_NUM_THREADS=1 export OPENBLAS_NUM_THREADS=1 - mpirun -np 2 python3 setup.py test + mpirun -n 2 python3 -m unittest discover # FIXME: Re-enable after testing this procedure on a local # apple machine. diff --git a/.gitignore b/.gitignore index 7ecd9dfd7..bb277c3c6 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ venv.bak/ # pychram .idea + +# vscode +.vscode \ No newline at end of file diff --git a/docs/axisman.rst b/docs/axisman.rst index e7875361d..b4d497f6d 100644 --- a/docs/axisman.rst +++ b/docs/axisman.rst @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be:: Note the boresight entry is marked with a ``*``, indicating that it's an AxisManager rather than a numpy array. +Data access under an AxisManager is done based on field names. For example:: + + >>> print(dset.boresight.az) + [0. 0. 0. ... 0. 0. 0.] + +Advanced data access is possible by a path like syntax. This is especially useful when +data access is dynamic and the field name is not known in advance. For example:: + + >>> print(dset["boresight.az"]) + [0. 0. 0. ... 0. 0. 0.] + To slice this object, use the restrict() method. First, let's restrict in the 'dets' axis. Since it's an Axis of type LabelAxis, the restriction selector must be a list of strings:: diff --git a/docs/preprocess.rst b/docs/preprocess.rst index 6ce35a95a..1845208f5 100644 --- a/docs/preprocess.rst +++ b/docs/preprocess.rst @@ -252,6 +252,7 @@ Flagging and Products .. autoclass:: sotodlib.preprocess.processes.FlagTurnarounds .. autoclass:: sotodlib.preprocess.processes.DarkDets .. autoclass:: sotodlib.preprocess.processes.SourceFlags +.. autoclass:: sotodlib.preprocess.processes.GetStats HWP Related ::::::::::: diff --git a/docs/site_pipeline.rst b/docs/site_pipeline.rst index 1081991de..189c9c673 100644 --- a/docs/site_pipeline.rst +++ b/docs/site_pipeline.rst @@ -940,16 +940,13 @@ and binned. Every atomic map consist of a ``weights``, ``wmap`` (weighted map), and ``hits`` map, as well as an information file that is used for adding the map to an atomic map database. -Command line arguments -`````````````````````` +Configuration yaml file +```````````````````````` -.. argparse:: - :module: sotodlib.site_pipeline.make_atomic_filterbin_map - :func: get_parser - :prog: make-atomic-filterbin-map +The mapmaker is configured by supplying a yaml file with ``--config_file``. -Config file format -`````````````````` +.. autoclass:: sotodlib.site_pipeline.make_atomic_filterbin_map.Cfg + :members: The only mandatory parameters are ``context`` for a context file and ``preprocess_config``, a preprocess database configuration file that will tell the script how to process the diff --git a/setup.py b/setup.py index 0b496c3a6..1d6988dba 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,7 @@ import unittest -from setuptools import find_packages, setup, Extension -from setuptools.command.test import test as TestCommand +from setuptools import find_packages, setup, Extension, Command import versioneer @@ -43,7 +42,7 @@ setup_opts["url"] = "https://github.com/simonsobs/sotodlib" setup_opts["packages"] = find_packages(where=".", exclude="tests") setup_opts["license"] = "MIT" -setup_opts["python_requires"] = ">=3.8.0" +setup_opts["python_requires"] = ">=3.9.0" setup_opts["package_data"] = { "sotodlib": [ "toast/ops/data/*" @@ -84,16 +83,16 @@ # Class to run unit tests -class SOTestCommand(TestCommand): +class SOTestCommand(Command): def __init__(self, *args, **kwargs): super(SOTestCommand, self).__init__(*args, **kwargs) def initialize_options(self): - TestCommand.initialize_options(self) + Command.initialize_options(self) def finalize_options(self): - TestCommand.finalize_options(self) + Command.finalize_options(self) self.test_suite = True def mpi_world(self): diff --git a/sotodlib/coords/demod.py b/sotodlib/coords/demod.py index 67d15f453..6ebeda545 100644 --- a/sotodlib/coords/demod.py +++ b/sotodlib/coords/demod.py @@ -182,7 +182,7 @@ def from_map(tod, signal_map, cuts=None, flip_gamma=True, wrap=False, modulated= return signal_sim -def rotate_demodQU(tod, update_focal_plane=True): +def rotate_demodQU(tod, sign=1, offset=0, update_focal_plane=True): """ Apply detectors' polarization angle calibration to the HWP demodulated Q and U timestreams to place all detectors' Q and U timestreams in a common telescope frame. This updates tod.demodQ @@ -190,12 +190,16 @@ def rotate_demodQU(tod, update_focal_plane=True): Args: tod : an axisManager object - update_focal_plane (bool, optional): Whether to set focal_plane.gamma angles to zero, + update_focal_plane (bool, optional): Whether to set focal_plane.gamma angles to zero, consistent with new coordinate reference. Make this true for polarization mapmaking using make_map. + offset : float, optional + The rotation angle in degrees to apply (default is 0). + sign : int, optional + A sign factor to control the direction of the rotation (default is +1). """ - demodC = ((tod.demodQ + 1j*tod.demodU).T * np.exp(-2j*tod.focal_plane.gamma)).T + demodC = ((tod.demodQ + 1j*tod.demodU).T * np.exp( sign*(-2j*tod.focal_plane.gamma + 1j*np.deg2rad(offset)) )).T tod.demodQ = demodC.real tod.demodU = demodC.imag del demodC diff --git a/sotodlib/coords/helpers.py b/sotodlib/coords/helpers.py index d6246002e..0a7d8bb34 100644 --- a/sotodlib/coords/helpers.py +++ b/sotodlib/coords/helpers.py @@ -353,7 +353,7 @@ def get_footprint(tod, wcs_kernel, dets=None, timestamps=None, boresight=None, def get_focal_plane_cover(tod=None, count=0, focal_plane=None, - xieta=None): + xieta=None, det_weights=None): """Process a bunch of detector positions into a center and radius such that a circle with that center and radius contains all the detectors. Also return detector positions, arranged approximately @@ -366,6 +366,9 @@ def get_focal_plane_cover(tod=None, count=0, focal_plane=None, not passed in xieta (array): (2, n) array (or similar) of xi and eta detector positions. + det_weights (array): If provided, must be same length as the xi + and eta vectors. Only dets with non-zero value for det_weights + will be included in the evaluation of the cover. Returns: xieta0: array[2] with array center, (xi0, eta0). @@ -374,9 +377,22 @@ def get_focal_plane_cover(tod=None, count=0, focal_plane=None, coords of the circular convex hull. Notes: - If count=0, an empty list is returned for xietas. Otherwise, + If count=0, an empty list is returned for xietas. Otherwise, count must be at least 3 so that the shape is not degenerate. + Any xi, eta that are not finite (e.g. nan or inf) are excluded + from the computation. If no detectors remain after the combined + finiteness and det_weights cuts, a ValueError is raised. + + Note that ``det_weights`` can be int, float, or bool + type. Sometimes you might want to only include optical dets in + the result; e.g.:: + + ..., det_weights=(aman.det_info.wafer.type=="OPTC"), ... + + In degenerate cases (all dets are in exactly the same place), a + radius of zero may be returned. + """ if xieta is None: if focal_plane is None: @@ -385,6 +401,17 @@ def get_focal_plane_cover(tod=None, count=0, focal_plane=None, eta = focal_plane.eta else: xi, eta = xieta[:2] + mask = np.isfinite(xi) * np.isfinite(eta) + if det_weights is not None: + mask *= det_weights.astype(bool) + + if not np.any(mask): + raise ValueError('All provided (xi, eta) coords are excluded; ' + 'cannot estimate a focal plane cover.') + + # Restrict to only dets under consideration. + xi, eta = xi[mask], eta[mask] + qs = so3g.proj.quat.rotation_xieta(xi, eta) # Starting guess for center @@ -477,23 +504,23 @@ def as_geom(item): s0 = corner_b - corner_a return tuple(map(int, s0)), w0 -def _confirm_wcs(*maps): - """Insist that all arguments either have the same .wcs, or do not have - a wcs. Each argument should be either an ndmap (with a .wcs - attribute) or an ndarray (without a .wcs attribute). +def _confirm_wcs(*wcss): + """Insist that all arguments are either the same wcs, or None. - Raises a ValueError if more than one argument has a .wcs attribute - and they do not all agree. Returns either the first .wcs - attribute encountered, or None if there aren't any. + Raises a ValueError if more than one argument is a wcs, and not + all wcs agree. Returns either the first valid wcs, or None if + there aren't any. """ wcs_to_use = None - for i, m in enumerate(maps): - if hasattr(m, 'wcs'): - if wcs_to_use is None: - wcs_to_use = m.wcs - elif not wcsutils.equal(wcs_to_use, m.wcs): - raise ValueError('The wcs from %ith item is discordant with prior ones.' % i) + for i, wcs in enumerate(wcss): + if wcs is None: + continue + if wcs_to_use is None: + wcs_to_use = wcs + elif not wcsutils.equal(wcs_to_use, wcs): + raise ValueError( + f'The wcs from {i}th item ({wcs}) is discordant with prior ones ({wcs_to_use})') return wcs_to_use def _invert_weights_map(weights, eigentol=1e-6, kill_partials=True, @@ -565,23 +592,25 @@ def _invert_weights_map(weights, eigentol=1e-6, kill_partials=True, # Reshape the output to match what was passed in. return iw.transpose(1,2,0).reshape(weights.shape) -def _apply_inverse_weights_map(inverse_weights, target): +def _apply_inverse_weights_map(inverse_weights, target, out=None): """Apply a map of matrices to a map of vectors. Assumes inverse_weights.shape = (a, b, ...) and target.shape = (b, ...); the result has shape (a, ...). """ - # master had: - #iw = inverse_weights.transpose((2,3,0,1)) - #m = target.transpose((1,2,0)).reshape( - # target.shape[1], target.shape[2], target.shape[0], 1) - #m1 = np.matmul(iw, m) - #return m1.transpose(2,3,0,1).reshape(target.shape) + if out is None: + out = np.empty(inverse_weights.shape[1:], + dtype=target.dtype) + if isinstance(target, enmap.ndmap): + out = enmap.ndmap(out, target.wcs) + # Recall matmul(a, b) operates on the last two axes of (a, b). So + # move axes, and create a second one in target; re-order at end. iw = np.moveaxis(inverse_weights, (0,1), (-2,-1)) t = np.moveaxis(target[:,None], (0,1), (-2,-1)) - m = np.matmul(iw, t) - return np.moveaxis(m, (-2,-1), (0,1))[:,0] + out_moved = np.moveaxis(out[:,None], (0,1), (-2,-1)) + np.matmul(iw, t, out=out_moved) + return out class ScalarLastQuat(np.ndarray): """Wrapper class for numpy arrays carrying quaternions with the ijk1 diff --git a/sotodlib/coords/pmat.py b/sotodlib/coords/pmat.py index e3d4a4145..ca0cbb0ff 100644 --- a/sotodlib/coords/pmat.py +++ b/sotodlib/coords/pmat.py @@ -1,6 +1,5 @@ import so3g.proj import numpy as np -import scipy from pixell import enmap, tilemap from .helpers import _get_csl, _valid_arg, _not_both, _confirm_wcs @@ -239,16 +238,13 @@ def zeros(self, super_shape=None, comps=None): """ if super_shape is None: super_shape = (self._comp_count(comps), ) + proj, _ = self._get_proj_tiles() if self.pix_scheme == 'healpix': - proj, _ = self._get_proj_threads() return proj.zeros(super_shape) elif self.pix_scheme == 'rectpix': if self.tiled: - # Need to fully resolve tiling to get occupied tiles. - proj, _ = self._get_proj_threads() return tilemap.from_tiles(proj.zeros(super_shape), self.geom) else: - proj = self._get_proj() return enmap.ndmap(proj.zeros(super_shape), wcs=self.geom.wcs) def to_map(self, tod=None, dest=None, comps=None, signal=None, @@ -338,22 +334,20 @@ def to_inverse_weights(self, weights_map=None, tod=None, dest=None, weights_map = self.to_weights( tod=tod, comps=comps, signal=signal, det_weights=det_weights, cuts=cuts) - if self.pix_scheme == "healpix": - tile_list = self._get_hp_tile_list(weights_map) - if tile_list is not None: - weights_map = hp_utils.tiled_to_compressed(weights_map, -1) - if dest is None: - dest = np.zeros_like(weights_map) - elif (self.pix_scheme == "rectpix") and (dest is None): - dest = self._enmapify(np.zeros_like(weights_map), - wcs=_confirm_wcs(weights_map)) - - logger.info('to_inverse_weights: calling _invert_weights_map') - dest[:] = helpers._invert_weights_map( - weights_map, eigentol=eigentol, UPLO='U') - - if self.pix_scheme == "healpix" and tile_list is not None: - dest = hp_utils.compressed_to_tiled(dest, tile_list, -1) + weights_map, uf_info = self._flatten_map(weights_map) + + if dest is not None: + dest, uf_info = self._flatten_map(dest, uf_info) + + _dest = helpers._invert_weights_map( + weights_map, eigentol=eigentol, UPLO='U') + if dest is not None: + dest[:] = _dest + else: + dest = _dest + del _dest + + dest = self._unflatten_map(dest, uf_info) return dest def remove_weights(self, signal_map=None, weights_map=None, inverse_weights_map=None, @@ -376,7 +370,8 @@ def remove_weights(self, signal_map=None, weights_map=None, inverse_weights_map= weights_map: the matrix W. Shape should be (n_comp, n_comp, n_row, n_col), but only the upper diagonal in the first two dimensions needs to be populated. If this is None, - then W will be computed by a call to + then W will be computed and inverted via + ``self.to_inverse_weights``. """ if inverse_weights_map is None: @@ -384,25 +379,15 @@ def remove_weights(self, signal_map=None, weights_map=None, inverse_weights_map= if signal_map is None: signal_map = self.to_map(**kwargs) - if self.pix_scheme == "healpix": - tile_list = self._get_hp_tile_list(signal_map) - if tile_list is not None: - signal_map = hp_utils.tiled_to_compressed(signal_map, -1) - inverse_weights_map = hp_utils.tiled_to_compressed(inverse_weights_map, -1) - if dest is None: - dest = np.zeros_like(signal_map) - elif self.pix_scheme == "rectpix": - if dest is None: - wcs_to_use = _confirm_wcs(inverse_weights_map, signal_map) - dest = self._enmapify(np.empty(signal_map.shape, signal_map.dtype), - wcs=wcs_to_use) - else: - _confirm_wcs(inverse_weights_map, signal_map, dest) + # Get flat numpy-compatible forms for the maps. + signal_map, uf_info = self._flatten_map(signal_map) + inverse_weights_map, uf_info = self._flatten_map(inverse_weights_map, uf_info) - dest[:] = helpers._apply_inverse_weights_map(inverse_weights_map, signal_map) + if dest is not None: + dest, uf_info = self._flatten_map(dest, uf_info) - if self.pix_scheme == "healpix" and tile_list is not None: - dest = hp_utils.compressed_to_tiled(dest, tile_list, -1) + dest = helpers._apply_inverse_weights_map(inverse_weights_map, signal_map, out=dest) + dest = self._unflatten_map(dest, uf_info) return dest @@ -436,11 +421,9 @@ def from_map(self, signal_map, dest=None, comps=None, wrap=None, """ assert cuts is None # whoops, not implemented. - # _get_proj doesn't set up the tiling info, so - # must call get_proj_threads. This is ugly, since from_map - # doesn't need the threading structures. Can we find a better design? - if self.tiled: proj, _ = self._get_proj_threads() - else: proj = self._get_proj() + # This is not free but it is pretty fast, doesn't do thread + # assignments. + proj, _ = self._get_proj_tiles() if comps is None: comps = self.comps @@ -448,7 +431,13 @@ def from_map(self, signal_map, dest=None, comps=None, wrap=None, if dest is None: dest = np.zeros(tod_shape, np.float32) assert(dest.shape == tod_shape) # P.fp/P.sight and dest argument disagree - proj.from_map(self._prepare_map(signal_map), self._get_asm(), signal=dest, comps=comps) + + if self.tiled and self.pix_scheme == 'rectpix': + # so3g <= 0.1.15 has a dims check on signal_map that fails on the tiled map format. + so3g.proj.wcs._ProjectionistBase.from_map( + proj, self._prepare_map(signal_map), self._get_asm(), signal=dest, comps=comps) + else: + proj.from_map(self._prepare_map(signal_map), self._get_asm(), signal=dest, comps=comps) if wrap is not None: if wrap in tod: @@ -492,6 +481,28 @@ def _get_proj(self): return so3g.proj.Projectionist.for_geom(self.geom.shape, self.geom.wcs, **interpol_kw) + def _get_proj_tiles(self, assign=False): + # Get Projectionist and compute self.active_tiles if it's not + # already known. Return Projectionist with active_tiles set, + # which is suitable for from_map and zeros (though not for + # threaded to_map etc). + proj = self._get_proj() + if not self.tiled or (self.active_tiles is not None and not assign): + return proj, {} + tile_info = proj.get_active_tiles(self._get_asm(), assign=assign) + self.active_tiles = tile_info['active_tiles'] + + if self.pix_scheme == "healpix": + self.geom.nside_tile = proj.nside_tile # Update nside_tile if it was 'auto' + self.geom.ntile = 12*proj.nside_tile**2 + elif self.pix_scheme == 'rectpix': + # Promote geometry to one with the active tiles marked. + self.geom = tilemap.geometry( + self.geom.shape, self.geom.wcs, self.geom.tile_shape, + active=self.active_tiles) + + return self._get_proj(), tile_info + def _get_proj_threads(self, cuts=None): """Return the Projectionist and sample-thread assignment for the present geometry. If the thread assignment has not been @@ -518,19 +529,12 @@ def _get_proj_threads(self, cuts=None): elif self.pix_scheme == "healpix": self.threads = 'tiles' if (self.geom.nside_tile is not None) else 'simple' - if self.tiled and self.active_tiles is None: - logger.info('_get_proj_threads: get_active_tiles') - if isinstance(self.threads, str) and self.threads == 'tiles': - logger.info('_get_proj_threads: assigning using "tiles"') - tile_info = proj.get_active_tiles(self._get_asm(), assign=True) - _tile_threads = wrap_ranges(tile_info['group_ranges']) - else: - tile_info = proj.get_active_tiles(self._get_asm()) - self.active_tiles = tile_info['active_tiles'] - if self.pix_scheme == "healpix": - self.geom.nside_tile = proj.nside_tile # Update nside_tile if it was 'auto' - self.geom.ntile = 12*proj.nside_tile**2 - proj = self._get_proj() # Add active_tiles to proj + need_tiles = (self.active_tiles is None) + need_assign = (self.threads in ['tiles']) + if need_tiles or need_assign: + proj, tile_info = self._get_proj_tiles(need_assign) + if need_assign: + _tile_threads = wrap_ranges(tile_info['group_ranges']) if self.threads is False: return proj, ~cuts @@ -560,34 +564,88 @@ def _get_asm(self): return so3g.proj.Assembly.attach(self.sight, so3g_fp) def _prepare_map(self, map): + """Gently reformat a map in order to send it to so3g.""" if self.tiled and self.pix_scheme == "rectpix": - return map.tiles + return list(map.tiles) else: return map - def _enmapify(self, data, wcs=None): - """Promote a numpy.ndarray to an enmap.ndmap by attaching a wcs. In - sensible cases (e.g. data is an ndarray or ndmap) this will - not cause a copy of the underlying data array. + def _flatten_map(self, map, uf_base=None): + """Get a version of the map that is a numpy array, for passing + to per-pixel math operations. Relies on (self.pix_scheme, + self.tiled) to interpret map. + + This also tries to extract wcs info (if rectpix) from the map, + for inline consistency checking (e.g. so we're not happily + projecting into a map we loaded from disk that has the same + shape but is off by a few pixels from what pmat thinks is the + right footprint). It also looks at active_tiles / tile_list, + and stores that for downstream compatibility checking with + other flattened maps. + + If uf_base is passed in, it should be an unflatten_info dict + (likely from a previous call to _flatten_map). The analysis + here will be checked against it for compatibility and any + missing values (ahem wcs) will be used to augment the + unflatten_info that is returned. - If a wcs is not passed in, then wcs=self.geom[1] is used. + Returns: + array: The map, reformatted as an array (could simply be the + input arg map, or a view of that, or a copy if necessary). + unflatten_info: dict with misc compatibility info. """ - if wcs is None: - wcs = self.geom[1] - return enmap.ndmap(data, wcs=wcs) - - def _get_hp_tile_list(self, tiled_arr=None): - """For healpix maps, get len(nTile) bool array of whether tiles are active. None if un-tiled. - If tiled_arr is not None, get tiling from there. Else get from self.geom""" - if isinstance(tiled_arr, list): # Assume we are tiled iff tiled_arr is a list instead of ndarr - tile_list = hp_utils.get_active_tile_list(tiled_arr) - elif self.tiled: - tile_list = np.zeros(12*self.geom.nside_tile**2, dtype='bool') - tile_list[self.active_tiles] = True # The tiling must be initiated already - else: - tile_list = None # Un-tiled - return tile_list + ufinfo = {'pix_scheme': self.pix_scheme, + 'tiled': self.tiled} + wcs = None + crit_dims = 1 + if self.pix_scheme == 'healpix': + if self.tiled: + ufinfo['tile_list'] = [_m is not None for _m in map] + map = hp_utils.tiled_to_compressed(map, -1) + crit_dims = 2 + else: + pass + elif self.pix_scheme == 'rectpix': + if self.tiled: + if isinstance(map, tilemap.TileMap): + wcs = map.geometry.wcs + ufinfo['active_tiles'] = list(map.active) + ufinfo['tile_geom'] = map.geometry.copy(pre=()) + else: + if isinstance(map, enmap.ndmap): + wcs = map.wcs + crit_dims = 2 + ufinfo.update({'wcs': wcs, + 'crit_dims': crit_dims, + 'shape': map.shape}) + if uf_base is not None: + ufinfo['wcs'] = _check_compat(uf_base, ufinfo) + return map, ufinfo + + def _unflatten_map(self, map, uf_info): + """Restore a map to full format, assuming it's currently an + ndarray. Intended as the inverse op to _flatten_map. + Minimize the use of cached self.* here ... rely instead on + uf_info. + + """ + if uf_info['pix_scheme'] == 'healpix': + if uf_info['tiled']: + map = hp_utils.compressed_to_tiled(map, uf_info['tile_list'], -1) + else: + pass + elif uf_info['pix_scheme'] == 'rectpix': + if uf_info['tiled']: + if not isinstance(map, tilemap.TileMap): + g = uf_info['tile_geom'] + g = tilemap.geometry(map.shape[:-1] + g.shape, g.wcs, g.tile_shape, + active=uf_info['active_tiles']) + map = tilemap.TileMap(map, g) + else: + if not isinstance(map, enmap.ndmap) and uf_info['wcs']: + map = enmap.ndmap(map, uf_info['wcs']) + return map class P_PrecompDebug: @@ -672,3 +730,27 @@ def _infer_pix_scheme(geom): else: raise fail_err return pix_scheme + +def _check_compat(*uf_infos): + # Given one or more "uf_info" dicts, as returned by _flatten_map, + # check that the flattened arrays are pixel-correspondent, + # including (for rectpix cases) the wcs. + # + # Raises an error if any of that doesn't pan out. On success, + # returns the agreed-upon wcs (which could be None). + ref_uf = uf_infos[0] + for uf in uf_infos[1:]: + for k in ['pix_scheme', 'tiled', 'active_tiles']: + if ref_uf.get(k) != uf.get(k): + raise ValueError(f"Inconsistent map structures: {uf_infos}") + # Not sure how to handle broadcasting of lefter dims, so focus + # on the pixel dim(s). + dims_to_check = ref_uf['crit_dims'] + if ref_uf['shape'][-dims_to_check:] != uf['shape'][-dims_to_check:]: + raise ValueError(f"Non-broadcastable map shapes: {uf_infos}") + + # And the wcs. + wcss = [uf.get('wcs') for uf in uf_infos] + wcs_to_use = _confirm_wcs(*wcss) + + return wcs_to_use diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 539b632cb..b799e091c 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -349,29 +349,63 @@ def move(self, name, new_name): self._fields[new_name] = self._fields.pop(name) self._assignments[new_name] = self._assignments.pop(name) return self - + def add_axis(self, a): assert isinstance( a, AxisInterface) self._axes[a.name] = a.copy() def __contains__(self, name): - return name in self._fields or name in self._axes + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + return False + return True def __getitem__(self, name): - if name in self._fields: - return self._fields[name] - if name in self._axes: - return self._axes[name] - raise KeyError(name) + + # We want to support options like: + # aman.focal_plane.xi . aman['focal_plane.xi'] + # We will safely assume that a getitem will always have '.' as the separator + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + raise KeyError(attr_name) + return tmp_item def __setitem__(self, name, val): - if name in self._fields: - self._fields[name] = val + + last_pos = name.rfind(".") + val_key = name + tmp_item = self + if last_pos > -1: + val_key = name[last_pos + 1:] + attrs = name[:last_pos] + tmp_item = self[attrs] + + if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager): + raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.") + + if val_key in tmp_item._fields: + tmp_item._fields[val_key] = val else: - raise KeyError(name) + raise KeyError(val_key) def __setattr__(self, name, value): # Assignment to members update those members + # We will assume that a path exists until the last member. + # If any member prior to that does not exist a keyerror is raised. if "_fields" in self.__dict__ and name in self._fields.keys(): self._fields[name] = value else: @@ -381,7 +415,11 @@ def __setattr__(self, name, value): def __getattr__(self, name): # Prevent members from override special class members. if name.startswith("__"): raise AttributeError(name) - return self[name] + try: + val = self[name] + except KeyError as ex: + raise AttributeError(name) from ex + return val def __dir__(self): return sorted(tuple(self.__dict__.keys()) + tuple(self.keys())) @@ -514,12 +552,12 @@ def concatenate(items, axis=0, other_fields='exact'): output.wrap(k, new_data[k], axis_map) else: if other_fields == "exact": - ## if every item named k is a scalar + ## if every item named k is a scalar err_msg = (f"The field '{k}' does not share axis '{axis}'; " f"{k} is not identical across all items " f"pass other_fields='drop' or 'first' or else " f"remove this field from the targets.") - + if np.any([np.isscalar(i[k]) for i in items]): if not np.all([np.isscalar(i[k]) for i in items]): raise ValueError(err_msg) @@ -527,14 +565,14 @@ def concatenate(items, axis=0, other_fields='exact'): raise ValueError(err_msg) output.wrap(k, items[0][k], axis_map) continue - + elif not np.all([i[k].shape==items[0][k].shape for i in items]): raise ValueError(err_msg) elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]): raise ValueError(err_msg) - + output.wrap(k, items[0][k].copy(), axis_map) - + elif other_fields == 'fail': raise ValueError( f"The field '{k}' does not share axis '{axis}'; " diff --git a/sotodlib/core/axisman_io.py b/sotodlib/core/axisman_io.py index f31d27874..80dc4a5a5 100644 --- a/sotodlib/core/axisman_io.py +++ b/sotodlib/core/axisman_io.py @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm): if shape[0] == 0: return so3g.proj.RangesMatrix([], child_shape=shape[1:]) # Otherwise non-trivial - count = np.product(shape[:-1]) + count = np.prod(shape[:-1]) start, stride = 0, count // shape[0] for i in range(0, len(ends), stride): _e = ends[i:i+stride] - start diff --git a/sotodlib/core/context.py b/sotodlib/core/context.py index 3dad3d76f..52429d9b1 100644 --- a/sotodlib/core/context.py +++ b/sotodlib/core/context.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) + class Context(odict): # Sets of special handlers may be registered in this class variable, then # requested by name in the context.yaml key "context_hooks". @@ -326,7 +327,8 @@ def get_meta(self, check=False, ignore_missing=False, on_missing=None, - det_info_scan=False): + det_info_scan=False + ): """Load supporting metadata for an observation and return it in an AxisManager. diff --git a/sotodlib/core/flagman.py b/sotodlib/core/flagman.py index 4a513278b..928963d63 100644 --- a/sotodlib/core/flagman.py +++ b/sotodlib/core/flagman.py @@ -79,7 +79,7 @@ def wrap(self, name, data, axis_map=None, **kwargs): else Ranges.zeros_like(x) for Y in data]) axis_map = [(0,self._dets_name),(1,self._samps_name)] - super().wrap(name, data, axis_map, **kwargs) + return super().wrap(name, data, axis_map, **kwargs) def wrap_dets(self, name, data): """Adding flag with just (dets,) axis. @@ -88,7 +88,7 @@ def wrap_dets(self, name, data): if not len(s) == 1 or s[0] != self[self._dets_name].count: raise ValueError("Data of shape {} is cannot be aligned with" "the detector axis".format(s)) - self.wrap(name, data, axis_map=[(0,self._dets_name)]) + return self.wrap(name, data, axis_map=[(0,self._dets_name)]) def wrap_samps(self, name, data): """Adding flag with just (samps,) axis. @@ -97,7 +97,7 @@ def wrap_samps(self, name, data): if not len(s) == 1 or s[0] != self[self._samps_name].count: raise ValueError("Data of shape {} is cannot be aligned with" "the samps axis".format(s)) - self.wrap(name, data, axis_map=[(0,self._samps_name)]) + return self.wrap(name, data, axis_map=[(0,self._samps_name)]) def wrap_dets_samps(self, name, data): """Adding flag with (dets, samps) axes. @@ -107,7 +107,7 @@ def wrap_dets_samps(self, name, data): s[1] != self[self._samps_name].count): raise ValueError("Data of shape {} is cannot be aligned with" "the (dets,samps) axss".format(s)) - self.wrap(name, data, axis_map=[(0,self._dets_name), (1,self._samps_name)]) + return self.wrap(name, data, axis_map=[(0,self._dets_name), (1,self._samps_name)]) def copy(self, axes_only=False): diff --git a/sotodlib/core/g3_core.py b/sotodlib/core/g3_core.py index 219a65728..7c9aea49e 100644 --- a/sotodlib/core/g3_core.py +++ b/sotodlib/core/g3_core.py @@ -6,7 +6,7 @@ """ -from spt3g import core +from so3g.spt3g import core class DataG3Module(object): diff --git a/sotodlib/core/metadata/obsfiledb.py b/sotodlib/core/metadata/obsfiledb.py index 1fabca5a3..0c693badd 100644 --- a/sotodlib/core/metadata/obsfiledb.py +++ b/sotodlib/core/metadata/obsfiledb.py @@ -278,13 +278,16 @@ def get_files(self, obs_id, detsets=None, prefix=None): prefix = self.prefix if detsets is None: - detsets = self.get_detsets(obs_id) - - c = self.conn.execute('select detset, name, sample_start, sample_stop ' - 'from files where obs_id=? and detset in (%s) ' - 'order by detset, sample_start' % - ','.join(['?' for _ in detsets]), - (obs_id,) + tuple(detsets)) + c = self.conn.execute('select detset, name, sample_start, sample_stop ' + 'from files where obs_id=? ' + 'order by detset, sample_start', + (obs_id,)) + else: + c = self.conn.execute('select detset, name, sample_start, sample_stop ' + 'from files where obs_id=? and detset in (%s) ' + 'order by detset, sample_start' % + ','.join(['?' for _ in detsets]), + (obs_id,) + tuple(detsets)) output = OrderedDict() for r in c: if not r[0] in output: diff --git a/sotodlib/io/bookbinder.py b/sotodlib/io/bookbinder.py index 41a2704d2..a1bdd1e63 100644 --- a/sotodlib/io/bookbinder.py +++ b/sotodlib/io/bookbinder.py @@ -13,7 +13,12 @@ import logging import sys import shutil +import yaml +import datetime as dt +from zipfile import ZipFile +import sotodlib from sotodlib.site_pipeline.util import init_logger +from .datapkg_utils import walk_files log = logging.getLogger('bookbinder') @@ -28,6 +33,10 @@ class NoScanFrames(Exception): """Exception raised when we try and bind a book but the SMuRF file contains not Scan frames (so no detector data)""" pass +class NoHKFiles(Exception): + """Exception raised when we cannot find any HK data around the book time""" + pass + class NoMountData(Exception): """Exception raised when we cannot find mount data""" pass @@ -68,14 +77,12 @@ def setup_logger(logfile=None): return log - def get_frame_iter(files): """ Returns a continuous iterator over frames for a list of files. """ return itertools.chain(*[core.G3File(f) for f in files]) - def close_writer(writer): """ Closes out a G3FileWriter with an end-processing frame. If None is passed, @@ -85,7 +92,6 @@ def close_writer(writer): return writer(core.G3Frame(core.G3FrameType.EndProcessing)) - def next_scan(it): """ Returns the next Scan frame, along with any intermediate frames for an @@ -98,7 +104,6 @@ def next_scan(it): interm_frames.append(frame) return None, interm_frames - class HkDataField: """ Class containing HK Data for a single field. @@ -271,6 +276,13 @@ def __init__(self, files, book_id, hk_fields: Dict, else: self.log = log + if len(self.files) == 0: + if self.require_acu or self.require_hwp: + raise NoHKFiles("No HK files specified for book") + self.log.warning("No HK files found for book") + for fld in ['az', 'el', 'boresight', 'corotator_enc','az_mode', 'hwp_freq']: + setattr(self.hkdata, fld, None) + if self.require_acu and self.hkdata.az is None: self.log.warning("No ACU data specified in hk_fields!") @@ -480,7 +492,6 @@ def add_acu_summary_info(self, frame, t0, t1): if k not in frame: frame[k] = np.nan - class SmurfStreamProcessor: def __init__(self, obs_id, files, book_id, readout_ids, log=None, allow_bad_timing=False): @@ -757,10 +768,10 @@ def bind(self, outdir, times, frame_idxs, file_idxs, pbar=False, ancil=None, if pbar.n >= pbar.total: pbar.close() - class BookBinder: """ - Class for combining smurf and hk L2 data to create books. + Class for combining smurf and hk L2 data to create books containing detector + timestreams. Parameters ---------- @@ -820,12 +831,12 @@ def __init__(self, book, obsdb, filedb, data_root, readout_ids, outdir, hk_field self.data_root = data_root self.hk_root = os.path.join(data_root, 'hk') self.meta_root = os.path.join(data_root, 'smurf') - self.hkfiles = get_hk_files(self.hk_root, - book.start.timestamp(), - book.stop.timestamp()) + self.obsdb = obsdb self.outdir = outdir + assert book.schema==0, "obs/oper books only have schema=0" + self.max_samps_per_frame = max_samps_per_frame self.max_file_size = max_file_size self.ignore_tags = ignore_tags @@ -851,6 +862,24 @@ def __init__(self, book, obsdb, filedb, data_root, readout_ids, outdir, hk_field logfile = os.path.join(outdir, 'Z_bookbinder_log.txt') self.log = setup_logger(logfile) + try: + self.hkfiles = get_hk_files( + self.hk_root, + book.start.timestamp(), + book.stop.timestamp() + ) + except NoHKFiles as e: + if require_hwp or require_acu: + self.log.error( + "HK files are required if we require ACU or HWP data" + ) + raise e + self.log.warning( + "Found no HK files during book time, binding anyway because " + "require_acu and require_hwp are False" + ) + self.hkfiles = [] + self.ancil = AncilProcessor( self.hkfiles, book.bid, @@ -967,6 +996,34 @@ def copy_smurf_files_to_book(self): self.meta_files = meta_files + def write_M_files(self, telescope, tube_config): + # write M_book file + m_book_file = os.path.join(self.outdir, "M_book.yaml") + book_meta = {} + book_meta["book"] = { + "type": self.book.type, + "schema_version": self.book.schema, + "book_id": self.book.bid, + "finalized_at": dt.datetime.utcnow().isoformat(), + } + book_meta["bookbinder"] = { + "codebase": sotodlib.__file__, + "version": sotodlib.__version__, + # leaving this in but KH doesn't know what it's supposed to be for + "context": "unknown", + } + with open(m_book_file, "w") as f: + yaml.dump(book_meta, f) + + mfile = os.path.join(self.outdir, "M_index.yaml") + with open(mfile, "w") as f: + yaml.dump( + self.get_metadata( + telescope=telescope, + tube_config=tube_config, + ), f + ) + def get_metadata(self, telescope=None, tube_config={}): """ Returns metadata dict for the book @@ -1091,6 +1148,108 @@ def bind(self, pbar=False): self.log.info("Finished binding data. Exiting.") return True +class TimeCodeBinder: + """Class for building the timecode based books, smurf, stray, and hk books. + These books are built primarily just by copying specified files from level + 2 locations to new locations at level 2. + """ + + def __init__( + self, book, timecode, indir, outdir, file_list=None, + ignore_pattern=None, + ): + self.book = book + self.timecode = timecode + self.indir = indir + self.outdir = outdir + self.file_list = file_list + if ignore_pattern is not None: + self.ignore_pattern = ignore_pattern + else: + self.ignore_pattern = [] + + if book.type == 'smurf' and book.schema > 0: + self.compress_output = True + else: + self.compress_output = False + + def get_metadata(self, telescope=None, tube_config={}): + return { + "book_id": self.book.bid, + # dummy start and stop times + "start_time": float(self.timecode) * 1e5, + "stop_time": (float(self.timecode) + 1) * 1e5, + "telescope": telescope, + "type": self.book.type, + } + + def write_M_files(self, telescope, tube_config): + # write M_book file + + book_meta = {} + book_meta["book"] = { + "type": self.book.type, + "schema_version": self.book.schema, + "book_id": self.book.bid, + "finalized_at": dt.datetime.utcnow().isoformat(), + } + book_meta["bookbinder"] = { + "codebase": sotodlib.__file__, + "version": sotodlib.__version__, + # leaving this in but KH doesn't know what it's supposed to be for + "context": "unknown", + } + if self.compress_output: + with ZipFile(self.outdir, mode='a') as zf: + zf.writestr("M_book.yaml", yaml.dump(book_meta)) + else: + m_book_file = os.path.join(self.outdir, "M_book.yaml") + with open(m_book_file, "w") as f: + yaml.dump(book_meta, f) + + index = self.get_metadata( + telescope=telescope, + tube_config=tube_config, + ) + if self.compress_output: + with ZipFile(self.outdir, mode='a') as zf: + zf.writestr("M_index.yaml", yaml.dump(index)) + else: + mfile = os.path.join(self.outdir, "M_index.yaml") + with open(mfile, "w") as f: + yaml.dump(index, f) + + def bind(self, pbar=False): + if self.compress_output: + if self.file_list is None: + self.file_list = walk_files(self.indir, include_suprsync=True) + ignore = shutil.ignore_patterns(*self.ignore_pattern) + to_ignore = ignore("", self.file_list) + self.file_list = sorted( + [f for f in self.file_list if f not in to_ignore] + ) + with ZipFile(self.outdir, mode='x') as zf: + for f in self.file_list: + relpath = os.path.relpath(f, self.indir) + zf.write(f, arcname=relpath) + elif self.file_list is None: + shutil.copytree( + self.indir, + self.outdir, + ignore=shutil.ignore_patterns( + *self.ignore_pattern, + ), + ) + else: + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + for f in self.file_list: + relpath = os.path.relpath(f, self.indir) + path = os.path.join(self.outdir, relpath) + base, _ = os.path.split(path) + if not os.path.exists(base): + os.makedirs(base) + shutil.copy(f, os.path.join(self.outdir, relpath)) def fill_time_gaps(ts): """ @@ -1132,7 +1291,6 @@ def fill_time_gaps(ts): return new_ts, ~m - _primary_idx_map = {} def get_frame_times(frame, allow_bad_timing=False): """ @@ -1172,7 +1330,6 @@ def get_frame_times(frame, allow_bad_timing=False): ## don't change this error message. used in Imprinter CLI raise TimingSystemOff("Timing counters not incrementing") - def split_ts_bits(c): """ Split up 64 bit to 2x32 bit @@ -1183,7 +1340,6 @@ def split_ts_bits(c): b = c & MAXINT return a, b - def counters_to_timestamps(c0, c2): s, ns = split_ts_bits(c2) @@ -1193,7 +1349,6 @@ def counters_to_timestamps(c0, c2): ts = np.round(c2 - (c0 / 480000) ) + c0 / 480000 return ts - def find_ref_idxs(refs, vs): """ Creates a mapping from a list of timestamps (vs) to a list of reference @@ -1251,8 +1406,10 @@ def get_hk_files(hkdir, start, stop, tbuff=10*60): m = (start-tbuff <= file_times) & (file_times < stop+tbuff) if not np.any(m): check = np.where( file_times <= start ) - if len(check) < 1: - raise ValueError("Cannot find HK files we need") + if len(check) < 1 or len(check[0]) < 1: + raise NoHKFiles( + f"Cannot find HK files between {start} and {stop}" + ) fidxs = [check[0][-1]] m[fidxs] = 1 else: @@ -1371,7 +1528,6 @@ def find_frame_splits(ancil, t0=None, t1=None): idxs = locate_scan_events(az.times[msk], az.data[msk], filter_window=100) return az.times[msk][idxs] - def get_smurf_files(obs, meta_path, all_files=False): """ Returns a list of smurf files that should be copied into a book. diff --git a/sotodlib/io/datapkg_completion.py b/sotodlib/io/datapkg_completion.py new file mode 100644 index 000000000..60292bbc7 --- /dev/null +++ b/sotodlib/io/datapkg_completion.py @@ -0,0 +1,715 @@ +import os +import yaml +import logging +import shutil +import numpy as np +import datetime as dt +from sqlalchemy import or_, and_, not_ +from collections import OrderedDict + +from .load_smurf import ( + TimeCodes, + SupRsyncType, + Finalize, + SmurfStatus, + logger as smurf_log +) +from .imprinter import ( + Books, + Imprinter, + BOUND, + UNBOUND, + UPLOADED, + FAILED, + WONT_BIND, + DONE, + SMURF_EXCLUDE_PATTERNS, +) +import sotodlib.io.imprinter_utils as utils +from .imprinter_cli import autofix_failed_books +from .datapkg_utils import walk_files, just_suprsync + +from .bookbinder import log as book_logger + +def combine_loggers(imprint, fname=None): + log_list = [imprint.logger, smurf_log, book_logger] + logger = logging.getLogger("DataPackaging") + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + '%(levelname)s - %(name)s - %(message)s' + ) + # Create a file handler + if fname is not None: + handler = logging.FileHandler(fname) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + [l.addHandler(handler) for l in log_list] + + # Create a stream handler to print logs to the console + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) # You can set the desired log level for console output + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + return logger + +class DataPackaging: + def __init__(self, platform, log_filename=None): + self.platform = platform + self.imprint = Imprinter.for_platform(platform) + self.logger = combine_loggers(self.imprint, fname=log_filename) + self.session = self.imprint.get_session() + if self.imprint.build_det: + self.g3session, self.SMURF = self.imprint.get_g3tsmurf_session(return_archive=True) + else: + self.g3session = None + self.SMURF = None + self.HK = self.imprint.get_g3thk() + + def get_first_timecode_on_disk(self, include_hk=True): + tc = 50000 + if self.imprint.build_det: + tc = min([ + tc, + int(sorted(os.listdir(self.SMURF.meta_path))[0]), + int(sorted(os.listdir(self.SMURF.archive_path))[0]), + ]) + if include_hk: + tc = min([ + tc, + int(sorted(os.listdir(self.HK.hkarchive_path))[0]) + ]) + if tc == 50000: + raise ValueError(f"Found no timecode folders for {self.platform}") + return tc + + def get_first_timecode_in_staged(self, include_hk=True): + q = self.session.query(Books).filter( + Books.status == UPLOADED, + ) + if not include_hk: + q = q.filter(Books.type != 'hk') + first = q.order_by(Books.start).first() + tc = int( first.start.timestamp() // 1e5) + return tc + + def all_files_in_timecode(self, timecode, include_hk=True): + flist = [] + if self.imprint.build_det: + stc = os.path.join(self.SMURF.meta_path, str(timecode)) + flist.extend(walk_files(stc, include_suprsync=True)) + ttc = os.path.join(self.SMURF.archive_path, str(timecode)) + flist.extend(walk_files(ttc, include_suprsync=True)) + if include_hk: + htc = os.path.join(self.HK.hkarchive_path, str(timecode)) + flist.extend(walk_files(htc, include_suprsync=True)) + return flist + + def get_suprsync_files(self, timecode): + if not self.imprint.build_det: + return [] + stc = os.path.join(self.SMURF.meta_path, str(timecode)) + ttc = os.path.join(self.SMURF.archive_path, str(timecode)) + flist = [] + + if not os.path.exists(stc) and not os.path.exists(ttc): + return flist + if os.path.exists(ttc) and 'suprsync' in os.listdir(ttc): + for root, _, files in os.walk(os.path.join(ttc, 'suprsync')): + for name in files: + flist.append(os.path.join(ttc, root, name)) + if os.path.exists(stc) and 'suprsync' in os.listdir(stc): + for root, _, files in os.walk(os.path.join(stc, 'suprsync')): + for name in files: + flist.append(os.path.join(stc, root, name)) + return flist + + def check_hk_registered(self, timecode, complete): + min_ctime = timecode*1e5 + max_ctime = (timecode+1)*1e5 + + self.HK.add_hkfiles( + min_ctime=min_ctime, max_ctime=max_ctime, + show_pb=False, update_last_file=False, + ) + self.imprint.register_hk_books( + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + # check the hk book is registered + book = self.session.query(Books).filter( + Books.bid == f"hk_{timecode}_{self.platform}" + ).one_or_none() + if book is None: + complete[0] = False + complete[1] += f"HK book hk_{timecode}_{self.platform} missing\n" + elif book.status == UNBOUND: + try: + self.imprint.bind_book(book) + except: + self.logger.warning(f"Failed to bind {book.bid}") + if book.status < BOUND: + complete[0] = False + complete[1] += f"Book hk_{timecode}_{self.platform} not bound" + return complete + + def make_timecode_complete( + self, timecode, try_binding_books=True, try_single_obs=True, + include_hk=True, + ): + """ + Carefully go through an entire timecode and check that the data packaging as + complete as it can be. The verification will also try and fix any errors + found in the system. Updating databases, registering books, and binding + books if try_binding_books is True + + Arguments + ---------- + timecode: int + 5-digit ctime to check for completion + try_binding_books: bool + if true, go through and try to bind any newly registered books + try_single_obs: bool + if true, tries to register any missing observations as single wafer + observations if registering as multi-wafer observations fails. This + happens sometimes if the stream lengths are very close to the + minimum overlap time + include_hk: bool + if true, also checkes everything related to hk + """ + + complete = [True, ""] + min_ctime = timecode*1e5 + max_ctime = (timecode+1)*1e5 + + if not self.imprint.build_det: + ## no detector data tracked by imprinter + if include_hk: + return self.check_hk_registered(timecode, complete) + else: + self.logger.warning( + f"No detector data built for platform " + f"{self.imprint.daq_node} and not checking HK. Nothing to " + "check for completion" + ) + return complete + + has_smurf, has_timestreams = True, True + stc = os.path.join(self.SMURF.meta_path, str(timecode)) + ttc = os.path.join(self.SMURF.archive_path, str(timecode)) + + if not os.path.exists(stc): + self.logger.debug(f"TC {timecode}: No level 2 smurf folder") + has_smurf = False + if not os.path.exists(ttc): + self.logger.debug(f"TC {timecode}: No level 2 timestream folder") + has_timestreams = False + + if os.path.exists(ttc) and just_suprsync(ttc): + self.logger.info( + f"TC {timecode}: Level 2 timestreams is only suprsync" + ) + has_timestreams = False + if os.path.exists(stc) and just_suprsync(stc): + self.logger.info(f"TC {timecode}: Level 2 smurf is only suprsync") + has_smurf = False + + if not has_smurf and not has_timestreams: + return complete + if not has_smurf and has_timestreams: + self.logger.error(f"TC {timecode}: Has timestreams folder without smurf!") + + overall_final_ctime = self.SMURF.get_final_time( + self.imprint.all_slots, check_control=False + ) + tcode_limit = int(overall_final_ctime//1e5) + if timecode+1 > tcode_limit: + raise ValueError( + f"We cannot check files from {timecode} because finalization time " + f"is {overall_final_ctime}" + ) + + self.logger.info(f"Checking Timecode {timecode} for completion") + ## check for files on disk to be in database + missing_files = self.SMURF.find_missing_files( + timecode, session=self.g3session + ) + if len(missing_files) > 0: + self.logger.warning( + f"{len(missing_files)} files not in G3tSmurf" + ) + self.SMURF.index_metadata( + min_ctime=min_ctime, + max_ctime=max_ctime, + session=self.g3session + ) + self.SMURF.index_archive( + min_ctime=min_ctime, + max_ctime=max_ctime, + show_pb=False, + session=self.g3session + ) + self.SMURF.index_timecodes( + min_ctime=min_ctime, + max_ctime=max_ctime, + session=self.g3session + ) + still_missing = len( + self.SMURF.find_missing_files(timecode, session=self.g3session) + ) + if still_missing>0: + msg = f"{still_missing} file(s) were not able to be added to the " \ + "G3tSmurf database." + self.logger.error(msg) + complete[0] = False + complete[1] += msg+"\n" + else: + self.logger.debug("All files on disk are in G3tSmurf database") + + ## check for level 2 files to be assigned to level 2 observations + missing_obs = self.SMURF.find_missing_files_from_obs( + timecode, session=self.g3session + ) + if len(missing_obs) > 0: + msg = f"{len(missing_obs)} files not assigned lvl2 obs" + no_tags = 0 + for fpath in missing_obs: + if fpath[-6:] != "000.g3": + msg += f"\n{fpath} was not added to a larger observation." \ + " Will be fixed later if possible." + else: + status = SmurfStatus.from_file(fpath) + if len(status.tags)==0: + no_tags += 1 + else: + msg += f"\Trying to add {fpath} to database" + self.SMURF.add_file( + fpath, self.g3session, overwrite=True + ) + if no_tags > 0: + msg += f"\n{no_tags} of the files have no tags, so these should "\ + "not be observations." + self.logger.warning(msg) + + ## if the stray book has already been bound then we cannot add + ## more detector books without causing problems + add_new_detector_books = True + stray_book = self.session.query(Books).filter( + Books.bid == f"stray_{timecode}_{self.platform}", + Books.status >= BOUND, + ).one_or_none() + if stray_book is not None: + add_new_detector_books = False + + ## check for incomplete observations + ## add time to max_ctime to account for observations on the edge + incomplete = self.imprint._find_incomplete( + min_ctime, max_ctime+24*2*3600 + ) + if incomplete.count() > 0: + ic_list = incomplete.all() + """Check if these are actually incomplete, imprinter incomplete checker + includes making sure the stop isn't beyond max ctime. + """ + obs_list = [] + for obs in ic_list: + if obs.stop is None or obs.timestamp <= max_ctime: + obs_list.append(obs) + + ## complete these no matter what for file tracking / deletion + self.logger.warning( + f"Found {len(obs_list)} incomplete observations. Fixing" + ) + for obs in obs_list: + self.logger.debug(f"Updating {obs}") + self.SMURF.update_observation_files( + obs, + self.g3session, + force=True, + ) + + ## make sure all obs / operation books from this period are registered + ## looks like short but overlapping observations are sometimes missed, + ## use `try_single_obs` flag to say if we want to try and clean those up + missing = self.imprint.find_missing_lvl2_obs_from_books( + min_ctime,max_ctime + ) + if add_new_detector_books and len(missing) > 0: + self.logger.info( + f"{len(missing)} lvl2 observations are not registered in books." + " Trying to register them" + ) + ## add time to max_ctime to account for observations on the edge + self.imprint.update_bookdb_from_g3tsmurf( + min_ctime=min_ctime, max_ctime=max_ctime+24*2*3600, + ) + still_missing = self.imprint.find_missing_lvl2_obs_from_books( + min_ctime,max_ctime + ) + if len(still_missing) > 0 and try_single_obs: + self.logger.warning("Trying single stream registration") + self.imprint.update_bookdb_from_g3tsmurf( + min_ctime=min_ctime, max_ctime=max_ctime+24*2*3600, + force_single_stream=True, + ) + still_missing = self.imprint.find_missing_lvl2_obs_from_books( + min_ctime,max_ctime + ) + if len(still_missing) > 0: + msg = f"Level 2 observations {still_missing} could not be " \ + "registered in books" + self.logger.error(msg) + complete[0] = False + complete[1] += msg+"\n" + elif not add_new_detector_books and len(missing)>0: + msg = f"Have level 2 observations missing but cannot add new " \ + f"detector books because {timecode} was already finalized " \ + " and stray exists. These files should be in stray" + self.logger.warning(msg) + + ## at this point, if an obs or oper book is going to be registered it is + if try_binding_books: + books = self.session.query(Books).filter( + Books.status == UNBOUND, + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start <= dt.datetime.utcfromtimestamp(max_ctime), + ).all() + self.logger.info(f"{len(books)} new books to bind") + for book in books: + try: + self.imprint.bind_book(book) + except: + self.logger.warning(f"Failed to bind {book.bid}") + + failed = self.session.query(Books).filter( + Books.status == FAILED, + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start <= dt.datetime.utcfromtimestamp(max_ctime), + ).all() + if len(failed) > 0: + self.logger.info( + f"{len(failed)} books failed to bind. trying to autofix" + ) + autofix_failed_books( + self.imprint, + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + + is_final, reason = utils.get_timecode_final(self.imprint, timecode) + if not is_final: + self.logger.info( + f"Timecode {timecode} not counted as final: reason {reason}" + ) + meta_entries = self.g3session.query(TimeCodes).filter( + TimeCodes.timecode == timecode, + TimeCodes.suprsync_type == SupRsyncType.META.value, + ).count() + file_entries = self.g3session.query(TimeCodes).filter( + TimeCodes.timecode == timecode, + TimeCodes.suprsync_type == SupRsyncType.FILES.value, + ).count() + if ( + meta_entries == len(self.imprint.all_slots) and + file_entries == len(self.imprint.all_slots) + ): + self.logger.info( + f"{timecode} was part of the mixed up timecode agent entries" + ) + elif timecode < tcode_limit: + self.logger.info( + f"At least one server was likely off during timecode {timecode}" + ) + self.logger.info( + f"Setting timecode {timecode} to final in SMuRF database" + ) + utils.set_timecode_final(self.imprint, timecode) + + self.imprint.register_timecode_books( + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + + if try_binding_books: + books = self.session.query(Books).filter( + Books.status == UNBOUND, + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start <= dt.datetime.utcfromtimestamp(max_ctime), + ).all() + self.logger.info(f"{len(books)} new to bind") + for book in books: + try: + self.imprint.bind_book(book) + except: + self.logger.warning(f"Failed to bind {book.bid}") + + # check the smurf book is registered + book = self.session.query(Books).filter( + Books.bid == f"smurf_{timecode}_{self.platform}" + ).one_or_none() + if book is None: + complete[0] = False + complete[1] += f"SMuRF book smurf_{timecode}_{self.platform} missing\n" + if include_hk: + complete = self.check_hk_registered(timecode, complete) + + # check if there's a stray book + stray = self.session.query(Books).filter( + Books.bid == f"stray_{timecode}_{self.platform}" + ).one_or_none() + if stray is None and try_binding_books: + # all files should be in obs/oper books + flist = self.imprint.get_files_for_stray_book( + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + if len(flist) > 0: + complete[0] = False + complete[1] += f"Stray book stray_{timecode}_{self.platform} missing\n" + elif stray is None and not try_binding_books: + my_list = self.imprint.get_files_for_stray_book( + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + if len(my_list) > 0: + self.logger.warning( + f"We expect {len(my_list)} books in a stray book but need " + "to bind books to verify" + ) + complete[0] = False + complete[1] += f"Stray book stray_{timecode}_{self.platform} missing\n" + else: + flist = self.imprint.get_files_for_book(stray) + my_list = self.imprint.get_files_for_stray_book( + min_ctime=min_ctime, + max_ctime=max_ctime, + ) + assert np.all( + sorted(flist) == sorted(my_list) + ), "logic error somewhere" + ## check that all books are bound + books = self.session.query(Books).filter( + or_(Books.status == UNBOUND, Books.status == FAILED), + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start <= dt.datetime.utcfromtimestamp(max_ctime), + ).count() + if books != 0: + complete[0] = False + complete[1] += f"Have {books} unbound or failed books in timecode \n" + return complete + + def books_in_timecode( + self, timecode, include_wont_fix=False, include_hk=True + ): + min_ctime = timecode*1e5 + max_ctime = (timecode+1)*1e5 + + q = self.session.query(Books).filter( + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start < dt.datetime.utcfromtimestamp(max_ctime), + ) + if not include_wont_fix: + q = q.filter(Books.status != WONT_BIND) + if not include_hk: + q = q.filter(Books.type != 'hk') + return q.all() + + def file_list_from_database( + self, timecode, deletable, verify_with_librarian, include_hk=True, + ): + file_list = [] + min_ctime = timecode*1e5 + max_ctime = (timecode+1)*1e5 + + q = self.session.query(Books).filter( + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start < dt.datetime.utcfromtimestamp(max_ctime), + ) + not_ready = q.filter( not_(or_( + Books.status == WONT_BIND, Books.status >= UPLOADED) + )).count() + if not_ready > 0: + self.logger.error( + f"There are {not_ready} non-uploaded books in this timecode" + ) + deletable[0] = False + deletable[1] += f"There are {not_ready} non-uploaded books in " \ + "this timecode\n" + if not include_hk: + q = q.filter(Books.type != 'hk') + book_list = q.filter(Books.status >= UPLOADED).all() + self.logger.debug( + f"Found {len(book_list)} books in time code {timecode}" + ) + + for book in book_list: + if book.lvl2_deleted: + continue + if verify_with_librarian: + in_lib = self.imprint.check_book_in_librarian( + book, n_copies=1, raise_on_error=False + ) + if not in_lib: + deletable[0] = False + deletable[1] += f"{book.bid} has not been uploaded to librarain\n" + + flist = self.imprint.get_files_for_book(book) + if isinstance(flist, OrderedDict): + x = [] + for k in flist: + x.extend(flist[k]) + flist=x + file_list.extend(flist) + # add suprsync files + file_list.extend( self.get_suprsync_files(timecode) ) + return file_list, deletable + + def verify_timecode_deletable( + self, timecode, verify_with_librarian=True, include_hk=True, + ): + """ + Checkes that all books in that timecode are uploaded to the librarian + and that there is a copy offsite (if verify_with_librarian=True) + + Steps for checking: + + 1. Walk the file system and build up a list of all files there + 2. Go book by book within timecode and build up the list of level 2 + files that went into it using the databases. Add any files in suprsync + folders into this list since they aren't book bound but we'd like them + to be deleted + 3. Compare the two lists and make sure they're the same. + """ + deletable = [True, ""] + + files_on_disk = self.all_files_in_timecode( + timecode, include_hk=include_hk + ) + if len(files_on_disk) == 0: + return deletable + # these are files that are in the smurf directory but we don't save in the + # smurf books. mostly watching out for .dat files + ignore = shutil.ignore_patterns(*SMURF_EXCLUDE_PATTERNS) + ignored_files = ignore("", files_on_disk) + self.logger.debug( + f"Timecode {timecode} has {len(ignored_files)} ignored files" + ) + files_in_database, deletable = self.file_list_from_database( + timecode, deletable, verify_with_librarian, include_hk=include_hk + ) + + missed_files = [] + extra_files = [] + for f in files_on_disk: + if f not in files_in_database and f not in ignored_files: + missed_files.append(f) + for f in files_in_database: + if f not in files_on_disk: + extra_files.append(f) + if len(missed_files) == 0 and len(extra_files) == 0: + self.logger.info(f"Timecode {timecode} has complete coverage") + if len(missed_files)>0: + msg = f"Files on disk but not in database {len(missed_files)}:\n" + for f in missed_files: + msg += f"\t{f}\n" + self.logger.warning(msg) + deletable[0] = False + deletable[1] += msg + if len(extra_files)>0: + msg = f"Files in database but not on disk: {extra_files}" + for f in missed_files: + msg += f"\t{f}\n" + self.logger.error(msg) + deletable[0] = False + deletable[1] += msg + return deletable + + def delete_timecode_level2( + self, timecode, dry_run=True, include_hk=True, + verify_with_librarian=True, + ): + book_list = self.books_in_timecode(timecode, include_hk=include_hk) + books_not_deleted = [] + + for book in book_list: + stat = self.imprint.delete_level2_files( + book, verify_with_librarian=verify_with_librarian, + n_copies_in_lib=2, dry_run=dry_run + ) + if stat > 0: + books_not_deleted.append(book) + + if len(books_not_deleted) > 0: + msg = "Could not delete level 2 for books:\n" + for book in books_not_deleted: + msg += f'\t{book.bid}\n' + self.logger.error(msg) + return False, "" + return True, "" + + + def delete_timecode_staged( + self, timecode, include_hk=True, verify_with_librarian=False, + check_level2=False, + ): + book_list = self.books_in_timecode(timecode, include_hk=include_hk) + books_not_deleted = [] + for book in book_list: + stat = self.imprint.delete_book_staged( + book, check_level2=check_level2, + verify_with_librarian=verify_with_librarian + ) + if stat > 0: + books_not_deleted.append(book) + # cleanup + for tube in self.imprint.tubes: + for btype in ['obs', 'oper']: + path = os.path.join( + self.imprint.output_root, tube, btype, str(timecode) + ) + if os.path.exists(path) and len(os.listdir(path))==0: + os.rmdir(path) + + if len(books_not_deleted) > 0: + msg = "Could not delete stages for books:\n" + for book in books_not_deleted: + msg += f'\t{book.bid}\n' + self.logger.error(msg) + return False, "msg" + return True, "" + + def check_and_delete_timecode( + self, timecode, include_hk=True, verify_with_librarian=True + ): + check = self.make_timecode_complete(timecode, include_hk=include_hk) + if not check[0]: + self.logger.error(f"Timecode {timecode} not complete") + self.logger.error(check[1]) + return check + check = self.verify_timecode_deletable( + timecode, include_hk=include_hk, + verify_with_librarian=False, + ) + if not check[0]: + self.logger.error(f"Timecode {timecode} not ready to delete") + self.logger.error(check[1]) + return check + + check = self.delete_timecode_level2( + timecode, dry_run=False, include_hk=include_hk, + verify_with_librarian=verify_with_librarian, + ) + + if not self.imprint.build_det: + return check + stc = os.path.join(self.SMURF.meta_path, str(timecode)) + ttc = os.path.join(self.SMURF.archive_path, str(timecode)) + + if os.path.exists(stc): + if len(os.listdir(stc)) == 0 or just_suprsync(stc): + shutil.rmtree(stc) + if os.path.exists(ttc): + if len(os.listdir(ttc)) == 0 or just_suprsync(ttc): + shutil.rmtree(ttc) + return check diff --git a/sotodlib/io/datapkg_utils.py b/sotodlib/io/datapkg_utils.py index ef26cc264..c28533bc8 100644 --- a/sotodlib/io/datapkg_utils.py +++ b/sotodlib/io/datapkg_utils.py @@ -63,3 +63,35 @@ def get_imprinter_config( platform, env_file=None, env_var="DATAPKG_ENV"): raise ValueError(f"configs not found in tags {tags}") return os.path.join( tags['configs'], platform, 'imprinter.yaml') + +def walk_files(path, include_suprsync=False): + """get a list of the files in a timecode folder, optional flag to ignore + suprsync files + + Arguments + ---------- + path: path to a level 2 timecode folder, either smurf or timestreams + include_suprsync: optional, bool + if true, includes the suprsync files in the returned list + + Returns + -------- + files (list): list of the absolute paths to all files in a timecode folder + """ + if not os.path.exists(path): + return [] + flist = [] + for root, _, files in os.walk(path): + if not include_suprsync and 'suprsync' in root: + continue + for f in files: + flist.append( os.path.join(path, root, f)) + return flist + +def just_suprsync(path): + """check if timecode folder only has suprsync folder in it + """ + flist = os.listdir( path ) + if len(flist) == 1 and flist[0] == "suprsync": + return True + return False diff --git a/sotodlib/io/g3thk_db.py b/sotodlib/io/g3thk_db.py index 743f28a69..08829da0d 100644 --- a/sotodlib/io/g3thk_db.py +++ b/sotodlib/io/g3thk_db.py @@ -13,6 +13,7 @@ import logging from .datapkg_utils import load_configs + logger = logging.getLogger(__name__) Base = declarative_base() @@ -126,7 +127,7 @@ class HKFields(Base): class G3tHk: - def __init__(self, hkarchive_path, db_path=None, echo=False): + def __init__(self, hkarchive_path, iids, db_path=None, echo=False): """ Class to manage a housekeeping data archive @@ -134,6 +135,8 @@ def __init__(self, hkarchive_path, db_path=None, echo=False): ____ hkarchive_path : path Path to the data directory + iids : list + List of agent instance ids db_path : path, optional Path to the sqlite file echo : bool, optional @@ -144,13 +147,15 @@ def __init__(self, hkarchive_path, db_path=None, echo=False): self.hkarchive_path = hkarchive_path self.db_path = db_path + self.iids = iids self.engine = db.create_engine(f"sqlite:///{db_path}", echo=echo) Session.configure(bind=self.engine) self.Session = sessionmaker(bind=self.engine) self.session = Session() Base.metadata.create_all(self.engine) - def load_fields(self, hk_path): + + def load_fields(self, hk_path, iids): """ Load fields from .g3 file and start and end time for each field. @@ -167,14 +172,15 @@ def load_fields(self, hk_path): # enact HKArchiveScanner hkas = hk.HKArchiveScanner() hkas.process_file(hk_path) - + arc = hkas.finalize() # get fields from .g3 file fields, timelines = arc.get_fields() hkfs = [] for key in fields.keys(): - hkfs.append(key) + if any(iid in key for iid in iids): + hkfs.append(key) starts = [] stops = [] @@ -346,14 +352,14 @@ def add_agents_and_fields(self, path, overwrite=False): .one() ) - db_agents = db_file.agents + # line below may not be needed; is redundant + db_agents = [a for a in db_file.agents if a.instance_id in self.iids] db_fields = db_file.fields - agents = [] - - out = self.load_fields(db_file.path) + out = self.load_fields(db_file.path, self.iids) fields, starts, stops, medians, means, min_vals, max_vals, stds = out + agents = [] for field in fields: agent = field.split(".")[1] agents.append(agent) @@ -531,10 +537,12 @@ def get_last_update(self): .order_by(db.desc(HKFiles.global_start_time)) .first() ) + if len(last_file.agents) == 0: + return last_file.global_start_time return max([a.stop for a in last_file.agents]) @classmethod - def from_configs(cls, configs): + def from_configs(cls, configs, iids=None): """ Create a G3tHK instance from a configs dictionary @@ -545,12 +553,34 @@ def from_configs(cls, configs): if type(configs) == str: configs = load_configs(configs) + if iids is None: + iids = [] + if "finalization" in configs: + servers = configs["finalization"].get("servers", {}) + for server in servers: + for key in server.keys(): + # Append the value (iid) to the iids list + iids.append(server[key]) + else: + logger.debug( + "No finalization information in configuration, agents and " + "fields will not be added." + ) + return cls( - os.path.join(configs["data_prefix"], "hk"), - configs["g3thk_db"] + hkarchive_path = os.path.join(configs["data_prefix"], "hk"), + db_path = configs["g3thk_db"], + iids = iids ) - def delete_file(self, hkfile, dry_run=False, my_logger=None): + def batch_delete_files(self, file_list, dry_run=False, my_logger=None): + for f in file_list: + self.delete_file( + f, dry_run=dry_run, my_logger=my_logger, commit=False + ) + self.session.commit() + + def delete_file(self, hkfile, dry_run=False, my_logger=None, commit=True): """WARNING: Removes actual files from file system. Delete an hkfile instance, its on-disk file, and all associated agents @@ -565,13 +595,13 @@ def delete_file(self, hkfile, dry_run=False, my_logger=None): my_logger = logger # remove field info - my_logger.info(f"removing field entries for {hkfile.path} from database") + my_logger.debug(f"removing field entries for {hkfile.path} from database") if not dry_run: for f in hkfile.fields: self.session.delete(f) # remove agent info - my_logger.info(f"removing agent entries for {hkfile.path} from database") + my_logger.debug(f"removing agent entries for {hkfile.path} from database") if not dry_run: for a in hkfile.agents: self.session.delete(a) @@ -590,4 +620,5 @@ def delete_file(self, hkfile, dry_run=False, my_logger=None): my_logger.info(f"remove {hkfile.path} from database") if not dry_run: self.session.delete(hkfile) - self.session.commit() + if commit: + self.session.commit() diff --git a/sotodlib/io/hkdb.py b/sotodlib/io/hkdb.py index 999648e7f..68ceea631 100644 --- a/sotodlib/io/hkdb.py +++ b/sotodlib/io/hkdb.py @@ -356,12 +356,17 @@ class LoadSpec: End time to load downsample_factor: int Downsample factor for data + hkdb: Optional[HkDb] + HkDb instance to use. If not specified, will create a new one from the + cfg. This should be set manually if you are calling ``load_hk`` in a loop + to prevent connection build-up. """ cfg: HkConfig fields: List[str] start: float end: float downsample_factor: int = 1 + hkdb: Optional[HkDb] = None def __post_init__(self): fs = [] @@ -418,7 +423,11 @@ def load_hk(load_spec: Union[LoadSpec, dict], show_pb=False): if isinstance(load_spec, dict): load_spec = LoadSpec(**load_spec) - hkdb = HkDb(load_spec.cfg) + if load_spec.hkdb is not None: + hkdb: HkDb = load_spec.hkdb + else: + hkdb = HkDb(load_spec.cfg) + agent_set = list(set(f.agent for f in load_spec.fields)) file_spec = {} # {path: [offsets]} diff --git a/sotodlib/io/imprinter.py b/sotodlib/io/imprinter.py index 22fb9f84c..0a5a29ecb 100644 --- a/sotodlib/io/imprinter.py +++ b/sotodlib/io/imprinter.py @@ -17,7 +17,7 @@ from spt3g import core import sotodlib -from .bookbinder import BookBinder +from .bookbinder import BookBinder, TimeCodeBinder from .load_smurf import ( G3tSmurf, Observations as G3tObservations, @@ -49,6 +49,8 @@ # tel tube, stream_id, slot mapping VALID_OBSTYPES = ["obs", "oper", "smurf", "hk", "stray", "misc"] +# file patterns excluded from smurf books +SMURF_EXCLUDE_PATTERNS = ["*.dat", "*_mask.txt", "*_freq.txt"] class BookExistsError(Exception): """Exception raised when a book already exists in the database""" @@ -132,6 +134,7 @@ class Books(Base): timing = db.Column(db.Boolean) path = db.Column(db.String) lvl2_deleted = db.Column(db.Boolean, default=False) + schema = db.Column(db.Integer, default=0) def __repr__(self): return f"" @@ -237,10 +240,12 @@ def __init__(self, im_config=None, db_args={}, logger=None, make_db=False): self.config = load_configs(im_config) self.db_path = self.config.get("db_path") - self.daq_node = self.config.get("daq_node") + self.daq_node = self.config.get("daq_node") self.output_root = self.config.get("output_root") self.g3tsmurf_config = self.config.get("g3tsmurf") - + g3tsmurf_cfg = load_configs(self.g3tsmurf_config) + self.lvl2_data_root = g3tsmurf_cfg["data_prefix"] + self.build_hk = self.config.get("build_hk") self.build_det = self.config.get("build_det") @@ -371,6 +376,7 @@ def register_book(self, obsset, bid=None, commit=True, session=None): if bid is None: bid = obsset.get_id() assert obsset.mode is not None + assert obsset.mode in ['obs','oper'] # check whether book exists in the database if self.book_exists(bid, session=session): raise BookExistsError(f"Book {bid} already exists in the database") @@ -406,6 +412,7 @@ def register_book(self, obsset, bid=None, commit=True, session=None): [s for s in obsset.slots if obsset.contains_stream(s)] ), # not worth having a extra table timing=timing_on, + schema=0, ) book.path = self.get_book_path(book) @@ -444,9 +451,6 @@ def register_hk_books(self, min_ctime=None, max_ctime=None, session=None): session = session or self.get_session() if not self.build_hk: return - - g3tsmurf_cfg = load_configs(self.g3tsmurf_config) - lvl2_data_root = g3tsmurf_cfg["data_prefix"] if min_ctime is None: min_ctime = 16000e5 @@ -454,7 +458,7 @@ def register_hk_books(self, min_ctime=None, max_ctime=None, session=None): max_ctime = 5e10 # all ctime dir except the last ctime dir will be considered complete - ctime_dirs = sorted(glob(op.join(lvl2_data_root, "hk", "*"))) + ctime_dirs = sorted(glob(op.join(self.lvl2_data_root, "hk", "*"))) for ctime_dir in ctime_dirs[:-1]: ctime = op.basename(ctime_dir) if int(ctime) < int(min_ctime//1e5): @@ -473,6 +477,7 @@ def register_hk_books(self, min_ctime=None, max_ctime=None, session=None): start=dt.datetime.utcfromtimestamp(int(ctime) * 1e5), stop=dt.datetime.utcfromtimestamp((int(ctime) + 1) * 1e5), tel_tube=self.daq_node, + schema=0, ) book.path = self.get_book_path(book) session.add(book) @@ -498,9 +503,8 @@ def register_timecode_books( smurf books are registered whenever all the relevant metadata timecode entries have been found. stray books are registered when metadata and - file timecode entries exist ASSUMING all obs/oper books in that time + file timecode entries exist AND all obs/oper books in that time range have been bound successfully. - """ if not self.build_det: @@ -508,6 +512,10 @@ def register_timecode_books( session = session or self.get_session() g3session, SMURF = self.get_g3tsmurf_session(return_archive=True) + final_time = SMURF.get_final_time( + self.all_slots, min_ctime, max_ctime, check_control=False + ) + final_tc = int(final_time//1e5) servers = SMURF.finalize["servers"] meta_agents = [s["smurf-suprsync"] for s in servers] files_agents = [s["timestream-suprsync"] for s in servers] @@ -523,6 +531,12 @@ def register_timecode_books( tcs = tcs.distinct().all() for (tc,) in tcs: + if tc >= final_tc: + self.logger.info( + f"Not ready to make timecode books for {tc} because final" + f" timecode is {final_tc}" + ) + continue q = g3session.query(TimeCodes).filter( TimeCodes.timecode == tc, ) @@ -556,6 +570,7 @@ def register_timecode_books( tel_tube=self.daq_node, start=book_start, stop=book_stop, + schema=1, ) smurf_book.path = self.get_book_path(smurf_book) session.add(smurf_book) @@ -580,21 +595,26 @@ def register_timecode_books( ) if q.count() > 0: self.logger.info( - f"Not ready to bind {book_id} due to unbound or " + f"Not ready to register {book_id} due to unbound or " "failed obs/oper books." ) continue - stray_book = Books( - bid=book_id, - type="stray", - status=UNBOUND, - tel_tube=self.daq_node, - start=book_start, - stop=book_stop, + + flist = self.get_files_for_stray_book( + min_ctime= tc * 1e5, + max_ctime= (tc + 1) * 1e5 ) - stray_book.path = self.get_book_path(stray_book) - flist = self.get_files_for_book(stray_book) if len(flist) > 0: + stray_book = Books( + bid=book_id, + type="stray", + status=UNBOUND, + tel_tube=self.daq_node, + start=book_start, + stop=book_stop, + schema=0, + ) + stray_book.path = self.get_book_path(stray_book) self.logger.info(f"registering {book_id}") session.add(stray_book) session.commit() @@ -607,8 +627,6 @@ def get_book_abs_path(self, book): return os.path.join(self.output_root, book_path) def get_book_path(self, book): - g3tsmurf_cfg = load_configs(self.g3tsmurf_config) - lvl2_data_root = g3tsmurf_cfg["data_prefix"] if book.type in ["obs", "oper"]: session_id = book.bid.split("_")[1] @@ -617,11 +635,12 @@ def get_book_path(self, book): return os.path.join(odir, book.bid) elif book.type in ["hk", "smurf"]: # get source directory for hk book - root = op.join(lvl2_data_root, book.type) first5 = book.bid.split("_")[1] assert first5.isdigit(), f"first5 of {book.bid} is not a digit" - odir = op.join(book.tel_tube, book.type) - return os.path.join(odir, book.bid) + odir = op.join(book.tel_tube, book.type, book.bid) + if book.type == 'smurf' and book.schema > 0: + return odir + '.zip' + return odir elif book.type in ["stray"]: first5 = book.bid.split("_")[1] assert first5.isdigit(), f"first5 of {book.bid} is not a digit" @@ -641,8 +660,6 @@ def _get_binder_for_book(self, require_acu=True, ): """get the appropriate bookbinder for the book based on its type""" - g3tsmurf_cfg = load_configs(self.g3tsmurf_config) - lvl2_data_root = g3tsmurf_cfg["data_prefix"] if book.type in ["obs", "oper"]: book_path = self.get_book_abs_path(book) @@ -660,7 +677,7 @@ def _get_binder_for_book(self, # bind book using bookbinder library bookbinder = BookBinder( - book, obsdb, filedb, lvl2_data_root, readout_ids, book_path, hk_fields, + book, obsdb, filedb, self.lvl2_data_root, readout_ids, book_path, hk_fields, ignore_tags=ignore_tags, ancil_drop_duplicates=ancil_drop_duplicates, allow_bad_timing=allow_bad_timing, @@ -671,86 +688,43 @@ def _get_binder_for_book(self, elif book.type in ["hk", "smurf"]: # get source directory for hk book - root = op.join(lvl2_data_root, book.type) - first5 = book.bid.split("_")[1] - assert first5.isdigit(), f"first5 of {book.bid} is not a digit" - book_path_src = op.join(root, first5) + root = op.join(self.lvl2_data_root, book.type) + timecode = book.bid.split("_")[1] + assert timecode.isdigit(), f"timecode of {book.bid} is not a digit" + book_path_src = op.join(root, timecode) # get target directory for hk book - odir = op.join(self.output_root, book.tel_tube, book.type) + book_path_tgt = self.get_book_abs_path(book) + odir, _ = op.split(book_path_tgt) if not op.exists(odir): os.makedirs(odir) - book_path_tgt = os.path.join(odir, book.bid) - - class _FakeBinder: # dummy class to mimic baseline bookbinder - def __init__(self, indir, outdir): - self.indir = indir - self.outdir = outdir - - def get_metadata(self, telescope=None, tube_config={}): - return { - "book_id": book.bid, - # dummy start and stop times - "start_time": float(first5) * 1e5, - "stop_time": (float(first5) + 1) * 1e5, - "telescope": telescope, - "type": book.type, - } - - def bind(self, pbar=False): - shutil.copytree( - self.indir, - self.outdir, - ignore=shutil.ignore_patterns( - "*.dat", "*_mask.txt", "*_freq.txt" - ), - ) - - return _FakeBinder(book_path_src, book_path_tgt) + + bookbinder = TimeCodeBinder( + book, timecode, book_path_src, book_path_tgt, + ignore_pattern=SMURF_EXCLUDE_PATTERNS, + ) + return bookbinder elif book.type in ["stray"]: flist = self.get_files_for_book(book) # get source directory for stray book - root = op.join(lvl2_data_root, "timestreams") - first5 = book.bid.split("_")[1] - assert first5.isdigit(), f"first5 of {book.bid} is not a digit" - book_path_src = op.join(root, first5) - - # get target directory for hk book - odir = op.join(self.output_root, book.tel_tube, book.type) + root = op.join(self.lvl2_data_root, "timestreams") + timecode = book.bid.split("_")[1] + assert timecode.isdigit(), f"timecode of {book.bid} is not a digit" + book_path_src = op.join(root, timecode) + + # get target directory for book + book_path_tgt = self.get_book_abs_path(book) + odir, _ = op.split(book_path_tgt) if not op.exists(odir): os.makedirs(odir) - book_path_tgt = os.path.join(odir, book.bid) - - class _FakeBinder: # dummy class to mimic baseline bookbinder - def __init__(self, indir, outdir, file_list): - self.indir = indir - self.outdir = outdir - self.file_list = file_list - - def get_metadata(self, telescope=None, tube_config={}): - return { - "book_id": book.bid, - # dummy start and stop times - "start_time": float(first5) * 1e5, - "stop_time": (float(first5) + 1) * 1e5, - "telescope": telescope, - "type": book.type, - } - - def bind(self, pbar=False): - if not os.path.exists(self.outdir): - os.makedirs(self.outdir) - for f in self.file_list: - relpath = os.path.relpath(f, self.indir) - path = os.path.join(self.outdir, relpath) - base, _ = os.path.split(path) - if not os.path.exists(base): - os.makedirs(base) - shutil.copy(f, os.path.join(self.outdir, relpath)) - - return _FakeBinder(book_path_src, book_path_tgt, flist) + + bookbinder = TimeCodeBinder( + book, timecode, book_path_src, book_path_tgt, + file_list=flist, + ) + return bookbinder else: raise NotImplementedError( f"binder for book type {book.type} not implemented" @@ -828,38 +802,13 @@ def bind_book( require_hwp=require_hwp, ) binder.bind(pbar=pbar) - - # write M_book file - m_book_file = os.path.join(binder.outdir, "M_book.yaml") - book_meta = {} - book_meta["book"] = { - "type": book.type, - "schema_version": 0, - "book_id": book.bid, - "finalized_at": dt.datetime.utcnow().isoformat(), - } - book_meta["bookbinder"] = { - "codebase": sotodlib.__file__, - "version": sotodlib.__version__, - "context": self.config.get("context", "unknown"), - } - with open(m_book_file, "w") as f: - yaml.dump(book_meta, f) - + # write M_index file if book.type in ['obs', 'oper']: tc = self.tube_configs[book.tel_tube] else: tc = {} - - mfile = os.path.join(binder.outdir, "M_index.yaml") - with open(mfile, "w") as f: - yaml.dump( - binder.get_metadata( - telescope=self.daq_node, - tube_config = tc, - ), f - ) + binder.write_M_files(self.daq_node, tc) if book.type in ['obs', 'oper']: # check that detectors books were written out correctly @@ -944,65 +893,6 @@ def get_books_by_status(self, status, session=None): return session.query(Books).filter( Books.status == status ).order_by(Books.start).all() - - def get_level2_deleteable_books( - self, session=None, cleanup_delay=None, max_time=None - ): - """Get all bound books from database where we need to delete the level2 - data - - Parameters - ---------- - session: BookDB session - cleanup_delay: float - amount of time to delay book deletation relative to g3tsmurf finalization - time in units of days. - max_time: datetime - maxmimum time of book start to search. Overrides cleanup_delay if - earlier - - Returns - ------- - books: list of book objects - """ - raise NotImplementedError("This function hasn't been fixed yet") - if session is None: - session = self.get_session() - if cleanup_delay is None: - cleanup_delay = 0 - - base_filt = and_( - Books.status == BOUND, - Books.lvl2_deleted == False, - or_( ## not implementing smurf deletion just yet - Books.type == "obs", - Books.type == "oper", - Books.type == "stray", - Books.type == "hk", - ), - ) - sources = session.query( - Books.tel_tube - ).filter(base_filt).distinct().all() - - source_filt = [] - for source, in sources: - streams = self.tubes[source].get("slots") - _, SMURF = self.get_g3tsmurf_session(source, return_archive=True) - limit = SMURF.get_final_time(streams, check_control=False) - max_stop = dt.datetime.utcfromtimestamp(limit) - dt.timedelta(days=cleanup_delay) - - source_filt.append( and_(Books.tel_tube == source, Books.stop <= max_stop) ) - - q = session.query(Books).filter( - base_filt, - or_(*source_filt), - ) - - if max_time is not None: - q = q.filter(Books.stop <= max_time) - - return q.all() # some aliases for readability def get_unbound_books(self, session=None): @@ -1033,6 +923,20 @@ def get_bound_books(self, session=None): """ return self.get_books_by_status(BOUND, session) + def get_done_books(self, session=None): + """Get all "done" books from database. Done means staged files are deleted. + + Parameters + ---------- + session: BookDB session + + Returns + ------- + books: list of book objects + + """ + return self.get_books_by_status(DONE, session) + def get_failed_books(self, session=None): """Get all failed books from database @@ -1138,23 +1042,23 @@ def rollback(self, session=None): session = self.get_session() session.rollback() - def _find_incomplete(self, min_ctime, max_ctime, stream_filt=None): + @property + def all_slots(self): + return [x for xs in [ + t.get('slots') for (_,t) in self.tubes.items() + ] for x in xs] + + def _find_incomplete(self, min_ctime, max_ctime, streams=None): """return G3tSmurf session query for incomplete observations """ - if stream_filt is None: - streams = [] - streams.extend( - *[t.get("slots") for (_,t) in self.tubes.items()] - ) - stream_filt = or_( - *[G3tObservations.stream_id == s for s in streams] - ) + if streams is None: + streams = self.all_slots session = self.get_g3tsmurf_session() q = session.query(G3tObservations).filter( G3tObservations.timestamp >= min_ctime, G3tObservations.timestamp <= max_ctime, - stream_filt, + G3tObservations.stream_id.in_(streams), or_( G3tObservations.stop == None, G3tObservations.stop >= dt.datetime.utcfromtimestamp(max_ctime), @@ -1226,9 +1130,6 @@ def update_bookdb_from_g3tsmurf( streams = stream_ids self.logger.debug(f"Looking for observations from stream_ids {streams}") - # restrict to given stream ids (wafers) - stream_filt = or_(*[G3tObservations.stream_id == s for s in streams]) - # check data transfer finalization final_time = SMURF.get_final_time( streams, min_ctime, max_ctime, check_control=True @@ -1238,7 +1139,7 @@ def update_bookdb_from_g3tsmurf( self.logger.debug(f"Searching between {min_ctime} and {max_ctime}") # check for incomplete observations in time range - q_incomplete = self._find_incomplete(min_ctime, max_ctime, stream_filt) + q_incomplete = self._find_incomplete(min_ctime, max_ctime, streams) # if we have incomplete observations in our stream_id list we cannot # bookbind any observations overlapping the incomplete ones. @@ -1275,7 +1176,7 @@ def update_bookdb_from_g3tsmurf( obs_q = session.query(G3tObservations).filter( G3tObservations.timestamp >= min_ctime, G3tObservations.timestamp < max_ctime, - stream_filt, + G3tObservations.stream_id.in_(streams), G3tObservations.stop < max_stop, not_(G3tObservations.stop == None), G3tObservations.obs_id.not_in(already_registered), @@ -1311,7 +1212,7 @@ def add_to_output(obs_list, mode): # observations from other streams q = obs_q.filter( G3tObservations.stream_id != str_obs.stream_id, - stream_filt, + G3tObservations.stream_id.in_(streams), or_( and_( G3tObservations.start <= str_obs.start, @@ -1437,38 +1338,10 @@ def get_files_for_book(self, book): res[o.obs_id] = sorted([f.name for f in o.files]) return res elif book.type in ["stray"]: - session = self.get_session() - - ## build list of files already in bound books - book_list = session.query(Books).filter( - Books.start >= book.start, - Books.start < book.stop, - or_(Books.type == 'obs', Books.type == 'oper'), - Books.status != WONT_BIND, - ).all() - files_in_books = [] - for b in book_list: - flist = self.get_files_for_book(b) - for k in flist: - files_in_books.extend(flist[k]) - - g3session = self.get_g3tsmurf_session() - tcode = int(book.bid.split("_")[1]) - - files_in_tc = g3session.query(Files).filter( - Files.name.like(f"%/{tcode}/%"), - ).all() - files_in_tc = [f.name for f in files_in_tc] - - files_into_stray = [] - for f in files_in_tc: - if f in files_in_books: - continue - files_into_stray.append(f) - return files_into_stray + return self.get_files_for_stray_book(book) elif book.type == "hk": - HK = self.get_g3thk(book.tel_tube) + HK = self.get_g3thk() flist = ( HK.session.query(HKFiles) .filter( @@ -1478,12 +1351,82 @@ def get_files_for_book(self, book): .all() ) return [f.path for f in flist] + elif book.type == "smurf": + tcode = int(book.bid.split("_")[1]) + basepath = os.path.join( + self.lvl2_data_root, 'smurf', str(tcode) + ) + ignore = shutil.ignore_patterns(*SMURF_EXCLUDE_PATTERNS) + flist = [] + for root, _, files in os.walk(basepath): + to_ignore = ignore('', files) + flist.extend([ + os.path.join(basepath, root, f) + for f in files if f not in to_ignore + ]) + return flist else: raise NotImplementedError( - f"book type {book.type} not understood for" " file search" + f"book type {book.type} not understood for file search" ) + def get_files_for_stray_book( + self, book=None, min_ctime=None, max_ctime=None + ): + """generate list of files that are not in detector books and should + going into stray books. if book is None then we expect both min and max + ctime to be provided + + Arguments + ---------- + book: optional, book instance + min_ctime: optional, minimum ctime value to search + max_ctime: optional, maximum ctime value to search + + Returns + -------- + list of files that should go into a stray book + """ + if book is None: + assert min_ctime is not None and max_ctime is not None + start = dt.datetime.utcfromtimestamp(min_ctime) + stop = dt.datetime.utcfromtimestamp(max_ctime) + + tcode = int(min_ctime//1e5) + if max_ctime > (tcode+1)*1e5: + self.logger.error( + f"Max ctime {max_ctime} is higher than would be expected " + f"for a single stray book with min ctime {min_ctime}. only" + " checking the first timecode directory" + ) + else: + assert book.type == 'stray' + start = book.start + stop = book.stop + tcode = int(book.bid.split("_")[1]) + + session = self.get_session() + g3session, SMURF = self.get_g3tsmurf_session(return_archive=True) + path = os.path.join(SMURF.archive_path, str(tcode)) + registered_obs = [ + x[0] for x in session.query(Observations.obs_id).join(Books).filter( + Books.start >= start, + Books.start < stop, + Books.status != WONT_BIND, + ).all()] + db_files = g3session.query(Files).filter( + Files.name.like(f"{path}%") + ).all() + + stray_files = [] + for f in db_files: + if f.obs_id is None or f.obs_id not in registered_obs: + stray_files.append(f.name) + + return stray_files + + def get_readout_ids_for_book(self, book): """ Get all readout IDs for a book @@ -1611,6 +1554,19 @@ def get_g3tsmurf_obs_for_book(self, book): ) return {o.obs_id: o for o in obs} + def _librarian_connect(self): + """ + start connection to librarian + """ + from hera_librarian import LibrarianClient + from hera_librarian.settings import client_settings + conn = client_settings.connections.get( + self.config.get("librarian_conn") + ) + if conn is None: + raise ValueError(f"'librarian_conn' not in imprinter config") + self.librarian = LibrarianClient.from_info(conn) + def upload_book_to_librarian(self, book, session=None, raise_on_error=True): """Upload bound book to the librarian @@ -1626,14 +1582,7 @@ def upload_book_to_librarian(self, book, session=None, raise_on_error=True): if session is None: session = self.get_session() if self.librarian is None: - from hera_librarian import LibrarianClient - from hera_librarian.settings import client_settings - conn = client_settings.connections.get( - self.config.get("librarian_conn") - ) - if conn is None: - raise ValueError(f"'librarian_conn' not in imprinter config") - self.librarian = LibrarianClient.from_info(conn) + self._librarian_connect() assert book.status == BOUND, "cannot upload unbound books" @@ -1655,8 +1604,32 @@ def upload_book_to_librarian(self, book, session=None, raise_on_error=True): return False, e return True, None - - def delete_level2_files(self, book, dry_run=True): + def check_book_in_librarian(self, book, n_copies=1, raise_on_error=True): + """have the librarian validate the books is stored offsite. returns true + if at least n_copies are storied offsite. + """ + if self.librarian is None: + self._librarian_connect() + try: + resp = self.librarian.validate_file(book.path) + in_lib = sum( + [(x.computed_same_checksum) for x in resp] + ) >= n_copies + if not in_lib: + self.logger.info(f"received response from librarian {resp}") + except Exception as e: + if raise_on_error: + raise e + else: + self.logger.warning( + f"Failed to check libraian status for {book.bid}: {e}" + ) + self.logger.warning(traceback.format_exc()) + in_lib = False + return in_lib + + def delete_level2_files(self, book, verify_with_librarian=True, + n_copies_in_lib=2, dry_run=True): """Delete level 2 data from already bound books Parameters @@ -1665,13 +1638,31 @@ def delete_level2_files(self, book, dry_run=True): dry_run: bool if true, just prints plans to self.logger.info """ - if book.status != BOUND: - raise ValueError(f"Book must be bound to delete level 2 files") - + if book.lvl2_deleted: + self.logger.debug( + f"Level 2 for {book.bid} has already been deleted" + ) + return 0 + if book.status < UPLOADED: + self.logger.warning( + f"Book {book.bid} is not uploaded, not deleting level 2" + ) + return 1 + if verify_with_librarian: + in_lib = self.check_book_in_librarian( + book, n_copies=n_copies_in_lib, raise_on_error=False + ) + if not in_lib: + self.logger.warning( + f"Book {book.bid} does not have {n_copies_in_lib} copies" + " will not delete level 2" + ) + return 2 + self.logger.info(f"Removing level 2 files for {book.bid}") if book.type == "obs" or book.type == "oper": session, SMURF = self.get_g3tsmurf_session( - book.tel_tube, return_archive=True + return_archive=True ) odic = self.get_g3tsmurf_obs_for_book(book) @@ -1681,7 +1672,7 @@ def delete_level2_files(self, book, dry_run=True): ) elif book.type == "stray": session, SMURF = self.get_g3tsmurf_session( - book.tel_tube, return_archive=True + return_archive=True ) flist = self.get_files_for_book(book) for f in flist: @@ -1689,12 +1680,25 @@ def delete_level2_files(self, book, dry_run=True): SMURF.delete_file( db_file, session, dry_run=dry_run, my_logger=self.logger ) + elif book.type == "smurf": + tcode = int(book.bid.split("_")[1]) + basepath = os.path.join( + self.lvl2_data_root, 'smurf', str(tcode) + ) + if not dry_run: + shutil.rmtree(basepath) + elif book.type == "hk": - HK = self.get_g3thk(book.tel_tube) + HK = self.get_g3thk() flist = self.get_files_for_book(book) - for f in flist: - hkfile = HK.session.query(HKFiles).filter(HKFiles.path == f).one() - HK.delete_file(hkfile, dry_run=dry_run, my_logger=self.logger) + hkf_list = [ + HK.session.query(HKFiles).filter( + HKFiles.path == f + ).one() for f in flist + ] + HK.batch_delete_files( + hkf_list, dry_run=dry_run, my_logger=self.logger + ) else: raise NotImplementedError( f"Do not know how to delete level 2 files" @@ -1703,8 +1707,10 @@ def delete_level2_files(self, book, dry_run=True): if not dry_run: book.lvl2_deleted = True self.session.commit() + return 0 - def delete_book_files(self, book): + def delete_book_staged(self, book, check_level2=False, + verify_with_librarian=False, n_copies_in_lib=1, override=False): """Delete all files associated with a book Parameters @@ -1712,27 +1718,80 @@ def delete_book_files(self, book): book: Book object """ + if book.status == DONE: + self.logger.debug( + f"Book {book.bid} has already had staged files deleted" + ) + return 0 + if not override: + if book.status < UPLOADED: + self.logger.warning( + "Cannot delete non-uploaded books without override" + ) + return 1 + if check_level2 and not book.lvl2_deleted: + self.logger.warning( + f"Level 2 data not deleted for {book.bid}, not deleting " + "staged" + ) + return 2 + if verify_with_librarian: + in_lib = self.check_book_in_librarian( + book, n_copies=n_copies_in_lib, raise_on_error=False + ) + if not in_lib: + self.logger.warning( + f"Book {book.bid} does not have {n_copies_in_lib} copies" + " will not delete staged" + ) + return 3 + # remove all files within the book book_path = self.get_book_abs_path(book) try: - shutil.rmtree( book_path ) + self.logger.info( + f"Removing {book.bid} from staged" + ) + if book.type == 'smurf' and book.schema == 1: + os.remove(book_path) + else: + shutil.rmtree( book_path ) except Exception as e: self.logger.warning(f"Failed to remove {book_path}: {e}") self.logger.error(traceback.format_exc()) + book.status = DONE + self.session.commit() + return 0 + def find_missing_lvl2_obs_from_books( + self, min_ctime, max_ctime + ): + """create a list of level 2 observation IDs that are not registered in + the imprinter database + + Arguments + ---------- + min_ctime: minimum ctime value to search + max_ctime: maximum ctime value to search - def all_bound_until(self): - """report a datetime object to indicate that all books are bound - by this datetime. + Returns + -------- + list of level 2 observation ids not in books """ session = self.get_session() - # sort by start time and find the start time by which - # all books are bound - books = session.query(Books).order_by(Books.start).all() - for book in books: - if book.status < BOUND: - return book.start - return book.start # last book + g3session, SMURF = self.get_g3tsmurf_session(return_archive=True) + registered_obs = [ + x[0] for x in session.query(Observations.obs_id).join(Books).filter( + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + Books.start < dt.datetime.utcfromtimestamp(max_ctime), + ).all()] + missing_obs = g3session.query(G3tObservations).filter( + G3tObservations.timestamp >= min_ctime, + G3tObservations.timestamp < max_ctime, + G3tObservations.stream_id.in_(self.all_slots), + G3tObservations.obs_id.not_in(registered_obs) + ).all() + return missing_obs ##################### # Utility functions # diff --git a/sotodlib/io/imprinter_cli.py b/sotodlib/io/imprinter_cli.py index 50209f1f1..3c8cf5f54 100644 --- a/sotodlib/io/imprinter_cli.py +++ b/sotodlib/io/imprinter_cli.py @@ -13,9 +13,10 @@ import os import argparse +import datetime as dt from typing import Optional -from sotodlib.io.imprinter import Imprinter, Books +from sotodlib.io.imprinter import Imprinter, Books, FAILED import sotodlib.io.imprinter_utils as utils def main(): @@ -110,9 +111,21 @@ def _last_line(book): if len(s) > 0: return s -def autofix_failed_books(imprint:Imprinter, test_mode=False): - fail_list = imprint.get_failed_books() - for book in fail_list: +def autofix_failed_books( + imprint:Imprinter, test_mode=False, min_ctime=None, max_ctime=None, +): + session = imprint.get_session() + failed = session.query(Books).filter(Books.status == FAILED) + if min_ctime is not None: + failed = failed.filter( + Books.start >= dt.datetime.utcfromtimestamp(min_ctime), + ) + if max_ctime is not None: + failed = failed.filter( + Books.start <= dt.datetime.utcfromtimestamp(max_ctime), + ) + failed = failed.all() + for book in failed: print("-----------------------------------------------------") print(f"On book {book.bid}. Has error:\n{_last_line(book)}") if 'SECOND-FAIL' in book.message: diff --git a/sotodlib/io/imprinter_utils.py b/sotodlib/io/imprinter_utils.py index 4ec18f80f..1ed0be33a 100644 --- a/sotodlib/io/imprinter_utils.py +++ b/sotodlib/io/imprinter_utils.py @@ -80,7 +80,12 @@ def set_book_rebind(imprint, book, update_level2=False): if op.exists(book_dir): print(f"Removing all files from {book_dir}") - shutil.rmtree(book_dir) + if os.path.isfile(book_dir): + os.remove(book_dir) + elif os.path.isdir(book_dir): + shutil.rmtree(book_dir) + else: + print("How is this not a file or directory") else: print(f"Found no files in {book_dir} to remove") @@ -160,6 +165,21 @@ def block_set_rebind(imprint, update_level2=False): imprint.logger.info(f"Setting book {book.bid} for rebinding") set_book_rebind(imprint, book, update_level2=update_level2) +def block_fix_bad_timing(imprint): + """Run through and try rebinding all books with bad timing""" + failed_books = imprint.get_failed_books() + fix_list = [] + for book in failed_book: + if "TimingSystemOff" in book.message: + fix_list.append(book) + for book in fix_list: + imprint.logger.info(f"Setting book {book.bid} for rebinding") + set_book_rebind(imprint, book) + imprint.logger.info( + f"Binding book {book.bid} while accepting bad timing" + ) + imprint.bind_book(book, allow_bad_timing=True) + def get_timecode_final(imprint, time_code, type='all'): """Check if all required entries in the g3tsmurf database are present for smurf or stray book regisitration. @@ -186,16 +206,17 @@ def get_timecode_final(imprint, time_code, type='all'): g3session, SMURF = imprint.get_g3tsmurf_session(return_archive=True) session = imprint.get_session() + # this is another place I was reminded sqlite does not accept + # numpy int32s or numpy int64s + time_code = int(time_code) + servers = SMURF.finalize["servers"] meta_agents = [s["smurf-suprsync"] for s in servers] files_agents = [s["timestream-suprsync"] for s in servers] - meta_query = or_(*[TimeCodes.agent == a for a in meta_agents]) - files_query = or_(*[TimeCodes.agent == a for a in files_agents]) - tcm = g3session.query(TimeCodes.agent).filter( TimeCodes.timecode==time_code, - meta_query, + TimeCodes.agent.in_(meta_agents), TimeCodes.suprsync_type == SupRsyncType.META.value, ).distinct().all() @@ -209,8 +230,8 @@ def get_timecode_final(imprint, time_code, type='all'): return False, 1 tcf = g3session.query(TimeCodes.agent).filter( - TimeCodes.timecode==time_code, - files_query, + TimeCodes.timecode == time_code, + TimeCodes.agent.in_(files_agents), TimeCodes.suprsync_type == SupRsyncType.FILES.value, ).distinct().all() @@ -244,8 +265,10 @@ def set_timecode_final(imprint, time_code): """ g3session, SMURF = imprint.get_g3tsmurf_session(return_archive=True) - servers = SMURF.finalize["servers"] + # this is another place I was reminded sqlite does not accept + # numpy int32s or numpy int64s + time_code = int(time_code) for server in servers: tcf = g3session.query(TimeCodes).filter( @@ -261,6 +284,7 @@ def set_timecode_final(imprint, time_code): agent=server["timestream-suprsync"], ) g3session.add(tcf) + g3session.commit() tcm = g3session.query(TimeCodes).filter( TimeCodes.timecode==time_code, @@ -274,5 +298,5 @@ def set_timecode_final(imprint, time_code): timecode=time_code, agent=server["smurf-suprsync"], ) - g3session.add(tcm) - g3session.commit() \ No newline at end of file + g3session.add(tcm) + g3session.commit() \ No newline at end of file diff --git a/sotodlib/io/load_book.py b/sotodlib/io/load_book.py index 89804defb..8a8d29b92 100644 --- a/sotodlib/io/load_book.py +++ b/sotodlib/io/load_book.py @@ -111,7 +111,9 @@ def load_obs_book(db, obs_id, dets=None, prefix=None, samples=None, dets_req.extend([p[1] for p in pairs_req if p[0] == _ds]) del pairs_req - file_map = db.get_files(obs_id, detsets=detsets_req) + # Don't pass a restriction of detsets here, as we need at least + # one result for some downstream processing. + file_map = db.get_files(obs_id) one_group = list(file_map.values())[0] # [('file0', 0, 1000), ('file1', 1000, 2000), ...] # Figure out how many samples we're loading. diff --git a/sotodlib/io/load_smurf.py b/sotodlib/io/load_smurf.py index 0bafb1cf6..b0a8cb415 100644 --- a/sotodlib/io/load_smurf.py +++ b/sotodlib/io/load_smurf.py @@ -20,7 +20,7 @@ from .. import core from . import load as io_load -from .datapkg_utils import load_configs +from .datapkg_utils import load_configs, walk_files, just_suprsync from .g3thk_db import G3tHk, HKFiles, HKAgents, HKFields from .g3thk_utils import pysmurf_monitor_control_list @@ -194,6 +194,7 @@ def __init__( self.meta_path = meta_path self.db_path = db_path self.hk_db_path = hk_db_path + self.HK = None self.finalize = finalize if os.path.exists(self.db_path): @@ -606,13 +607,13 @@ def delete_file(self, db_file, session=None, dry_run=False, my_logger=None): my_logger = logger db_frames = db_file.frames - my_logger.info(f"Deleting frame entries for {db_file.name}") + my_logger.debug(f"Deleting frame entries for {db_file.name}") if not dry_run: [session.delete(frame) for frame in db_frames] if not os.path.exists(db_file.name): my_logger.warning( - f"Database file {db_file.name} appears already" " deleted on disk" + f"Database file {db_file.name} appears already deleted on disk" ) else: my_logger.info(f"Deleting file {db_file.name}") @@ -1155,7 +1156,9 @@ def update_observation_files( logger.debug(f"Setting {obs.obs_id} stop time to {obs.stop}") session.commit() - def delete_observation_files(self, obs, session, dry_run=False, my_logger=None): + def delete_observation_files( + self, obs, session, dry_run=False, my_logger=None + ): """WARNING: Deletes files from the file system Args @@ -1394,20 +1397,33 @@ def index_timecodes(self, session=None, min_ctime=16000e5, max_ctime=None): session.add(tcf) session.commit() - def update_finalization(self, update_time, session=None): - """Update the finalization time rows in the database""" + def get_HK(self): if self.hk_db_path is None: - raise ValueError("HK database path required to update finalization" " time") + raise ValueError("HK database path required") + + if self.HK is None: + iids = [] + for server in self.finalize.get("servers", []): + for key in server.keys(): + # Append the value (iid) to the iids list + iids.append(server[key]) + + self.HK = G3tHk( + os.path.join(os.path.split(self.archive_path)[0], "hk"), + iids = iids, + db_path = self.hk_db_path, + ) + return self.HK + def update_finalization(self, update_time, session=None): + """Update the finalization time rows in the database""" + if session is None: session = self.Session() # look for new rows to add to table self._start_finalization(session) - HK = G3tHk( - os.path.join(os.path.split(self.archive_path)[0], "hk"), - self.hk_db_path, - ) + HK = self.get_HK() agent_list = session.query(Finalize).all() for agent in agent_list: @@ -1449,15 +1465,19 @@ def update_finalization(self, update_time, session=None): session.commit() def get_final_time( - self, stream_ids, start=None, stop=None, check_control=True, session=None + self, stream_ids, start=None, stop=None, check_control=True, + session=None ): - """Return the ctime to which database is finalized for a set of stream_ids - between ctimes start and stop. If check_control is True it will use the - pysmurf-monitor entries in the HK database to determine which - pysmurf-monitors were in control of which stream_ids between start and stop. + """Return the ctime to which database is finalized for a set of + stream_ids between ctimes start and stop. If check_control is True it + will use the pysmurf-monitor entries in the HK database to determine + which pysmurf-monitors were in control of which stream_ids between + start and stop. """ if check_control and self.hk_db_path is None: - raise ValueError("HK database path required to update finalization" " time") + raise ValueError( + "HK database path required to update finalization time" + ) if check_control and ((start is None) or (stop is None)): raise ValueError( "start and stop ctimes are required to check which" @@ -1465,10 +1485,8 @@ def get_final_time( ) if session is None: session = self.Session() - HK = G3tHk( - os.path.join(os.path.split(self.archive_path)[0], "hk"), - self.hk_db_path, - ) + + HK = self.get_HK() agent_list = [] if "servers" not in self.finalize: @@ -1730,6 +1748,51 @@ def index_action_observations( if new_session: session.close() + def find_missing_files(self, timecode, session=None): + """create a list of files in the timecode folder that are not in the + g3tsmurf database + + Arguments + ---------- + timecode (int): a level 2 timestreams timecode + + Returns + -------- + missing (list): list of file paths that are not in the g3tsmurf database + """ + if session is None: + session = self.Session() + path = os.path.join(self.archive_path, str(timecode)) + + q = session.query(Files).filter(Files.name.like(f"{path}%")) + db_list = [f.name for f in q.all()] + sys_list = walk_files(path) + missing = [] + for f in sys_list: + if f not in db_list: + missing.append(f) + + return missing + + def find_missing_files_from_obs(self, timecode, session=None): + """create a list of files in the g3tsmurf database that do not have an + assigned level 2 observation ID + + Arguments + ---------- + timecode (int): a level 2 timestreams timecode + + Returns + -------- + missing (list): list of file paths that do not have level 2 observation IDs + """ + if session is None: + session = self.Session() + path = os.path.join(self.archive_path, str(timecode)) + q = session.query(Files).filter(Files.name.like(f"{path}%")) + db_list = q.all() + return [f.name for f in db_list if f.obs_id is None] + def lookup_file(self, filename, fail_ok=False): """Lookup a file's observations details in database. Meant to look and act like core.metadata.obsfiledb.lookup_file. @@ -1914,7 +1977,6 @@ def load_status(self, time, stream_id=None, show_pb=False): """ return SmurfStatus.from_time(time, self, stream_id=stream_id, show_pb=show_pb) - def dump_DetDb(archive, detdb_file): """ Take a G3tSmurf archive and create a a DetDb of the type used with Context @@ -1951,7 +2013,6 @@ def dump_DetDb(archive, detdb_file): session.close() return my_db - def make_DetDb_single_obs(obsfiledb, obs_id): # find relevant files to get status c = obsfiledb.conn.execute( @@ -2016,18 +2077,15 @@ def make_DetDb_single_obs(obsfiledb, obs_id): detdb.conn.commit() return detdb - def obs_detdb_context_hook(ctx, obs_id, *args, **kwargs): ddb = make_DetDb_single_obs(ctx.obsfiledb, obs_id) ctx.obs_detdb = ddb return ddb - core.Context.hook_sets["obs_detdb_load"] = { "before-use-detdb": obs_detdb_context_hook, } - class SmurfStatus: """ This is a class that attempts to extract essential information from the @@ -2385,7 +2443,6 @@ def smurf_to_readout(self, band, chan): """ return self.mask_inv[band, chan] - def get_channel_mask( ch_list, status, archive=None, obsfiledb=None, ignore_missing=True ): @@ -2568,7 +2625,6 @@ def _get_tuneset_channel_names(status, ch_map, archive): session.close() return ruids - def _get_detset_channel_names(status, ch_map, obsfiledb): """Update channel maps with name from obsfiledb""" # tune file in status @@ -2639,7 +2695,6 @@ def _get_detset_channel_names(status, ch_map, obsfiledb): return ruids - def _get_channel_mapping(status, ch_map): """Generate baseline channel map from status object""" for i, ch in enumerate(ch_map["idx"]): @@ -2656,7 +2711,6 @@ def _get_channel_mapping(status, ch_map): ch_map[i]["channel"] = -1 return ch_map - def get_channel_info( status, mask=None, @@ -2740,7 +2794,6 @@ def get_channel_info( return ch_info - def _get_sample_info(filenames): """Scan through a list of files and count samples. Starts counting from the first file in the list. Used in load_file for sample restiction @@ -2780,7 +2833,6 @@ def _get_sample_info(filenames): start += samps return out - def split_ts_bits(c): """Split up 64 bit to 2x32 bit""" NUM_BITS_PER_INT = 32 @@ -2789,7 +2841,6 @@ def split_ts_bits(c): b = c & MAXINT return a, b - def _get_timestamps(streams, load_type=None, linearize_timestamps=True): """Calculate the timestamp field for loaded data @@ -2841,7 +2892,6 @@ def _get_timestamps(streams, load_type=None, linearize_timestamps=True): return io_load.hstack_into(None, streams["time"]) logger.error("Timing System could not be determined") - def load_file( filename, channels=None, @@ -3099,7 +3149,6 @@ def load_file( return aman - def load_g3tsmurf_obs(db, obs_id, dets=None, samples=None, no_signal=None, **kwargs): """Obsloader function for g3tsmurf data archives. diff --git a/sotodlib/mapmaking/demod_mapmaker.py b/sotodlib/mapmaking/demod_mapmaker.py index 9b893e1c2..8b13d9152 100644 --- a/sotodlib/mapmaking/demod_mapmaker.py +++ b/sotodlib/mapmaking/demod_mapmaker.py @@ -13,7 +13,7 @@ from .. import core from .. import coords from .utilities import recentering_to_quat_lonlat, evaluate_recentering, MultiZipper, unarr, safe_invert_div -from .utilities import import_optional, get_flags +from .utilities import import_optional from .noise_model import NmatWhite hp = import_optional('healpy') @@ -291,10 +291,10 @@ def add_obs(self, id, obs, nmat, Nd, pmap=None, split_labels=None): else: rot = None if self.Nsplits == 1: # this is the case with no splits - flagnames = ['glitch_flags'] + cuts = obs.flags.glitch_flags else: - flagnames = ['glitch_flags', split_labels[n_split]] - cuts = get_flags(obs, flagnames) + # remember that the dets or samples you want to keep should be false, hence we negate + cuts = obs.flags.glitch_flags + ~obs.preprocess.split_flags.cuts[split_labels[n_split]] if self.pix_scheme == "rectpix": threads='domdir' geom = self.rhs.geometry @@ -447,8 +447,9 @@ def make_demod_map(context, obslist, noise_model, info, Noise model to pass to DemodMapmaker. info : list Information for the database, will be written as a .hdf file. - preprocess_config : dict - Dictionary with the config yaml file for the preprocess database. + preprocess_config : list of dict + List of dictionaries with the config yaml file for the preprocess database. + If two, then a multilayer preprocessing is to be used. prefix : str Prefix for the output files shape : tuple, optional @@ -492,7 +493,7 @@ def make_demod_map(context, obslist, noise_model, info, List of outputs from preprocess database. To be used in cleanup_mandb. """ from ..preprocess import preprocess_util - context = core.Context(context) + #context = core.Context(context) if L is None: L = preprocess_util.init_logger("Demod filterbin mapmaking") pre = "" if tag is None else tag + " " @@ -508,21 +509,28 @@ def make_demod_map(context, obslist, noise_model, info, errors = [] ; outputs = []; # PENDING: do an allreduce of these. # not needed for atomic maps, but needed for # depth-1 maps + if len(preprocess_config)==1: + preproc_init = preprocess_config[0] + preproc_proc = None + else: + preproc_init = preprocess_config[0] + preproc_proc = preprocess_config[1] + for oi in range(len(obslist)): obs_id, detset, band = obslist[oi][:3] name = "%s:%s:%s" % (obs_id, detset, band) - error, output, obs = preprocess_util.preproc_or_load_group(obs_id, - configs=preprocess_config, - dets={'wafer_slot':detset, 'wafer.bandpass':band}, - logger=L, context=context, overwrite=False) - errors.append(error) ; outputs.append(output) ; + error, output_init, output_proc, obs = preprocess_util.preproc_or_load_group(obs_id, + configs_init=preproc_init, + configs_proc=preproc_proc, + dets={'wafer_slot':detset, 'wafer.bandpass':band}, + logger=L, + overwrite=False) + errors.append(error) ; outputs.append((output_init, output_proc)) ; if error not in [None,'load_success']: L.info('tod %s:%s:%s failed in the prepoc database'%(obs_id,detset,band)) continue obs.wrap("weather", np.full(1, "toco")) obs.wrap("site", np.full(1, site)) - obs.flags.wrap('glitch_flags', obs.preprocess.turnaround_flags.turnarounds - + obs.preprocess.jumps_2pi.jump_flag + obs.preprocess.glitches.glitch_flags, ) mapmaker.add_obs(name, obs, split_labels=split_labels) L.info('Done with tod %s:%s:%s'%(obs_id,detset,band)) nobs_kept += 1 diff --git a/sotodlib/mapmaking/ml_mapmaker.py b/sotodlib/mapmaking/ml_mapmaker.py index 198198318..bc80a4a95 100644 --- a/sotodlib/mapmaking/ml_mapmaker.py +++ b/sotodlib/mapmaking/ml_mapmaker.py @@ -1,12 +1,36 @@ +import os + import numpy as np -from pixell import enmap, utils, tilemap, bunch +import h5py +import so3g +from typing import Optional +from pixell import bunch, enmap, tilemap +from pixell import utils as putils from .. import coords -from .utilities import * -from .pointing_matrix import * +from .pointing_matrix import PmatCut +from .utilities import ( + MultiZipper, + recentering_to_quat_lonlat, + evaluate_recentering, + TileMapZipper, + MapZipper, + safe_invert_div, + unarr, + ArrayZipper, +) +from .noise_model import NmatUncorr + class MLMapmaker: - def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False): + def __init__( + self, + signals=[], + noise_model=None, + dtype=np.float32, + verbose=False, + glitch_flags: str = "flags.glitch_flags", + ): """Initialize a Maximum Likelihood Mapmaker. Arguments: * signals: List of Signal-objects representing the models that will be solved @@ -19,26 +43,29 @@ def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False * verbose: Whether to print progress messages. Not implemented""" if noise_model is None: noise_model = NmatUncorr() - self.signals = signals - self.dtype = dtype - self.verbose = verbose - self.noise_model = noise_model - self.data = [] - self.dof = MultiZipper() - self.ready = False + self.signals = signals + self.dtype = dtype + self.verbose = verbose + self.noise_model = noise_model + self.data = [] + self.dof = MultiZipper() + self.ready = False + self.glitch_flags_path = glitch_flags def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None): # Prepare our tod - ctime = obs.timestamps - srate = (len(ctime)-1)/(ctime[-1]-ctime[0]) - tod = obs.signal.astype(self.dtype, copy=False) + ctime = obs.timestamps + srate = (len(ctime) - 1) / (ctime[-1] - ctime[0]) + tod = obs.signal.astype(self.dtype, copy=False) # Subtract an existing estimate of the signal before estimating # the noise model, if available - if signal_estimate is not None: tod -= signal_estimate + if signal_estimate is not None: + tod -= signal_estimate if deslope: - utils.deslope(tod, w=5, inplace=True) + putils.deslope(tod, w=5, inplace=True) # Allow the user to override the noise model on a per-obs level - if noise_model is None: noise_model = self.noise_model + if noise_model is None: + noise_model = self.noise_model # Build the noise model from the obs unless a fully # initialized noise model was passed if noise_model.ready: @@ -55,18 +82,26 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None) # The signal estimate might not be desloped, so # adding it back can reintroduce a slope. Fix that here. if deslope: - utils.deslope(tod, w=5, inplace=True) + putils.deslope(tod, w=5, inplace=True) # And apply it to the tod - tod = nmat.apply(tod) + tod = nmat.apply(tod) # Add the observation to each of our signals for signal in self.signals: - signal.add_obs(id, obs, nmat, tod) + signal.add_obs(id, obs, nmat, tod, glitch_flags=self.glitch_flags_path) # Save what we need about this observation - self.data.append(bunch.Bunch(id=id, ndet=obs.dets.count, nsamp=len(ctime), - dets=obs.dets.vals, nmat=nmat)) + self.data.append( + bunch.Bunch( + id=id, + ndet=obs.dets.count, + nsamp=len(ctime), + dets=obs.dets.vals, + nmat=nmat, + ) + ) def prepare(self): - if self.ready: return + if self.ready: + return for signal in self.signals: signal.prepare() self.dof.add(signal.dof) @@ -75,53 +110,92 @@ def prepare(self): def A(self, x): # unzip goes from flat array of all the degrees of freedom to individual maps, cuts etc. # to_work makes a scratch copy and does any redistribution needed - #t0 = time() - #t1 = time() - iwork = [signal.to_work(m) for signal,m in zip(self.signals,self.dof.unzip(x))] - #t2 = time(); print(f" A iwork : {t2-t1:8.3f}s", flush=True) - owork = [w*0 for w in iwork] - #t1 = time(); print(f" A owork : {t1-t2:8.3f}s", flush=True) - #t_forward = 0 - #t_apply = 0 - #t_backward = 0 + # t0 = time() + # t1 = time() + iwork = [ + signal.to_work(m) for signal, m in zip(self.signals, self.dof.unzip(x)) + ] + # t2 = time(); print(f" A iwork : {t2-t1:8.3f}s", flush=True) + owork = [w * 0 for w in iwork] + # t1 = time(); print(f" A owork : {t1-t2:8.3f}s", flush=True) + # t_forward = 0 + # t_apply = 0 + # t_backward = 0 for di, data in enumerate(self.data): tod = np.zeros([data.ndet, data.nsamp], self.dtype) - #t1 = time() + # t1 = time() for si, signal in reversed(list(enumerate(self.signals))): signal.forward(data.id, tod, iwork[si]) - #t2 = time() - #t_forward += t2 - t1 + # t2 = time() + # t_forward += t2 - t1 data.nmat.apply(tod) - #t1 = time() - #t_apply += t1 - t2 + # t1 = time() + # t_apply += t1 - t2 for si, signal in enumerate(self.signals): signal.backward(data.id, tod, owork[si]) - #t2 = time() - #t_backward += t2 - t1 - #print(f" A forward : {t_forward:8.3f}s", flush=True) - #print(f" A apply : {t_apply:8.3f}s", flush=True) - #print(f" A backward : {t_backward:8.3f}s", flush=True) - #t1 = time() - result = self.dof.zip(*[signal.from_work(w) for signal,w in zip(self.signals,owork)]) - #t2 = time(); print(f" A zip : {t2-t1:8.3f}s", flush=True) - #print(f" A TOTAL : {t2-t0:8.3f}s", flush=True) + # t2 = time() + # t_backward += t2 - t1 + # print(f" A forward : {t_forward:8.3f}s", flush=True) + # print(f" A apply : {t_apply:8.3f}s", flush=True) + # print(f" A backward : {t_backward:8.3f}s", flush=True) + # t1 = time() + result = self.dof.zip( + *[signal.from_work(w) for signal, w in zip(self.signals, owork)] + ) + # t2 = time(); print(f" A zip : {t2-t1:8.3f}s", flush=True) + # print(f" A TOTAL : {t2-t0:8.3f}s", flush=True) return result def M(self, x): - #t1 = time() + # t1 = time() iwork = self.dof.unzip(x) - #t2 = time(); print(f" M iwork : {t2-t1:8.3f}s", flush=True) - result = self.dof.zip(*[signal.precon(w) for signal, w in zip(self.signals, iwork)]) - #t1 = time(); print(f" M zip : {t1-t2:8.3f}s", flush=True) + # t2 = time(); print(f" M iwork : {t2-t1:8.3f}s", flush=True) + result = self.dof.zip( + *[signal.precon(w) for signal, w in zip(self.signals, iwork)] + ) + # t1 = time(); print(f" M zip : {t1-t2:8.3f}s", flush=True) return result - def solve(self, maxiter=500, maxerr=1e-6, x0=None): + def solve( + self, + maxiter=500, + maxerr=1e-6, + x0=None, + fname_checkpoint=None, + checkpoint_interval=1, + ): self.prepare() - rhs = self.dof.zip(*[signal.rhs for signal in self.signals]) - if x0 is not None: x0 = self.dof.zip(*x0) - solver = utils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0) - while solver.i < maxiter and solver.err > maxerr: - solver.step() + rhs = self.dof.zip(*[signal.rhs for signal in self.signals]) + if x0 is not None: + x0 = self.dof.zip(*x0) + + solver = putils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0) + # If there exists a checkpoint, restore solver state + if fname_checkpoint is None: + checkpoint = False + restart = False + else: + checkpoint = True + outdir = os.path.dirname(fname_checkpoint) + if len(outdir) != 0: + os.makedirs(outdir, exist_ok=True) + if os.path.isfile(fname_checkpoint): + solver.load(fname_checkpoint) + restart = True + else: + restart = False + while restart or (solver.i < maxiter and solver.err > maxerr): + if restart: + # When restarting, do not step + restart = False + else: + solver.step() + if checkpoint and solver.i % checkpoint_interval == 0: + # Avoid checkpoint corruption by making a copy of the previous checkpoint + if os.path.isfile(fname_checkpoint): + os.replace(fname_checkpoint, fname_checkpoint + ".old") + # Write a checkpoint + solver.save(fname_checkpoint) yield bunch.Bunch(i=solver.i, err=solver.err, x=self.dof.unzip(solver.x)) def translate(self, other, x): @@ -138,15 +212,20 @@ def transeval(self, id, obs, other, x, tod=None): """Evaluate degrees of freedom x for the given tod after translating it from those used by another, similar mapmaker. This will have the same signals, but possibly with different sample rates etc.""" - if tod is None: tod = np.zeros([obs.dets.count, obs.samps.count], self.dtype) - for si, (ssig, osig, oval), in reversed(list(enumerate(zip(self.signals,other.signals,x)))): + if tod is None: + tod = np.zeros([obs.dets.count, obs.samps.count], self.dtype) + for ( + si, + (ssig, osig, oval), + ) in reversed(list(enumerate(zip(self.signals, other.signals, x)))): ssig.transeval(id, obs, osig, oval, tod=tod) return tod class Signal: """This class represents a thing we want to solve for, e.g. the sky, ground, cut samples, etc.""" - def __init__(self, name, ofmt, output, ext): + + def __init__(self, name, ofmt, output, ext, **kwargs): """Initialize a Signal. It probably doesn't make sense to construct a generic signal directly, though. Use one of the subclasses. Arguments: @@ -154,86 +233,142 @@ def __init__(self, name, ofmt, output, ext): * ofmt: The format used when constructing output file prefix * output: Whether this signal should be part of the output or not. * ext: The extension used for the files. + * **kwargs: additional keyword based parameters, accessible as class parameters """ - self.name = name - self.ofmt = ofmt + self.name = name + self.ofmt = ofmt self.output = output - self.ext = ext - self.dof = None - self.ready = False - def add_obs(self, id, obs, nmat, Nd): pass - def prepare(self): self.ready = True - def forward (self, id, tod, x): pass - def backward(self, id, tod, x): pass - def precon(self, x): return x - def to_work (self, x): return x.copy() - def from_work(self, x): return x - def write (self, prefix, tag, x): pass - def translate(self, other, x): return x - def transeval(self, id, obs, other, x, tod): pass + self.ext = ext + self.dof = None + self.ready = False + self.__dict__.update(kwargs) + + def add_obs(self, id, obs, nmat, Nd, **kwargs): + pass + + def prepare(self): + self.ready = True + + def forward(self, id, tod, x): + pass + + def backward(self, id, tod, x): + pass + + def precon(self, x): + return x + + def to_work(self, x): + return x.copy() + + def from_work(self, x): + return x + + def write(self, prefix, tag, x): + pass + + def translate(self, other, x): + return x + + def transeval(self, id, obs, other, x, tod): + pass + class SignalMap(Signal): """Signal describing a non-distributed sky map.""" - def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", output=True, - ext="fits", dtype=np.float32, sys=None, recenter=None, tile_shape=(500,500), tiled=False, - interpol=None): + + def __init__( + self, + shape, + wcs, + comm, + comps="TQU", + name="sky", + ofmt="{name}", + output=True, + ext="fits", + dtype=np.float32, + sys=None, + recenter=None, + tile_shape=(500, 500), + tiled=False, + interpol=None, + glitch_flags: str = "flags.glitch_flags", + ): """Signal describing a sky map in the coordinate system given by "sys", which defaults to equatorial coordinates. If tiled==True, then this will be a distributed map with the given tile_shape, otherwise it will be a plain enmap. interpol controls the - pointing matrix interpolation mode. See so3g's Projectionist docstring for details.""" - Signal.__init__(self, name, ofmt, output, ext) - self.comm = comm + pointing matrix interpolation mode. See so3g's Projectionist docstring for details. + """ + Signal.__init__(self, name, ofmt, output, ext, glitch_flags=glitch_flags) + self.comm = comm self.comps = comps - self.sys = sys + self.sys = sys self.recenter = recenter self.dtype = dtype self.tiled = tiled self.interpol = interpol - self.data = {} - ncomp = len(comps) - shape = tuple(shape[-2:]) + self.data = {} + ncomp = len(comps) + shape = tuple(shape[-2:]) + if tiled: geo = tilemap.geometry(shape, wcs, tile_shape=tile_shape) - self.rhs = tilemap.zeros(geo.copy(pre=(ncomp,)), dtype=dtype) - self.div = tilemap.zeros(geo.copy(pre=(ncomp,ncomp)), dtype=dtype) - self.hits= tilemap.zeros(geo.copy(pre=()), dtype=dtype) + self.rhs = tilemap.zeros(geo.copy(pre=(ncomp,)), dtype=dtype) + self.div = tilemap.zeros(geo.copy(pre=(ncomp, ncomp)), dtype=dtype) + self.hits = tilemap.zeros(geo.copy(pre=()), dtype=dtype) else: - self.rhs = enmap.zeros((ncomp,) +shape, wcs, dtype=dtype) - self.div = enmap.zeros((ncomp,ncomp)+shape, wcs, dtype=dtype) - self.hits= enmap.zeros( shape, wcs, dtype=dtype) + self.rhs = enmap.zeros((ncomp,) + shape, wcs, dtype=dtype) + self.div = enmap.zeros((ncomp, ncomp) + shape, wcs, dtype=dtype) + self.hits = enmap.zeros(shape, wcs, dtype=dtype) - def add_obs(self, id, obs, nmat, Nd, pmap=None): + def add_obs(self, id, obs, nmat, Nd, pmap=None, glitch_flags: Optional[str] = None): """Add and process an observation, building the pointing matrix and our part of the RHS. "obs" should be an Observation axis manager, nmat a noise model, representing the inverse noise covariance matrix, and Nd the result of applying the noise model to the detector time-ordered data. """ - Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts - ctime = obs.timestamps - pcut = PmatCut(obs.flags.glitch_flags) # could pass this in, but fast to construct + Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts + ctime = obs.timestamps + gflags = glitch_flags if glitch_flags is not None else self.glitch_flags + pcut = PmatCut(obs[gflags]) # could pass this in, but fast to construct if pmap is None: # Build the local geometry and pointing matrix for this observation if self.recenter: - rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter, - ctime=ctime[len(ctime)//2], geom=(self.rhs.shape, self.rhs.wcs), site=unarr(obs.site))) - else: rot = None - pmap = coords.pmat.P.for_tod(obs, comps=self.comps, geom=self.rhs.geometry, - rot=rot, threads="domdir", weather=unarr(obs.weather), site=unarr(obs.site), - interpol=self.interpol) + rot = recentering_to_quat_lonlat( + *evaluate_recentering( + self.recenter, + ctime=ctime[len(ctime) // 2], + geom=(self.rhs.shape, self.rhs.wcs), + site=unarr(obs.site), + ) + ) + else: + rot = None + pmap = coords.pmat.P.for_tod( + obs, + comps=self.comps, + geom=self.rhs.geometry, + rot=rot, + threads="domdir", + weather=unarr(obs.weather), + site=unarr(obs.site), + interpol=self.interpol, + ) # Build the RHS for this observation pcut.clear(Nd) obs_rhs = pmap.zeros() pmap.to_map(dest=obs_rhs, signal=Nd) # Build the per-pixel inverse covmat for this observation - obs_div = pmap.zeros(super_shape=(self.ncomp,self.ncomp)) + obs_div = pmap.zeros(super_shape=(self.ncomp, self.ncomp)) for i in range(self.ncomp): - obs_div[i] = 0 - obs_div[i,i] = 1 - Nd[:] = 0 + obs_div[i] = 0 + obs_div[i, i] = 1 + Nd[:] = 0 pmap.from_map(obs_div[i], dest=Nd) pcut.clear(Nd) Nd = nmat.white(Nd) - obs_div[i] = 0 + obs_div[i] = 0 pmap.to_map(signal=Nd, dest=obs_div[i]) # Build hitcount Nd[:] = 1 @@ -244,7 +379,7 @@ def add_obs(self, id, obs, nmat, Nd, pmap=None): # Update our full rhs and div. This works for both plain and distributed maps self.rhs = self.rhs.insert(obs_rhs, op=np.ndarray.__iadd__) self.div = self.div.insert(obs_div, op=np.ndarray.__iadd__) - self.hits= self.hits.insert(obs_hits[0],op=np.ndarray.__iadd__) + self.hits = self.hits.insert(obs_hits[0], op=np.ndarray.__iadd__) # Save the per-obs things we need. Just the pointing matrix in our case. # Nmat and other non-Signal-specific things are handled in the mapmaker itself. self.data[id] = bunch.Bunch(pmap=pmap, obs_geo=obs_rhs.geometry) @@ -252,58 +387,76 @@ def add_obs(self, id, obs, nmat, Nd, pmap=None): def prepare(self): """Called when we're done adding everything. Sets up the map distribution, degrees of freedom and preconditioner.""" - if self.ready: return + if self.ready: + return if self.tiled: self.geo_work = self.rhs.geometry - self.rhs = tilemap.redistribute(self.rhs, self.comm) - self.div = tilemap.redistribute(self.div, self.comm) - self.hits = tilemap.redistribute(self.hits,self.comm) - self.dof = TileMapZipper(self.rhs.geometry, dtype=self.dtype, comm=self.comm) + self.rhs = tilemap.redistribute(self.rhs, self.comm) + self.div = tilemap.redistribute(self.div, self.comm) + self.hits = tilemap.redistribute(self.hits, self.comm) + self.dof = TileMapZipper( + self.rhs.geometry, dtype=self.dtype, comm=self.comm + ) else: if self.comm is not None: - self.rhs = utils.allreduce(self.rhs, self.comm) - self.div = utils.allreduce(self.div, self.comm) - self.hits = utils.allreduce(self.hits, self.comm) - self.dof = MapZipper(*self.rhs.geometry, dtype=self.dtype) - self.idiv = safe_invert_div(self.div) + self.rhs = putils.allreduce(self.rhs, self.comm) + self.div = putils.allreduce(self.div, self.comm) + self.hits = putils.allreduce(self.hits, self.comm) + self.dof = MapZipper(*self.rhs.geometry, dtype=self.dtype) + self.idiv = safe_invert_div(self.div) self.ready = True @property - def ncomp(self): return len(self.comps) + def ncomp(self): + return len(self.comps) def forward(self, id, tod, map, tmul=1, mmul=1): """map2tod operation. For tiled maps, the map should be in work distribution, as returned by unzip. Adds into tod.""" - if id not in self.data: return # Should this really skip silently like this? - if tmul != 1: tod *= tmul - if mmul != 1: map = map*mmul + if id not in self.data: + return # Should this really skip silently like this? + if tmul != 1: + tod *= tmul + if mmul != 1: + map = map * mmul self.data[id].pmap.from_map(dest=tod, signal_map=map, comps=self.comps) def backward(self, id, tod, map, tmul=1, mmul=1): """tod2map operation. For tiled maps, the map should be in work distribution, as returned by unzip. Adds into map""" - if id not in self.data: return - if tmul != 1: tod = tod*tmul - if mmul != 1: map *= mmul + if id not in self.data: + return + if tmul != 1: + tod = tod * tmul + if mmul != 1: + map *= mmul self.data[id].pmap.to_map(signal=tod, dest=map, comps=self.comps) def precon(self, map): - if self.tiled: return tilemap.map_mul(self.idiv, map) - else: return enmap.map_mul(self.idiv, map) + if self.tiled: + return tilemap.map_mul(self.idiv, map) + else: + return enmap.map_mul(self.idiv, map) def to_work(self, map): - if self.tiled: return tilemap.redistribute(map, self.comm, self.geo_work.active) - else: return map.copy() + + if self.tiled: + return tilemap.redistribute(map, self.comm, self.geo_work.active) + else: + return map.copy() def from_work(self, map): if self.tiled: return tilemap.redistribute(map, self.comm, self.rhs.geometry.active) else: - if self.comm is None: return map - else: return utils.allreduce(map, self.comm) + if self.comm is None: + return map + else: + return putils.allreduce(map, self.comm) def write(self, prefix, tag, m): - if not self.output: return + if not self.output: + return oname = self.ofmt.format(name=self.name) oname = "%s%s_%s.%s" % (prefix, oname, tag, self.ext) if self.tiled: @@ -325,7 +478,7 @@ def _checkcompat(self, other): raise ValueError("Geometry mismatch") # Tiling is not set up yet by the time transeval is called. # Transeval doesn't need the tiling to match, though - #if other.rhs.ntile != self.rhs.ntile or other.rhs.nactive != self.rhs.nactive: + # if other.rhs.ntile != self.rhs.ntile or other.rhs.nactive != self.rhs.nactive: # raise ValueError("Tiling mismatch") else: if other.rhs.shape != self.rhs.shape: @@ -343,51 +496,91 @@ def transeval(self, id, obs, other, map, tod): """Translate map from SignalMap other to the current SignalMap, and then evaluate it for the given observation, returning a tod. This is used when building a signal-free tod for the noise model - in multipass mapmaking.""" + in multipass mapmaking. This function is not used during the first pass + of the ML mapmaker. It is a bridge logic between passes.""" # Currently we don't support any actual translation, but could handle # resolution changes in the future (probably not useful though) self._checkcompat(other) + ctime = obs.timestamps # Build the local geometry and pointing matrix for this observation if self.recenter: - rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter, - ctime=ctime[len(ctime)//2], geom=(self.rhs.shape, self.rhs.wcs), site=unarr(obs.site))) - else: rot = None - pmap = coords.pmat.P.for_tod(obs, comps=self.comps, geom=self.rhs.geometry, - rot=rot, threads="domdir", weather=unarr(obs.weather), site=unarr(obs.site), - interpol=self.interpol) + rot = recentering_to_quat_lonlat( + *evaluate_recentering( + self.recenter, + ctime=ctime[len(ctime) // 2], + geom=(self.rhs.shape, self.rhs.wcs), + site=unarr(obs.site), + ) + ) + else: + rot = None + pmap = coords.pmat.P.for_tod( + obs, + comps=self.comps, + geom=self.rhs.geometry, + rot=rot, + threads="domdir", + weather=unarr(obs.weather), + site=unarr(obs.site), + interpol=self.interpol, + ) # Build the RHS for this observation - pmap.from_map(dest=tod, signal_map=map, comps=self.comps) + # These lines are not activated during the first pass of mapmaking. + map_work = self.to_work(map) + try: + pmap.from_map(dest=tod, signal_map=map_work, comps=self.comps) + except RuntimeError as e: + raise RuntimeError( + f"""{e}. + Possibly caused by the assumption that exactly the same tiles will be hit each pass, + which can in rare cases break when downsampling by different amounts in different passes + when a tile is just barely hit by a single sample. This can be fixed by adding support + for constructing coords.pmat.P which treats hits to a missing tile as zero instead of + as an error. This also requires minor changes to so3g Projection.cxx. TODO.""" + ) + return tod + class SignalCut(Signal): - def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32, - output=False, cut_type=None): + def __init__( + self, + comm, + name="cut", + ofmt="{name}_{rank:02}", + dtype=np.float32, + output=False, + cut_type=None, + glitch_flags: str = "flags.glitch_flags", + ): """Signal for handling the ML solution for the values of the cut samples.""" - Signal.__init__(self, name, ofmt, output, ext="hdf") - self.comm = comm - self.data = {} + Signal.__init__(self, name, ofmt, output, ext="hdf", glitch_flags=glitch_flags) + self.comm = comm + self.data = {} self.dtype = dtype self.cut_type = cut_type - self.off = 0 - self.rhs = [] - self.div = [] + self.off = 0 + self.rhs = [] + self.div = [] - def add_obs(self, id, obs, nmat, Nd): + def add_obs(self, id, obs, nmat, Nd, glitch_flags: Optional[str] = None): """Add and process an observation. "obs" should be an Observation axis manager, nmat a noise model, representing the inverse noise covariance matrix, - and Nd the result of applying the noise model to the detector time-ordered data.""" - Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts - pcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type) + and Nd the result of applying the noise model to the detector time-ordered data. + """ + Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts + gflags = glitch_flags if glitch_flags is not None else self.glitch_flags + pcut = PmatCut(obs[gflags], model=self.cut_type) # Build our RHS obs_rhs = np.zeros(pcut.njunk, self.dtype) pcut.backward(Nd, obs_rhs) # Build our per-pixel inverse covmat obs_div = np.ones(pcut.njunk, self.dtype) - Nd[:] = 0 + Nd[:] = 0 pcut.forward(Nd, obs_div) - Nd *= nmat.ivar[:,None] + Nd *= nmat.ivar[:, None] pcut.backward(Nd, obs_div) - self.data[id] = bunch.Bunch(pcut=pcut, i1=self.off, i2=self.off+pcut.njunk) + self.data[id] = bunch.Bunch(pcut=pcut, i1=self.off, i2=self.off + pcut.njunk) self.off += pcut.njunk self.rhs.append(obs_rhs) self.div.append(obs_div) @@ -395,27 +588,31 @@ def add_obs(self, id, obs, nmat, Nd): def prepare(self): """Process the added observations, determining our degrees of freedom etc. Should be done before calling forward and backward.""" - if self.ready: return + if self.ready: + return self.rhs = np.concatenate(self.rhs) self.div = np.concatenate(self.div) self.dof = ArrayZipper(self.rhs.shape, dtype=self.dtype, comm=self.comm) self.ready = True def forward(self, id, tod, junk): - if id not in self.data: return + if id not in self.data: + return d = self.data[id] - d.pcut.forward(tod, junk[d.i1:d.i2]) + d.pcut.forward(tod, junk[d.i1 : d.i2]) def precon(self, junk): - return junk/self.div + return junk / self.div def backward(self, id, tod, junk): - if id not in self.data: return + if id not in self.data: + return d = self.data[id] - d.pcut.backward(tod, junk[d.i1:d.i2]) + d.pcut.backward(tod, junk[d.i1 : d.i2]) def write(self, prefix, tag, m): - if not self.output: return + if not self.output: + return if self.comm is None: rank = 0 else: @@ -436,12 +633,19 @@ def translate(self, other, junk): self._checkcompat(other) res = np.full(self.off, -1e10, self.dtype) for id in self.data: - sdata = self .data[id] + sdata = self.data[id] odata = other.data[id] - so3g.translate_cuts(odata.pcut.cuts, sdata.pcut.cuts, sdata.pcut.model, sdata.pcut.params, junk[odata.i1:odata.i2], res[sdata.i1:sdata.i2]) + so3g.translate_cuts( + odata.pcut.cuts, + sdata.pcut.cuts, + sdata.pcut.model, + sdata.pcut.params, + junk[odata.i1 : odata.i2], + res[sdata.i1 : sdata.i2], + ) return res - def transeval(self, id, obs, other, junk, tod): + def transeval(self, id, obs, other, junk, tod, glitch_flags: Optional[str] = None): """Translate data junk from SignalCut other to the current SignalCut, and then evaluate it for the given observation, returning a tod. This is used when building a signal-free tod for the noise model @@ -449,14 +653,22 @@ def transeval(self, id, obs, other, junk, tod): self._checkcompat(other) # We have to make a pointing matrix from scratch because add_obs # won't have been called yet at this point - spcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type) + gflags = glitch_flags if glitch_flags is not None else self.glitch_flags + spcut = PmatCut(obs[gflags], model=self.cut_type) # We do have one for other though, since that will be the output # from the previous round of multiplass mapmaking. odata = other.data[id] sjunk = np.zeros(spcut.njunk, junk.dtype) # Translate the cut degrees of freedom. The sample rate could have # changed, for example. - so3g.translate_cuts(odata.pcut.cuts, spcut.cuts, spcut.model, spcut.params, junk[odata.i1:odata.i2], sjunk) + so3g.translate_cuts( + odata.pcut.cuts, + spcut.cuts, + spcut.model, + spcut.params, + junk[odata.i1 : odata.i2], + sjunk, + ) # And project onto the tod spcut.forward(tod, sjunk) return tod diff --git a/sotodlib/mapmaking/utilities.py b/sotodlib/mapmaking/utilities.py index ac417d1dc..f6c714a7d 100644 --- a/sotodlib/mapmaking/utilities.py +++ b/sotodlib/mapmaking/utilities.py @@ -1,11 +1,13 @@ +from typing import Any, Union + +import importlib + import numpy as np -from pixell import enmap, utils, fft, tilemap, resample import so3g -import importlib +from pixell import enmap, fft, resample, tilemap, utils + +from .. import coords, core, tod_ops -from .. import core -from .. import tod_ops -from .. import coords def deslope_el(tod, el, srate, inplace=False): if not inplace: tod = tod.copy() @@ -24,7 +26,7 @@ def deslope_el(tod, el, srate, inplace=False): class ArrayZipper: def __init__(self, shape, dtype, comm=None): self.shape = shape - self.ndof = int(np.product(shape)) + self.ndof = int(np.prod(shape)) self.dtype = dtype self.comm = comm @@ -38,7 +40,7 @@ def dot(self, a, b): class MapZipper: def __init__(self, shape, wcs, dtype, comm=None): self.shape, self.wcs = shape, wcs - self.ndof = int(np.product(shape)) + self.ndof = int(np.prod(shape)) self.dtype = dtype self.comm = comm @@ -103,14 +105,14 @@ def inject_map(obs, map, recenter=None, interpol=None): rot = recentering_to_quat_lonlat(*evaluate_recentering(recenter, ctime=ctime[len(ctime)//2], geom=(map.shape, map.wcs), site=unarr(obs.site))) else: rot = None # Set up our pointing matrix for the map - pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=(map.shape, map.wcs), rot=rot, threads="domdir", interpol=self.interpol) + pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=(map.shape, map.wcs), rot=rot, threads="domdir", interpol=interpol) # And perform the actual injection - pmat.from_map(map.extract(shape, wcs), dest=obs.signal) + pmat.from_map(map.extract(map.shape, map.wcs), dest=obs.signal) def safe_invert_div(div, lim=1e-2, lim0=np.finfo(np.float32).tiny**0.5): try: # try setting up a context manager that limits the number of threads - from threadpoolctl import threadpool_limitse + from threadpoolctl import threadpool_limits cm = threadpool_limits(limits=1, user_api="blas") except: # threadpoolctl not available, need a dummy context manager @@ -137,7 +139,6 @@ def safe_invert_div(div, lim=1e-2, lim0=np.finfo(np.float32).tiny**0.5): return idiv - def measure_cov(d, nmax=10000): d = d[:,::max(1,d.shape[1]//nmax)] n,m = d.shape @@ -340,6 +341,7 @@ def evaluate_recentering(info, ctime, geom=None, site=None, weather="typical"): """Evaluate the quaternion that performs the coordinate recentering specified in info, which can be obtained from parse_recentering.""" import ephem + # Get the coordinates of the from, to and up points. This was a bit involved... def to_cel(lonlat, sys, ctime=None, site=None, weather=None): # Convert lonlat from sys to celestial coorinates. Maybe polish and put elswhere @@ -371,6 +373,7 @@ def recentering_to_quat_lonlat(p1, p2, pu): """Return the quaternion that represents the rotation that takes point p1 to p2, with the up direction pointing towards the point pu, all given as lonlat pairs""" from so3g.proj import quat + # 1. First rotate our point to the north pole: Ry(-(90-dec1))Rz(-ra1) # 2. Apply the same rotation to the up point. # 3. We want the up point to be upwards, so rotate it to ra = 180°: Rz(pi-rau2) @@ -440,8 +443,8 @@ def rangemat_sum(rangemat): res[i] = np.sum(ra[:,1]-ra[:,0]) return res -def find_usable_detectors(obs, maxcut=0.1): - ncut = rangemat_sum(obs.flags.glitch_flags) +def find_usable_detectors(obs, maxcut=0.1, glitch_flags: str = "flags.glitch_flags"): + ncut = rangemat_sum(obs[glitch_flags]) good = ncut < obs.samps.count * maxcut return obs.dets.vals[good] @@ -500,7 +503,7 @@ def downsample_obs(obs, down): if isinstance(val, core.AxisManager): res.wrap(key, val) else: - axdesc = [(k,v) for k,v in enumerate(axes) if v is not None] + axdesc = [(k, v) for k, v in enumerate(axes) if v is not None] res.wrap(key, val, axdesc) # The normal sample stuff res.wrap("timestamps", obs.timestamps[::down], [(0, "samps")]) @@ -508,48 +511,35 @@ def downsample_obs(obs, down): for key in ["az", "el", "roll"]: bore.wrap(key, getattr(obs.boresight, key)[::down], [(0, "samps")]) res.wrap("boresight", bore) - res.wrap("signal", resample.resample_fft_simple(obs.signal, onsamp), [(0,"dets"),(1,"samps")]) - - # The cuts - # obs.flags will contain all types of flags. We should query it for glitch_flags and source_flags - cut_keys = ["glitch_flags"] - - if "source_flags" in obs.flags: + res.wrap("signal", resample.resample_fft_simple(obs.signal, onsamp), + [(0,"dets"),(1,"samps")]) + + # # The cuts + # # obs.flags will contain all types of flags. We should query it for glitch_flags + # # and source_flags + cut_keys = [] + if "glitch_flags" in obs: + cut_keys.append("glitch_flags") + elif "flags.glitch_flags" in obs: + cut_keys.append("flags.glitch_flags") + + if "source_flags" in obs: cut_keys.append("source_flags") + elif "flags.source_flags" in obs: + cut_keys.append("flags.source_flags") # We need to add a res.flags FlagManager to res res = res.wrap('flags', core.FlagManager.for_tod(res)) for key in cut_keys: - res.flags.wrap(key, downsample_cut(getattr(obs.flags, key), down), [(0,"dets"),(1,"samps")]) + new_key = key.split(".")[-1] + res.flags.wrap(new_key, downsample_cut(obs[key], down), + [(0,"dets"),(1,"samps")]) # Not sure how to deal with flags. Some sort of or-binning operation? But it # doesn't matter anyway return res -def get_flags(obs, flagnames): - """Parse detector-set splits""" - cuts_out = None - if flagnames is None: - return so3g.proj.RangesMatrix.zeros(obs.shape) - det_splits = ['det_left','det_right','det_in','det_out','det_upper','det_lower'] - for flagname in flagnames: - if flagname in det_splits: - cuts = obs.det_flags[flagname] - elif flagname == 'scan_left': - cuts = obs.flags.left_scan - elif flagname == 'scan_right': - cuts = obs.flags.right_scan - else: - cuts = getattr(obs.flags, flagname) # obs.flags.flagname - - ## Add to the output matrix - if cuts_out is None: - cuts_out = cuts - else: - cuts_out += cuts - return cuts_out - def import_optional(module_name): try: module = importlib.import_module(module_name) diff --git a/sotodlib/obs_ops/splits.py b/sotodlib/obs_ops/splits.py index 279ae98af..0e1f41350 100644 --- a/sotodlib/obs_ops/splits.py +++ b/sotodlib/obs_ops/splits.py @@ -64,8 +64,10 @@ def det_splits_relative(aman, det_left_right=False, det_upper_lower=False, det_i def get_split_flags(aman, proc_aman=None, split_cfg=None): ''' - Function returning flags used for null splits consumed by the mapmaking and bundling codes. Fields labeled ``field_name_flag`` contain boolean masks and ``_avg`` are the mean - of the numerical based split flags to be used for observation level splits. + Function returning flags used for null splits consumed by the mapmaking + and bundling codes. Fields labeled ``field_name_flag`` contain boolean + masks and ``_avg`` are the mean of the numerical based split flags to + be used for observation level splits. Arguments --------- @@ -80,55 +82,96 @@ def get_split_flags(aman, proc_aman=None, split_cfg=None): ------- split_aman: AxisManager Axis manager containing splitting flags. + ``cuts`` field is a FlagManager containing the detector and subscan based splits used in the mapmaker. + ``_threshold`` fields contain the threshold used for the split. + Other fields conatain info for obs-level splits. ''' - if proc_aman is None: - try: - proc_aman = aman.preprocess - except: - raise ValueError('proc_aman is None and no preprocess field in aman provide valid preprocess metadata') - - if (not 't2p' in proc_aman) | (not 'hwpss_stats' in proc_aman): - raise ValueError('t2p or hwpss_stats not in proc_aman must run after those steps in the pipeline.') - # Set default set of splits - default_cfg = {'high_gain': 0.115, 'high_noise': 3.5e-5, 'high_tau': 1.5e-3, - 'det_A': 'A', 'pol_angle': 35, 'det_top': 'B', 'high_leakage': 1e-3, - 'high_2f': 1.5e-3, 'right_focal_plane': 0, 'top_focal_plane': 0, - 'central_pixels': 0.071 } + default_cfg = {'high_gain': 0.115, 'high_tau': 1.5e-3, + 'det_A': 'A', 'pol_angle': 35, 'det_top': 'B', 'right_focal_plane': 0, + 'top_focal_plane': 0, 'central_pixels': 0.071 } if split_cfg is None: split_cfg = default_cfg split_aman = AxisManager(aman.dets) + fm = FlagManager.for_tod(aman) # If provided split config doesn't include all of the splits in default for k in default_cfg.keys(): if not k in split_cfg: split_cfg[k] = default_cfg[k] - split_aman.wrap(f'{k}_threshold', split_cfg[k]) - split_aman.wrap('high_gain_flag', aman.det_cal.phase_to_pW > split_cfg['high_gain'], - [(0, 'dets')]) + # Gain split + fm.wrap_dets('high_gain', aman.det_cal.phase_to_pW > split_cfg['high_gain']) + fm.wrap_dets('low_gain', aman.det_cal.phase_to_pW <= split_cfg['high_gain']) split_aman.wrap('gain_avg', np.nanmean(aman.det_cal.phase_to_pW)) - split_aman.wrap('high_noise_flag', proc_aman.noiseQ_fit.fit[:,1] > split_cfg['high_noise'], - [(0, 'dets')]) - split_aman.wrap('noise_avg', np.nanmean(proc_aman.noiseQ_fit.fit[:,1])) - split_aman.wrap('high_tau_flag', aman.det_cal.tau_eff > split_cfg['high_tau'], - [(0, 'dets')]) + # Time constant split + fm.wrap_dets('high_tau', aman.det_cal.tau_eff > split_cfg['high_tau']) + fm.wrap_dets('low_tau', aman.det_cal.tau_eff <= split_cfg['high_tau']) split_aman.wrap('tau_avg', np.nanmean(aman.det_cal.tau_eff)) - split_aman.wrap('det_A_flag', aman.det_info.wafer.pol <= split_cfg['det_A'], - [(0, 'dets')]) - split_aman.wrap('pol_angle_flag', aman.det_info.wafer.angle > split_cfg['pol_angle'], - [(0, 'dets')]) - split_aman.wrap('det_top_flag', aman.det_info.wafer.crossover > split_cfg['det_top'], - [(0, 'dets')]) - split_aman.wrap('high_leakage_flag', np.sqrt(proc_aman.t2p.lamQ**2 + proc_aman.t2p.lamU**2) > split_cfg['high_leakage'], - [(0, 'dets')]) - split_aman.wrap('leakage_avg', np.nanmean(np.sqrt(proc_aman.t2p.lamQ**2 + proc_aman.t2p.lamU**2)), - [(0, 'dets')]) - a2 = aman.det_cal.phase_to_pW*np.sqrt(proc_aman.hwpss_stats.coeffs[:,2]**2 + proc_aman.hwpss_stats.coeffs[:,3]**2) - split_aman.wrap('high_2f_flag', a2 > split_cfg['high_2f'], [(0, 'dets')]) - split_aman.wrap('2f_avg', np.nanmean(a2), [(0, 'dets')]) - split_aman.wrap('right_focal_plane_flag', aman.focal_plane.xi > split_cfg['right_focal_plane'], [(0, 'dets')]) - split_aman.wrap('top_focal_plane_flag', aman.focal_plane.eta > split_cfg['top_focal_plane'], [(0, 'dets')]) - split_aman.wrap('central_pixels_flag', np.sqrt(aman.focal_plane.xi**2 + aman.focal_plane.eta**2) < split_cfg['central_pixels'], + # detAB split + fm.wrap_dets('det_A', aman.det_info.wafer.pol <= split_cfg['det_A']) + fm.wrap_dets('det_B', aman.det_info.wafer.pol > split_cfg['det_A']) + # def pol split + fm.wrap_dets('high_pol_angle', aman.det_info.wafer.angle > split_cfg['pol_angle']) + fm.wrap_dets('low_pol_angle', aman.det_info.wafer.angle <= split_cfg['pol_angle']) + # det top/bottom split + fm.wrap_dets('det_type_top', aman.det_info.wafer.crossover > split_cfg['det_top']) + fm.wrap_dets('det_type_bottom', aman.det_info.wafer.crossover <= split_cfg['det_top']) + # Right/left focal plane split + fm.wrap_dets('det_right', aman.focal_plane.xi > split_cfg['right_focal_plane']) + fm.wrap_dets('det_left', aman.focal_plane.xi <= split_cfg['right_focal_plane']) + # Top/bottom focal plane split + fm.wrap_dets('det_upper', aman.focal_plane.eta > split_cfg['top_focal_plane']) + fm.wrap_dets('det_lower', aman.focal_plane.eta <= split_cfg['top_focal_plane']) + # Inner/outter pixel split + r = np.sqrt(aman.focal_plane.xi**2 + aman.focal_plane.eta**2) + fm.wrap_dets('det_in', r < split_cfg['central_pixels']) + fm.wrap_dets('det_out', r >= split_cfg['central_pixels']) + + # Preproc dependent splits + if proc_aman is None: + try: + proc_aman = aman.preprocess + except: + print('Preprocess information not present, cannot generate preprocess dependent splits.') + for k in split_cfg.keys(): + split_aman.wrap(f'{k}_threshold', split_cfg[k]) + split_aman.wrap('cuts', fm) + return split_aman + + # This one is a bit funky to be forcing it to be noiseQ_fit, and units matter! + if 'noiseQ_fit' in proc_aman: + # Noise split + if not 'high_noise' in split_cfg: + split_cfg['high_noise'] = 3.5e-5 + fm.wrap_dets('high_noise', proc_aman.noiseQ_fit.fit[:,1] > split_cfg['high_noise']) + fm.wrap_dets('low_noise', proc_aman.noiseQ_fit.fit[:,1] <= split_cfg['high_noise']) + split_aman.wrap('noise_avg', np.nanmean(proc_aman.noiseQ_fit.fit[:,1])) + + if 't2p' in proc_aman: + # T2P Leakage split + if not 'high_leakage' in split_cfg: + split_cfg['high_leakage'] = 1e-3 + fm.wrap_dets('high_leakage', np.sqrt(proc_aman.t2p.lamQ**2 + proc_aman.t2p.lamU**2) > split_cfg['high_leakage']) + fm.wrap_dets('low_leakage', np.sqrt(proc_aman.t2p.lamQ**2 + proc_aman.t2p.lamU**2) <= split_cfg['high_leakage']) + split_aman.wrap('leakage_avg', np.nanmean(np.sqrt(proc_aman.t2p.lamQ**2 + proc_aman.t2p.lamU**2)), [(0, 'dets')]) + if 'hwpss_stats' in proc_aman: + # High 2f amplitude split + if not 'high_2f' in split_cfg: + split_cfg['high_2f'] = 1.5e-3 + a2 = aman.det_cal.phase_to_pW*np.sqrt(proc_aman.hwpss_stats.coeffs[:,2]**2 + proc_aman.hwpss_stats.coeffs[:,3]**2) + fm.wrap_dets('high_2f', a2 > split_cfg['high_2f']) + fm.wrap_dets('low_2f', a2 <= split_cfg['high_2f']) + split_aman.wrap('2f_avg', np.nanmean(a2), [(0, 'dets')]) + # Left/right subscans + if 'turnaround_flags' in proc_aman: + fm.wrap('scan_left', proc_aman.turnaround_flags.left_scan) + fm.wrap('scan_right', proc_aman.turnaround_flags.right_scan) + + for k in split_cfg.keys(): + split_aman.wrap(f'{k}_threshold', split_cfg[k]) + + split_aman.wrap('cuts', fm) + return split_aman diff --git a/sotodlib/preprocess/pcore.py b/sotodlib/preprocess/pcore.py index e6bf5a42e..7f7840f2c 100644 --- a/sotodlib/preprocess/pcore.py +++ b/sotodlib/preprocess/pcore.py @@ -1,10 +1,12 @@ """Base Class and PIPELINE register for the preprocessing pipeline scripts.""" import os +import copy import logging import numpy as np from .. import core from so3g.proj import Ranges, RangesMatrix from scipy.sparse import csr_array +from matplotlib import pyplot as plt class _Preprocess(object): """The base class for Preprocessing modules which defines the required @@ -270,7 +272,7 @@ def _expand(new, full, wrap_valid=True): continue out.wrap_new( k, new._assignments[k], cls=_zeros_cls(v)) oidx=[]; nidx=[] - for a in new._assignments[k]: + for ii, a in enumerate(new._assignments[k]): if a == 'dets': oidx.append(fs_dets) nidx.append(ns_dets) @@ -278,8 +280,19 @@ def _expand(new, full, wrap_valid=True): oidx.append(fs_samps) nidx.append(ns_samps) else: - oidx.append(slice(None)) - nidx.append(slice(None)) + if (ii == 0) and isinstance(out[k], RangesMatrix): # Treat like dets + # _ranges_matrix_match expects oidx[0] and nidx[0] to be list(inds), not slice. + # Unknown axes treated as dets if first entry, else like samps. Added to support (subscans, samps) RangesMatrix. + if a in full._axes: + _, fs, ns = full[a].intersection(new[a], return_slices=True) + else: + fs = range(new[a].count) + ns = range(new[a].count) + oidx.append(fs) + nidx.append(ns) + else: # Treat like samps + oidx.append(slice(None)) + nidx.append(slice(None)) oidx = tuple(oidx) nidx = tuple(nidx) if isinstance(out[k], RangesMatrix): @@ -357,7 +370,7 @@ def __init__(self, modules, plot_dir='./', logger=None, wrap_valid=True): self.logger = logger self.plot_dir = plot_dir self.wrap_valid = wrap_valid - super().__init__( [self._check_item(item) for item in modules]) + super().__init__( [self._check_item(item) for item in copy.deepcopy(modules)]) def _check_item(self, item): if isinstance(item, _Preprocess): @@ -431,8 +444,12 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): """ if proc_aman is None: - proc_aman = core.AxisManager( aman.dets, aman.samps) - full = core.AxisManager( aman.dets, aman.samps) + if 'preprocess' in aman: + proc_aman = aman.preprocess.copy() + full = aman.preprocess.copy() + else: + proc_aman = core.AxisManager(aman.dets, aman.samps) + full = core.AxisManager( aman.dets, aman.samps) run_calc = True update_plot = False else: @@ -456,6 +473,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): update_full_aman( proc_aman, full, self.wrap_valid) if update_plot: process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) + plt.close() if select: process.select(aman, proc_aman) proc_aman.restrict('dets', aman.dets.vals) diff --git a/sotodlib/preprocess/preprocess_plot.py b/sotodlib/preprocess/preprocess_plot.py index 83f3325d0..3985cfc81 100644 --- a/sotodlib/preprocess/preprocess_plot.py +++ b/sotodlib/preprocess/preprocess_plot.py @@ -390,6 +390,44 @@ def plot_trending_flags(aman, trend_aman, filename='./trending_flags.png'): os.makedirs(head_tail[0], exist_ok=True) plt.savefig(filename) +def plot_signal(aman, signal=None, xx=None, signal_name="signal", x_name="timestamps", plot_ds_factor=50, plot_ds_factor_dets=None, xlim=None, alpha=0.2, yscale='linear', y_unit=None, filename="./signal.png"): + from operator import attrgetter + if plot_ds_factor_dets is None: + plot_ds_factor_dets = plot_ds_factor + if signal is None: + signal = attrgetter(signal_name)(aman) + if xx is None: + xx = attrgetter(x_name)(aman) + yy = signal[::plot_ds_factor_dets, 1::plot_ds_factor].copy() # (dets, samps); (dets, nusamps); (dets, nusamps, subscans) + xx = xx[1::plot_ds_factor].copy() # (samps); (nusamps) + if x_name == "timestamps": + xx -= xx[0] + if yy.ndim > 2: # Flatten subscan axis into dets + yy = yy.swapaxes(1,2).reshape(-1, yy.shape[1]) + + if xlim is not None: + xinds = np.logical_and(xx >= xlim[0], xx <= xlim[1]) + xx = xx[xinds] + yy = yy[:,xinds] + + fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8)) + ax.plot(xx, yy.T, color='k', alpha=0.2) + ax.set_yscale(yscale) + if "freqs" in x_name: + ax.set_xlabel("freq [Hz]") + else: + ax.set_xlabel(f"{x_name} [s]") + y_unit = "" if y_unit is None else f" [{y_unit}]" + ax.set_ylabel(f"{signal_name.replace('.Pxx', '')}{y_unit}") + plt.suptitle(f"{aman.obs_info.obs_id}, dT = {np.ptp(aman.timestamps)/60:.1f} min") + plt.tight_layout() + head_tail = os.path.split(filename) + os.makedirs(head_tail[0], exist_ok=True) + plt.savefig(filename) + +def plot_psd(aman, signal=None, xx=None, signal_name="psd.Pxx", x_name="psd.freqs", plot_ds_factor=4, plot_ds_factor_dets=20, xlim=None, alpha=0.2, yscale='log', y_unit=None, filename="./psd.png"): + return plot_signal(aman, signal, xx, signal_name, x_name, plot_ds_factor, plot_ds_factor_dets, xlim, alpha, yscale, y_unit, filename) + def plot_signal_diff(aman, flag_aman, flag_type="glitches", flag_threshold=10, plot_ds_factor=50, filename="./glitch_signal_diff.png"): """ Function for plotting the difference in signal before and after cuts from either glitches or jumps. diff --git a/sotodlib/preprocess/preprocess_util.py b/sotodlib/preprocess/preprocess_util.py index e8262f3ec..c15298ddc 100644 --- a/sotodlib/preprocess/preprocess_util.py +++ b/sotodlib/preprocess/preprocess_util.py @@ -7,6 +7,7 @@ import numpy as np import h5py import traceback +import inspect from .. import core @@ -175,19 +176,25 @@ def get_groups(obs_id, configs, context): groups : list of list of int The list of groups of detectors. """ - group_by = np.atleast_1d(configs['subobs'].get('use', 'detset')) - for i, gb in enumerate(group_by): - if gb.startswith('dets:'): - group_by[i] = gb.split(':',1)[1] - - if (gb == 'detset') and (len(group_by) == 1): - groups = context.obsfiledb.get_detsets(obs_id) - return group_by, [[g] for g in groups] - - det_info = context.get_det_info(obs_id) - rs = det_info.subset(keys=group_by).distinct() - groups = [[b for a,b in r.items()] for r in rs] - return group_by, groups + try: + group_by = np.atleast_1d(configs['subobs'].get('use', 'detset')) + for i, gb in enumerate(group_by): + if gb.startswith('dets:'): + group_by[i] = gb.split(':',1)[1] + + if (gb == 'detset') and (len(group_by) == 1): + groups = context.obsfiledb.get_detsets(obs_id) + return group_by, [[g] for g in groups], None + + det_info = context.get_det_info(obs_id) + rs = det_info.subset(keys=group_by).distinct() + groups = [[b for a,b in r.items()] for r in rs] + return group_by, groups, None + except Exception as e: + error = f'Failed get groups for: {obs_id}' + errmsg = f'{type(e)}: {e}' + tb = ''.join(traceback.format_tb(e.__traceback__)) + return [], [], [error, errmsg, tb] def get_preprocess_db(configs, group_by, logger=None): @@ -265,20 +272,22 @@ def load_preprocess_det_select(obs_id, configs, context=None, Arguments ---------- obs_id: multiple - passed to `context.get_obs` to load AxisManager, see Notes for + Passed to `context.get_obs` to load AxisManager, see Notes for `context.get_obs` configs: string or dictionary - config file or loaded config directory + Config file or loaded config directory + context: core.Context + The Context file to use. dets: dict - dets to restrict on from info in det_info. See context.get_meta. + Dets to restrict on from info in det_info. See context.get_meta. meta: AxisManager Contains supporting metadata to use for loading. Can be pre-restricted in any way. See context.get_meta. - logger : PythonLogger + logger: PythonLogger Optional. Logger object. If None, a new logger is created. """ - + if logger is None: logger = init_logger("preprocess") @@ -290,9 +299,10 @@ def load_preprocess_det_select(obs_id, configs, context=None, pipe[-1].select(meta) return meta + def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, no_signal=None, logger=None): - """ Loads the saved information from the preprocessing pipeline and runs + """Loads the saved information from the preprocessing pipeline and runs the processing section of the pipeline. Assumes preprocess_tod has already been run on the requested observation. @@ -300,12 +310,14 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, Arguments ---------- obs_id: multiple - passed to `context.get_obs` to load AxisManager, see Notes for + Passed to `context.get_obs` to load AxisManager, see Notes for `context.get_obs` configs: string or dictionary - config file or loaded config directory + Config file or loaded config directory + context: core.Context + Optional. The Context file to use. dets: dict - dets to restrict on from info in det_info. See context.get_meta. + Dets to restrict on from info in det_info. See context.get_meta. meta: AxisManager Contains supporting metadata to use for loading. Can be pre-restricted in any way. See context.get_meta. @@ -313,17 +325,17 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, If True, signal will be set to None. This is a way to get the axes and pointing info without the (large) TOD blob. Not all loaders may support this. - logger : PythonLogger + logger: PythonLogger Optional. Logger object. If None, a new logger is created. """ - + if logger is None: logger = init_logger("preprocess") - + configs, context = get_preprocess_context(configs, context) meta = load_preprocess_det_select(obs_id, configs=configs, context=context, - dets=dets, meta=meta) + dets=dets, meta=meta, logger=logger) if meta.dets.count == 0: logger.info(f"No detectors left after cuts in obs {obs_id}") @@ -335,35 +347,324 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, return aman -def preproc_or_load_group(obs_id, configs, dets, logger=None, - context=None, overwrite=False): +def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, + dets=None, meta=None, no_signal=None, + logger=None): + """Loads the saved information from the preprocessing pipeline from a + reference and a dependent database and runs the processing section of + the pipeline for each. + + Assumes preprocess_tod and multilayer_preprocess_tod have already been run + on the requested observation. + + Arguments + ---------- + obs_id: multiple + Passed to `context.get_obs` to load AxisManager, see Notes for + `context.get_obs` + configs_init: string or dictionary + Config file or loaded config directory + configs_proc: string or dictionary + Second config file or loaded config dictionary to load + dependent databases generated using multilayer_preprocess_tod.py. + dets: dict + Dets to restrict on from info in det_info. See context.get_meta. + meta: AxisManager + Contains supporting metadata to use for loading. + Can be pre-restricted in any way. See context.get_meta. + no_signal: bool + If True, signal will be set to None. + This is a way to get the axes and pointing info without + the (large) TOD blob. Not all loaders may support this. + logger: PythonLogger + Optional. Logger object or None will generate a new one. + """ + + if logger is None: + logger = init_logger("preprocess") + + configs_init, context_init = get_preprocess_context(configs_init) + meta_init = context_init.get_meta(obs_id, dets=dets, meta=meta) + + configs_proc, context_proc = get_preprocess_context(configs_proc) + meta_proc = context_proc.get_meta(obs_id, dets=dets, meta=meta) + + group_by_init, groups_init, error_init = get_groups(obs_id, configs_init, context_init) + group_by_proc, groups_proc, error_proc = get_groups(obs_id, configs_proc, context_proc) + + if error_init is not None: + raise ValueError(f"{error_init[0]}\n{error_init[1]}\n{error_init[2]}") + + if error_proc is not None: + raise ValueError(f"{error_proc[0]}\n{error_proc[1]}\n{error_proc[2]}") + + if (group_by_init != group_by_proc).any(): + raise ValueError('init and proc groups do not match') + + if meta_init.dets.count == 0 or meta_proc.dets.count == 0: + logger.info(f"No detectors in obs {obs_id}") + return None + else: + pipe_init = Pipeline(configs_init["process_pipe"], logger=logger) + aman_cfgs_ref = get_pcfg_check_aman(pipe_init) + + if check_cfg_match(aman_cfgs_ref, meta_proc.preprocess['pcfg_ref'], + logger=logger): + aman = context_init.get_obs(meta_proc, no_signal=no_signal) + logger.info("Running initial pipeline") + pipe_init.run(aman, aman.preprocess) + + pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger) + logger.info("Running dependent pipeline") + proc_aman = context_proc.get_meta(obs_id, meta=meta_proc) + + aman.preprocess.merge(proc_aman.preprocess) + + pipe_proc.run(aman, aman.preprocess) + + return aman + else: + raise ValueError('Dependency check between configs failed.') + + +def find_db(obs_id, configs, dets, context=None, logger=None): + """This function checks if the manifest db from + a config file exists and searches if it contains + an entry for the provided Obs id and set of detectors. + + Arguments + ---------- + obs_id: str + Obs id to process or load + configs: fpath or dict + Filepath or dictionary containing the preprocess configuration file. + dets: dict + Dictionary specifying which detectors/wafers to load see ``Context.obsdb.get_obs``. + context: core.Context + Optional. Context object used for data loading/querying. + logger: PythonLogger + Optional. Logger object or None will generate a new one. + + Returns + ------- + dbexist : bool + True if db exists and entry for input detectors is found. + """ + + if logger is None: + logger = init_logger("preprocess") + + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + if context is None: + context = core.Context(configs["context_file"]) + group_by, _, _ = get_groups(obs_id, configs, context) + cur_groups = [list(np.fromiter(dets.values(), dtype='/.h5 and then this adds # /_.h5 where xxx is a number that increments up from 0 @@ -493,13 +901,12 @@ def cleanup_mandb(error, outputs, configs, logger=None): folder = os.path.dirname(configs['archive']['policy']['filename']) basename = os.path.splitext(configs['archive']['policy']['filename'])[0] dest_file = basename + '_' + str(nfile).zfill(3) + '.h5' - if not(os.path.exists(folder)): - os.makedirs(folder) + if os.path.isabs(folder) and not(os.path.exists(folder)): + os.makedirs(folder) while os.path.exists(dest_file) and os.path.getsize(dest_file) > 10e9: nfile += 1 dest_file = basename + '_' + str(nfile).zfill(3) + '.h5' group_by = [k.split(':')[-1] for k in outputs['db_data'].keys() if 'dets' in k] - db = get_preprocess_db(configs, group_by, logger) h5_path = os.path.relpath(dest_file, start=os.path.dirname(configs['archive']['index'])) @@ -507,11 +914,15 @@ def cleanup_mandb(error, outputs, configs, logger=None): with h5py.File(dest_file,'a') as f_dest: with h5py.File(src_file,'r') as f_src: for dts in f_src.keys(): + # If the dataset or group already exists, delete it to overwrite + if overwrite and dts in f_dest: + del f_dest[dts] f_src.copy(f_src[f'{dts}'], f_dest, f'{dts}') for member in f_src[dts]: if isinstance(f_src[f'{dts}/{member}'], h5py.Dataset): f_src.copy(f_src[f'{dts}/{member}'], f_dest[f'{dts}'], f'{dts}/{member}') logger.info(f"Saving to database under {outputs['db_data']}") + db = get_preprocess_db(configs, group_by, logger) if len(db.inspect(outputs['db_data'])) == 0: db.add_entry(outputs['db_data'], h5_path) os.remove(src_file) @@ -524,5 +935,66 @@ def cleanup_mandb(error, outputs, configs, logger=None): errlog = os.path.join(folder, 'errlog.txt') f = open(errlog, 'a') f.write(f'{time.time()}, {error}\n') - f.write(f'\t{outputs[0]}\n\t{outputs[1]}\n') + if outputs is not None: + f.write(f'\t{outputs[0]}\n\t{outputs[1]}\n') f.close() + + +def get_pcfg_check_aman(pipe): + """ + Given a preprocess pipeline class return an axis manager containing + the ordered steps of the pipeline with all arguments for each step. + """ + pcfg_ref = core.AxisManager() + for i, pp in enumerate(pipe): + pcfg_ref.wrap(f'{i}_{pp.name}', core.AxisManager()) + for memb in inspect.getmembers(pp, lambda a:not(inspect.isroutine(a))): + if not memb[0][0] == '_': + if type(memb[1]) is dict: + pcfg_ref[f'{i}_{pp.name}'].wrap(memb[0], core.AxisManager()) + for itm in memb[1].items(): + pcfg_ref[f'{i}_{pp.name}'][memb[0]].wrap(itm[0], str(itm[1])) + else: + pcfg_ref[f'{i}_{pp.name}'].wrap(memb[0], memb[1]) + return pcfg_ref + + +def _check_assignment_length(a, b): + """ + Helper function to check if the set of assignments in axis manager ``a`` matches + the length of assignments in axis manager ``b``. + """ + aa = np.fromiter(a._assignments.keys(), dtype='= low_f, freqs <= high_f], axis=0) + self.calc_cfgs['mask'] = fmask + del self.calc_cfgs['psd_mask'] + + _f = attrgetter(self.signal) + try: + signal = _f(aman) + except KeyError: + signal = _f(proc_aman) + stats_aman = tod_ops.flags.get_stats(aman, signal, **self.calc_cfgs) + self.save(proc_aman, stats_aman) + + def save(self, proc_aman, stats_aman): + if not(self.save_cfgs is None): + proc_aman.wrap(self.wrap, stats_aman) + + def plot(self, aman, proc_aman, filename): + if self.plot_cfgs is None: + return + if self.plot_cfgs: + from .preprocess_plot import plot_signal + + filename = filename.replace('{ctime}', f'{str(aman.timestamps[0])[:5]}') + filename = filename.replace('{obsid}', aman.obs_info.obs_id) + det = aman.dets.vals[0] + ufm = det.split('_')[2] + filename = filename.replace('{name}', f'{ufm}_{self.signal}') + + plot_signal(aman, signal_name=self.signal, x_name="timestamps", filename=filename, **self.plot_cfgs) class Noise(_Preprocess): """Estimate the white noise levels in the data. Assumes the PSD has been - wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`. + wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`. Saves the results into the "noise" field of proc_aman. @@ -391,6 +480,8 @@ class Noise(_Preprocess): Example config block:: - name: "noise" + fit: False + subscan: False calc: low_f: 5 high_f: 10 @@ -408,6 +499,7 @@ class Noise(_Preprocess): def __init__(self, step_cfgs): self.psd = step_cfgs.get('psd', 'psd') self.fit = step_cfgs.get('fit', False) + self.subscan = step_cfgs.get('subscan', False) super().__init__(step_cfgs) @@ -415,21 +507,28 @@ def calc_and_save(self, aman, proc_aman): if self.psd not in proc_aman: raise ValueError("PSD is not saved in Preprocessing AxisManager") psd = proc_aman[self.psd] - + pxx = psd.Pxx_ss if self.subscan else psd.Pxx + if self.calc_cfgs is None: self.calc_cfgs = {} - + if self.fit: - calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=psd.Pxx, + if self.calc_cfgs.get('subscan') is None: + self.calc_cfgs['subscan'] = self.subscan + calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=pxx, f=psd.freqs, merge_fit=True, **self.calc_cfgs) else: - wn = tod_ops.fft_ops.calc_wn(aman, pxx=psd.Pxx, + wn = tod_ops.fft_ops.calc_wn(aman, pxx=pxx, freqs=psd.freqs, **self.calc_cfgs) - calc_aman = core.AxisManager(aman.dets) - calc_aman.wrap("white_noise", wn, [(0,"dets")]) + if not self.subscan: + calc_aman = core.AxisManager(aman.dets) + calc_aman.wrap("white_noise", wn, [(0,"dets")]) + else: + calc_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans) + calc_aman.wrap("white_noise", wn, [(0,"dets"), (1,"subscans")]) self.save(proc_aman, calc_aman) @@ -454,13 +553,28 @@ def select(self, meta, proc_aman=None): if proc_aman is None: proc_aman = meta.preprocess - self.select_cfgs['name'] = self.select_cfgs.get('name','noise') + if 'wrap_name' in self.save_cfgs: + self.select_cfgs['name'] = self.select_cfgs.get('name', self.save_cfgs['wrap_name']) + else: + self.select_cfgs['name'] = self.select_cfgs.get('name', 'noise') if self.fit: - keep = proc_aman[self.select_cfgs['name']].fit[:,1] <= self.select_cfgs["max_noise"] + wn = proc_aman[self.select_cfgs['name']].fit[:,1] + fk = proc_aman[self.select_cfgs['name']].fit[:,0] else: - keep = proc_aman[self.select_cfgs['name']].white_noise <= self.select_cfgs["max_noise"] - + wn = proc_aman[self.select_cfgs['name']].white_noise + fk = None + if self.subscan: + wn = np.nanmean(wn, axis=-1) # Mean over subscans + if fk is not None: + fk = np.nanmean(fk, axis=-1) # Mean over subscans + keep = np.ones_like(wn, dtype=bool) + if "max_noise" in self.select_cfgs.keys(): + keep &= (wn <= np.float64(self.select_cfgs["max_noise"])) + if "min_noise" in self.select_cfgs.keys(): + keep &= (wn >= np.float64(self.select_cfgs["min_noise"])) + if fk is not None and "max_fknee" in self.select_cfgs.keys(): + keep &= (fk <= np.float64(self.select_cfgs["max_fknee"])) meta.restrict("dets", meta.dets.vals[keep]) return meta @@ -766,7 +880,7 @@ class FlagTurnarounds(_Preprocess): .. autofunction:: sotodlib.tod_ops.flags.get_turnaround_flags """ name = 'flag_turnarounds' - + def calc_and_save(self, aman, proc_aman): if self.calc_cfgs is None: self.calc_cfgs = {} @@ -780,12 +894,15 @@ def calc_and_save(self, aman, proc_aman): calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')]) calc_aman.wrap('left_scan', left, [(0, 'dets'), (1, 'samps')]) calc_aman.wrap('right_scan', right, [(0, 'dets'), (1, 'samps')]) - + if self.calc_cfgs['method'] == 'az': ta = tod_ops.flags.get_turnaround_flags(aman, **self.calc_cfgs) calc_aman = core.AxisManager(aman.dets, aman.samps) calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')]) + if ('merge_subscans' not in self.calc_cfgs) or (self.calc_cfgs['merge_subscans']): + calc_aman.wrap('subscan_info', aman.subscan_info) + self.save(proc_aman, calc_aman) def save(self, proc_aman, turn_aman): @@ -796,7 +913,7 @@ def save(self, proc_aman, turn_aman): def process(self, aman, proc_aman): tod_ops.flags.get_turnaround_flags(aman, **self.process_cfgs) - + class SubPolyf(_Preprocess): """Fit TOD in each subscan with polynominal of given order and subtract it. All process configs go to `sotodlib.tod_ops.sub_polyf`. @@ -904,58 +1021,79 @@ class SourceFlags(_Preprocess): Example config block:: - name : "source_flags" - signal: "signal" # optional calc: mask: {'shape': 'circle', - 'xyr': (0, 0, 1.)} - center_on: 'jupiter' - res: 0.005817764173314432 # np.radians(20/60) - max_pix: 4e6 + 'xyr': [0, 0, 1.]} + center_on: ['jupiter', 'moon'] # list of str + res: 20 # arcmin + max_pix: 4000000 # max number of allowed pixels in map + distance: 0 # max distance of footprint from source in degrees save: True select: True # optional - + .. autofunction:: sotodlib.tod_ops.flags.get_source_flags """ name = "source_flags" - + def calc_and_save(self, aman, proc_aman): - center_on = self.calc_cfgs.get('center_on', 'planet') - # Get source from tags - if center_on == 'planet': + source_list = np.atleast_1d(self.calc_cfgs.get('center_on', 'planet')) + if source_list == ['planet']: from sotodlib.coords.planets import SOURCE_LIST - matches = [x for x in aman.tags if x in SOURCE_LIST] - if len(matches) != 0: - source = matches[0] - else: + source_list = [x for x in aman.tags if x in SOURCE_LIST] + if len(source_list) == 0: raise ValueError("No tags match source list") - else: - source = center_on - source_flags = tod_ops.flags.get_source_flags(aman, - merge=self.calc_cfgs.get('merge', False), - overwrite=self.calc_cfgs.get('overwrite', True), - source_flags_name=self.calc_cfgs.get('source_flags_name', 'source_flags'), - mask=self.calc_cfgs.get('mask', None), - center_on=source, - res=self.calc_cfgs.get('res', None), - max_pix=self.calc_cfgs.get('max_pix', None)) + + # find if source is within footprint + distance + positions = planets.get_nearby_sources(tod=aman, source_list=source_list, + distance=self.calc_cfgs.get('distance', 0)) source_aman = core.AxisManager(aman.dets, aman.samps) - source_aman.wrap('source_flags', source_flags, [(0, 'dets'), (1, 'samps')]) + for p in positions: + source_flags = tod_ops.flags.get_source_flags(aman, + merge=self.calc_cfgs.get('merge', False), + overwrite=self.calc_cfgs.get('overwrite', True), + source_flags_name=self.calc_cfgs.get('source_flags_name', None), + mask=self.calc_cfgs.get('mask', None), + center_on=p[0], + res=self.calc_cfgs.get('res', None), + max_pix=self.calc_cfgs.get('max_pix', None)) + + source_aman.wrap(p[0], source_flags, [(0, 'dets'), (1, 'samps')]) + + # add sources that were not nearby from source list + for source in source_list: + if source not in source_aman._fields: + source_aman.wrap(source, RangesMatrix.zeros([aman.dets.count, aman.samps.count]), + [(0, 'dets'), (1, 'samps')]) + self.save(proc_aman, source_aman) - + def save(self, proc_aman, source_aman): if self.save_cfgs is None: return if self.save_cfgs: - proc_aman.wrap("sources", source_aman) + proc_aman.wrap("source_flags", source_aman) def select(self, meta, proc_aman=None): if self.select_cfgs is None: return meta if proc_aman is None: - proc_aman = meta.preprocess - keep = ~has_any_cuts(proc_aman.sources.source_flags) - meta.restrict("dets", meta.dets.vals[keep]) + source_flags = meta.preprocess.source_flags + else: + source_flags = proc_aman.source_flags + + source_list = np.atleast_1d(self.calc_cfgs.get('center_on', 'planet')) + if source_list == ['planet']: + from sotodlib.coords.planets import SOURCE_LIST + source_list = [x for x in aman.tags if x in SOURCE_LIST] + if len(source_list) == 0: + raise ValueError("No tags match source list") + + for source in source_list: + if source in source_flags._fields: + keep = ~has_all_cut(source_flags[source]) + meta.restrict("dets", meta.dets.vals[keep]) + source_flags.restrict("dets", source_flags.dets.vals[keep]) return meta class HWPAngleModel(_Preprocess): @@ -993,7 +1131,7 @@ def save(self, proc_aman, hwp_angle_aman): return if self.save_cfgs: proc_aman.wrap("hwp_angle", hwp_angle_aman) - + class FourierFilter(_Preprocess): """ @@ -1024,6 +1162,7 @@ class FourierFilter(_Preprocess): See :ref:`fourier-filters` documentation for more details. """ name = 'fourier_filter' + def __init__(self, step_cfgs): self.signal_name = step_cfgs.get('signal_name', 'signal') # By default signal is overwritted by the filtered signal @@ -1083,9 +1222,9 @@ class PCARelCal(_Preprocess): yfac: 1.5 calc_good_medianw: True lpf: - type: "low_pass_sine2" + type: "sine2" cutoff: 1 - width: 0.1 + trans_width: 0.1 trim_samps: 2000 save: True plot: @@ -1094,6 +1233,7 @@ class PCARelCal(_Preprocess): See :ref:`pca-background` for more details on the method. """ name = 'pca_relcal' + def __init__(self, step_cfgs): self.signal = step_cfgs.get('signal', 'signal') self.run = step_cfgs.get('pca_run', 'run1') @@ -1102,6 +1242,7 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) def calc_and_save(self, aman, proc_aman): + self.plot_signal = self.signal if self.calc_cfgs.get("lpf") is not None: filt = tod_ops.filters.get_lpf(self.calc_cfgs.get("lpf")) filt_tod = tod_ops.fourier_filter(aman, filt, signal_name='signal') @@ -1111,16 +1252,17 @@ def calc_and_save(self, aman, proc_aman): if self.calc_cfgs.get("trim_samps") is not None: trim = self.calc_cfgs["trim_samps"] - aman.restrict('samps', (aman.samps.offset + trim, - aman.samps.offset + aman.samps.count - trim)) proc_aman.restrict('samps', (proc_aman.samps.offset + trim, proc_aman.samps.offset + proc_aman.samps.count - trim)) filt_aman.restrict('samps', (filt_aman.samps.offset + trim, filt_aman.samps.offset + filt_aman.samps.count - trim)) + if self.plot_cfgs: + self.plot_signal = filt_aman[self.signal] bands = np.unique(aman.det_info.wafer.bandpass) bands = bands[bands != 'NC'] - rc_aman = core.AxisManager(aman.dets, aman.samps) + # align samps w/ proc_aman to include samps restriction when loading back from db. + rc_aman = core.AxisManager(proc_aman.dets, proc_aman.samps) pca_det_mask = np.full(aman.dets.count, False, dtype=bool) relcal = np.zeros(aman.dets.count) pca_weight0 = np.zeros(aman.dets.count) @@ -1168,7 +1310,7 @@ def select(self, meta, proc_aman=None): keep = ~proc_aman[self.run_name]['pca_det_mask'] meta.restrict("dets", meta.dets.vals[keep]) return meta - + def plot(self, aman, proc_aman, filename): if self.plot_cfgs is None: return @@ -1184,7 +1326,7 @@ def plot(self, aman, proc_aman, filename): for band in bands: pca_aman = aman.restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False) band_aman = proc_aman[self.run_name].restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False) - plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20)) + plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.plot_signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20)) class PTPFlags(_Preprocess): @@ -1211,13 +1353,13 @@ def calc_and_save(self, aman, proc_aman): ptp_aman = core.AxisManager(aman.dets, aman.samps) ptp_aman.wrap('ptp_flags', mskptps, [(0, 'dets'), (1, 'samps')]) self.save(proc_aman, ptp_aman) - - def save(self, proc_aman, dark_aman): + + def save(self, proc_aman, calc_aman): if self.save_cfgs is None: return if self.save_cfgs: - proc_aman.wrap("ptp_flags", dark_aman) - + proc_aman.wrap("ptp_flags", calc_aman) + def select(self, meta, proc_aman=None): if self.select_cfgs is None: return meta @@ -1251,13 +1393,13 @@ def calc_and_save(self, aman, proc_aman): ptp_aman = core.AxisManager(aman.dets, aman.samps) ptp_aman.wrap('inv_var_flags', mskptps, [(0, 'dets'), (1, 'samps')]) self.save(proc_aman, ptp_aman) - + def save(self, proc_aman, dark_aman): if self.save_cfgs is None: return if self.save_cfgs: proc_aman.wrap("inv_var_flags", dark_aman) - + def select(self, meta, proc_aman=None): if self.select_cfgs is None: return meta @@ -1266,7 +1408,7 @@ def select(self, meta, proc_aman=None): keep = ~has_all_cut(proc_aman.inv_var_flags.inv_var_flags) meta.restrict("dets", meta.dets.vals[keep]) return meta - + class EstimateT2P(_Preprocess): """Estimate T to P leakage coefficients. @@ -1294,7 +1436,7 @@ class EstimateT2P(_Preprocess): def calc_and_save(self, aman, proc_aman): t2p_aman = tod_ops.t2pleakage.get_t2p_coeffs(aman, **self.calc_cfgs) self.save(proc_aman, t2p_aman) - + def save(self, proc_aman, t2p_aman): if self.save_cfgs is None: return @@ -1318,6 +1460,7 @@ class SubtractT2P(_Preprocess): def process(self, aman, proc_aman): tod_ops.t2pleakage.subtract_t2p(aman, proc_aman['t2p'], **self.process_cfgs) + class SplitFlags(_Preprocess): """Get flags used for map splitting/bundling. @@ -1345,7 +1488,7 @@ class SplitFlags(_Preprocess): name = "split_flags" def calc_and_save(self, aman, proc_aman): - split_flg_aman = obs_ops.flags.get_split_flags(aman, proc_aman, split_cfg=self.calc_cfgs) + split_flg_aman = obs_ops.splits.get_split_flags(aman, proc_aman, split_cfg=self.calc_cfgs) self.save(proc_aman, split_flg_aman) @@ -1355,6 +1498,153 @@ def save(self, proc_aman, split_flg_aman): if self.save_cfgs: proc_aman.wrap("split_flags", split_flg_aman) +class UnionFlags(_Preprocess): + """Do the union of relevant flags for mapping + Typically you would include turnarounds, glitches, etc. + + Saves results for aman under the "flags.[total_flags_label]" field. + + Example config block:: + + - name : "union_flags" + process: + flag_labels: ['jumps_2pi.jump_flag', 'glitches.glitch_flags', 'turnaround_flags.turnarounds'] + total_flags_label: 'glitch_flags' + + """ + name = "union_flags" + + def process(self, aman, proc_aman): + from so3g.proj import RangesMatrix + total_flags = RangesMatrix.zeros([proc_aman.dets.count, proc_aman.samps.count]) # get an empty flags with shape (Ndets,Nsamps) + for label in self.process_cfgs['flag_labels']: + _label = attrgetter(label) + total_flags += _label(proc_aman) # The + operator is the union operator in this case + + if 'flags' not in aman._fields: + from sotodlib.core import FlagManager + aman.wrap('flags', FlagManager.for_tod(aman)) + if self.process_cfgs['total_flags_label'] in aman['flags']: + aman['flags'].move(self.process_cfgs['total_flags_label'], None) + aman['flags'].wrap(self.process_cfgs['total_flags_label'], total_flags) + +class RotateQU(_Preprocess): + """Rotate Q and U components to/from telescope coordinates. + + Example config block:: + + - name : "rotate_qu" + process: + sign: 1 + offset: 0 + update_focal_plane: True + + .. autofunction:: sotodlib.coords.demod.rotate_demodQU + """ + name = "rotate_qu" + + def process(self, aman, proc_aman): + from sotodlib.coords import demod + demod.rotate_demodQU(aman, **self.process_cfgs) + +class SubtractQUCommonMode(_Preprocess): + """Subtract Q and U common mode. + + Example config block:: + + - name : 'subtract_qu_common_mode' + signal_name_Q: 'demodQ' + signal_name_U: 'demodU' + process: True + calc: True + save: True + + .. autofunction:: sotodlib.tod_ops.deproject.subtract_qu_common_mode + """ + name = "subtract_qu_common_mode" + + def __init__(self, step_cfgs): + self.signal_name_Q = step_cfgs.get('signal_Q', 'demodQ') + self.signal_name_U = step_cfgs.get('signal_U', 'demodU') + super().__init__(step_cfgs) + + def calc_and_save(self, aman, proc_aman): + self.save(proc_aman, aman) + + def save(self, proc_aman, aman): + if self.save_cfgs is None: + return + if self.save_cfgs: + proc_aman.wrap('qu_common_mode_coeffs', aman['qu_common_mode_coeffs']) + + def process(self, aman, proc_aman): + if 'qu_common_mode_coeffs' in proc_aman: + tod_ops.deproject.subtract_qu_common_mode(aman, self.signal_name_Q, self.signal_name_U, + coeff_aman=proc_aman['qu_common_mode_coeffs'], + merge=False) + else: + tod_ops.deproject.subtract_qu_common_mode(aman, self.signal_name_Q, + self.signal_name_U, merge=True) + +class FocalplaneNanFlags(_Preprocess): + """Find additional detectors which have nans + in their focal plane coordinates. + + Saves results in proc_aman under the "fp_flags" field. + + Example config block:: + + - name : "fp_flags" + signal: "signal" # optional + calc: + merge: False + save: True + select: True + + .. autofunction:: sotodlib.tod_ops.flags.get_focalplane_flags + """ + name = "fp_flags" + + def calc_and_save(self, aman, proc_aman): + mskfp = tod_ops.flags.get_focalplane_flags(aman, **self.calc_cfgs) + fp_aman = core.AxisManager(aman.dets, aman.samps) + fp_aman.wrap('fp_nans', mskfp, [(0, 'dets'), (1, 'samps')]) + self.save(proc_aman, fp_aman) + + def save(self, proc_aman, fp_aman): + if self.save_cfgs is None: + return + if self.save_cfgs: + proc_aman.wrap("fp_flags", fp_aman) + + def select(self, meta, proc_aman=None): + if self.select_cfgs is None: + return meta + if proc_aman is None: + proc_aman = meta.preprocess + keep = ~has_all_cut(proc_aman.fp_flags.fp_nans) + meta.restrict("dets", meta.dets.vals[keep]) + return meta + +class PointingModel(_Preprocess): + """Apply pointing model to the TOD. + + Saves results in proc_aman under the "pointing" field. + + Example config block:: + + - name : "pointing_model" + process: True + + .. autofunction:: sotodlib.coords.pointing_model.apply_pointing_model + """ + name = "pointing_model" + + def process(self, aman, proc_aman): + from sotodlib.coords import pointing_model + if self.process_cfgs: + pointing_model.apply_pointing_model(aman) + _Preprocess.register(SplitFlags) _Preprocess.register(SubtractT2P) _Preprocess.register(EstimateT2P) @@ -1384,3 +1674,9 @@ def save(self, proc_aman, split_flg_aman): _Preprocess.register(DarkDets) _Preprocess.register(SourceFlags) _Preprocess.register(HWPAngleModel) +_Preprocess.register(GetStats) +_Preprocess.register(UnionFlags) +_Preprocess.register(RotateQU) +_Preprocess.register(SubtractQUCommonMode) +_Preprocess.register(FocalplaneNanFlags) +_Preprocess.register(PointingModel) diff --git a/sotodlib/site_pipeline/check_book.py b/sotodlib/site_pipeline/check_book.py index d24c55479..5b6783d75 100644 --- a/sotodlib/site_pipeline/check_book.py +++ b/sotodlib/site_pipeline/check_book.py @@ -104,16 +104,19 @@ def get_parser(parser=None): return parser -def main(book_dir, config=None, add=None, overwrite=None): - logger = util.init_logger(__name__, 'check_book: ') +def scan_book_dir(book_dir, logger, config, prep_obsfiledb=False): + """Run the BookScanner on book_dir. - logger.info(f'Examining {book_dir}') + Returns: + ok : bool + True only if the book passed the checks. + obsfiledb_info : dict or None + If prep_obsfiledb, and ok, then this dict has the info needed + for updating obsfiledb (see add_to_obsfiledb). - if config is not None: - logger.debug(f'Loading config from {config}') - config = yaml.safe_load(open(config, 'rb')) - else: - config = {} + (Note this function is used by update-obsdb, as well.) + """ + logger.info(f'Examining {book_dir}') bs = check_book.BookScanner(book_dir, config) bs.go() @@ -121,13 +124,24 @@ def main(book_dir, config=None, add=None, overwrite=None): if len(bs.results['errors']): logger.error('Cannot register this obs due to errors.') - sys.exit(1) - - if not add: - sys.exit(0) - - detset_rows, file_rows = bs.prep_obsfiledb(config.get('root_path', '/')) - + return False, None + + if prep_obsfiledb: + detset_rows, file_rows = bs.prep_obsfiledb(config.get('root_path', '/')) + return True, { + 'detset_rows': detset_rows, + 'file_rows': file_rows, + } + return True, None + +def add_to_obsfiledb(info, logger, config, overwrite=False): + """Add observation info to the obsfiledb specified in config. The + "info" dict is the one returned by scan_book_dir. If overwrite, + then file entries in the obsfiledb, for this obs_id, will be + replaced. + + (Note this function is used by update-obsdb, as well.) + """ # Write to obsfiledb obsfiledb_file = config.get('obsfiledb', 'obsfiledb.sqlite') logger.debug('Updating %s ...' % obsfiledb_file) @@ -137,18 +151,36 @@ def main(book_dir, config=None, add=None, overwrite=None): # Note this only drops the obs ... if detsets need to be # rewritten, you'd better start over entirely. logger.debug(' -- removing any existing references.') - db.drop_obs(file_rows[0]['obs_id']) + db.drop_obs(info['file_rows'][0]['obs_id']) logger.debug( - ' -- adding %i detsets and %i file refs' % (len(detset_rows), - len(file_rows)) + ' -- adding %i detsets and %i file refs' % (len(info['detset_rows']), + len(info['file_rows'])) ) - for name, dets in detset_rows: + for name, dets in info['detset_rows']: if len(db.get_dets(name)) == 0: db.add_detset(name, dets) - for row in file_rows: + for row in info['file_rows']: db.add_obsfile(**row) +def main(book_dir, config=None, add=None, overwrite=None): + logger = util.init_logger(__name__, 'check_book: ') + + if config is not None: + logger.debug(f'Loading config from {config}') + config = yaml.safe_load(open(config, 'rb')) + else: + config = {} + + ok, info = scan_book_dir(book_dir, logger, config, prep_obsfiledb=add) + if not ok: + logger.error('Cannot register this obs due to errors.') + sys.exit(1) + + if add: + add_to_obsfiledb(info, logger, config, overwrite=overwrite) + + if __name__ == '__main__': util.main_launcher(main, get_parser) diff --git a/sotodlib/site_pipeline/cleanup_level2.py b/sotodlib/site_pipeline/cleanup_level2.py index e71f916e7..16c3947e9 100644 --- a/sotodlib/site_pipeline/cleanup_level2.py +++ b/sotodlib/site_pipeline/cleanup_level2.py @@ -1,52 +1,221 @@ +import numpy as np import datetime as dt from typing import Optional import argparse from sotodlib.io.imprinter import Imprinter +from sotodlib.io.datapkg_completion import DataPackaging +from sotodlib.site_pipeline.util import init_logger + +logger = init_logger(__name__, "cleanup_level2: ") + +def level2_completion( + dpk: DataPackaging, + lag: Optional[float] = 14, + min_timecode: Optional[int] = None, + max_timecode: Optional[int] = None, + raise_incomplete: Optional[bool] = True, +): + + ## build time range where we require timecodes to be complete + if min_timecode is None: + min_timecode = dpk.get_first_timecode_on_disk() + if max_timecode is None: + x = dt.datetime.now() - dt.timedelta(days=lag) + max_timecode = int( x.timestamp() // 1e5) + + logger.info( + f"Checking Timecode completion from {min_timecode} to " + f"{max_timecode}." + ) + + check_list = [] + for timecode in range(min_timecode, max_timecode): + check = dpk.make_timecode_complete(timecode) + if not check[0]: + check_list.append( (timecode, check[1]) ) + continue + check = dpk.verify_timecode_deletable( + timecode, include_hk=True, + verify_with_librarian=False, + ) + if not check[0]: + check_list.append( (timecode, check[1]) ) + + if len( check_list ) > 0 and raise_incomplete: + raise ValueError( + f"Data Packaging cannot be completed for {check_list}" + ) + +def do_delete_level2( + dpk: DataPackaging, + lag: Optional[float] = 28, + min_timecode: Optional[int] = None, + max_timecode: Optional[int] = None, + raise_incomplete: Optional[bool] =True, +): + ## build time range where we should be deleting + if min_timecode is None: + min_timecode = dpk.get_first_timecode_on_disk() + + if max_timecode is None: + x = dt.datetime.now() - dt.timedelta(days=lag) + max_timecode = int( x.timestamp() // 1e5) + + logger.info( + f"Removing Level 2 data from {min_timecode} to " + f"{max_timecode}." + ) + delete_list = [] + for timecode in range(min_timecode, max_timecode): + check = dpk.check_and_delete_timecode(timecode) + if not check[0]: + logger.error(f"Failed to remove level 2 for {timecode}") + delete_list.append( (timecode, check[1])) + continue + if len( delete_list ) > 0 and raise_incomplete: + raise ValueError( + f"Level 2 Deletion not finished for {delete_list}" + ) + +def do_delete_staged( + dpk: DataPackaging, + lag: Optional[float] = 14, + min_timecode: Optional[int] = None, + max_timecode: Optional[int] = None, + raise_incomplete: Optional[bool] =True, +): + ## build time range where we should be deleting + if min_timecode is None: + min_timecode = dpk.get_first_timecode_in_staged() + + if max_timecode is None: + x = dt.datetime.now() - dt.timedelta(days=lag) + max_timecode = int( x.timestamp() // 1e5) + + logger.info( + f"Removing staged from {min_timecode} to " + f"{max_timecode}." + ) + delete_list = [] + for timecode in range(min_timecode, max_timecode): + check = dpk.make_timecode_complete(timecode) + if not check[0]: + delete_list.append( (timecode, check[1]) ) + continue + check = dpk.verify_timecode_deletable( + timecode, include_hk=True, + verify_with_librarian=False, + ) + if not check[0]: + delete_list.append( (timecode, check[1]) ) + continue + check = dpk.delete_timecode_staged(timecode) + if not check[0]: + logger.error(f"Failed to remove staged for {timecode}") + delete_list.append( (timecode, check[1])) + continue + if len( delete_list ) > 0 and raise_incomplete: + raise ValueError( + f"Staged Deletion not finished for {delete_list}" + ) def main( - config: str, - cleanup_delay: float = 7, - max_ctime: Optional[float] = None, - dry_run: Optional[bool] = False, + platform: str, + check_complete: Optional[bool]= False, + delete_staged: Optional[bool] = False, + delete_lvl2: Optional[bool]= False, + completion_lag: Optional[float] = 14, + min_complete_timecode: Optional[int] = None, + max_complete_timecode: Optional[int] = None, + staged_deletion_lag: Optional[float] = 14, + min_staged_delete_timecode: Optional[int] = None, + max_staged_delete_timecode: Optional[int] = None, + lvl2_deletion_lag: Optional[float] = 28, + min_lvl2_delete_timecode: Optional[int] = None, + max_lvl2_delete_timecode: Optional[int] = None, ): """ Use the imprinter database to clean up already bound level 2 files. Parameters ---------- - config : str - Path to config file for imprinter - cleanup_delay : float, optional - The amount of time to delay book deletion in units of days, by default 1 - max_ctime : Optional[datetime], optional - The maximum datetime to delete level 2 data. Overrides cleanup_delay. + platform : str + platform we're running for + completion_lag : float, optional + The number of days in the past where we expect data packaging to be + fully complete. + min_complete_timecode : Optional[datetime], optional + The lowest timecode to run completion checking. over-rides the "start + from beginning" behavior. + max_complete_timecode : Optional[datetime], optional + The highest timecode to run completion checking. over-rides the + completion_lag calculated value. dry_run : Optional[bool], If true, only prints deletion to logger """ + dpk = DataPackaging(platform) - if max_ctime is not None: - max_time = dt.datetime.utcfromtimestamp(max_ctime) - else: - max_time = None + if check_complete: + level2_completion( + dpk, completion_lag, + min_complete_timecode, max_complete_timecode, + ) + + if delete_staged: + do_delete_staged( + dpk, staged_deletion_lag, + min_staged_delete_timecode, max_staged_delete_timecode + ) + + if delete_lvl2: + do_delete_level2( + dpk, lvl2_deletion_lag, + min_lvl2_delete_timecode, max_lvl2_delete_timecode, + ) - imprinter = Imprinter(config, db_args={'connect_args': {'check_same_thread': False}}) - book_list = imprinter.get_level2_deleteable_books(max_time=max_time, cleanup_delay=cleanup_delay) - for book in book_list: - imprinter.delete_level2_files(book, dry_run=dry_run) def get_parser(parser=None): if parser is None: parser = argparse.ArgumentParser() - parser.add_argument('config', type=str, help="Config file for Imprinter") - parser.add_argument('--cleanup-delay', type=float, default=7, - help="Days to keep level 2 data before cleaning") - parser.add_argument('--max-ctime', type=float, - help="Maximum ctime to delete to, overrides cleanup_delay ONLY if its an earlier time") - parser.add_argument('--dry-run', action="store_true", - help="if passed, only prints delete behavior") + + parser.add_argument('platform', type=str, help="Platform for Imprinter") + parser.add_argument('--check-complete', action="store_true", + help="If passed, run completion check") + parser.add_argument('--delete-lvl2', action="store_true", + help="If passed, delete lvl2 raw data") + parser.add_argument('--delete-staged', action="store_true", + help="If passed, delete lvl2 staged data") + + parser.add_argument('--completion-lag', type=float, default=14, + help="Buffer days before we start failing completion") + parser.add_argument('--min-complete-timecode', type=int, + help="Minimum timecode to start completion check. Overrides starting " + "from the beginning") + parser.add_argument('--max-complete-timecode', type=int, + help="Maximum timecode to stop completion check. Overrides the " + "completion-lag setting") + + parser.add_argument('--lvl2-deletion-lag', type=float, default=28, + help="Buffer days before we start deleting level 2 raw data") + parser.add_argument('--min-lvl2-delete-timecode', type=int, + help="Minimum timecode to start level 2 raw data deletion. Overrides " + "starting from the beginning") + parser.add_argument('--max-lvl2-delete-timecode', type=int, + help="Maximum timecode to stop level 2 raw data deletion. Overrides the" + " lvl2-deletion-lag setting") + + parser.add_argument('--staged-deletion-lag', type=float, default=28, + help="Buffer days before we start deleting level 2 staged data") + parser.add_argument('--min-staged-delete-timecode', type=int, + help="Minimum timecode to start level 2 staged data deletion. Overrides" + " starting from the beginning") + parser.add_argument('--max-staged-delete-timecode', type=int, + help="Maximum timecode to stop level 2 staged data deletion. Overrides" + " the lvl2-deletion-lag setting") + return parser if __name__ == "__main__": diff --git a/sotodlib/site_pipeline/make_atomic_filterbin_map.py b/sotodlib/site_pipeline/make_atomic_filterbin_map.py index becf89db8..193c6e7ff 100644 --- a/sotodlib/site_pipeline/make_atomic_filterbin_map.py +++ b/sotodlib/site_pipeline/make_atomic_filterbin_map.py @@ -1,4 +1,6 @@ from argparse import ArgumentParser +from typing import Optional +from dataclasses import dataclass import time import warnings import os @@ -9,6 +11,7 @@ import sqlite3 import numpy as np import so3g +import sotodlib.site_pipeline.util as util from sotodlib import coords, mapmaking from sotodlib.core import Context from sotodlib.io import hk_utils @@ -16,115 +19,149 @@ from pixell import enmap, utils as putils, bunch from pixell import wcsutils, colors, memory, mpi from concurrent.futures import ProcessPoolExecutor, as_completed -import sotodlib.site_pipeline.util as util -defaults = { - "area": None, - "nside": None, - "query": "type == 'obs' and subtype == 'cmb'", - "odir": "./output", - "update_delay": None, - "comps": "TQU", - "nproc": 1, - "ntod": None, - "tods": None, - "nset": None, - "wafer": None, - "freq": None, - "center_at": None, - "dec_ref": -40.0, - "site": 'so_sat3', - "max_dets": None, # not implemented yet - "verbose": 0, - "quiet": 0, - "tiled": 0, # not implemented yet - "singlestream": False, - "only_hits": False, - "det_in_out": False, - "det_left_right": False, - "det_upper_lower": False, - "scan_left_right": False, - "window": 0.0, # not implemented yet - "dtype_tod": 'float32', - "dtype_map": 'float64', - "atomic_db": "atomic_maps.db", - "fixed_time": None, - "min_dur": None, - "hk_data_path": None, - } - - -def get_parser(parser=None): - if parser is None: - parser = ArgumentParser() - parser.add_argument("--config_file", type=str, default=None, - help="Path to mapmaker config.yaml file") - parser.add_argument("--context", type=str, help='Path to context file') - parser.add_argument("--preprocess_config", type=str, - help='Path to config file to run the\ - preprocessing pipeline') - parser.add_argument("--area", - help='WCS kernel for rectangular pixels') - parser.add_argument("--nside", - help='Nside for HEALPIX pixels') - parser.add_argument("--query", - help='Query, can be a file (list of obs_id)\ - or selection string (will select only CMB scans\ - by default)') - parser.add_argument("--odir", - help='Output directory') - parser.add_argument('--update_delay', type=int, - help="Number of days (unit is days) in the past\ - to start observation list") - parser.add_argument("--nproc", type=int, - help='Number of procs in\ - the multiprocessing pool') - parser.add_argument("--comps", type=str, - help="Components to map (TQU by default)") - parser.add_argument("--singlestream", action="store_true", - help="Map without demodulation (e.g. with\ - a static HWP)") - parser.add_argument("--only_hits", action="store_true", - help='Only create a hits map') - - # detector position splits (fixed in time) - parser.add_argument("--det_in_out", action="store_true") - parser.add_argument("--det_left_right", action="store_true") - parser.add_argument("--det_upper_lower", action="store_true") - - # time samples splits - parser.add_argument("--scan_left_right", action="store_true") - - parser.add_argument("--ntod", type=int, ) - parser.add_argument("--tods", type=str, ) - parser.add_argument("--nset", type=int, ) - parser.add_argument("--wafer", type=str, - help="Detector set to map with") - parser.add_argument("--freq", type=str, - help="Frequency band to map with") - parser.add_argument("--dec_ref", type=float, - help="Decl. at which we will calculate the\ - reference R.A.") - parser.add_argument("--center_at", type=str) - parser.add_argument("--max_dets", type=int, ) - parser.add_argument("--fixed_ftime", type=int, ) - parser.add_argument("--min_dur", type=int, ) - parser.add_argument("--site", type=str, ) - parser.add_argument("--verbose", action="count", ) - parser.add_argument("--quiet", action="count", ) - parser.add_argument("--window", type=float, ) - parser.add_argument("--dtype_tod", type=str) - parser.add_argument("--dtype_map", type=str) - parser.add_argument("--atomic_db", type=str, - help='name of the atomic map database, will be\ - saved where this script is being run') - parser.add_argument("--hk_data_path", - help='Path to housekeeping data') - return parser - - -def _get_config(config_file): - return yaml.safe_load(open(config_file, 'r')) +@dataclass +class Cfg: + """ + Class to configure make-atomic-filterbin-map + + Args + -------- + context: str + Path to context file + preprocess_config: str + Path to config file(s) to run the preprocessing pipeline. + If 2 files, representing 2 layers of preprocessing, they + should be separated by a comma. + area: str + WCS kernel for rectangular pixels + nside: int + Nside for HEALPIX pixels + query: str + Query, can be a file (list of obs_id) or selection string (will select only CMB scans by default) + odir: str + Output directory + update_delay: float + Number of days in the past to start obs list + site: str + hk_data_path: str + Path to housekeeping data + nproc: int + Number of procs for the multiprocessing pool + atomic_db: str + Path to the atomic map database + comps: str + Components to map, only TQU implemented + singlestream: bool + Map without demodulation (e.g. with a static HWP) + only_hits: bool + Only create a hits map + all_splits: bool + If True, map all implemented splits + det_in_out: bool + Make focal plane split: inner vs outer detector + det_left_right: bool + Make focal plane split: left vs right detector + det_upper_lower: bool + Make focal plane split: upper vs lower detector + scan_left_right: bool + Make samples split: left-going vs right going scans + ntod: int + Run the first ntod observations in your query + tods: str + Run a specific obs + nset: int + Run the first nset wafers + wafer: str + Run a specific wafer + freq: str + Run a specific frequency band + center_at: str + max_dets: int + fixed_time: int + min_dur: int + verbose: int + quiet: int + window: float + dtype_tod: str + Data type for timestreams + dtype_map: str + Data type for maps + """ + def __init__( + self, + context: str, + preprocess_config: str, + area: Optional[str] = None, + nside: Optional[int] = None, + query: str = "type == 'obs' and subtype == 'cmb'", + odir: str = "./output", + update_delay: Optional[float] = None, + site: str = 'so_sat3', + hk_data_path: Optional[str] = None, + nproc: int = 1, + atomic_db: Optional[str] = None, + comps: str = 'TQU', + singlestream: bool = False, + only_hits: bool = False, + all_splits: bool = False, + det_in_out: bool = False, + det_left_right: bool = False, + det_upper_lower: bool = False, + scan_left_right: bool = False, + ntod: Optional[int] = None, + tods: Optional[str] = None, + nset: Optional[int] = None, + wafer: Optional[str] = None, + freq: Optional[str] = None, + center_at: Optional[str] = None, + max_dets: Optional[int] = None, + fixed_time: Optional[int] = None, + min_dur: Optional[int] = None, + verbose: int = 0, + quiet: int = 0, + window: Optional[float] = None, + dtype_tod: str = 'float32', + dtype_map: str = 'float64' + ) -> None: + self.context = context + self.preprocess_config = preprocess_config + self.area = area + self.nside = nside + self.query = query + self.odir = odir + self.update_delay = update_delay + self.site = site + self.hk_data_path = hk_data_path + self.nproc = nproc + self.atomic_db = atomic_db + self.comps = comps + self.singlestream = singlestream + self.only_hits = only_hits + self.all_splits = all_splits + self.det_in_out = det_in_out + self.det_left_right = det_left_right + self.det_upper_lower = det_upper_lower + self.scan_left_right = scan_left_right + self.ntod = ntod + self.tods = tods + self.nset = nset + self.wafer = wafer + self.freq = freq + self.center_at = center_at + self.max_dets = max_dets + self.fixed_time = fixed_time + self.min_dur = min_dur + self.verbose = verbose + self.quiet = quiet + self.window = window + self.dtype_tod = dtype_tod + self.dtype_map = dtype_map + @classmethod + def from_yaml(cls, path) -> "Cfg": + with open(path, "r") as f: + d = yaml.safe_load(f) + return cls(**d) class DataMissing(Exception): @@ -146,8 +183,7 @@ def get_pwv(obs, data_dir): def read_tods(context, obslist, dtype_tod=np.float32, only_hits=False, site='so_sat3', - l2_data=None, - dec_ref=None): + l2_data=None): context = Context(context) # this function will run on multiprocessing and can be returned in any # random order we will also return the obslist to keep track of the order @@ -179,11 +215,11 @@ def read_tods(context, obslist, class ColoredFormatter(logging.Formatter): - def __init__(self, msg, colors={'DEBUG': colors.reset, - 'INFO': colors.lgreen, - 'WARNING': colors.lbrown, - 'ERROR': colors.lred, - 'CRITICAL': colors.lpurple}): + def __init__(self, msg, colors={'DEBUG':colors.reset, + 'INFO':colors.lgreen, + 'WARNING':colors.lbrown, + 'ERROR':colors.lred, + 'CRITICAL':colors.lpurple}): logging.Formatter.__init__(self, msg) self.colors = colors @@ -194,7 +230,6 @@ def format(self, record): col = colors.reset return col + logging.Formatter.format(self, record) + colors.reset - class LogInfoFilter(logging.Filter): def __init__(self, rank=0): self.rank = rank @@ -225,8 +260,9 @@ def future_write_to_log(e, errlog): f.write(f'\n{time.time()}, future.result() error\n{errmsg}\n{tb}\n') f.close() +def main(config_file: str) -> None: + args = Cfg.from_yaml(config_file) -def main(config_file=None, defaults=defaults, **args): # Set up logging. L = logging.getLogger(__name__) L.setLevel(logging.INFO) @@ -237,57 +273,48 @@ def main(config_file=None, defaults=defaults, **args): ch.addFilter(LogInfoFilter()) L.addHandler(ch) - cfg = dict(defaults) - # Update the default dict with values provided from a config.yaml file - if config_file is not None: - cfg_from_file = _get_config(config_file) - cfg.update({k: v for k, v in cfg_from_file.items() if v is not None}) - else: - L.error("No config file provided, assuming default values") - # Merge flags from config file and defaults with any passed through CLI - cfg.update({k: v for k, v in args.items() if v is not None}) - # Certain fields are required. Check if they are all supplied here - required_fields = ['context', 'preprocess_config'] - for req in required_fields: - if req not in cfg.keys(): - raise KeyError("{} is a required argument. Please supply it in a\ - config file or via the command line".format(req)) - args = cfg - warnings.simplefilter('ignore') comm = mpi.FAKE_WORLD # Fake communicator since we won't use MPI - verbose = args['verbose'] - args['quiet'] - if args['area'] is not None: - shape, wcs = enmap.read_map_geometry(args['area']) + verbose = args.verbose - args.quiet + if args.area is not None: + shape, wcs = enmap.read_map_geometry(args.area) wcs = wcsutils.WCS(wcs.to_header()) - elif args['nside'] is not None: + elif args.nside is not None: pass # here I will map in healpix else: L.error('Neither rectangular area or nside specified, exiting.') exit(1) noise_model = mapmaking.NmatWhite() - putils.mkdir(args['odir']) + putils.mkdir(args.odir) recenter = None - if args['center_at']: - recenter = mapmaking.parse_recentering(args['center_at']) - preprocess_config = yaml.safe_load(open(args['preprocess_config'], 'r')) - errlog = os.path.join(os.path.dirname( - preprocess_config['archive']['index']), 'errlog.txt') + if args.center_at: + recenter = mapmaking.parse_recentering(args.center_at) + preprocess_config_str = [s.strip() for s in args.preprocess_config.split(",")] + preprocess_config = [] ; errlog = [] + for preproc_cf in preprocess_config_str: + preproc_local = yaml.safe_load(open(preproc_cf, 'r')) + preprocess_config.append( preproc_local ) + errlog.append( os.path.join(os.path.dirname( + preproc_local['archive']['index']), 'errlog.txt') ) multiprocessing.set_start_method('spawn') - if (args['update_delay'] is not None): - min_ctime = int(time.time()) - args['update_delay']*86400 - args['query'] += f" and timestamp>={min_ctime}" + if (args.update_delay is not None): + min_ctime = int(time.time()) - args.update_delay*86400 + args.query += f" and timestamp>={min_ctime}" + + # Check for map data type + if args.dtype_map == 'float32' or args.dtype_map == 'single': + warnings.warn("You are using single precision for maps, we advice to use double precision") - context = Context(args['context']) + context_obj = Context(args.context) # obslists is a dict, obskeys is a list, periods is an array, only rank 0 # will do this and broadcast to others. try: obslists, obskeys, periods, obs_infos = mapmaking.build_obslists( - context, args['query'], nset=args['nset'], wafer=args['wafer'], - freq=args['freq'], ntod=args['ntod'], tods=args['tods'], - fixed_time=args['fixed_time'], mindur=args['min_dur']) + context_obj, args.query, nset=args.nset, wafer=args.wafer, + freq=args.freq, ntod=args.ntod, tods=args.tods, + fixed_time=args.fixed_time, mindur=args.min_dur) except mapmaking.NoTODFound as err: L.exception(err) exit(1) @@ -295,16 +322,18 @@ def main(config_file=None, defaults=defaults, **args): cwd = os.getcwd() split_labels = [] - if args['det_in_out']: + if args.all_splits: + raise ValueError('all_splits not implemented yet') + if args.det_in_out: split_labels.append('det_in') split_labels.append('det_out') - if args['det_left_right']: + if args.det_left_right: split_labels.append('det_left') split_labels.append('det_right') - if args['det_upper_lower']: + if args.det_upper_lower: split_labels.append('det_upper') split_labels.append('det_lower') - if args['scan_left_right']: + if args.scan_left_right: split_labels.append('scan_left') split_labels.append('scan_right') if not split_labels: @@ -312,58 +341,61 @@ def main(config_file=None, defaults=defaults, **args): # We open the data base for checking if we have maps already, # if we do we will not run them again. - if os.path.isfile(args['atomic_db']) and not args['only_hits']: - # open the connector, in reading mode only - conn = sqlite3.connect('./'+args['atomic_db']) - cursor = conn.cursor() - keys_to_remove = [] - # Now we have obslists and splits ready, we look through the database - # to remove the maps we already have from it - for key, value in obslists.items(): - missing_split = False - for split_label in split_labels: - query_ = 'SELECT * from atomic where obs_id="%s" and\ - telescope="%s" and freq_channel="%s" and wafer="%s" and\ - split_label="%s"' % ( - value[0][0], obs_infos[value[0][3]].telescope, key[2], - key[1], split_label) - res = cursor.execute(query_) - matches = res.fetchall() - if len(matches) == 0: - # this means one of the requested splits is missing - # in the data base - missing_split = True - break - if missing_split is False: - # this means we have all the splits we requested for the - # particular obs_id/telescope/freq/wafer - keys_to_remove.append(key) - for key in keys_to_remove: - obskeys.remove(key) - del obslists[key] - conn.close() # I close since I only wanted to read + if isinstance(args.atomic_db, str): + if os.path.isfile(args.atomic_db) and not args.only_hits: + # open the connector, in reading mode only + conn = sqlite3.connect(args.atomic_db) + cursor = conn.cursor() + keys_to_remove = [] + # Now we have obslists and splits ready, we look through the database + # to remove the maps we already have from it + for key, value in obslists.items(): + missing_split = False + for split_label in split_labels: + query_ = 'SELECT * from atomic where obs_id="%s" and\ + telescope="%s" and freq_channel="%s" and wafer="%s" and\ + split_label="%s"' % ( + value[0][0], obs_infos[value[0][3]].telescope, key[2], + key[1], split_label) + res = cursor.execute(query_) + matches = res.fetchall() + if len(matches) == 0: + # this means one of the requested splits is missing + # in the data base + missing_split = True + break + if missing_split is False: + # this means we have all the splits we requested for the + # particular obs_id/telescope/freq/wafer + keys_to_remove.append(key) + for key in keys_to_remove: + obskeys.remove(key) + del obslists[key] + conn.close() # I close since I only wanted to read obslists_arr = [item for key, item in obslists.items()] tod_list = [] # this list will receive the outputs from read_tods L.info('Starting with read_tods') - with ProcessPoolExecutor(args['nproc']) as exe: + + with ProcessPoolExecutor(args.nproc) as exe: futures = [exe.submit( - read_tods, args['context'], obslist, dtype_tod=args['dtype_tod'], - only_hits=args['only_hits'], l2_data=args['hk_data_path'], - site=args['site'], dec_ref=args['dec_ref']) + read_tods, args.context, obslist, dtype_tod=args.dtype_tod, + only_hits=args.only_hits, l2_data=args.hk_data_path, + site=args.site) for obslist in obslists_arr] for future in as_completed(futures): try: tod_list.append(future.result()) except Exception as e: - future_write_to_log(e, errlog) + # if read_tods fails for some reason we log into the first preproc DB + future_write_to_log(e, errlog[0]) continue futures.remove(future) # flatten the list of lists L.info('Done with read_tods') my_tods = [bb.my_tods for bb in tod_list] - if args['area'] is not None: + if args.area is not None: subgeoms = [] for obs in my_tods: if recenter is None: @@ -373,6 +405,19 @@ def main(config_file=None, defaults=defaults, **args): subshape = shape subwcs = wcs subgeoms.append((subshape, subwcs)) + + # clean up lingering files from previous incomplete runs + for obs in obslists_arr: + obs_id = obs[0][0] + if len(preprocess_config)==1: + preprocess_util.save_group_and_cleanup(obs_id, preprocess_config[0], + subdir='temp', remove=False) + else: + preprocess_util.save_group_and_cleanup(obs_id, preprocess_config[0], + subdir='temp', remove=False) + preprocess_util.save_group_and_cleanup(obs_id, preprocess_config[1], + subdir='temp_proc', remove=False) + run_list = [] for oi in range(len(my_tods)): # tod_list[oi].obslist[0] is the old obslist @@ -383,8 +428,8 @@ def main(config_file=None, defaults=defaults, **args): t = putils.floor(periods[pid, 0]) t5 = ("%05d" % t)[:5] prefix = "%s/%s/atomic_%010d_%s_%s" % ( - args['odir'], t5, t, detset, band) - if args['area'] is not None: + args.odir, t5, t, detset, band) + if args.area is not None: subshape, subwcs = subgeoms[oi] tag = "%5d/%d" % (oi+1, len(obskeys)) @@ -394,7 +439,7 @@ def main(config_file=None, defaults=defaults, **args): # We will write an individual file, # another script will loop over those files # and write into sqlite data base - if not args['only_hits']: + if not args.only_hits: info = [] for split_label in split_labels: info.append(bunch.Bunch( @@ -412,25 +457,26 @@ def main(config_file=None, defaults=defaults, **args): azimuth=obs_infos[obslist[0][3]].az_center, pwv=float(pwv_atomic))) # inputs that are unique per atomic map go into run_list - if args['area'] is not None: + if args.area is not None: run_list.append([obslist, subshape, subwcs, info, prefix, t]) - elif args['nside'] is not None: + elif args.nside is not None: run_list.append([obslist, None, None, info, prefix, t]) # Done with creating run_list - with ProcessPoolExecutor(args['nproc']) as exe: + + with ProcessPoolExecutor(args.nproc) as exe: futures = [exe.submit( - mapmaking.make_demod_map, args['context'], r[0], + mapmaking.make_demod_map, args.context, r[0], noise_model, r[3], preprocess_config, r[4], - shape=r[1], wcs=r[2], nside=args['nside'], + shape=r[1], wcs=r[2], nside=args.nside, comm=comm, t0=r[5], tag=tag, recenter=recenter, - dtype_map=args['dtype_map'], - dtype_tod=args['dtype_tod'], - comps=args['comps'], + dtype_map=args.dtype_map, + dtype_tod=args.dtype_tod, + comps=args.comps, verbose=verbose, split_labels=split_labels, - singlestream=args['singlestream'], - site=args['site']) for r in run_list] + singlestream=args.singlestream, + site=args.site) for r in run_list] for future in as_completed(futures): L.info('New future as_completed result') try: @@ -440,11 +486,22 @@ def main(config_file=None, defaults=defaults, **args): continue futures.remove(future) for ii in range(len(errors)): - preprocess_util.cleanup_mandb(errors[ii], outputs[ii], - preprocess_config, L) + for idx_prepoc in range(len(preprocess_config)): + preprocess_util.cleanup_mandb(errors[ii], outputs[ii][idx_prepoc], + preprocess_config[idx_prepoc], L) L.info("Done") return True +def get_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser: + if parser is None: + p = ArgumentParser() + else: + p = parser + p.add_argument( + "--config_file", type=str, help="yaml file with configuration." + ) + return p + if __name__ == '__main__': util.main_launcher(main, get_parser) diff --git a/sotodlib/site_pipeline/multilayer_preprocess_tod.py b/sotodlib/site_pipeline/multilayer_preprocess_tod.py new file mode 100644 index 000000000..4017d1642 --- /dev/null +++ b/sotodlib/site_pipeline/multilayer_preprocess_tod.py @@ -0,0 +1,367 @@ +import os +import yaml +import time +import logging +import numpy as np +import argparse +import traceback +from typing import Optional +import multiprocessing +from concurrent.futures import ProcessPoolExecutor, as_completed +import h5py +import copy +from sotodlib.coords import demod as demod_mm +from sotodlib.hwp import hwp_angle_model +from sotodlib import core +from sotodlib.preprocess import _Preprocess, Pipeline, processes +import sotodlib.preprocess.preprocess_util as pp_util +import sotodlib.site_pipeline.util as sp_util + +logger = pp_util.init_logger("preprocess") + +def multilayer_preprocess_tod(obs_id, + configs_init, + configs_proc, + verbosity=0, + group_list=None, + overwrite=False, + run_parallel=False): + """Meant to be run as part of a batched script. Given a single + Observation ID, this function uses an existing ManifestDb generated + from a previous runof the processing pipeline, runs the pipeline using + a second config, and outputs a new ManifestDb. Det groups must exist in + the first dB to be included in the pipeline run on the second config. + + Arguments + ---------- + obs_id: string or ResultSet entry + obs_id or obs entry that is passed to context.get_obs + configs_init: string or dictionary + config file or loaded config directory for existing database + configs_proc: string or dictionary + config file or loaded config directory for processing database + to be output + group_list: None or list + list of groups to run if you only want to run a partial update + overwrite: bool + if True, overwrite existing entries in ManifestDb + verbosity: log level + 0 = error, 1 = warn, 2 = info, 3 = debug + run_parallel: Bool + If true preprocess_tod is called in a parallel process which returns + dB info and errors and does no sqlite writing inside the function. + """ + + logger = pp_util.init_logger("preprocess", verbosity=verbosity) + + # list to hold error, destination file, and db data + outputs_init = [] + outputs_proc = [] + + if type(configs_init) == str: + configs_init = yaml.safe_load(open(configs_init, "r")) + context_init = core.Context(configs_init["context_file"]) + + if type(configs_proc) == str: + configs_proc = yaml.safe_load(open(configs_proc, "r")) + context_proc = core.Context(configs_proc["context_file"]) + + group_by_init, groups_init, error_init = pp_util.get_groups(obs_id, configs_init, context_init) + group_by_proc, groups_proc, error_proc = pp_util.get_groups(obs_id, configs_proc, context_proc) + + if error_init is not None: + if run_parallel: + return error_init[0], [None, None], [None, None] + else: + return + + if error_proc is not None: + if run_parallel: + return error_proc[0], [None, None], [None, None] + else: + return + + if len(groups_init) > 0 and len(groups_proc) > 0: + if (group_by_init != group_by_proc).any(): + raise ValueError('init and proc groups do not match') + + all_groups_proc = groups_proc.copy() + for g in all_groups_proc: + if g not in groups_init: + groups_proc.remove(g) + continue + if group_list is not None: + if g not in group_list: + groups_proc.remove(g) + continue + if 'wafer.bandpass' in group_by_proc: + if 'NC' in g: + groups_proc.remove(g) + continue + try: + meta = context_proc.get_meta(obs_id, dets = {gb:gg for gb, gg in zip(group_by_proc, g)}) + except Exception as e: + errmsg = f'{type(e)}: {e}' + tb = ''.join(traceback.format_tb(e.__traceback__)) + logger.info(f"ERROR: {obs_id} {g}\n{errmsg}\n{tb}") + groups_proc.remove(g) + continue + + if meta.dets.count == 0: + groups_proc.remove(g) + + if len(groups_proc) == 0: + logger.warning(f"group_list:{group_list} contains no overlap with " + f"groups in observation: {obs_id}:{all_groups_proc}. " + f"No analysis to run.") + error = 'no_group_overlap' + if run_parallel: + return error, [None, None], [None, None] + else: + return + + if not(run_parallel): + db_init = pp_util.get_preprocess_db(configs_init, group_by_proc) + db_proc = pp_util.get_preprocess_db(configs_proc, group_by_proc) + + # pipeline for init config + pipe_init = Pipeline(configs_init["process_pipe"], + plot_dir=configs_init["plot_dir"], logger=logger) + # pipeline for processing config + pipe_proc = Pipeline(configs_proc["process_pipe"], + plot_dir=configs_proc["plot_dir"], logger=logger) + + if configs_proc.get("lmsi_config", None) is not None: + make_lmsi = True + else: + make_lmsi = False + + # loop through and reduce each group + n_fail = 0 + for group in groups_proc: + logger.info(f"Beginning run for {obs_id}:{group}") + dets = {gb:gg for gb, gg in zip(group_by_proc, group)} + try: + error, outputs_grp_init, _, aman = pp_util.preproc_or_load_group(obs_id, configs_init, + dets=dets, logger=logger, + context_init=context_init) + if error is None: + outputs_init.append(outputs_grp_init) + + init_fields = aman.preprocess._fields.copy() + + outputs_grp_proc = pp_util.save_group(obs_id, configs_proc, dets, + context_proc, subdir='temp_proc') + + # tags from context proc + tags_proc = np.array(context_proc.obsdb.get(aman.obs_info.obs_id, tags=True)['tags']) + if "tags" in aman._fields: + aman.move("tags", None) + aman.wrap('tags', tags_proc) + + # now run the pipeline on the processed axis manager + logger.info(f"Beginning processing pipeline for {obs_id}:{group}") + proc_aman, success = pipe_proc.run(aman) + proc_aman.wrap('pcfg_ref', pp_util.get_pcfg_check_aman(pipe_init)) + + # remove fields found in aman.preprocess from proc_aman + for fld_init in init_fields: + if fld_init in proc_aman: + proc_aman.move(fld_init, None) + + except Exception as e: + errmsg = f'{type(e)}: {e}' + tb = ''.join(traceback.format_tb(e.__traceback__)) + logger.info(f"ERROR: {obs_id} {group}\n{errmsg}\n{tb}") + n_fail += 1 + continue + if success != 'end': + # If a single group fails we don't log anywhere just mis an entry in the db. + logger.info(f"ERROR: {obs_id} {group}\nFailed at step {success}") + n_fail += 1 + continue + + logger.info(f"Saving data to {outputs_grp_proc['temp_file']}:{outputs_grp_proc['db_data']['dataset']}") + proc_aman.save(outputs_grp_proc['temp_file'], outputs_grp_proc['db_data']['dataset'], overwrite) + + if run_parallel: + outputs_proc.append(outputs_grp_proc) + else: + logger.info(f"Saving to database under {outputs_grp_proc['db_data']}") + if len(db_proc.inspect(outputs_grp_proc['db_data'])) == 0: + h5_path = os.path.relpath(outputs_grp_proc['temp_file'], + start=os.path.dirname(configs_proc['archive']['index'])) + db_proc.add_entry(outputs_grp_proc['db_data'], h5_path) + + if make_lmsi: + from pathlib import Path + import lmsi.core as lmsi + + if os.path.exists(new_plots): + lmsi.core([Path(x.name) for x in Path(new_plots).glob("*.png")], + Path(configs1["lmsi_config"]), + Path(os.path.join(new_plots, 'index.html'))) + + if run_parallel: + if n_fail == len(groups_proc): + # If no groups make it to the end of the processing return error. + logger.info(f'ERROR: all groups failed for {obs_id}') + error = 'all_fail' + return error, [obs_id, 'all groups'], [obs_id, 'all groups'] + else: + logger.info('Returning data to futures') + error = None + return error, outputs_init, outputs_proc + + +def get_parser(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + parser.add_argument('configs_init', help="Preprocessing Configuration File for existing database") + parser.add_argument('configs_proc', help="Preprocessing Configuration File for new database") + parser.add_argument( + '--query', + help="Query to pass to the observation list. Use \\'string\\' to " + "pass in strings within the query.", + type=str + ) + parser.add_argument( + '--obs-id', + help="obs-id of particular observation if we want to run on just one" + ) + parser.add_argument( + '--overwrite', + help="If true, overwrites existing entries in the database", + action='store_true', + ) + parser.add_argument( + '--min-ctime', + help="Minimum timestamp for the beginning of an observation list", + ) + parser.add_argument( + '--max-ctime', + help="Maximum timestamp for the beginning of an observation list", + ) + parser.add_argument( + '--update-delay', + help="Number of days (unit is days) in the past to start observation list.", + type=int + ) + parser.add_argument( + '--tags', + help="Observation tags. Ex: --tags 'jupiter' 'setting'", + nargs='*', + type=str + ) + parser.add_argument( + '--planet-obs', + help="If true, takes all planet tags as logical OR and adjusts related configs", + action='store_true', + ) + parser.add_argument( + '--verbosity', + help="increase output verbosity. 0:Error, 1:Warning, 2:Info(default), 3:Debug", + default=2, + type=int + ) + parser.add_argument( + '--nproc', + help="Number of parallel processes to run on.", + type=int, + default=4 + ) + return parser + +def main(configs_init: str, + configs_proc: str, + query: Optional[str] = None, + obs_id: Optional[str] = None, + overwrite: bool = False, + min_ctime: Optional[int] = None, + max_ctime: Optional[int] = None, + update_delay: Optional[int] = None, + tags: Optional[str] = None, + planet_obs: bool = False, + verbosity: Optional[int] = None, + nproc: Optional[int] = 4): + + logger = pp_util.init_logger("preprocess", verbosity=verbosity) + + configs_init, context_init = pp_util.get_preprocess_context(configs_init) + configs_proc, context_proc = pp_util.get_preprocess_context(configs_proc) + + errlog = os.path.join(os.path.dirname(configs_proc['archive']['index']), + 'errlog.txt') + multiprocessing.set_start_method('spawn') + + obs_list = sp_util.get_obslist(context_proc, query=query, obs_id=obs_id, min_ctime=min_ctime, + max_ctime=max_ctime, update_delay=update_delay, tags=tags, + planet_obs=planet_obs) + if len(obs_list)==0: + logger.warning(f"No observations returned from query: {query}") + + # clean up lingering files from previous incomplete runs + policy_dir_init = os.path.join(os.path.dirname(configs_init['archive']['policy']['filename']), 'temp') + policy_dir_proc = os.path.join(os.path.dirname(configs_proc['archive']['policy']['filename']), 'temp_proc') + for obs in obs_list: + obs_id = obs['obs_id'] + pp_util.cleanup_obs(obs_id, policy_dir_init, errlog, configs_init, context_init, + subdir='temp', remove=overwrite) + pp_util.cleanup_obs(obs_id, policy_dir_proc, errlog, configs_proc, context_proc, + subdir='temp_proc', remove=overwrite) + + run_list = [] + + if overwrite or not os.path.exists(configs_proc['archive']['index']): + # run on all if database doesn't exist + for obs in obs_list: + #run on all if database doesn't exist + run_list = [ (o,None) for o in obs_list] + group_by_proc = np.atleast_1d(configs_proc['subobs'].get('use', 'detset')) + else: + db = core.metadata.ManifestDb(configs_proc['archive']['index']) + for obs in obs_list: + x = db.inspect({'obs:obs_id': obs["obs_id"]}) + if x is None or len(x) == 0: + run_list.append( (obs, None) ) + else: + group_by_proc, groups_proc, _ = pp_util.get_groups(obs["obs_id"], configs_proc, context_proc) + if len(x) != len(groups_proc): + [groups_proc.remove([a[f'dets:{gb}'] for gb in group_by_proc]) for a in x] + run_list.append( (obs, groups_proc) ) + + logger.info(f'Run list created with {len(run_list)} obsids') + + # run write_block obs-ids in parallel at once then write all to the sqlite db. + with ProcessPoolExecutor(nproc) as exe: + futures = [exe.submit(multilayer_preprocess_tod, obs_id=r[0]['obs_id'], + group_list=r[1], verbosity=verbosity, + configs_init=configs_init, + configs_proc=configs_proc, + overwrite=overwrite, run_parallel=True) for r in run_list] + for future in as_completed(futures): + logger.info('New future as_completed result') + try: + err, db_datasets_init, db_datasets_proc = future.result() + except Exception as e: + errmsg = f'{type(e)}: {e}' + tb = ''.join(traceback.format_tb(e.__traceback__)) + logger.info(f"ERROR: future.result()\n{errmsg}\n{tb}") + f = open(errlog, 'a') + f.write(f'\n{time.time()}, future.result() error\n{errmsg}\n{tb}\n') + f.close() + continue + futures.remove(future) + + if db_datasets_init: + logger.info(f'Processing future result db_dataset: {db_datasets_init}') + for db_dataset in db_datasets_init: + pp_util.cleanup_mandb(err, db_dataset, configs_init, logger, overwrite) + + if db_datasets_proc: + logger.info(f'Processing future dependent result db_dataset: {db_datasets_proc}') + for db_dataset in db_datasets_proc: + pp_util.cleanup_mandb(err, db_dataset, configs_proc, logger, overwrite) + +if __name__ == '__main__': + sp_util.main_launcher(main, get_parser) diff --git a/sotodlib/site_pipeline/preprocess_tod.py b/sotodlib/site_pipeline/preprocess_tod.py index 0605033f8..1c4d14a7f 100644 --- a/sotodlib/site_pipeline/preprocess_tod.py +++ b/sotodlib/site_pipeline/preprocess_tod.py @@ -29,33 +29,26 @@ def dummy_preproc(obs_id, group_list, logger, error = None outputs = [] context = core.Context(configs["context_file"]) - group_by, groups = pp_util.get_groups(obs_id, configs, context) + group_by, groups, error = pp_util.get_groups(obs_id, configs, context) pipe = Pipeline(configs["process_pipe"], plot_dir=configs["plot_dir"], logger=logger) for group in groups: logger.info(f"Beginning run for {obs_id}:{group}") + dets = {gb:gg for gb, gg in zip(group_by, group)} proc_aman = core.AxisManager(core.LabelAxis('dets', ['det%i' % i for i in range(3)]), core.OffsetAxis('samps', 1000)) proc_aman.wrap_new('signal', ('dets', 'samps'), dtype='float32') proc_aman.wrap_new('timestamps', ('samps',))[:] = (np.arange(proc_aman.samps.count) / 200) - policy = pp_util.ArchivePolicy.from_params(configs['archive']['policy']) - dest_file, dest_dataset = policy.get_dest(obs_id) - for gb, g in zip(group_by, group): - if gb == 'detset': - dest_dataset += "_" + g - else: - dest_dataset += "_" + gb + "_" + str(g) - logger.info(f"Saving data to {dest_file}:{dest_dataset}") - proc_aman.save(dest_file, dest_dataset, overwrite) - - # Collect index info. - db_data = {'obs:obs_id': obs_id, - 'dataset': dest_dataset} - for gb, g in zip(group_by, group): - db_data['dets:'+gb] = g + + outputs_grp = pp_util.save_group(obs_id, configs, dets, context, subdir='temp') + logger.info(f"Saving data to {outputs_grp['temp_file']}:{outputs_grp['db_data']['dataset']}") + proc_aman.save(outputs_grp['temp_file'], outputs_grp['db_data']['dataset'], overwrite) + if run_parallel: - outputs.append(db_data) + outputs.append(outputs_grp) + if run_parallel: - return error, dest_file, outputs + return error, outputs + def preprocess_tod(obs_id, configs, @@ -90,7 +83,14 @@ def preprocess_tod(obs_id, configs = yaml.safe_load(open(configs, "r")) context = core.Context(configs["context_file"]) - group_by, groups = pp_util.get_groups(obs_id, configs, context) + group_by, groups, error = pp_util.get_groups(obs_id, configs, context) + + if error is not None: + if run_parallel: + return error[0], [None, None] + else: + return + all_groups = groups.copy() for g in all_groups: if group_list is not None: @@ -119,7 +119,7 @@ def preprocess_tod(obs_id, f"No analysis to run.") error = 'no_group_overlap' if run_parallel: - return error, None, [None, None] + return error, [None, None] else: return @@ -136,8 +136,9 @@ def preprocess_tod(obs_id, n_fail = 0 for group in groups: logger.info(f"Beginning run for {obs_id}:{group}") + dets = {gb:gg for gb, gg in zip(group_by, group)} try: - aman = context.get_obs(obs_id, dets={gb:g for gb, g in zip(group_by, group)}) + aman = context.get_obs(obs_id, dets=dets) tags = np.array(context.obsdb.get(aman.obs_info.obs_id, tags=True)['tags']) aman.wrap('tags', tags) proc_aman, success = pipe.run(aman) @@ -161,29 +162,18 @@ def preprocess_tod(obs_id, n_fail += 1 continue - policy = pp_util.ArchivePolicy.from_params(configs['archive']['policy']) - dest_file, dest_dataset = policy.get_dest(obs_id) - for gb, g in zip(group_by, group): - if gb == 'detset': - dest_dataset += "_" + g - else: - dest_dataset += "_" + gb + "_" + str(g) - logger.info(f"Saving data to {dest_file}:{dest_dataset}") - proc_aman.save(dest_file, dest_dataset, overwrite) - - # Collect index info. - db_data = {'obs:obs_id': obs_id, - 'dataset': dest_dataset} - for gb, g in zip(group_by, group): - db_data['dets:'+gb] = g + outputs_grp = pp_util.save_group(obs_id, configs, dets, context, subdir='temp') + logger.info(f"Saving data to {outputs_grp['temp_file']}:{outputs_grp['db_data']['dataset']}") + proc_aman.save(outputs_grp['temp_file'], outputs_grp['db_data']['dataset'], overwrite) + if run_parallel: - outputs.append(db_data) + outputs.append(outputs_grp) else: - logger.info(f"Saving to database under {db_data}") - if len(db.inspect(db_data)) == 0: - h5_path = os.path.relpath(dest_file, + logger.info(f"Saving to database under {outputs_grp['db_data']}") + if len(db.inspect(outputs_grp['db_data'])) == 0: + h5_path = os.path.relpath(outputs_grp['temp_file'], start=os.path.dirname(configs['archive']['index'])) - db.add_entry(db_data, h5_path) + db.add_entry(outputs_grp['db_data'], h5_path) if make_lmsi: from pathlib import Path @@ -199,17 +189,18 @@ def preprocess_tod(obs_id, # If no groups make it to the end of the processing return error. logger.info(f'ERROR: all groups failed for {obs_id}') error = 'all_fail' - return error, None, [obs_id, 'all groups'] + return error, [obs_id, 'all groups'] else: logger.info('Returning data to futures') error = None - return error, dest_file, outputs + return error, outputs + def load_preprocess_tod_sim(obs_id, sim_map, configs="preprocess_configs.yaml", context=None, dets=None, meta=None, modulated=True): - """ Loads the saved information from the preprocessing pipeline and runs the + """Loads the saved information from the preprocessing pipeline and runs the processing section of the pipeline on simulated data Assumes preprocess_tod has already been run on the requested observation. @@ -235,7 +226,8 @@ def load_preprocess_tod_sim(obs_id, sim_map, """ configs, context = pp_util.get_preprocess_context(configs, context) meta = pp_util.load_preprocess_det_select(obs_id, configs=configs, - context=context, dets=dets, meta=meta) + context=context, dets=dets, + meta=meta) if meta.dets.count == 0: logger.info(f"No detectors left after cuts in obs {obs_id}") @@ -253,6 +245,7 @@ def load_preprocess_tod_sim(obs_id, sim_map, pipe.run(aman, aman.preprocess, sim=True) return aman + def get_parser(parser=None): if parser is None: parser = argparse.ArgumentParser() @@ -310,6 +303,7 @@ def get_parser(parser=None): ) return parser + def main( configs: str, query: Optional[str] = None, @@ -333,8 +327,17 @@ def main( obs_list = sp_util.get_obslist(context, query=query, obs_id=obs_id, min_ctime=min_ctime, max_ctime=max_ctime, update_delay=update_delay, tags=tags, planet_obs=planet_obs) + if len(obs_list)==0: logger.warning(f"No observations returned from query: {query}") + + # clean up lingering files from previous incomplete runs + policy_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'temp') + for obs in obs_list: + obs_id = obs['obs_id'] + pp_util.cleanup_obs(obs_id, policy_dir, errlog, configs, context, + subdir='temp', remove=overwrite) + run_list = [] if overwrite or not os.path.exists(configs['archive']['index']): @@ -345,40 +348,26 @@ def main( db = core.metadata.ManifestDb(configs['archive']['index']) for obs in obs_list: x = db.inspect({'obs:obs_id': obs["obs_id"]}) - group_by, groups = pp_util.get_groups(obs["obs_id"], configs, context) if x is None or len(x) == 0: run_list.append( (obs, None) ) - elif len(x) != len(groups): - [groups.remove([a[f'dets:{gb}'] for gb in group_by]) for a in x] - run_list.append( (obs, groups) ) + else: + group_by, groups, _ = pp_util.get_groups(obs["obs_id"], configs, context) + if len(x) != len(groups): + [groups.remove([a[f'dets:{gb}'] for gb in group_by]) for a in x] + run_list.append( (obs, groups) ) logger.info(f'Run list created with {len(run_list)} obsids') - # Expects archive policy filename to be /.h5 and then this adds - # /_.h5 where xxx is a number that increments up from 0 - # whenever the file size exceeds 10 GB. - nfile = 0 - folder = os.path.dirname(configs['archive']['policy']['filename']) - basename = os.path.splitext(configs['archive']['policy']['filename'])[0] - dest_file = basename + '_' + str(nfile).zfill(3) + '.h5' - if not(os.path.exists(folder)): - os.makedirs(folder) - while os.path.exists(dest_file) and os.path.getsize(dest_file) > 10e9: - nfile += 1 - dest_file = basename + '_' + str(nfile).zfill(3) + '.h5' - - logger.info(f'Starting dest_file set to {dest_file}') - # Run write_block obs-ids in parallel at once then write all to the sqlite db. with ProcessPoolExecutor(nproc) as exe: futures = [exe.submit(preprocess_tod, obs_id=r[0]['obs_id'], group_list=r[1], verbosity=verbosity, - configs=pp_util.swap_archive(configs, f'temp/{r[0]["obs_id"]}.h5'), + configs=configs, overwrite=overwrite, run_parallel=True) for r in run_list] for future in as_completed(futures): logger.info('New future as_completed result') try: - err, src_file, db_datasets = future.result() + err, db_datasets = future.result() except Exception as e: errmsg = f'{type(e)}: {e}' tb = ''.join(traceback.format_tb(e.__traceback__)) @@ -389,37 +378,10 @@ def main( continue futures.remove(future) - logger.info(f'Processing future result db_dataset: {db_datasets}') - db = pp_util.get_preprocess_db(configs, group_by) - logger.info('Database connected') - if os.path.exists(dest_file) and os.path.getsize(dest_file) >= 10e9: - nfile += 1 - dest_file = basename + '_'+str(nfile).zfill(3)+'.h5' - logger.info('Starting a new h5 file.') - - h5_path = os.path.relpath(dest_file, - start=os.path.dirname(configs['archive']['index'])) - - if err is None: - logger.info(f'Moving files from temp to final destination.') - with h5py.File(dest_file,'a') as f_dest: - with h5py.File(src_file,'r') as f_src: - for dts in f_src.keys(): - f_src.copy(f_src[f'{dts}'], f_dest, f'{dts}') - for member in f_src[dts]: - if isinstance(f_src[f'{dts}/{member}'], h5py.Dataset): - f_src.copy(f_src[f'{dts}/{member}'], f_dest[f'{dts}'], f'{dts}/{member}') - for db_data in db_datasets: - logger.info(f"Saving to database under {db_data}") - if len(db.inspect(db_data)) == 0: - db.add_entry(db_data, h5_path) - logger.info(f'Deleting {src_file}.') - os.remove(src_file) - else: - logger.info(f'Writing {db_datasets[0]} to error log') - f = open(errlog, 'a') - f.write(f'\n{time.time()}, {err}, {db_datasets[0]}\n{db_datasets[1]}\n') - f.close() + if db_datasets: + logger.info(f'Processing future result db_dataset: {db_datasets}') + for db_dataset in db_datasets: + pp_util.cleanup_mandb(err, db_dataset, configs, logger) if __name__ == '__main__': sp_util.main_launcher(main, get_parser) diff --git a/sotodlib/site_pipeline/update_book_plan.py b/sotodlib/site_pipeline/update_book_plan.py index d76ac04cd..6a80e603d 100644 --- a/sotodlib/site_pipeline/update_book_plan.py +++ b/sotodlib/site_pipeline/update_book_plan.py @@ -2,11 +2,21 @@ import datetime as dt import time from typing import Optional +from sqlalchemy import not_ -from sotodlib.io.imprinter import Imprinter from sotodlib.site_pipeline.monitor import Monitor from sotodlib.site_pipeline.util import init_logger +from sotodlib.io.imprinter import ( + Imprinter, + Books, + BOUND, + UNBOUND, + UPLOADED, + FAILED, + DONE, +) + logger = init_logger(__name__, "update_book_plan: ") def main( @@ -134,9 +144,13 @@ def record_book_counts(monitor, imprinter): log_tags = {} script_run = time.time() + session = imprinter.get_session() + def get_count( q ): + return session.query(Books).filter(q).count() + monitor.record( "unbound", - [ len(imprinter.get_unbound_books()) ], + [ get_count(Books.status == UNBOUND) ], [script_run], tags, imprinter.config["monitor"]["measurement"], @@ -145,7 +159,7 @@ def record_book_counts(monitor, imprinter): monitor.record( "bound", - [ len(imprinter.get_bound_books()) ], + [ get_count(Books.status == BOUND) ], [script_run], tags, imprinter.config["monitor"]["measurement"], @@ -154,7 +168,7 @@ def record_book_counts(monitor, imprinter): monitor.record( "uploaded", - [ len(imprinter.get_uploaded_books()) ], + [ get_count(Books.status == UPLOADED) ], [script_run], tags, imprinter.config["monitor"]["measurement"], @@ -163,7 +177,25 @@ def record_book_counts(monitor, imprinter): monitor.record( "failed", - [ len(imprinter.get_failed_books()) ], + [ get_count(Books.status == FAILED) ], + [script_run], + tags, + imprinter.config["monitor"]["measurement"], + log_tags=log_tags + ) + + monitor.record( + "done", + [ get_count(Books.status == DONE) ], + [script_run], + tags, + imprinter.config["monitor"]["measurement"], + log_tags=log_tags + ) + + monitor.record( + "has_level2", + [ get_count(not_(Books.lvl2_deleted)) ], [script_run], tags, imprinter.config["monitor"]["measurement"], diff --git a/sotodlib/site_pipeline/update_obsdb.py b/sotodlib/site_pipeline/update_obsdb.py index 2d4e16014..b41b728fb 100644 --- a/sotodlib/site_pipeline/update_obsdb.py +++ b/sotodlib/site_pipeline/update_obsdb.py @@ -20,6 +20,9 @@ lat_tube_list_file: path to yaml dict matching tubes and bands tolerate_stray_files: True skip_bad_books: True + known_bad_books: + - oper_1736874485_satp3_0000100 + - obs_9999999999_satp0_1111111 extra_extra_files: - Z_bookbinder_log.txt extra_files: @@ -30,7 +33,7 @@ from sotodlib.core.metadata import ObsDb from sotodlib.core import Context -from sotodlib.site_pipeline.check_book import main as checkbook +from sotodlib.site_pipeline import check_book from sotodlib.io import load_book import os import glob @@ -42,6 +45,7 @@ import logging from sotodlib.site_pipeline import util from typing import Optional +from itertools import product logger = util.init_logger('update_obsdb', 'update-obsdb: ') @@ -86,7 +90,8 @@ def main(config: str, recency: float = None, booktype: Optional[str] = "both", verbosity: Optional[int] = 2, - overwrite: Optional[bool] = False): + overwrite: Optional[bool] = False, + fastwalk: Optional[bool] = False): """ Create or update an obsdb for observation or operations data. @@ -104,6 +109,10 @@ def main(config: str, Output verbosity. 0:Error, 1:Warning, 2:Info(default), 3:Debug overwrite : bool if False, do not re-check existing entries + fastwalk : bool + if True, assume the directories have a structure /base_dir/obs|oper/\d{5}/... + Then replace base_dir with only the directories where \d{5} is greater or + equal to recency. """ if verbosity == 0: logger.setLevel(logging.ERROR) @@ -144,6 +153,8 @@ def main(config: str, bookcartobsdb.add_obs_columns(col_list) if "skip_bad_books" not in config_dict: config_dict["skip_bad_books"] = False + if "known_bad_books" not in config_dict: + config_dict["known_bad_books"] = [] #How far back we should look tnow = time.time() @@ -156,6 +167,14 @@ def main(config: str, #Check if there are one or multiple base_dir specified if isinstance(base_dir,str): base_dir = [base_dir] + if fastwalk: + abv_tback = int(f"{int(tback):05}"[:5]) #Make sure we have at least five chars + abv_tnow = int(f"{int(tnow):05}"[:5]) + abv_codes = np.arange(abv_tback, abv_tnow+1) + #Build the combinations base_dir/booktype/\d{5} + base_dir = [f"{os.path.join(x[0], x[1], str(x[2]))}" for x in product(base_dir, accept_type, abv_codes)] + logger.info(f"Looking in the following directories only: {str(base_dir)}") + for bd in base_dir: #Find folders that are book-like and recent for dirpath, _, _ in os.walk(bd): @@ -163,6 +182,9 @@ def main(config: str, _, book_id = os.path.split(dirpath) if book_id in existing and not overwrite: continue + if book_id in config_dict["known_bad_books"]: + logger.debug(f"{book_id} known to be bad, skipping it") + continue found_timestamp = re.search(r"\d{10}", book_id)#Find the rough timestamp if found_timestamp and int(found_timestamp.group())>tback: #Looks like a book folder and young enough @@ -173,10 +195,16 @@ def main(config: str, for bookpath in sorted(bookcart): if check_meta_type(bookpath) in accept_type: t1 = time.time() + logger.info(f"Examining book at {bookpath}") try: #obsfiledb creation - checkbook(bookpath, config, add=True, overwrite=True) - logger.info(f"Ran check_book for {bookpath} in {time.time()-t1} s") + ok, obsfiledb_info = check_book.scan_book_dir( + bookpath, logger, config_dict, prep_obsfiledb=True) + if not ok: + raise RuntimeError("check_book found fatal errors, not adding.") + check_book.add_to_obsfiledb( + obsfiledb_info, logger, config_dict, overwrite=True) + logger.info(f"Ran check_book in {time.time()-t1} s") except Exception as e: if config_dict["skip_bad_books"]: logger.warning(f"failed to add {bookpath}") @@ -295,7 +323,7 @@ def main(config: str, tags = [t.strip() for t in tags if t.strip() != ''] bookcartobsdb.update_obs(obs_id, very_clean, tags=tags) - logger.info(f"Added {obs_id} in {time.time()-t1} s") + logger.info(f"Finished {obs_id} in {time.time()-t1} s") else: bookcart.remove(bookpath) @@ -305,7 +333,7 @@ def get_parser(parser=None): parser = argparse.ArgumentParser() parser.add_argument("--config", help="ObsDb, ObsfileDb configuration file", type=str, required=True) - parser.add_argument('--recency', default=None, type=float, + parser.add_argument("--recency", default=None, type=float, help="Days to subtract from now to set as minimum ctime. If None, no minimum") parser.add_argument("--verbosity", default=2, type=int, help="Increase output verbosity. 0:Error, 1:Warning, 2:Info(default), 3:Debug") @@ -313,6 +341,8 @@ def get_parser(parser=None): help="Select book type to look for: obs, oper, both(default)") parser.add_argument("--overwrite", action="store_true", help="If true, writes over existing entries") + parser.add_argument("--fastwalk", action="store_true", + help="Assume known directory tree shape and speed up walkthrough") return parser diff --git a/sotodlib/site_pipeline/util.py b/sotodlib/site_pipeline/util.py index f0566d4fe..e22eeb34f 100644 --- a/sotodlib/site_pipeline/util.py +++ b/sotodlib/site_pipeline/util.py @@ -8,6 +8,7 @@ import argparse import yaml import numpy as np +import copy from astropy import units as u diff --git a/sotodlib/toast/ops/corotator.py b/sotodlib/toast/ops/corotator.py index 86b3dc16b..b94c8e64b 100644 --- a/sotodlib/toast/ops/corotator.py +++ b/sotodlib/toast/ops/corotator.py @@ -116,7 +116,7 @@ def _exec(self, data, detectors=None, **kwargs): msg = f"LAT Co-rotation: obs {obs.name} at scan El = " msg += f"{scan_el_deg:0.2f} degrees, rotating by " msg += f"{np.mean(corot_deg):0.2f} average degrees" - log.info(msg) + log.debug(msg) obs.shared[self.corotator_angle].set( corot, offset=(0,), @@ -126,7 +126,7 @@ def _exec(self, data, detectors=None, **kwargs): # We are not co-rotating. Set the angle to zero. if obs.comm_col_rank == 0: msg = f"LAT Co-rotation: obs {obs.name} disabled" - log.info(msg) + log.debug(msg) corot = np.zeros(obs.n_local_samples) obs.shared[self.corotator_angle].set( corot, diff --git a/sotodlib/toast/ops/mlmapmaker.py b/sotodlib/toast/ops/mlmapmaker.py index 40d49ad30..cd23a4197 100644 --- a/sotodlib/toast/ops/mlmapmaker.py +++ b/sotodlib/toast/ops/mlmapmaker.py @@ -9,10 +9,10 @@ import numpy as np import traitlets from astropy import units as u -from pixell import enmap, tilemap, fft +from pixell import enmap, tilemap, fft, utils, bunch import toast -from toast.traits import trait_docs, Unicode, Int, Instance, Bool, Float +from toast.traits import trait_docs, Unicode, Int, Instance, Bool, Float, List from toast.ops import Operator from toast.utils import Logger, Environment, rate_from_times from toast.timing import function_timer, Timer @@ -76,7 +76,17 @@ class MLMapmaker(Operator): help="The format is [from=](ra:dec|name),[to=(ra:dec|name)],[up=(ra:dec|name|system)]", ) - comps = Unicode("T", help="Components (must be 'T', 'QU' or 'TQU')") + interpol = Unicode( + "nearest", + help="Either one or a comma-separated list of interpolation modes", + ) + + downsample = List( + [1], + help="Downsample TOD by these factors.", + ) + + comps = Unicode("TQU", help="Components (must be 'T', 'QU' or 'TQU')") Nmat = Instance(klass=mm.Nmat, allow_none=True, help="The noise matrix to use") @@ -142,8 +152,20 @@ class MLMapmaker(Operator): help="If True, clear all observation detector data after accumulating", ) - tiled = Bool( + checkpoint_interval = Int( + 0, + help="If greater than zero, the CG solver will store its state and" + "restart from a checkpoint when available.", + ) + + skip_existing = Bool( False, + help="If True, the mapmaker will not write any map products that " + "already exist on disk. See `checkpoint`." + ) + + tiled = Bool( + True, help="If True, the map will be represented as distributed tiles in memory. " "For large maps this is faster and more memory efficient, but for small " "maps it has some overhead due to extra communication." @@ -163,12 +185,15 @@ class MLMapmaker(Operator): weather = Unicode("vacuum", help="Weather to assume when making maps") site = Unicode("so", help="Site to use when making maps") - maxiter = Int(500, help="Maximum number of CG iterations") + maxiter = List( + [500], + help="List of maximum number of CG iterations for each pass.", + ) maxerr = Float(1e-6, help="Maximum error in the CG solver") truncate_tod = Bool( - False, + True, help="Truncate TOD to an easily factorizable length to ensure efficient FFT.", ) @@ -192,6 +217,13 @@ def _check_mode(self, proposal): raise traitlets.TraitError("Invalid comps (must be 'T', 'QU' or 'TQU')") return check + @traitlets.validate("checkpoint_interval") + def _check_checkpoint_interval(self, proposal): + check = proposal["value"] + if check < 0: + raise traitlets.TraitError("Invalid checkpoint_interval. Must be non-negative.") + return check + @traitlets.validate("shared_flag_mask") def _check_shared_flag_mask(self, proposal): check = proposal["value"] @@ -231,13 +263,6 @@ def _check_params(self, proposal): check = 1 return check - @traitlets.validate("maxiter") - def _check_maxiter(self, proposal): - check = proposal["value"] - if check <= 0: - raise traitlets.TraitError("Maxiter should be greater than zero") - return check - @traitlets.validate("maxerr") def _check_maxerr(self, proposal): check = proposal["value"] @@ -262,276 +287,288 @@ def _check_nmat_type(self, proposal): return check def __init__(self, **kwargs): + self.shape = None + self.wcs = None + self.recenter = None + self.signal_map = None + self.mapmaker = None super().__init__(**kwargs) - self._mapmaker = None @function_timer - def _exec(self, data, detectors=None, **kwargs): - log = Logger.get() + def setup_passes(self): + tmp = bunch.Bunch() + tmp.downsample = self.downsample + tmp.maxiter = self.maxiter + tmp.interpol = self.interpol.split(",") + # The entries may have different lengths. We use the max + # and then pad the others by repeating the last element. + # The final output will be a list of bunches + npass = max([len(tmp[key]) for key in tmp]) + passes = [] + for i in range(npass): + entry = bunch.Bunch() + for key in tmp: + entry[key] = tmp[key][min(i, len(tmp[key]) - 1)] + passes.append(entry) + return passes - for trait in ["area"]: - value = getattr(self, trait) - if value is None: - raise RuntimeError( - f"You must set `{trait}` before running MLMapmaker" - ) - - # nmat_type is guaranteed to be a valid Nmat class - self.Nmat = getattr(mm, self.nmat_type)() - - comm = data.comm.comm_world - gcomm = data.comm.comm_group - - if self._mapmaker is None: - # First call- create the mapmaker instance. - # Get the timestream dtype from the first observation - self._dtype_tod = data.obs[0].detdata[self.det_data].dtype - - self._shape, self._wcs = enmap.read_map_geometry(self.area) - - self._recenter = None - if self.center_at is not None: - self._recenter = mm.parse_recentering(self.center_at) - - dtype_tod = np.float32 - signal_cut = mm.SignalCut(comm, dtype=dtype_tod) - - signal_map = mm.SignalMap(self._shape, self._wcs, comm, comps=self.comps, - dtype=np.dtype(self.dtype_map), recenter=self._recenter, tiled=self.tiled) - signals = [signal_cut, signal_map] - self._mapmaker = mm.MLMapmaker(signals, noise_model=self.Nmat, dtype=dtype_tod, verbose=self.verbose) - # Store this to be able to output rhs and div later - self._signal_map = signal_map - - for ob in data.obs: - # Get the detectors we are using locally for this observation - dets = ob.select_local_detectors(detectors) - if len(dets) == 0: - # Nothing to do for this observation - continue - - # Get the sample rate from the data. We also have nominal sample rates - # from the noise model and also from the focalplane. - (rate, dt, dt_min, dt_max, dt_std) = rate_from_times( - ob.shared[self.times].data + @function_timer + def _load_noise_model(self, ob, npass, ipass, gcomm): + # Maybe load precomputed noise model + log = Logger.get() + if self.nmat_dir is None: + nmat_dir = os.path.join(self.out_dir, "nmats") + else: + nmat_dir = self.nmat_dir + if npass != 1: + nmat_dir += f"_pass{ipass + 1}" + nmat_file = os.path.join(nmat_dir, f"nmat_{ob.name}.hdf") + there = os.path.isfile(nmat_file) + if self.nmat_mode == "load" and not there: + raise RuntimeError( + f"Nmat mode is 'load' but {nmat_file} does not exist." ) - - # Get the focalplane for this observation - fp = ob.telescope.focalplane - - # Prepare data for the mapmaker. - - axdets = LabelAxis("dets", dets) - - nsample = int(ob.n_local_samples) - ind = slice(None) - ncut = nsample - fft.fft_len(nsample) - if ncut != 0: - if self.truncate_tod: + if self.nmat_mode == "load" or (self.nmat_mode == "cache" and there): + log.debug_rank(f"Loading noise model from '{nmat_file}'", comm=gcomm) + try: + nmat = mm.read_nmat(nmat_file) + except Exception as e: + if self.nmat_mode == "cache": log.info_rank( - f"Truncating {ncut} / {nsample} samples ({100 * ncut / nsample:.3f}%) from " - f"{ob.name} for better FFT performance.", + f"Failed to load noise model from '{nmat_file}'" + f" : '{e}'. Will cache a new one", comm=gcomm, ) - nsample -= ncut - ind = slice(nsample) + nmat = None else: - log.warning_rank( - f"{ob.name} length contains large prime factors. " - F"FFT performance may be degrared. Recommend " - f"truncating {ncut} / {nsample} samples ({100 * ncut / nsample:.3f}%).", - comm=gcomm, - ) - - axsamps = OffsetAxis( - "samps", - count=nsample, - offset=ob.local_index_offset, - origin_tag=ob.name, - ) - - # Convert the data view into a RangesMatrix - ranges = so3g.proj.ranges.RangesMatrix.zeros((len(dets), nsample)) - if self.view is not None: - view_ranges = np.array( - [[x.first, min(x.last, nsample) + 1] for x in ob.intervals[self.view]] - ) - ranges += so3g.proj.ranges.Ranges.from_array(view_ranges, nsample) - - # Convert the focalplane offsets into the expected form - det_to_row = {y["name"]: x for x, y in enumerate(fp.detector_data)} - det_quat = np.array([fp.detector_data["quat"][det_to_row[x]] for x in dets]) - xi, eta, gamma = quat_to_xieta(det_quat) - - axfp = AxisManager() - axfp.wrap("xi", xi, axis_map=[(0, axdets)]) - axfp.wrap("eta", eta, axis_map=[(0, axdets)]) - axfp.wrap("gamma", gamma, axis_map=[(0, axdets)]) - - # Convert Az/El quaternion of the detector back into - # angles from the simulation. - theta, phi, pa = toast.qarray.to_iso_angles(ob.shared[self.boresight][ind]) - - # Azimuth is measured in the opposite direction from longitude - az = 2 * np.pi - phi - el = np.pi / 2 - theta - roll = pa - - axbore = AxisManager() - axbore.wrap("az", az, axis_map=[(0, axsamps)]) - axbore.wrap("el", el, axis_map=[(0, axsamps)]) - axbore.wrap("roll", roll, axis_map=[(0, axsamps)]) - - axobs = AxisManager() - axobs.wrap("focal_plane", axfp) - axobs.wrap("timestamps", ob.shared[self.times][ind], axis_map=[(0, axsamps)]) - axobs.wrap( - "signal", - ob.detdata[self.det_data][dets, ind], - axis_map=[(0, axdets), (1, axsamps)], - ) - axobs.wrap("boresight", axbore) - axobs.wrap('flags', FlagManager.for_tod(axobs)) - axobs.flags.wrap("glitch_flags", ranges, axis_map=[(0, axdets), (1, axsamps)]) - axobs.wrap("weather", np.full(1, self.weather)) - axobs.wrap("site", np.full(1, "so")) - - # NOTE: Expected contents look like: - # >>> tod - # AxisManager(signal[dets,samps], timestamps[samps], readout_filter_cal[dets], - # mce_filter_params[6], iir_params[3,5], flags*[samps], boresight*[samps], - # array_data*[dets], pointofs*[dets], focal_plane*[dets], abscal[dets], - # timeconst[dets], glitch_flags[dets,samps], source_flags[dets,samps], - # relcal[dets], dets:LabelAxis(63), samps:OffsetAxis(372680)) - # >>> tod.focal_plane - # AxisManager(xi[dets], eta[dets], gamma[dets], dets:LabelAxis(63)) - # >>> tod.boresight - # AxisManager(az[samps], el[samps], roll[samps], samps:OffsetAxis(372680)) - - # Maybe load precomputed noise model - if self.nmat_dir is None: - nmat_dir = os.path.join(self.out_dir, "nmats") - else: - nmat_dir = self.nmat_dir - nmat_file = nmat_dir + "/nmat_%s.hdf" % ob.name - there = os.path.isfile(nmat_file) - if self.nmat_mode == "load" and not there: - raise RuntimeError( - f"Nmat mode is 'load' but {nmat_file} does not exist." - ) - if self.nmat_mode == "load" or (self.nmat_mode == "cache" and there): - log.info_rank(f"Loading noise model from '{nmat_file}'", comm=gcomm) - try: - nmat = mm.read_nmat(nmat_file) - except Exception as e: - if self.nmat_mode == "cache": - log.info_rank( - f"Failed to load noise model from '{nmat_file}'" - f" : '{e}'. Will cache a new one", - comm=gcomm, - ) - nmat = None - else: - msg = f"Failed to load noise model from '{nmat_file}' : {e}" - raise RuntimeError(msg) - else: - nmat = None + msg = f"Failed to load noise model from '{nmat_file}' : {e}" + raise RuntimeError(msg) + else: + nmat = None + return nmat, nmat_file - self._mapmaker.add_obs( - ob.name, axobs, deslope=self.deslope, noise_model=nmat - ) - del axobs - - # Maybe save the noise model we built (only if we actually built one rather than - # reading one in) - if self.nmat_mode in ["save", "cache"] and nmat is None: - log.info_rank(f"Writing noise model to '{nmat_file}'", comm=gcomm) - os.makedirs(nmat_dir, exist_ok=True) - mm.write_nmat(nmat_file, self._mapmaker.data[-1].nmat) + @function_timer + def _save_noise_model(self, mapmaker, nmat, nmat_file, gcomm): + # Maybe save the noise model we built (only if we actually built one rather than + # reading one in) + log = Logger.get() + if self.nmat_mode in ["save", "cache"] and nmat is None: + log.debug_rank(f"Writing noise model to '{nmat_file}'", comm=gcomm) + nmat_dir = os.path.dirname(nmat_file) + os.makedirs(nmat_dir, exist_ok=True) + mm.write_nmat(nmat_file, mapmaker.data[-1].nmat) + return - # Optionally delete the input detector data to save memory, if - # the calling code knows that no additional operators will be - # used afterwards. - if self.purge_det_data: - del ob.detdata[self.det_data] + @function_timer + def _wrap_obs(self, ob, dets, passinfo): + """ Prepare data for the mapmaker """ + + # Get the focalplane for this observation + fp = ob.telescope.focalplane + + # Get the sample rate from the data. We also have nominal sample rates + # from the noise model and also from the focalplane. + # (rate, dt, dt_min, dt_max, dt_std) = rate_from_times( + # ob.shared[self.times].data + # ) + + axdets = LabelAxis("dets", dets) + nsample = int(ob.n_local_samples) + + axsamps = OffsetAxis( + "samps", + count=nsample, + offset=ob.local_index_offset, + origin_tag=ob.name, + ) - return + # Convert the data view into a RangesMatrix + ranges = so3g.proj.ranges.RangesMatrix.zeros((len(dets), nsample)) + if self.view is not None: + view_ranges = np.array( + [[x.first, min(x.last, nsample) + 1] for x in ob.intervals[self.view]] + ) + ranges += so3g.proj.ranges.Ranges.from_array(view_ranges, nsample) + + # Convert the focalplane offsets into the expected form + det_to_row = {y["name"]: x for x, y in enumerate(fp.detector_data)} + det_quat = np.array([fp.detector_data["quat"][det_to_row[x]] for x in dets]) + xi, eta, gamma = quat_to_xieta(det_quat) + + axfp = AxisManager() + axfp.wrap("xi", xi, axis_map=[(0, axdets)]) + axfp.wrap("eta", eta, axis_map=[(0, axdets)]) + axfp.wrap("gamma", gamma, axis_map=[(0, axdets)]) + + # Convert Az/El quaternion of the detector back into + # angles from the simulation. + theta, phi, pa = toast.qarray.to_iso_angles(ob.shared[self.boresight]) + + # Azimuth is measured in the opposite direction from longitude + az = 2 * np.pi - phi + el = np.pi / 2 - theta + roll = pa + + axbore = AxisManager() + axbore.wrap("az", az, axis_map=[(0, axsamps)]) + axbore.wrap("el", el, axis_map=[(0, axsamps)]) + axbore.wrap("roll", roll, axis_map=[(0, axsamps)]) + + axobs = AxisManager() + axobs.wrap("focal_plane", axfp) + axobs.wrap("timestamps", ob.shared[self.times], axis_map=[(0, axsamps)]) + axobs.wrap( + "signal", + ob.detdata[self.det_data][dets, :], + axis_map=[(0, axdets), (1, axsamps)], + ) + axobs.wrap("boresight", axbore) + axobs.wrap('flags', FlagManager.for_tod(axobs)) + axobs.flags.wrap("glitch_flags", ranges, axis_map=[(0, axdets), (1, axsamps)]) + axobs.wrap("weather", np.full(1, self.weather)) + axobs.wrap("site", np.full(1, "so")) + + if self.truncate_tod: + # FFT-truncate for faster fft ops + axobs.restrict("samps", [0, fft.fft_len(axobs.samps.count)]) + + # MLMapmaker.add_obs will apply deslope + # if self.deslope: + # utils.deslope(axobs.signal, w=5, inplace=True) + + if self.downsample != 1: + axobs = mm.downsample_obs(axobs, passinfo.downsample) + + # NOTE: Expected contents look like: + # >>> tod + # AxisManager(signal[dets,samps], timestamps[samps], readout_filter_cal[dets], + # mce_filter_params[6], iir_params[3,5], flags*[samps], boresight*[samps], + # array_data*[dets], pointofs*[dets], focal_plane*[dets], abscal[dets], + # timeconst[dets], glitch_flags[dets,samps], source_flags[dets,samps], + # relcal[dets], dets:LabelAxis(63), samps:OffsetAxis(372680)) + # >>> tod.focal_plane + # AxisManager(xi[dets], eta[dets], gamma[dets], dets:LabelAxis(63)) + # >>> tod.boresight + # AxisManager(az[samps], el[samps], roll[samps], samps:OffsetAxis(372680)) + + return axobs @function_timer - def _finalize(self, data, **kwargs): - # After multiple calls to exec, the finalize step will solve for the map. + def _init_mapmaker( + self, mapmaker, signal_map, mapmaker_prev, x_prev, comm, gcomm, prefix, + ): log = Logger.get() timer = Timer() - comm = data.comm.comm_world - gcomm = data.comm.comm_group timer.start() - self._mapmaker.prepare() + mapmaker.prepare() if self.tiled: - geo_work = self._mapmaker.signals[1].geo_work + # Each group reports how many tiles they are using + geo_work = mapmaker.signals[1].geo_work nactive = len(geo_work.active) ntile = np.prod(geo_work.shape[-2:]) - log.info_rank(f"{nactive} / {ntile} tiles active", comm=gcomm) + log.debug_rank(f"{nactive} / {ntile} tiles active", comm=gcomm) + if comm is not None: + comm.barrier() log.info_rank( f"MLMapmaker finished prepare in", comm=comm, timer=timer, ) - prefix = os.path.join(self.out_dir, f"{self.name}_") - # This will need to be modified for more general cases where we don't solve for # a sky map, or where we solve for multiple sky maps. The mapmaker itself supports it, # the problem is the direct access to the rhs, div and idiv members if self.write_rhs: - fname = self._signal_map.write(prefix, "rhs", self._signal_map.rhs) - log.info_rank(f"Wrote rhs to {fname}", comm=comm) + fname = f"{prefix}sky_rhs.fits" + if self.skip_existing and os.path.isfile(fname): + log.info_rank(f"Skipping existing rhs in {fname}", comm=comm) + else: + fname = signal_map.write(prefix, "rhs", signal_map.rhs) + log.info_rank(f"Wrote rhs to {fname}", comm=comm) if self.write_div: - #self._signal_map.write(prefix, "div", self._signal_map.div) - # FIXME : only writing the TT variance to avoid integer overflow in communication - fname = self._signal_map.write(prefix, "div", self._signal_map.div[0, 0]) - log.info_rank(f"Wrote div to {fname}", comm=comm) + fname = f"{prefix}sky_div.fits" + if self.skip_existing and os.path.isfile(fname): + log.info_rank(f"Skipping existing div in {fname}", comm=comm) + else: + # FIXME : only writing the TT variance to avoid integer overflow in communication + fname = signal_map.write(prefix, "div", signal_map.div) + # fname = signal_map.write(prefix, "div", signal_map.div[0, 0]) + log.info_rank(f"Wrote div to {fname}", comm=comm) if self.write_hits: - fname = self._signal_map.write(prefix, "hits", self._signal_map.hits) - log.info_rank(f"Wrote hits to {fname}", comm=comm) + fname = f"{prefix}sky_hits.fits" + if self.skip_existing and os.path.isfile(fname): + log.info_rank(f"Skipping existing div in {fname}", comm=comm) + else: + fname = signal_map.write(prefix, "hits", signal_map.hits) + log.info_rank(f"Wrote hits to {fname}", comm=comm) mmul = tilemap.map_mul if self.tiled else enmap.map_mul if self.write_bin: - fname = self._signal_map.write( - prefix, "bin", mmul(self._signal_map.idiv, self._signal_map.rhs) - ) - log.info_rank(f"Wrote bin to {fname}", comm=comm) + fname = f"{prefix}sky_bin.fits" + if self.skip_existing and os.path.isfile(fname): + log.info_rank(f"Skipping existing bin in {fname}", comm=comm) + else: + fname = signal_map.write( + prefix, "bin", mmul(signal_map.idiv, signal_map.rhs) + ) + log.info_rank(f"Wrote bin to {fname}", comm=comm) if comm is not None: comm.barrier() log.info_rank(f"MLMapmaker finished writing rhs, div, bin in", comm=comm, timer=timer) + # Set up initial condition + + if x_prev is None: + x0 = None + else: + x0 = mapmaker.translate(mapmaker_prev, x_prev) + + return x0 + + @function_timer + def _apply_mapmaker(self, mapmaker, x0, passinfo, prefix, comm): + log = Logger.get() + timer = Timer() + timer.start() tstep = Timer() tstep.start() - for step in self._mapmaker.solve(maxiter=self.maxiter, maxerr=self.maxerr): + if self.checkpoint_interval > 0: + fname_checkpoint = f"{prefix}checkpoint.{comm.rank:04}.hdf" + there = os.path.isfile(fname_checkpoint) + if there: + log.info_rank(f"Checkpoint detected. Will start from previous solver state", comm=comm) + else: + fname_checkpoint = None + + for step in mapmaker.solve( + maxiter=passinfo.maxiter, + maxerr=self.maxerr, + x0=x0, + fname_checkpoint=fname_checkpoint, + checkpoint_interval=self.checkpoint_interval, + ): if self.write_iter_map < 1: dump = False else: dump = step.i % self.write_iter_map == 0 - dstr = "" - if dump: - dstr = "(write)" - msg = f"CG step {step.i:4d} {step.err:15.7e} {dstr}" + msg = f"CG step {step.i:4d} {step.err:15.7e} write={dump}" log.info_rank(f"MLMapmaker {msg} ", comm=comm, timer=tstep) if dump: - for signal, val in zip(self._mapmaker.signals, step.x): + for signal, val in zip(mapmaker.signals, step.x): if signal.output: - fname = signal.write(prefix, "map%04d" % step.i, val) - log.info_rank(f"Wrote signal to {fname}", comm=comm) + fname = signal.write(prefix, f"map{step.i:04}", val) + log.info_rank(f"Wrote signal to {fname} in", comm=comm, timer=tstep) log.info_rank(f"MLMapmaker finished solve in", comm=comm, timer=timer) - for signal, val in zip(self._mapmaker.signals, step.x): + for signal, val in zip(mapmaker.signals, step.x): if signal.output: fname = signal.write(prefix, "map", val) log.info_rank(f"Wrote {fname}", comm=comm) @@ -540,6 +577,136 @@ def _finalize(self, data, **kwargs): comm.barrier() log.info_rank(f"MLMapmaker wrote map in", comm=comm, timer=timer) + return mapmaker, step.x + + @function_timer + def _exec(self, data, detectors=None, **kwargs): + log = Logger.get() + timer = Timer() + comm = data.comm.comm_world + gcomm = data.comm.comm_group + timer.start() + + if comm is None and self.tiled: + log.info("WARNING: Tiled mapmaking not supported without MPI.") + self.tiled = False + + for trait in ["area"]: + value = getattr(self, trait) + if value is None: + raise RuntimeError( + f"You must set `{trait}` before running MLMapmaker" + ) + + # nmat_type is guaranteed to be a valid Nmat class + noise_model = getattr(mm, self.nmat_type)() + + shape, wcs = enmap.read_map_geometry(self.area) + + if self.center_at is None: + recenter = None + else: + recenter = mm.parse_recentering(self.center_at) + dtype_tod = np.float32 + dtype_map = np.dtype(self.dtype_map) + + prefix = os.path.join(self.out_dir, f"{self.name}_") + + passes = self.setup_passes() + npass = len(passes) + mapmaker_prev = None + x_prev = None + + for ipass, passinfo in enumerate(passes): + # The multipass mapmaking loop + log.info_rank( + f"Starting pass {ipass + 1}/{npass}, maxit={passinfo.maxiter} " + f"down={passinfo.downsample}, interp={passinfo.interpol}", + comm=comm, + ) + if npass == 1: + pass_prefix = prefix + else: + pass_prefix = f"{prefix}pass{ipass + 1}_" + + signal_cut = mm.SignalCut(comm, dtype=dtype_tod) + signal_map = mm.SignalMap( + shape, + wcs, + comm, + comps=self.comps, + dtype=dtype_map, + recenter=recenter, + tiled=self.tiled, + interpol=passinfo.interpol, + ) + signals = [signal_cut, signal_map] + mapmaker = mm.MLMapmaker( + signals, noise_model=noise_model, dtype=dtype_tod, verbose=self.verbose + ) + + for ob in data.obs: + # Get the detectors we are using locally for this observation + dets = ob.select_local_detectors(detectors) + if len(dets) == 0: + # Nothing to do for this observation + continue + + nmat, nmat_file = self._load_noise_model(ob, npass, ipass, gcomm) + + axobs = self._wrap_obs(ob, dets, passinfo) + mapmaker.add_obs( + ob.name, + axobs, + deslope=self.deslope, + noise_model=nmat, + signal_estimate=None, + ) + del axobs + + self._save_noise_model(mapmaker, nmat, nmat_file, gcomm) + + # Optionally delete the input detector data to save memory, if + # the calling code knows that no additional operators will be + # used afterwards. + if ipass == npass - 1 and self.purge_det_data: + del ob.detdata[self.det_data] + + if comm is not None: + comm.barrier() + log.info_rank( + f"MLMapmaker wrapped observations in", + comm=comm, + timer=timer, + ) + + x0 = self._init_mapmaker( + mapmaker, + signal_map, + mapmaker_prev, + x_prev, + comm, + gcomm, + pass_prefix, + ) + mapmaker_prev, x_prev = self._apply_mapmaker( + mapmaker, x0, passinfo, pass_prefix, comm + ) + + # Save metadata, may get dropped later + + self.shape = shape + self.wcs = wcs + self.recenter = recenter + self.signal_map = signal_map + self.mapmaker = mapmaker + + return + + @function_timer + def _finalize(self, data, **kwargs): + pass + def _requires(self): req = { "meta": [self.noise_model], @@ -565,51 +732,3 @@ def _provides(self): def _accelerators(self): return list() - - -# class NmatToast(mm.Nmat): -# """Noise matrix class that uses a TOAST noise model. - -# This takes an existing TOAST noise model and uses it for a MLMapmaker compatible -# noise matrix. - -# Args: -# model (toast.Noise): The toast noise model. -# det_order (dict): The mapping from detector order in the AxisManager -# to name in the Noise object. - -# """ -# def __init__(self, model, n_sample, det_order): -# self.model = model -# self.det_order = det_order -# self.n_sample = n_sample - -# # Compute the radix-2 FFT length to use -# self.fftlen = 2 -# while self.fftlen <= self.n_sample: -# self.fftlen *= 2 -# self.npsd = self.fftlen // 2 + 1 - -# # Compute the time domain offset that centers our data within the -# # buffer -# self.padded_start = (self.fftlen - self.n_sample) // 2 - -# # Compute the common frequency values -# self.nyquist = model.freq(model.keys[0])[-1].to_value(u.Hz) -# self.rate = 2 * self.nyquist -# self.freqs = np.fft.rfftfreq(self.fftlen, 1 / self.rate)) - -# # Interpolate the PSDs to desired spacing and store for later -# # application. - -# def build(self, tod, **kwargs): -# """Build method is a no-op, we do all set up in the constructor.""" -# return self - -# def apply(self, tod, inplace=False): -# """Apply our noise filter to the TOD. - -# We use our pre-built Fourier domain kernels. - -# """ -# return tod diff --git a/sotodlib/toast/workflows/scripting.py b/sotodlib/toast/workflows/scripting.py index 239290b56..f216dc550 100644 --- a/sotodlib/toast/workflows/scripting.py +++ b/sotodlib/toast/workflows/scripting.py @@ -34,7 +34,7 @@ def load_or_simulate_observing(job, otherargs, runargs, comm): timer = toast.timing.Timer() timer.start() - if job_ops.sim_ground.enabled: + if job_ops.sim_ground.enabled or otherargs.schedule is not None: data = wrk.simulate_observing(job, otherargs, runargs, comm) wrk.select_pointing(job, otherargs, runargs, data) wrk.simple_noise_models(job, otherargs, runargs, data) diff --git a/sotodlib/toast/workflows/sim_observe.py b/sotodlib/toast/workflows/sim_observe.py index 4668606ef..5ba820eb2 100644 --- a/sotodlib/toast/workflows/sim_observe.py +++ b/sotodlib/toast/workflows/sim_observe.py @@ -83,6 +83,14 @@ def setup_simulate_observing(parser, operators): parser.add_argument( "--schedule", required=False, default=None, help="Input observing schedule" ) + parser.add_argument( + "--sort_schedule", + required=False, + default=False, + action="store_true", + help="Sort the observing schedule by mean boresight RA. " + "This can limit the area of sky each process group deals with.", + ) parser.add_argument( "--realization", required=False, @@ -142,10 +150,6 @@ def simulate_observing(job, otherargs, runargs, comm): # Configured operators for this job job_ops = job.operators - if not job_ops.sim_ground.enabled: - log.info_rank("Simulated observing is disabled", comm=comm) - return None - # Make sure we have the required bands and schedule. These might # not be set during a dry-run, but if we got this far they need to # be set. @@ -171,11 +175,21 @@ def simulate_observing(job, otherargs, runargs, comm): thinfp=otherargs.thinfp, comm=comm, ) + ndet = len(telescope.focalplane.detectors) + log.info_rank( + f" Simulated focalplane with {ndet} detectors " + f"(thinfp = {otherargs.thinfp}) in", + comm=comm, + timer=timer, + ) # Load the schedule file schedule = toast.schedule.GroundSchedule() schedule.read(otherargs.schedule, comm=comm) log.info_rank(" Loaded schedule in", comm=comm, timer=timer) + if otherargs.sort_schedule: + schedule.sort_by_RA() + log.info_rank(" Sorted schedule in", comm=comm, timer=timer) mem = toast.utils.memreport(msg="(whole node)", comm=comm, silent=True) log.info_rank(f" After loading schedule: {mem}", comm) @@ -203,12 +217,13 @@ def simulate_observing(job, otherargs, runargs, comm): # Simulate the telescope pointing job_ops.sim_ground.telescope = telescope + job_ops.sim_ground.enabled = True job_ops.sim_ground.schedule = schedule if job_ops.sim_ground.weather is None: job_ops.sim_ground.weather = telescope.site.name if otherargs.realization is not None: job_ops.sim_ground.realization = otherargs.realization - log.info_rank(" Running simulated observing...", comm=data.comm.comm_world) + log.info_rank(" Running simulated observing...", comm=comm) job_ops.sim_ground.apply(data) log.info_rank(" Simulated telescope pointing in", comm=comm, timer=timer) @@ -217,17 +232,13 @@ def simulate_observing(job, otherargs, runargs, comm): # Apply LAT co-rotation if job_ops.corotate_lat.enabled: - log.info_rank( - " Running simulated LAT corotation...", comm=data.comm.comm_world - ) + log.info_rank(" Running simulated LAT corotation...", comm=comm) job_ops.corotate_lat.apply(data) log.info_rank(" Apply LAT co-rotation in", comm=comm, timer=timer) # Perturb HWP spin if job_ops.perturb_hwp.enabled: - log.info_rank( - " Running simulated HWP perturbation...", comm=data.comm.comm_world - ) + log.info_rank(" Running simulated HWP perturbation...", comm=comm) job_ops.perturb_hwp.apply(data) log.info_rank(" Perturbed HWP rotation in", comm=comm, timer=timer) return data diff --git a/sotodlib/tod_ops/__init__.py b/sotodlib/tod_ops/__init__.py index ea4e2f6b1..57a82133f 100644 --- a/sotodlib/tod_ops/__init__.py +++ b/sotodlib/tod_ops/__init__.py @@ -11,3 +11,4 @@ from .sub_polyf import subscan_polyfilter from .azss import get_azss from .t2pleakage import get_t2p_coeffs, subtract_t2p +from . import deproject diff --git a/sotodlib/tod_ops/azss.py b/sotodlib/tod_ops/azss.py index c07ba79c7..e95f3ad43 100644 --- a/sotodlib/tod_ops/azss.py +++ b/sotodlib/tod_ops/azss.py @@ -144,7 +144,7 @@ def fit_azss(az, azss_stats, max_mode, fit_range=None): def get_azss(aman, signal='signal', az=None, range=None, bins=100, flags=None, apodize_edges=True, apodize_edges_samps=40000, apodize_flags=True, apodize_flags_samps=200, - apply_prefilt=True, prefilt_cfg=None, prefilt_detrend='linear', + apply_prefilt=True, prefilt_cfg=None, prefilt_detrend='linear', method='interpolate', max_mode=None, subtract_in_place=False, merge_stats=True, azss_stats_name='azss_stats', merge_model=True, azss_model_name='azss_model'): diff --git a/sotodlib/tod_ops/deproject.py b/sotodlib/tod_ops/deproject.py new file mode 100644 index 000000000..feb3271ca --- /dev/null +++ b/sotodlib/tod_ops/deproject.py @@ -0,0 +1,122 @@ +"""Module for deprojecting median Q/U from the data""" +import numpy as np +from sotodlib import core + +def get_qu_common_mode_coeffs(aman, Q_signal=None, U_signal=None, merge=False): + """ + Gets the median signal (template) and coefficients for the coupling to that + signal for each detector for both the Q and U signals. Returns an + AxisManager with the template and coefficients wrapped. + + Arguments: + ---------- + aman: AxisManager + Contains the signal to operate on. + Q_signal: ndarray or str + array or string with field in aman containing the demodulated Q signal. + U_signal: ndarray or str + array or string with field in aman containing the demodulated U signal. + merge: bool + If True wrap the returned AxisManager into aman. + + Returns: + -------- + output_aman: AxisManager + Contains the template signals for Q/U and coefficients coupling + each detector to the templates. + """ + if Q_signal is None: + Q_signal = aman['demodQ'] + if isinstance(Q_signal, str): + Q_signal = aman[Q_signal] + if not isinstance(Q_signal, np.ndarray): + raise TypeError("Signal is not an array") + + if U_signal is None: + U_signal = aman['demodU'] + if isinstance(U_signal, str): + U_signal = aman[U_signal] + if not isinstance(U_signal, np.ndarray): + raise TypeError("Signal is not an array") + + output_aman = core.AxisManager(aman.dets, aman.samps) + for sig, name in zip([Q_signal, U_signal], ['Q','U']): + coeffs, med = _get_qu_template(aman, sig, False) + output_aman.wrap(f'coeffs_{name}', coeffs[:,0], [(0, 'dets')]) + output_aman.wrap(f'med_{name}', med, [(0, 'samps')]) + if merge: + aman.wrap('qu_common_mode_coeffs', output_aman) + return output_aman + +def subtract_qu_common_mode(aman, Q_signal=None, U_signal=None, coeff_aman=None, + merge=False): + """ + Subtracts the median signal (template) from each detector scaled by the a + coupling coefficient per detector. + + Arguments: + ---------- + aman: AxisManager + Contains the signal to operate on. + Q_signal: ndarray or str + array or string with field in aman containing the demodulated Q signal. + U_signal: ndarray or str + array or string with field in aman containing the demodulated U signal. + coeff_aman: AxisManager + contains the coefficients and templates to use for subtraction. + See ``get_qu_common_mode_coeffs``. + merge: bool + If True wrap the returned AxisManager into aman. + """ + if Q_signal is None: + Q_signal = aman['demodQ'] + Q_signal_name = 'demodQ' + if isinstance(Q_signal, str): + Q_signal_name = Q_signal + Q_signal = aman[Q_signal] + if not isinstance(Q_signal, np.ndarray): + raise TypeError("Signal is not an array") + + if U_signal is None: + U_signal = aman['demodU'] + U_signal_name = 'demodU' + if isinstance(U_signal, str): + U_signal_name = U_signal + U_signal = aman[U_signal] + if not isinstance(U_signal, np.ndarray): + raise TypeError("Signal is not an array") + + if coeff_aman is None: + if 'QU_common_mode_coeffs' in aman: + coeff_aman = aman['QU_common_mode_coeffs'] + else: + coeff_aman = get_qu_common_mode_coeffs(aman, Q_signal, U_signal, merge) + + aman[Q_signal_name] -= np.atleast_2d(coeff_aman['coeffs_Q']).T*coeff_aman['med_Q'] + aman[U_signal_name] -= np.atleast_2d(coeff_aman['coeffs_U']).T*coeff_aman['med_U'] + + +def _get_qu_template(aman, signal, correct): + """ + Calculates coefficients and median for the given demodulated Q or U data + used for the deprojection. + + Parameters: + ----------- + aman : AxisManager + An AxisManager containing the demodulated Q and U components. + signal : str + The AxisManager field to access the specific signal in the aman object. + + Returns: + -------- + tuple: A tuple containing: + - coeffs (numpy.ndarray): The deprojected coefficients. + - med (numpy.ndarray): The median values of the input data along the first axis. + """ + med = np.median(signal, axis=0) + vects = np.atleast_2d(med) + I = np.linalg.inv(np.tensordot(vects, vects, (1, 1))) + coeffs = np.matmul(signal, vects.T) + coeffs = np.dot(I, coeffs.T).T + return coeffs, med diff --git a/sotodlib/tod_ops/fft_ops.py b/sotodlib/tod_ops/fft_ops.py index 4f173e8e6..9b64f43fd 100644 --- a/sotodlib/tod_ops/fft_ops.py +++ b/sotodlib/tod_ops/fft_ops.py @@ -213,7 +213,8 @@ def calc_psd( prefer='center', freq_spacing=None, merge=False, - overwrite=True, + overwrite=True, + subscan=False, **kwargs ): """Calculates the power spectrum density of an input signal using signal.welch(). @@ -234,6 +235,7 @@ def calc_psd( If an nperseg is explicitly passed then that will be used. merge (bool): if True merge results into axismanager. overwrite (bool): if true will overwrite f, pxx axes. + subscan (bool): if True, compute psd on subscans. **kwargs: keyword args to be passed to signal.welch(). Returns: @@ -242,34 +244,40 @@ def calc_psd( """ if signal is None: signal = aman.signal - if timestamps is None: - timestamps = aman.timestamps - - n_samps = signal.shape[-1] - if n_samps <= max_samples: - start = 0 - stop = n_samps + if subscan: + freqs, Pxx = _calc_psd_subscan(aman, signal=signal, freq_spacing=freq_spacing, **kwargs) + axis_map_pxx = [(0, "dets"), (1, "nusamps"), (2, "subscans")] else: - offset = n_samps - max_samples - if prefer == "left": - offset = 0 - elif prefer == "center": - offset //= 2 - elif prefer == "right": - pass - else: - raise ValueError(f"Invalid choise prefer='{prefer}'") - start = offset - stop = offset + max_samples - fs = 1 / np.nanmedian(np.diff(timestamps[start:stop])) - if "nperseg" not in kwargs: - if freq_spacing is not None: - nperseg = int(2 ** (np.around(np.log2(fs / freq_spacing)))) + if timestamps is None: + timestamps = aman.timestamps + + n_samps = signal.shape[-1] + if n_samps <= max_samples: + start = 0 + stop = n_samps else: - nperseg = int(2 ** (np.around(np.log2((stop - start) / 50.0)))) - kwargs["nperseg"] = nperseg + offset = n_samps - max_samples + if prefer == "left": + offset = 0 + elif prefer == "center": + offset //= 2 + elif prefer == "right": + pass + else: + raise ValueError(f"Invalid choice prefer='{prefer}'") + start = offset + stop = offset + max_samples + fs = 1 / np.nanmedian(np.diff(timestamps[start:stop])) + if "nperseg" not in kwargs: + if freq_spacing is not None: + nperseg = int(2 ** (np.around(np.log2(fs / freq_spacing)))) + else: + nperseg = int(2 ** (np.around(np.log2((stop - start) / 50.0)))) + kwargs["nperseg"] = nperseg + + freqs, Pxx = welch(signal[:, start:stop], fs, **kwargs) + axis_map_pxx = [(0, aman.dets), (1, "nusamps")] - freqs, Pxx = welch(signal[:, start:stop], fs, **kwargs) if merge: aman.merge( core.AxisManager(core.OffsetAxis("nusamps", len(freqs)))) if overwrite: @@ -278,9 +286,42 @@ def calc_psd( if "Pxx" in aman._fields: aman.move("Pxx", None) aman.wrap("freqs", freqs, [(0,"nusamps")]) - aman.wrap("Pxx", Pxx, [(0,"dets"),(1,"nusamps")]) + aman.wrap("Pxx", Pxx, axis_map_pxx) return freqs, Pxx +def _calc_psd_subscan(aman, signal=None, freq_spacing=None, **kwargs): + """ + Calculate the power spectrum density of subscans using signal.welch(). + Data defaults to aman.signal. aman.timestamps is used for times. + aman.subscan_info is used to identify subscans. + See calc_psd for arguments. + """ + from .flags import get_subscan_signal + if signal is None: + signal = aman.signal + + fs = 1 / np.nanmedian(np.diff(aman.timestamps)) + if "nperseg" not in kwargs: + if freq_spacing is not None: + nperseg = int(2 ** (np.around(np.log2(fs / freq_spacing)))) + else: + duration_samps = np.asarray([np.ptp(x.ranges()) if x.ranges().size > 0 else 0 for x in aman.subscan_info.subscan_flags]) + duration_samps = duration_samps[duration_samps > 0] + nperseg = int(2 ** (np.around(np.log2(np.median(duration_samps) / 4)))) + kwargs["nperseg"] = nperseg + + Pxx = [] + for iss in range(aman.subscan_info.subscans.count): + signal_ss = get_subscan_signal(aman, signal, iss) + axis = -1 if "axis" not in kwargs else kwargs["axis"] + if signal_ss.shape[axis] >= kwargs["nperseg"]: + freqs, pxx_sub = welch(signal_ss, fs, **kwargs) + Pxx.append(pxx_sub) + else: + Pxx.append(np.full((signal.shape[0], kwargs["nperseg"]//2+1), np.nan)) # Add nans if subscan is too short + Pxx = np.array(Pxx) + Pxx = Pxx.transpose(1, 2, 0) # Dets, nusamps, subscans + return freqs, Pxx def calc_wn(aman, pxx=None, freqs=None, low_f=5, high_f=10): """ @@ -346,13 +387,15 @@ def fit_noise_model( signal=None, f=None, pxx=None, - psdargs=None, + psdargs={}, fwhite=(10, 100), lowf=1, merge_fit=False, f_max=100, merge_name="noise_fit_stats", merge_psd=True, + freq_spacing=None, + subscan=False ): """ Fits noise model with white and 1/f noise to the PSD of signal. @@ -392,6 +435,10 @@ def fit_noise_model( If ``merge_fit`` is True then addes into axis manager with merge_name. merge_psd : bool If ``merg_psd`` is True then adds fres and Pxx to the axis manager. + freq_spacing : float + The approximate desired frequency spacing of the PSD. Passed to calc_psd. + subscan : bool + If True, fit noise on subscans. Returns ------- noise_fit_stats : AxisManager @@ -403,42 +450,48 @@ def fit_noise_model( signal = aman.signal if f is None or pxx is None: - if psdargs is None: - f, pxx = calc_psd( - aman, signal=signal, timestamps=aman.timestamps, merge=merge_psd - ) - else: - f, pxx = calc_psd( - aman, - signal=signal, - timestamps=aman.timestamps, - merge=merge_psd, - **psdargs, - ) - eix = np.argmin(np.abs(f - f_max)) - f = f[1:eix] - pxx = pxx[:, 1:eix] - - fitout = np.zeros((aman.dets.count, 3)) - # This is equal to np.sqrt(np.diag(cov)) when doing curve_fit - covout = np.zeros((aman.dets.count, 3, 3)) - for i in range(aman.dets.count): - p = pxx[i] - wnest = np.median(p[((f > fwhite[0]) & (f < fwhite[1]))]) - pfit = np.polyfit(np.log10(f[f < lowf]), np.log10(p[f < lowf]), 1) - fidx = np.argmin(np.abs(10 ** np.polyval(pfit, np.log10(f)) - wnest)) - p0 = [f[fidx], wnest, -pfit[0]] - bounds = [(0, None), (sys.float_info.min, None), (None, None)] - res = minimize(neglnlike, p0, args=(f, p), bounds=bounds, method="Nelder-Mead") - try: - Hfun = ndt.Hessian(lambda params: neglnlike(params, f, p), full_output=True) - hessian_ndt, _ = Hfun(res["x"]) - # Inverse of the hessian is an estimator of the covariance matrix - # sqrt of the diagonals gives you the standard errors. - covout[i] = np.linalg.inv(hessian_ndt) - except np.linalg.LinAlgError: - covout[i] = np.full((3, 3), np.nan) - fitout[i] = res.x + f, pxx = calc_psd( + aman, + signal=signal, + timestamps=aman.timestamps, + freq_spacing=freq_spacing, + merge=merge_psd, + subscan=subscan, + **psdargs, + ) + if subscan: + fitout, covout = _fit_noise_model_subscan(aman, signal, f, pxx, psdargs=psdargs, + fwhite=fwhite, lowf=lowf, f_max=f_max, + freq_spacing=freq_spacing) + axis_map_fit = [(0, "dets"), (1, "noise_model_coeffs"), (2, aman.subscans)] + axis_map_cov = [(0, "dets"), (1, "noise_model_coeffs"), (2, "noise_model_coeffs"), (3, aman.subscans)] + else: + eix = np.argmin(np.abs(f - f_max)) + f = f[1:eix] + pxx = pxx[:, 1:eix] + + fitout = np.zeros((aman.dets.count, 3)) + # This is equal to np.sqrt(np.diag(cov)) when doing curve_fit + covout = np.zeros((aman.dets.count, 3, 3)) + for i in range(aman.dets.count): + p = pxx[i] + wnest = np.median(p[((f > fwhite[0]) & (f < fwhite[1]))]) + pfit = np.polyfit(np.log10(f[f < lowf]), np.log10(p[f < lowf]), 1) + fidx = np.argmin(np.abs(10 ** np.polyval(pfit, np.log10(f)) - wnest)) + p0 = [f[fidx], wnest, -pfit[0]] + bounds = [(0, None), (sys.float_info.min, None), (None, None)] + res = minimize(neglnlike, p0, args=(f, p), bounds=bounds, method="Nelder-Mead") + try: + Hfun = ndt.Hessian(lambda params: neglnlike(params, f, p), full_output=True) + hessian_ndt, _ = Hfun(res["x"]) + # Inverse of the hessian is an estimator of the covariance matrix + # sqrt of the diagonals gives you the standard errors. + covout[i] = np.linalg.inv(hessian_ndt) + except np.linalg.LinAlgError: + covout[i] = np.full((3, 3), np.nan) + fitout[i] = res.x + axis_map_fit = [(0, "dets"), (1, "noise_model_coeffs")] + axis_map_cov = [(0, "dets"), (1, "noise_model_coeffs"), (2, "noise_model_coeffs")] noise_model_coeffs = ["fknee", "white_noise", "alpha"] noise_fit_stats = core.AxisManager( @@ -447,18 +500,46 @@ def fit_noise_model( name="noise_model_coeffs", vals=np.array(noise_model_coeffs, dtype=" iqr_range[:, None] * n_sig + + if subscan: + # We include turnarounds + subscan_indices = np.concatenate([aman.flags.left_scan.ranges(), (~aman.flags.left_scan).ranges()]) + else: + subscan_indices = np.array([[0, fvec.shape[1]]]) + + msk = np.zeros_like(fvec, dtype='bool') + for ss in subscan_indices: + iqr_range = 0.741 * stats.iqr(fvec[:,ss[0]:ss[1]:ds], axis=1) + # get flags + msk[:,ss[0]:ss[1]] = fvec[:,ss[0]:ss[1]] > iqr_range[:, None] * n_sig msk[:,:edge_guard] = False msk[:,-edge_guard:] = False flag = RangesMatrix([Ranges.from_bitmask(m) for m in msk]) @@ -523,6 +544,30 @@ def get_trending_flags(aman, return cut def get_dark_dets(aman, merge=True, overwrite=True, dark_flags_name='darks'): + """ + Identify and flag dark detectors in the given aman object. + + Parameters: + ---------- + aman : AxisManager + The tod. + merge : bool, optional + If True, merge the dark detector flags into the aman.flags. Default is True. + overwrite : bool, optional + If True, overwrite existing flags with the same name. Default is True. + dark_flags_name : str, optional + The name to use for the dark detector flags in aman.flags. Default is 'darks'. + + Returns: + ------- + mskdarks: RangesMatrix + A matrix of ranges indicating the dark detectors. + + Raises: + ------- + ValueError + If merge is True and dark_flags_name already exists in aman.flags and overwrite is False. + """ darks = np.array(aman.det_info.wafer.type != 'OPTC') x = Ranges(aman.samps.count) mskdarks = RangesMatrix([Ranges.ones_like(x) if Y @@ -538,17 +583,19 @@ def get_dark_dets(aman, merge=True, overwrite=True, dark_flags_name='darks'): return mskdarks -def get_source_flags(aman, merge=True, overwrite=True, source_flags_name='source_flags', +def get_source_flags(aman, merge=True, overwrite=True, source_flags_name=None, mask=None, center_on=None, res=None, max_pix=None): - if merge: - wrap = source_flags_name - else: - wrap = None + if res: - res = np.radians(res/60) - source_flags = coords.planets.compute_source_flags(tod=aman, wrap=wrap, mask=mask, center_on=center_on, res=res, max_pix=max_pix) - + res = np.radians(res/60) # config input in arcminutes + + source_flags = coords.planets.compute_source_flags(tod=aman, mask=mask, + center_on=center_on, + res=res, max_pix=max_pix) + if merge: + if source_flags_name is None: + source_flags_name = center_on if source_flags_name in aman.flags and not overwrite: raise ValueError(f"Flag name {source_flags_name} already exists in aman.flags") if source_flags_name in aman.flags: @@ -672,3 +719,284 @@ def get_inv_var_flags(aman, signal_name='signal', nsigma=5, aman.flags.wrap(inv_var_flag_name, mskinvar, [(0, 'dets'), (1, 'samps')]) return mskinvar + +def get_subscans(aman, merge=True, include_turnarounds=False, overwrite=True): + """ + Returns an axis manager with information about subscans. + This includes direction and a ranges matrix (subscans samps) + True inside each subscan. + + Parameters + ---------- + aman : AxisManager + Input AxisManager. + merge : bool + Merge into aman as 'subscan_info' + include_turnarounds : bool + Include turnarounds in the subscan ranges + overwrite : bool + If true, write over subscan_info. + + Returns + ------- + subscan_aman : AxisManager + AxisManager containing information about the subscans. + "direction" is a (subscans,) array of strings 'left' or 'right' + "subscan_flags" is a (subscans, samps) RangesMatrix; True inside the subscan. + """ + if not include_turnarounds: + ss_ind = (~aman.flags.turnarounds).ranges() # sliceable indices (first inclusive, last exclusive) for subscans + else: + left = aman.flags.left_scan.ranges() + right = aman.flags.right_scan.ranges() + start_left = 0 if (left[0,0] < right[0,0]) else 1 + ss_ind = np.empty((left.shape[0] + right.shape[0], 2), dtype=left.dtype) + ss_ind[start_left::2] = left + ss_ind[(start_left-1)%2::2] = right + + start_inds, end_inds = ss_ind.T + n_subscan = ss_ind.shape[0] + tt = aman.timestamps + subscan_aman = core.AxisManager(aman.samps, core.IndexAxis("subscans", n_subscan)) + + is_left = aman.flags.left_scan.mask()[start_inds] + subscan_aman.wrap('direction', np.array(['left' if is_left[ii] else 'right' for ii in range(n_subscan)]), [(0, 'subscans')]) + + rm = RangesMatrix([Ranges.from_array(np.atleast_2d(ss), tt.size) for ss in ss_ind]) + subscan_aman.wrap('subscan_flags', rm, [(0, 'subscans'), (1, 'samps')]) # True in the subscan + if merge: + name = 'subscan_info' + if overwrite and name in aman: + aman.move(name, None) + aman.wrap(name, subscan_aman) + return subscan_aman + +def get_subscan_signal(aman, arr, isub=None, trim=False): + """ + Split an array into subscans. + + Parameters + ---------- + aman : AxisManager + Input AxisManager. + arr : Array + Input array. + isub : int + Index of the desired subscan. May also be a list of indices. + If None, all are used. + trim : bool + Do not include size-zero arrays from empty subscans in the output. + + Returns + ------- + out : list + If isub is a scalar, return an Array of arr cut on the samps axis to the given subscan. + If isub is a list or None, return a list of such Arrays. + """ + if isinstance(arr, str): + arr = aman[arr] + if np.isscalar(isub): + out = apply_rng(arr, aman.subscan_info.subscan_flags[isub]) + if trim and out.size == 0: + out = None + else: + if isub is None: + isub = range(len(aman.subscan_info.subscan_flags)) + out = [apply_rng(arr, aman.subscan_info.subscan_flags[ii]) for ii in isub] + if trim: + out = [x for x in out if x.size > 0] + + return out + + +def apply_rng(arr, rng): + """ + Apply a Ranges object to an array. rng should be True on the samples you want to keep. + + Parameters + ---------- + arr : Array + Array containing the signal. Should have one axis of len (samps). + rng : Ranges + Ranges object of len (samps) selecting the desired range. + """ + if rng.ranges().size == 0: + slices = [slice(0,0)] # Return an empty array if rng is empty + else: + slices = [slice(*irng) for irng in rng.ranges()] + + # Identify the samps axis + isamps = np.where(np.array(arr.shape) == rng.count)[0] + if isamps.size != 1: + # Check for axis mismatch between arr and rng, or multiple axes with the same size + raise RuntimeError("Could not identify axis matching Ranges") + # Apply ranges + out = [] + for slc in slices: + ndslice = tuple((slice(None) if ii != isamps[0] else slc for ii in range(arr.ndim))) + out.append(arr[ndslice]) + return np.concatenate(out, axis=isamps[0]) + +def wrap_stats(aman, info_aman_name, info, info_names, merge=True): + """ + Wrap multiple stats into a new aman, checking for subscan information. Stats can be (dets,) or (dets, subscans). + + Parameters + ---------- + aman : AxisManager + Input AxisManager. + info_aman_name : str + Name for info_aman when wrapped into input. + info : Array + (stats, dets,) or (stats, dets, subscans) containing the information you want to wrap. + info_names : list + List of str names for each entry in the new aman. + merge : bool + If True merge info_aman into aman. + + Returns + ------- + info_aman : AxisManager + (dets,) or (dets, subscans) aman with a field for each item in info_names. + """ + info_names = np.atleast_1d(info_names) + info = np.atleast_2d(info) + if info.shape == (len(info_names), aman.dets.count): # (stats, dets) + if len(info_names) == aman.dets.count and aman.dets.count == aman.subscan_info.subscans.count: + raise RuntimeError("Cannot infer axis mapping") # Catch corner case + info_aman = core.AxisManager(aman.dets) + axmap = [(0, 'dets')] + + else: + info = np.atleast_3d(info) # (stats, dets, subscans) + info_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans) + axmap = [(0, 'dets'), (1, 'subscans')] + + for ii in range(len(info_names)): + info_aman.wrap(info_names[ii], info[ii], axmap) + if merge: + if info_aman_name in aman.keys(): + aman[info_aman_name].merge(info_aman) + else: + aman.wrap(info_aman_name, info_aman) + return info_aman + +def get_stats(aman, signal, stat_names, split_subscans=False, mask=None, name="stats", merge=False): + """ + Calculate basic statistics on a TOD or power spectrum. + + Parameters + ---------- + aman : AxisManager + Input AxisManager. + signal : Array + Input signal. Statistics will be computed over *axis 1*. + stat_names : list + List of strings identifying which statistics to run. + split_subscans : bool + If True statistics will be computed on subscans. Assumes aman.subscan_info exists already. + mask : Array + Mask to apply before computation. 1d array for advanced indexing (keep True), or a slice object. + name : str + Name of axis manager to add to aman if merge is True. + """ + stat_names = np.atleast_1d(stat_names) + fn_dict = {'mean': np.mean, 'median': np.median, 'ptp': np.ptp, 'std': np.std, + 'kurtosis': stats.kurtosis, 'skew': stats.skew} + + if isinstance(signal, str): + signal = aman[signal] + if split_subscans: + if mask is not None: + raise ValueError("Cannot mask samples and split subscans") + stats_arr = [] + for iss in range(aman.subscan_info.subscans.count): + data = get_subscan_signal(aman, signal, iss) + if data.size > 0: + stats_arr.append([fn_dict[name](data, axis=1) for name in stat_names]) # Samps axis assumed to be 1 + else: + stats_arr.append(np.full((len(stat_names), signal.shape[0]), np.nan)) # Add nans if subscan has been entirely cut + stats_arr = np.array(stats_arr).transpose(1, 2, 0) # stat, dets, subscan + else: + if mask is None: + mask = slice(None) + stats_arr = np.array([fn_dict[name](signal[:, mask], axis=1) for name in stat_names]) # Samps axis assumed to be 1 + + info_aman = wrap_stats(aman, name, stats_arr, stat_names, merge) + return info_aman + +def get_focalplane_flags(aman, merge=True, overwrite=True, invalid_flags_name='fp_flags'): + """ + Generate flags for invalid detectors in the focal plane. + The tod. + merge : bool + If true, merges the generated flag into aman. + overwrite : bool + If true, write over flag. If false, don't. + invalid_flags_name : str + Name of flag to add to aman.flags if merge is True. + + Returns + ------- + msk_invalid_fp : RangesMatrix + RangesMatrix of invalid detectors in the focal plane. + """ + # Available detectors in focalplane + xi_nan = np.isnan(aman.focal_plane.xi) + eta_nan = np.isnan(aman.focal_plane.eta) + gamma_nan = np.isnan(aman.focal_plane.gamma) + x = Ranges(aman.samps.count) + flag_invalid_fp = np.sum([xi_nan, eta_nan, gamma_nan], axis=0) != 0 + msk_invalid_fp = RangesMatrix([Ranges.ones_like(x) if Y else Ranges.zeros_like(x) for Y in flag_invalid_fp]) + + if merge: + if invalid_flags_name in aman.flags and not overwrite: + raise ValueError(f"Flag name {invalid_flags_name} already exists in aman.flags") + if invalid_flags_name in aman.flags: + aman.flags[invalid_flags_name] = msk_invalid_fp + else: + aman.flags.wrap(invalid_flags_name, msk_invalid_fp, [(0, 'dets'), (1, 'samps')]) + + return msk_invalid_fp + +def noise_fit_flags(aman, low_wn, high_wn, high_fk): + """ + Evaluate white noise and fknee cuts based on provided boundaries. + + Parameters: + aman : object + An object containing noise fit statistics and noise model coefficients. + low_wn : float or None + The lower boundary for white noise. If None, white noise flagging is skipped. + high_wn : float or None + The upper boundary for white noise. If None, white noise flagging is skipped. + high_fk : float or None + The upper boundary for fknee. If None, fknee flagging is skipped. + + Returns: + tuple or None + A tuple containing flags for valid white noise and fknee if both boundaries are provided. + If only one boundary is provided, returns the corresponding flag. + If no boundaries are provided, returns None. + """ + noise = aman.noise_fit_stats_signal.fit + fk = noise[:, 0] + wn = noise[:, 1] + if low_wn is None: + print(f"white noise boundaries are not defined, skipping.") + flag_valid_wn = None + else: + flag_valid_wn = (low_wn < wn * 1e6) & (wn * 1e6 < high_wn) + if high_fk is None: + print(f"fknee boundaries are not defined, skipping.") + flag_valid_fk = None + else: + flag_valid_fk = fk < high_fk + if low_wn is not None and high_fk is not None: + return flag_valid_wn, flag_valid_fk + elif low_wn is not None: + return flag_valid_wn + elif high_fk is not None: + return flag_valid_fk + else: + return None diff --git a/sotodlib/tod_ops/jumps.py b/sotodlib/tod_ops/jumps.py index 079bb0e11..0b2707353 100644 --- a/sotodlib/tod_ops/jumps.py +++ b/sotodlib/tod_ops/jumps.py @@ -1,26 +1,28 @@ -import concurrent.futures -import os from typing import Literal, Optional, Tuple, Union, cast, overload import numpy as np import scipy.ndimage as simg -import scipy.signal as sig import scipy.stats as ss from numpy.typing import NDArray -from pixell.utils import block_expand, block_reduce +from pixell.utils import block_expand, block_reduce, moveaxis from scipy.sparse import csr_array from skimage.restoration import denoise_tv_chambolle +from so3g import ( + matched_jumps, + matched_jumps64, + find_quantized_jumps, + find_quantized_jumps64, +) from so3g.proj import Ranges, RangesMatrix from sotodlib.core import AxisManager from ..flag_utils import _merge -NFUTURE = int(os.environ.get("NUM_FUTURES", min(32, int(os.cpu_count() or 0) + 4))) - def std_est( x: NDArray[np.floating], ds: int = 1, + win_size: int = 20, axis: int = -1, method: str = "median_unbiased", ) -> NDArray[np.floating]: @@ -32,7 +34,9 @@ def std_est( x: Data to compute standard deviation of. - ds: Downsample factor to use, does a naive slicing. + ds: Downsample factor to use, does a naive slicing in blocks of ``win_size``. + + win_size: Window size to downsample by. axis: The axis to compute along. @@ -44,12 +48,22 @@ def std_est( """ if ds > 2 * x.shape[axis]: ds = 1 - sl = [slice(None)] * len(x.shape) if ds > 1: - sl[axis] = slice(None, None, ds) + x = np.moveaxis(x, axis, -1) + x = x[..., : -1 * (x.shape[-1] % win_size)] + shape = list(x.shape) + [win_size] + shape[-2] = -1 + x = x.reshape(tuple(shape)) + x = np.moveaxis(x, -2, 0) + diff = np.diff(x[::ds], axis=-1) + diff = moveaxis(diff, 0, -2) + diff = diff.reshape(shape[:-1]) + diff = np.moveaxis(diff, -1, axis) + else: + diff = np.diff(x, axis=axis) # Find ~1 sigma limits of differenced data lims = np.quantile( - np.diff(x, axis=axis)[tuple(sl)], + diff, np.array([0.159, 0.841]), axis=axis, method=method, @@ -62,7 +76,7 @@ def _jumpfinder( x: NDArray[np.floating], min_size: Optional[Union[float, NDArray[np.floating]]] = None, win_size: int = 20, - nsigma: float = 25, + exact: bool = False, ) -> NDArray[np.bool_]: """ Matched filter jump finder. @@ -73,9 +87,10 @@ def _jumpfinder( min_size: The smallest jump size counted as a jump. - win_size: Size of window used by SG filter when peak finding. + win_size: Size of window used when peak finding. - nsigma: Number of sigma above the mean for something to be a peak. + exact: Flag only the jump locations if True. + If False flag the whole window (cheaper). Returns: @@ -89,80 +104,49 @@ def _jumpfinder( # and in the case of 2d data we find jumps along rows orig_shape = x.shape x = np.atleast_2d(x) + dtype = x.dtype.name + if len(x.shape) > 2: + raise ValueError("x may not have more than 2 dimensions") + if dtype == "float32": + matched_filt = matched_jumps + elif dtype == "float64": + matched_filt = matched_jumps64 + else: + raise TypeError("x must be float32 or float64") - jumps = np.zeros(x.shape, dtype=bool) + jumps = np.zeros(x.shape, dtype=bool, order="C") if x.shape[-1] < win_size: return jumps.reshape(orig_shape) - size_msk = (np.max(x, axis=-1) - np.min(x, axis=-1)) < min_size - if np.all(size_msk): - return jumps.reshape(orig_shape) - - # If std is basically 0 no need to check for jumps - std = np.std(x, axis=-1) - std_msk = np.isclose(std, 0.0) + np.isclose(std_est(x, ds=win_size, axis=-1), std) - - msk = ~(size_msk + std_msk) + msk = np.ptp(x, axis=-1) > min_size if not np.any(msk): return jumps.reshape(orig_shape) - # Take cumulative sum, this is equivalent to convolving with a step - x_step = np.cumsum(x[msk], axis=-1) - - # Smooth and take the second derivative - sg_x_step = np.abs(sig.savgol_filter(x_step, win_size, 2, deriv=2, axis=-1)) - - # Peaks should be jumps - # Doing the simple thing and looking for things much larger than the median - peaks = ( - sg_x_step - > ( - np.median(sg_x_step, axis=-1) - + nsigma * std_est(sg_x_step, ds=win_size, axis=-1) - )[..., None] - ) - if not np.any(peaks): - return jumps.reshape(orig_shape) - - # The peak may have multiple points above this criteria - peak_idx = np.where(peaks) - peak_idx_padded = peak_idx[1] + (x.shape[-1] + win_size) * peak_idx[0] - gaps = np.diff(peak_idx_padded) >= win_size - begins = np.insert(peak_idx_padded[1:][gaps], 0, peak_idx_padded[0]) - ends = np.append(peak_idx_padded[:-1][gaps], peak_idx_padded[-1]) - jump_idx = ((begins + ends) / 2).astype(int) + 1 - jump_rows = jump_idx // (x.shape[1] + win_size) - jump_cols = jump_idx % (x.shape[1] + win_size) - - # Estimate jump heights and get better positions - # TODO: Pad things to avoid np.diff annoyance - half_win = int(win_size / 2) - win_rows = np.repeat(jump_rows, 2 * half_win) - win_cols = np.repeat(jump_cols, 2 * half_win) + np.tile( - np.arange(-1 * half_win, half_win, dtype=int), len(jump_cols) - ) - win_cols = np.clip(win_cols, 0, x.shape[-1] - 3) - d2x_step = np.abs(np.diff(x_step, n=2, axis=-1))[win_rows, win_cols].reshape( - (len(jump_idx), 2 * half_win) - ) - jump_sizes = np.amax(d2x_step, axis=-1) - jump_cols = ( - win_cols.reshape(d2x_step.shape)[ - np.arange(len(jump_idx)), np.argmax(d2x_step, axis=-1) - ] - + 2 - ) - - # Make a jump size cut + # Flag with a matched filter + win_size += win_size % 2 # Odd win size adds a wierd phasing issue + _x = np.ascontiguousarray(x[msk]) + _jumps = np.ascontiguousarray(np.empty_like(_x), "int32") if isinstance(min_size, np.ndarray): - _min_size = min_size[jump_rows] + _min_size = min_size[msk].astype(_x.dtype) + elif min_size is None: + raise TypeError("min_size is None") else: - _min_size = min_size - size_cut = jump_sizes > _min_size - jump_rows = jump_rows[size_cut] - jump_cols = jump_cols[size_cut] - - jumps[np.flatnonzero(msk)[jump_rows], jump_cols] = True + _min_size = (min_size * np.ones(len(_x))).astype(_x.dtype) + matched_filt(_x, _jumps, _min_size, win_size) + jumps[msk] = _jumps > 0 + + if exact: + structure = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]]) + labels, _ = simg.label(jumps, structure) + peak_idx = np.array( + simg.maximum_position( + np.diff(_x, axis=-1, prepend=np.zeros(len(_x))), labels + ) + ) + jump_rows = [peak_idx[:, 0]] + jump_cols = peak_idx[:, 1] + jumps[:] = False + jumps[np.flatnonzero(msk)[jump_rows], jump_cols] = True return jumps.reshape(orig_shape) @@ -199,12 +183,12 @@ def jumpfix_subtract_heights( If inplace is True this is just a reference to x. """ - def _fix(i, jump_ranges, heights, x_fixed): + def _fix(jump_ranges, heights, x_fixed): for j, jump_range in enumerate(jump_ranges): for start, end in jump_range.ranges(): - _heights = heights[i + j, start:end] + _heights = heights[j, start:end] height = _heights[np.argmax(np.abs(_heights))] - x_fixed[i + j, int((start + end) / 2):] -= height + x_fixed[j, int((start + end) / 2) :] -= height x_fixed = x if not inplace: @@ -215,6 +199,8 @@ def _fix(i, jump_ranges, heights, x_fixed): jumps = RangesMatrix.from_mask(np.atleast_2d(jumps)) elif isinstance(jumps, Ranges): jumps = RangesMatrix.from_mask(np.atleast_2d(jumps.mask())) + if not isinstance(jumps, RangesMatrix): + raise TypeError("jumps not RangesMatrix or convertable to RangesMatrix") if heights is None: heights = estimate_heights(x_fixed, jumps.mask(), **kwargs) @@ -222,14 +208,7 @@ def _fix(i, jump_ranges, heights, x_fixed): heights = heights.toarray() heights = cast(NDArray[np.floating], heights) - nfuture = min(len(x_fixed), NFUTURE) - slices = [slice(i * nfuture, (i + 1) * nfuture) for i in range(nfuture)] - slices[-1] = slice(slices[-1].start, len(x_fixed)) - with concurrent.futures.ThreadPoolExecutor() as e: - _ = [ - e.submit(_fix, i, jumps.ranges[s], heights[s], x_fixed[s]) - for i, s in enumerate(slices) - ] + _fix(jumps.ranges, heights, x_fixed) return x_fixed.reshape(orig_shape) @@ -254,14 +233,16 @@ def _diff_buffed( win_size: int, make_step: bool, ) -> NDArray[np.floating]: - win_size = int(win_size) - pad = np.zeros((len(signal.shape), 2), dtype=int) - half_win = int(win_size / 2) - pad[-1, :] = half_win + win_size = int(win_size + win_size % 2) if jumps is not None and make_step: signal = _make_step(signal, jumps) - padded = np.pad(signal, pad, mode="edge") - diff_buffed = padded[..., win_size:] - padded[..., : (-1 * win_size)] + diff_buffed = np.empty_like(signal) + diff_buffed[..., :win_size] = 0 + diff_buffed[..., win_size:] = np.subtract( + signal[..., win_size:], + signal[..., : (-1 * win_size)], + out=diff_buffed[..., win_size:], + ) return diff_buffed @@ -342,11 +323,13 @@ def twopi_jumps( win_size=..., nsigma=..., atol=..., + max_tol=..., fix: Literal[True] = True, inplace=..., merge=..., overwrite=..., name=..., + ds=..., **filter_pars, ) -> Tuple[RangesMatrix, csr_array, NDArray[np.floating]]: ... @@ -359,11 +342,13 @@ def twopi_jumps( win_size=..., nsigma=..., atol=..., + max_tol=..., fix: Literal[False] = False, inplace=..., merge=..., overwrite=..., name=..., + ds=..., **filter_pars, ) -> Tuple[RangesMatrix, csr_array]: ... @@ -375,11 +360,13 @@ def twopi_jumps( win_size: int = 20, nsigma: float = 5.0, atol: Optional[Union[float, NDArray[np.floating]]] = None, + max_tol: float = 0.0314, fix: bool = True, inplace: bool = False, merge: bool = True, overwrite: bool = False, name: str = "jumps_2pi", + ds: int = 10, **filter_pars, ) -> Union[ Tuple[RangesMatrix, csr_array], Tuple[RangesMatrix, csr_array, NDArray[np.floating]] @@ -405,6 +392,9 @@ def twopi_jumps( If set to None, then nsigma times the WN level of the TOD is used. Note that in general this is faster than nsigma. + max_tol: Upper bound of the nsigma based thresh. + atol ignores this. + fix: If True the jumps will be fixed by adding N*2*pi at the jump locations. inplace: If True jumps will be fixed inplace. @@ -415,6 +405,8 @@ def twopi_jumps( name: String used to populate field in flagmanager if merge is True. + ds: Downsample factor used when computing noise level, the actual factor used is `ds*win_size`. + **filter_pars: Parameters to pass to _filter Returns: @@ -432,34 +424,36 @@ def twopi_jumps( if not isinstance(signal, np.ndarray): raise TypeError("Signal is not an array") if atol is None: - atol = nsigma * std_est(signal.astype(float), ds=win_size) - np.clip(atol, 1e-8, 1e-2) + atol = nsigma * std_est( + signal.astype(float), ds=win_size * ds, win_size=win_size + ) + np.clip(atol, 1e-8, max_tol) _signal = _filter(signal, **filter_pars) - diff_buffed = _diff_buffed(_signal, None, win_size, False) - - if isinstance(atol, int): - atol = float(atol) - if isinstance(atol, float): - ratio = np.abs(diff_buffed) / (2 * np.pi) - jumps = (np.abs(ratio - np.round(ratio, 0)) <= atol) & (ratio >= 0.5) - jumps[..., :win_size] = False - elif isinstance(atol, np.ndarray): - jumps = np.atleast_2d(np.zeros_like(signal, dtype=bool)) - diff_buffed = np.atleast_2d(diff_buffed) - if len(atol) != len(jumps): - raise ValueError(f"Non-scalar atol provided with length {len(atol)}") - ratio = np.abs(diff_buffed / (2 * np.pi)) - jumps = (np.abs(ratio - np.round(ratio, 0)) <= atol[..., None]) & (ratio >= 0.5) - jumps.reshape(signal.shape) - else: + _signal = np.atleast_2d(_signal) + if isinstance(atol, int) or isinstance(atol, float): + atol = np.ones(len(_signal), float) * float(atol) + elif np.isscalar(atol): raise TypeError(f"Invalid atol type: {type(atol)}") + if len(atol) != len(signal): + raise ValueError(f"Non-scalar atol provided with length {len(atol)}") + + _signal = np.ascontiguousarray(_signal) + heights = np.empty_like(_signal) + atol = np.ascontiguousarray(atol, dtype=_signal.dtype) + if _signal.dtype.name == "float32": + find_quantized_jumps(_signal, heights, atol, win_size, 2 * np.pi) + elif _signal.dtype.name == "float64": + find_quantized_jumps64(_signal, heights, atol, win_size, 2 * np.pi) + else: + raise TypeError("signal must be float32 or float64") + + # Shift things by half the window + heights = np.roll(heights, -1 * int(win_size / 2), -1) + heights[:, (-1 * int(win_size / 2)) :] = 0 + jumps = heights != 0 jump_ranges = RangesMatrix.from_mask(jumps).buffer(int(win_size / 2)) - jumps = jump_ranges.mask() - heights = estimate_heights( - signal, jumps, win_size=win_size, twopi=True, diff_buffed=diff_buffed - ) if merge: _merge(aman, jump_ranges, name, overwrite) @@ -596,16 +590,16 @@ def slow_jumps( def find_jumps( aman, signal=..., - max_iters=..., min_sigma=..., min_size=..., win_size=..., - nsigma=..., + exact=..., fix: Literal[False] = False, inplace=..., merge=..., overwrite=..., name=..., + ds=..., **filter_pars, ) -> Tuple[RangesMatrix, csr_array]: ... @@ -615,16 +609,16 @@ def find_jumps( def find_jumps( aman, signal=..., - max_iters=..., min_sigma=..., min_size=..., win_size=..., - nsigma=..., + exact=..., fix: Literal[True] = True, inplace=..., merge=..., overwrite=..., name=..., + ds=..., **filter_pars, ) -> Tuple[RangesMatrix, csr_array, NDArray[np.floating]]: ... @@ -633,16 +627,16 @@ def find_jumps( def find_jumps( aman: AxisManager, signal: Optional[NDArray[np.floating]] = None, - max_iters: int = 1, min_sigma: Optional[float] = None, min_size: Optional[Union[float, NDArray[np.floating]]] = None, win_size: int = 20, - nsigma: float = 25, + exact: bool = False, fix: bool = False, inplace: bool = False, merge: bool = True, overwrite: bool = False, name: str = "jumps", + ds: int = 10, **filter_pars, ) -> Union[ Tuple[RangesMatrix, csr_array], Tuple[RangesMatrix, csr_array, NDArray[np.floating]] @@ -668,10 +662,11 @@ def find_jumps( if set this will override min_sigma. If both min_sigma and min_size are None then the IQR is used as min_size. - win_size: Size of window used by SG filter when peak finding. + win_size: Size of window used when peak finding. Also used for height estimation, should be of order jump width. - nsigma: Number of sigma above the mean for something to be a peak. + exact: If True search for the exact jump location. + If False flag allow some undertainty within the window (cheaper). fix: Set to True to fix. @@ -683,6 +678,8 @@ def find_jumps( name: String used to populate field in flagmanager if merge is True. + ds: Downsample factor used when computing noise level, the actual factor used is `ds*win_size`. + **filter_pars: Parameters to pass to _filter Returns: @@ -702,35 +699,28 @@ def find_jumps( raise TypeError("Signal is not an array") orig_shape = signal.shape + _signal = _filter(signal, **filter_pars) + _signal = np.atleast_2d(_signal) + if len(orig_shape) > 2: raise ValueError("Jumpfinder only works on 1D or 2D data") if min_size is None and min_sigma is not None: - min_size = min_sigma * std_est(signal, ds=win_size, axis=-1) + min_size = min_sigma * std_est( + signal, ds=win_size * ds, win_size=win_size, axis=-1 + ) if min_size is None: raise ValueError("min_size is somehow still None") if isinstance(min_size, np.ndarray) and np.ndim(min_size) > 1: # type: ignore raise ValueError("min_size must be 1d or a scalar") elif isinstance(min_size, (float, int)): - min_size = float(min_size) * np.ones(len(signal)) + min_size = float(min_size) * np.ones(len(_signal)) - _signal = _filter(signal, **filter_pars) - if max_iters > 1: - _signal = signal.copy() - _signal = np.atleast_2d(_signal) - # Median subtract, if we don't do this then when we cumsum we get floats + # Mean subtract, if we don't do this then when we cumsum we get floats # that are too big and lack the precicion to find jumps well - _signal -= np.median(_signal, axis=-1)[..., None] - - nfuture = min(len(_signal), NFUTURE) - slices = [slice(i * nfuture, (i + 1) * nfuture) for i in range(nfuture)] - slices[-1] = slice(slices[-1].start, len(_signal)) - with concurrent.futures.ThreadPoolExecutor() as e: - jump_futures = [ - e.submit(_jumpfinder, _signal[s], min_size[s], win_size, nsigma) - for s in slices - ] - jumps = np.vstack([j.result() for j in jump_futures]).reshape(orig_shape) + _signal -= np.mean(_signal, axis=-1)[..., None] + + jumps = _jumpfinder(_signal, min_size, win_size, exact).reshape(orig_shape) jump_ranges = RangesMatrix.from_mask(jumps).buffer(int(win_size / 2)) jumps = jump_ranges.mask() diff --git a/sotodlib/tod_ops/pca.py b/sotodlib/tod_ops/pca.py index dd20e74d6..b78825255 100644 --- a/sotodlib/tod_ops/pca.py +++ b/sotodlib/tod_ops/pca.py @@ -364,4 +364,5 @@ def get_common_mode( raise ValueError("method flag must be median or average") if wrap is not None: tod.wrap(wrap, common_mode, [(0, 'samps')]) - return common_mode \ No newline at end of file + + return common_mode diff --git a/sotodlib/tod_ops/t2pleakage.py b/sotodlib/tod_ops/t2pleakage.py index 10cdf821f..24252e45c 100644 --- a/sotodlib/tod_ops/t2pleakage.py +++ b/sotodlib/tod_ops/t2pleakage.py @@ -183,11 +183,11 @@ def leakage_model(dT, AQ, AU, lamQ, lamU): yU = aman[U_sig_name][di][mask[di]][::ds_factor] try: - model = LmfitModel(leakage_model, independent_vars=['dT'], - weights=np.ones_like(x)/sigma_demod[di]) + model = LmfitModel(leakage_model, independent_vars=['dT']) params = model.make_params(AQ=np.median(yQ), AU=np.median(yU), lamQ=0., lamU=0.) - result = model.fit(yQ + 1j * yU, params, dT=x) + result = model.fit(yQ + 1j * yU, params, dT=x, + weights=np.ones_like(x)/sigma_demod[di]) A_Q_array[di] = result.params['AQ'].value A_U_array[di] = result.params['AU'].value lambda_Q_array[di] = result.params['lamQ'].value @@ -215,7 +215,6 @@ def leakage_model(dT, AQ, AU, lamQ, lamU): out_aman.wrap('AU', A_U_array, [(0, 'dets')]) out_aman.wrap('lamQ', lambda_Q_array, [(0, 'dets')]) out_aman.wrap('lamU', lambda_U_array, [(0, 'dets')]) - out_aman.wrap('AQ_error', A_Q_error, [(0, 'dets')]) out_aman.wrap('AU_error', A_U_error, [(0, 'dets')]) out_aman.wrap('lamQ_error', lambda_Q_error, [(0, 'dets')]) @@ -315,7 +314,6 @@ def subtract_t2p(aman, t2p_aman, T_signal=None): Temperature signal to scale and subtract from Q/U. Default is ``aman['dsT']``. """ - if T_signal is None: T_signal = aman['dsT'] @@ -326,4 +324,4 @@ def subtract_t2p(aman, t2p_aman, T_signal=None): aman.demodQ -= np.multiply(T_signal.T, t2p_aman.coeffsQ).T aman.demodU -= np.multiply(T_signal.T, t2p_aman.coeffsU).T else: - raise ValueError('no leakage coefficients found in axis manager') \ No newline at end of file + raise ValueError('no leakage coefficients found in axis manager') diff --git a/sotodlib/tod_ops/utils.py b/sotodlib/tod_ops/utils.py new file mode 100644 index 000000000..25b16a91c --- /dev/null +++ b/sotodlib/tod_ops/utils.py @@ -0,0 +1,73 @@ +""" +Generically useful utility functions. +""" +from typing import Optional +import numpy as np +from numpy.typing import NDArray +from so3g import block_moment, block_moment64 + + +def get_block_moment( + tod: NDArray[np.floating], + block_size: int, + moment: int = 1, + central: bool = True, + shift: int = 0, + output: Optional[NDArray[np.floating]] = None, +) -> NDArray[np.floating]: + """ + Compute the n'th moment of data in blocks along each row. + Note that the blocks are made to be exclusive, + so any samples left at the end will be in a smaller standalone block. + This is a wrapper around ``so3g.block_moment``. + + Arguments: + + tod: Data to compute the moment of. + Should be (ndet, nsamp) or (nsamp). + Must be float32 or float64. + + block_size: Size of block to use. + + moment: Which moment to compute. + Must be >= 1. + + central: If True compute the mean centered moment. + + shift: Sample to start the blocks at, will be 0 before this. + + output: Array to put the blocked moment into. + If provided must be the same shape as tod. + If None, will be intialized from tod. + Returns: + + block_moment: The blocked moment. + Will have the same shape as tod. + If output is provided it is modified in place and retured here. + """ + if not np.any(np.isfinite(tod)): + raise ValueError("Only finite values allowed in tod") + orig_shape = tod.shape + dtype = tod.dtype.name + tod = np.atleast_2d(tod) + if len(tod.shape) > 2: + raise ValueError("tod may not have more than 2 dimensions") + if dtype not in ["float32", "float64"]: + raise TypeError("tod must be float32 or float64") + + if output is None: + output = np.ascontiguousarray(np.empty_like(tod)) + if output.shape != tod.shape: + raise ValueError("output shape does not match tod") + if output.dtype.name != dtype: + raise TypeError("output type does not match tod") + + if moment < 1: + raise ValueError("moment must be at least 1") + + if dtype == "float32": + block_moment(tod, output, block_size, moment, central, shift) + else: + block_moment64(tod, output, block_size, moment, central, shift) + + return output.reshape(orig_shape) diff --git a/tests/test_coords.py b/tests/test_coords.py index a5e46c87b..143a448a3 100644 --- a/tests/test_coords.py +++ b/tests/test_coords.py @@ -185,6 +185,33 @@ def test_cover(self): atol=R*0.05) self.assertEqual(len(xi), 16) + # Works with nans? + xy[0,0] = np.nan + coords.helpers.get_focal_plane_cover(xieta=xy) + + # Exclude dets using det_weights? + det_weights = np.ones(xy.shape[1]) + det_weights[3:34] = 0. + for dtype in ['float', 'int', 'bool']: + coords.helpers.get_focal_plane_cover( + xieta=xy, det_weights=det_weights.astype(dtype)) + + # Works for only a single det? + det_weights[2:] = 0. + (xi0, eta0), R0, _ = \ + coords.helpers.get_focal_plane_cover(xieta=xy, det_weights=det_weights) + + # Fails if all dets excluded somehow? + det_weights[1] = 0. + with self.assertRaises(ValueError): + coords.helpers.get_focal_plane_cover(xieta=xy, det_weights=det_weights) + + # Fails with all nans? + xy[1,1:] = np.nan + with self.assertRaises(ValueError): + coords.helpers.get_focal_plane_cover(xieta=xy) + + class OpticsTest(unittest.TestCase): def test_sat_fp(self): x = np.array([-100, 0, 100]) diff --git a/tests/test_core.py b/tests/test_core.py index 86a3d1b06..af6e0cd7c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,7 +1,6 @@ import unittest import tempfile import os -import shutil import numpy as np import astropy.units as u @@ -66,7 +65,7 @@ def test_130_not_inplace(self): # This should return a separate thing. rman = aman.restrict('samps', (10, 30), in_place=False) - #self.assertNotEqual(aman.a1[0], 0.) + # self.assertNotEqual(aman.a1[0], 0.) self.assertEqual(len(aman.a1), 100) self.assertEqual(len(rman.a1), 20) self.assertNotEqual(aman.a1[10], 0.) @@ -190,23 +189,23 @@ def test_170_concat(self): # ... other_fields="exact" aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + ## add scalars amanA.wrap("ans", 42) amanB.wrap("ans", 42) aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" amanB.azimuth[:] = 2. with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" and arrays of different shapes amanB.move("azimuth", None) amanB.wrap("azimuth", np.array([43,5,2,3])) with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="fail" amanB.move("azimuth",None) amanB.wrap_new('azimuth', shape=('samps',))[:] = 2. @@ -269,6 +268,64 @@ def test_180_overwrite(self): self.assertNotEqual(aman.a1[2,11], 0) self.assertNotEqual(aman.a1[1,10], 1.) + def test_190_get_set(self): + dets = ["det0", "det1", "det2"] + n, ofs = 1000, 0 + aman = core.AxisManager( + core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs) + ) + child = core.AxisManager( + core.LabelAxis("dets", dets + ["det3"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) + + child2 = core.AxisManager( + core.LabelAxis("dets2", ["det4", "det5"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) + child2.wrap("tod", np.zeros((2, 1000))) + aman.wrap("child", child) + aman["child"].wrap("child2", child2) + self.assertEqual(aman["child.child2.dets2"].count, 2) + self.assertEqual(aman["child.dets"].name, "dets") + np.testing.assert_array_equal( + aman["child.child2.dets2"].vals, np.array(["det4", "det5"]) + ) + self.assertEqual(aman["child.child2.samps"].count, n // 2) + self.assertEqual(aman["child.child2.samps"].offset, 0) + self.assertEqual( + aman["child.child2.samps"].count, aman.child.child2.samps.count + ) + self.assertEqual( + aman["child.child2.samps"].offset, aman.child.child2.samps.offset + ) + + np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2, 1000))) + + with self.assertRaises(KeyError): + aman["child2"] + + with self.assertRaises(AttributeError): + aman["child.dets.an_extra_layer"] + + self.assertIn("child.dets", aman) + self.assertIn("child.dets2", aman) # I am not sure why this is true + self.assertNotIn("child.child2.someentry", aman) + self.assertNotIn("child.child2.someentry.someotherentry", aman) + + with self.assertRaises(ValueError): + aman["child"] = child2 + + new_tods = np.ones((2, 500)) + aman.child.child2.tod = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500))) + + new_tods = np.ones((2, 1500)) + aman["child.child2.tod"] = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500))) + # Multi-dimensional restrictions. def test_200_multid(self): diff --git a/tests/test_mapmaker_pointing.py b/tests/test_mapmaker_pointing.py index 500d0300b..7e46939bb 100644 --- a/tests/test_mapmaker_pointing.py +++ b/tests/test_mapmaker_pointing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Simons Observatory. +# Copyright (c) 2023-2024 Simons Observatory. # Full license can be found in the top level "LICENSE" file. """Check that pointing expanded with TOAST is compatible with the MLMapmaker @@ -26,7 +26,7 @@ if toast_available: import healpy as hp -from ._helpers import calibration_schedule, close_data_and_comm, simulation_test_data +from . import _helpers as helpers class MapmakerPointingTest(unittest.TestCase): @@ -44,7 +44,7 @@ def test_mapmaker_pointing(self): testdir = tempfile.TemporaryDirectory() comm, procs, rank = toast.get_world() - data = simulation_test_data( + data = helpers.simulation_test_data( comm, telescope_name=None, wafer_slot="w00", @@ -134,7 +134,7 @@ def test_mapmaker_pointing(self): out_dir=testdir.name, comps="TQU", nmat_type="Nmat", - maxiter=3, + maxiter=[3], truncate_tod=False, write_hits=True, write_rhs=False, @@ -147,10 +147,10 @@ def test_mapmaker_pointing(self): if rank == 0: # Direct comparison of pointing obs = data.obs[0] - pmap = mapmaker._signal_map.data[obs.name].pmap + pmap = mapmaker.signal_map.data[obs.name].pmap det_quats = pmap._get_asm().dets coords = np.array(pmap.sight.coords(det_quats)) - dets = mapmaker._mapmaker.data[0].dets + dets = mapmaker.mapmaker.data[0].dets ndet = len(dets) pointing.apply(data) @@ -240,7 +240,7 @@ def test_mapmaker_pointing(self): if np.abs(means[2]) > tol: raise RuntimeError("Found non-zero U") - close_data_and_comm(data) + helpers.close_data_and_comm(data) if __name__ == '__main__': unittest.main() diff --git a/tests/test_pmat.py b/tests/test_pmat.py index 32e5dae4e..32a6cdeb0 100644 --- a/tests/test_pmat.py +++ b/tests/test_pmat.py @@ -24,8 +24,8 @@ def test_pmat_rectpix(self): comps = 'T' out = run_test(obs, (shape, wcs), comps, None, False, False) # Basic _ = run_test(obs, None, comps, wcs, False, False) # Use wcs_kernel - #out2 = run_test(obs, tilemap.geometry(shape, wcs, tile_shape=(100, 100)), comps, None, False, True) # Tiled - #assert np.array_equal(out, out2) + out2 = run_test(obs, tilemap.geometry(shape, wcs, tile_shape=(100, 100)), comps, None, False, True) # Tiled + assert np.array_equal(out, out2) def test_pmat_healpix(self): obs = quick_tod(10, 10000) @@ -57,6 +57,22 @@ def run_test(obs, geom, comps, wcs_kernel, is_healpix, is_tiled): tod = pmat.from_map(remove_weights) TOL = 1e-9 assert np.all(np.abs(tod-obs.signal) < TOL) + + # Confirm we can do map-space ops without a pointing op first + pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=geom, wcs_kernel=wcs_kernel) + _ = pmat.to_inverse_weights(weights) + pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=geom, wcs_kernel=wcs_kernel) + _ = pmat.remove_weights(weights) + + # Confirm from_map works on uninitialized pmat + pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=geom, wcs_kernel=wcs_kernel) + tod2 = pmat.from_map(remove_weights) + assert np.all(np.abs(tod - tod2) < TOL) + + # And also zeros. + pmat = coords.pmat.P.for_tod(obs, comps=comps, geom=geom, wcs_kernel=wcs_kernel) + pmat.zeros() + if is_tiled: if is_healpix: remove_weights = hp_utils.tiled_to_full(remove_weights) diff --git a/tests/test_toast_workflow.py b/tests/test_toast_workflow.py index cc73dea3f..3d47bf916 100644 --- a/tests/test_toast_workflow.py +++ b/tests/test_toast_workflow.py @@ -60,7 +60,7 @@ def test_workflow_config(self): operators=operators, opts={ "sim_atmosphere.enable": True, - "sim_atmosphere.xstep": "10.0 m", + "sim_atmosphere.xstep": "Quantity('10.0 m')", }, ) diff --git a/tests/test_tod_ops.py b/tests/test_tod_ops.py index 30a732b30..831d215a8 100644 --- a/tests/test_tod_ops.py +++ b/tests/test_tod_ops.py @@ -293,27 +293,11 @@ def test_jumpfinder(self): tod.wrap('sig_jumps', sig_jumps, [(0, 'samps')]) # Find jumps without filtering - jumps_nf, _ = tod_ops.jumps.find_jumps(tod, signal=tod.sig_jumps, min_size=5) + jumps_nf, _ = tod_ops.jumps.find_jumps(tod, signal=tod.sig_jumps, min_size=5, win_size=23) jumps_nf = jumps_nf.ranges().flatten() - # Find jumps with TV filtering - jumps_tv, _ = tod_ops.jumps.find_jumps(tod, signal=tod.sig_jumps, tv_weight=.5, min_size=5) - jumps_tv = jumps_tv.ranges().flatten() - - # Find jumps with gaussian filtering - jumps_gauss, _ = tod_ops.jumps.find_jumps(tod, signal=tod.sig_jumps, gaussian_width=.5, min_size=5) - jumps_gauss = jumps_gauss.ranges().flatten() - # Remove double counted jumps and round to remove uncertainty jumps_nf = np.unique(np.round(jumps_nf, -2)) - jumps_tv = np.unique(np.round(jumps_tv, -2)) - jumps_gauss = np.unique(np.round(jumps_gauss, -2)) - - # Check that all methods agree - self.assertEqual(len(jumps_tv), len(jumps_gauss)) - self.assertTrue(np.all(np.abs(jumps_tv - jumps_gauss) == 0)) - self.assertEqual(len(jumps_nf), len(jumps_gauss)) - self.assertTrue(np.all(np.abs(jumps_nf - jumps_gauss) == 0)) # Check that they agree with the input self.assertEqual(len(jump_locs), len(jumps_nf))