Skip to content

Commit

Permalink
Move outlier detection utility functions from jwst to stcal (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Jul 30, 2024
2 parents 2474152 + 2112773 commit 809de23
Show file tree
Hide file tree
Showing 10 changed files with 548 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ General
Changes to API
--------------

-
- Add ``outlier_detection`` submodule with ``utils`` included
from jwst. [#270]

Bug Fixes
---------
Expand Down
4 changes: 4 additions & 0 deletions docs/stcal/outlier_detection/description.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Description
============

This sub-package contains functions useful for outlier detection.
12 changes: 12 additions & 0 deletions docs/stcal/outlier_detection/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _outlier_detection:

=======================
Outlier Detection Utils
=======================

.. toctree::
:maxdepth: 2

description.rst

.. automodapi:: stcal.outlier_detection.utils
1 change: 1 addition & 0 deletions docs/stcal/package_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Package Index
ramp_fitting/index.rst
alignment/index.rst
tweakreg/index.rst
outlier_detection/index.rst
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ classifiers = [
]
dependencies = [
"astropy >=5.0.4",
"drizzle>=1.15.0",
"scipy >=1.7.2",
"scikit-image>=0.19",
"numpy >=1.21.2",
"opencv-python-headless >=4.6.0.66",
"asdf >=2.15.0",
Expand Down Expand Up @@ -209,6 +211,7 @@ module = [
"stdatamodels.*",
"asdf.*",
"scipy.*",
"drizzle.*",
# don't complain about the installed c parts of this library
"stcal.ramp_fitting.ols_cas22._fit",
"stcal.ramp_fitting.ols_cas22._jump",
Expand Down
Empty file.
339 changes: 339 additions & 0 deletions src/stcal/outlier_detection/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
"""
Utility functions for outlier detection routines
"""
import warnings

import numpy as np
from astropy.stats import sigma_clip
from drizzle.cdrizzle import tblot
from scipy import ndimage
from skimage.util import view_as_windows
import gwcs

from stcal.alignment.util import wcs_bbox_from_shape

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


__all__ = [
"medfilt",
"compute_weight_threshold",
"flag_crs",
"flag_resampled_crs",
"gwcs_blot",
"calc_gwcs_pixmap",
"reproject",
]


def medfilt(arr, kern_size):
"""
scipy.signal.medfilt (and many other median filters) have undefined behavior
for nan inputs. See: https://github.com/scipy/scipy/issues/4800
Parameters
----------
arr : numpy.ndarray
The input array
kern_size : list of int
List of kernel dimensions, length must be equal to arr.ndim.
Returns
-------
filtered_arr : numpy.ndarray
Input array median filtered with a kernel of size kern_size
"""
padded = np.pad(arr, [[k // 2] for k in kern_size])
windows = view_as_windows(padded, kern_size, np.ones(len(kern_size), dtype='int'))
return np.nanmedian(windows, axis=np.arange(-len(kern_size), 0))


def compute_weight_threshold(weight, maskpt):
'''
Compute the weight threshold for a single image or cube.
Parameters
----------
weight : numpy.ndarray
The weight array
maskpt : float
The percentage of the mean weight to use as a threshold for masking.
Returns
-------
float
The weight threshold for this integration.
'''
# necessary in order to assure that mask gets applied correctly
if hasattr(weight, '_mask'):
del weight._mask
mask_zero_weight = np.equal(weight, 0.)
mask_nans = np.isnan(weight)
# Combine the masks
weight_masked = np.ma.array(weight, mask=np.logical_or(
mask_zero_weight, mask_nans))
# Sigma-clip the unmasked data
weight_masked = sigma_clip(weight_masked, sigma=3, maxiters=5)
mean_weight = np.mean(weight_masked)
# Mask pixels where weight falls below maskpt percent
weight_threshold = mean_weight * maskpt
return weight_threshold


def _abs_deriv(array):
"""
Do not use this function.
Take the absolute derivative of a numpy array.
This function assumes off-edge pixel values are 0
and leads to erroneous derivative values and should
likely not be used.
"""
tmp = np.zeros(array.shape, dtype=np.float64)
out = np.zeros(array.shape, dtype=np.float64)

tmp[1:, :] = array[:-1, :]
tmp, out = _absolute_subtract(array, tmp, out)
tmp[:-1, :] = array[1:, :]
tmp, out = _absolute_subtract(array, tmp, out)

tmp[:, 1:] = array[:, :-1]
tmp, out = _absolute_subtract(array, tmp, out)
tmp[:, :-1] = array[:, 1:]
tmp, out = _absolute_subtract(array, tmp, out)

return out


def _absolute_subtract(array, tmp, out):
"""
Do not use this function.
A helper function for _abs_deriv.
"""
tmp = np.abs(array - tmp)
out = np.maximum(tmp, out)
tmp = tmp * 0.
return tmp, out


def flag_crs(
sci_data,
sci_err,
blot_data,
snr,
):
"""
Straightforward detection of outliers for non-dithered data since
sci_err includes all noise sources (photon, read, and flat for baseline).
Parameters
----------
sci_data : numpy.ndarray
"Science" data possibly containing outliers.
sci_err : numpy.ndarray
Error estimates for sci_data.
blot_data : numpy.ndarray
Reference data used to detect outliers.
snr : float
Signal-to-noise ratio used during detection.
Returns
-------
cr_mask : numpy.ndarray
Boolean array where outliers (CRs) are true.
"""
return np.greater(np.abs(sci_data - blot_data), snr * np.nan_to_num(sci_err))


def flag_resampled_crs(
sci_data,
sci_err,
blot_data,
snr1,
snr2,
scale1,
scale2,
backg,
):
"""
Detect outliers (CRs) using resampled reference data.
Parameters
----------
sci_data : numpy.ndarray
"Science" data possibly containing outliers
sci_err : numpy.ndarray
Error estimates for sci_data
blot_data : numpy.ndarray
Reference data used to detect outliers.
snr1 : float
Signal-to-noise ratio threshold used prior to smoothing.
snr2 : float
Signal-to-noise ratio threshold used after smoothing.
scale1 : float
Scale used prior to smoothing.
scale2 : float
Scale used after smoothing.
backg : float
Scalar background to subtract from the difference.
Returns
-------
cr_mask : numpy.ndarray
boolean array where outliers (CRs) are true
"""
err_data = np.nan_to_num(sci_err)

blot_deriv = _abs_deriv(blot_data)
diff_noise = np.abs(sci_data - blot_data - backg)

# Create a boolean mask based on a scaled version of
# the derivative image (dealing with interpolating issues?)
# and the standard n*sigma above the noise
threshold1 = scale1 * blot_deriv + snr1 * err_data
mask1 = np.greater(diff_noise, threshold1)

# Smooth the boolean mask with a 3x3 boxcar kernel
kernel = np.ones((3, 3), dtype=int)
mask1_smoothed = ndimage.convolve(mask1, kernel, mode='nearest')

# Create a 2nd boolean mask based on the 2nd set of
# scale and threshold values
threshold2 = scale2 * blot_deriv + snr2 * err_data
mask2 = np.greater(diff_noise, threshold2)

# Final boolean mask
return mask1_smoothed & mask2


def gwcs_blot(median_data, median_wcs, blot_shape, blot_wcs, pix_ratio):
"""
Resample the median data to recreate an input image based on
the blot wcs.
Parameters
----------
median_data : numpy.ndarray
The data to blot.
median_wcs : gwcs.wcs.WCS
The wcs for the median data.
blot_shape : list of int
The target blot data shape.
blot_wcs : gwcs.wcs.WCS
The target/blotted wcs.
pix_ratio : float
Pixel ratio.
Returns
-------
blotted : numpy.ndarray
The blotted median data.
blot_img : datamodel
Datamodel containing header and WCS to define the 'blotted' image
"""
# Compute the mapping between the input and output pixel coordinates
pixmap = calc_gwcs_pixmap(blot_wcs, median_wcs, blot_shape)
log.debug("Pixmap shape: {}".format(pixmap[:, :, 0].shape))
log.debug("Sci shape: {}".format(blot_shape))
log.info('Blotting {} <-- {}'.format(blot_shape, median_data.shape))

outsci = np.zeros(blot_shape, dtype=np.float32)

# Currently tblot cannot handle nans in the pixmap, so we need to give some
# other value. -1 is not optimal and may have side effects. But this is
# what we've been doing up until now, so more investigation is needed
# before a change is made. Preferably, fix tblot in drizzle.
pixmap[np.isnan(pixmap)] = -1
tblot(median_data, pixmap, outsci, scale=pix_ratio, kscale=1.0,
interp='linear', exptime=1.0, misval=0.0, sinscl=1.0)

return outsci


def calc_gwcs_pixmap(in_wcs, out_wcs, in_shape):
"""
Return a pixel grid map from input frame to output frame.
Parameters
----------
in_wcs : gwcs.wcs.WCS
Input/source wcs.
out_wcs : gwcs.wcs.WCS
Output/projected wcs.
in_shape : list of int
Input shape used to compute the input bounding box.
Returns
-------
pixmap : numpy.ndarray
Computed pixmap.
"""
bb = wcs_bbox_from_shape(in_shape)
log.debug("Bounding box from data shape: {}".format(bb))

grid = gwcs.wcstools.grid_from_bounding_box(bb)
return np.dstack(reproject(in_wcs, out_wcs)(grid[0], grid[1]))


def reproject(wcs1, wcs2):
"""
Given two WCSs return a function which takes pixel
coordinates in wcs1 and computes them in wcs2.
It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.
Parameters
----------
wcs1, wcs2 : gwcs.wcs.WCS
WCS objects that have `pixel_to_world_values` and `world_to_pixel_values`
methods.
Returns
-------
_reproject :
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

try:
forward_transform = wcs1.pixel_to_world_values
backward_transform = wcs2.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err

def _reproject(x, y):
sky = forward_transform(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
det = backward_transform(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
return tuple(det_reshaped)
return _reproject
Empty file.
Loading

0 comments on commit 809de23

Please sign in to comment.