Skip to content

Commit

Permalink
Fix integration with rest of stcal
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Oct 4, 2023
1 parent 7b2bc1f commit e79bb89
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/stcal/ramp_fitting/ols_cas22/_fit_ramps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def fit_ramps(np.ndarray[float, ndim=2] resultants,
the resultants in electrons
dq : np.ndarry[n_resultants, n_pixel]
the dq array. dq != 0 implies bad pixel / CR.
read noise : float
the read noise in electrons
read_noise : np.ndarray[n_pixel]
the read noise in electrons for each pixel
read_time : float
Time to perform a readout. For Roman data, this is FRAME_TIME.
read_pattern : list[list[int]]
Expand Down
15 changes: 7 additions & 8 deletions src/stcal/ramp_fitting/ols_cas22_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import numpy as np

from . import ols_cas22
from .ols_cas22_util import ma_table_to_tau, ma_table_to_tbar, readpattern_to_matable
from .ols_cas22_util import ma_table_to_tau, ma_table_to_tbar, read_pattern_to_ma_table, ma_table_to_read_pattern


def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, read_pattern=None, use_jump=False):
Expand Down Expand Up @@ -76,7 +76,7 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re
# Get the Multi-accum table, either as given or from the read pattern
if read_pattern is None:
if ma_table is not None:
read_pattern = ma_table_to_readpattern(ma_table)
read_pattern = ma_table_to_read_pattern(ma_table)
if read_pattern is None:
raise RuntimeError('One of `ma_table` or `read_pattern` must be given.')

Expand All @@ -86,8 +86,7 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re

resultants = np.array(resultants).astype(np.float32)

dq = np.array(dq).astype(np.float32)

dq = np.array(dq).astype(np.int32)
if np.ndim(read_noise) <= 1:
read_noise = read_noise * np.ones(resultants.shape[1:])
read_noise = np.array(read_noise).astype(np.float32)
Expand All @@ -112,12 +111,12 @@ def fit_ramps_casertano(resultants, dq, read_noise, read_time, ma_table=None, re

# Extract the data request from the ramp fits
for index, ramp_fit in enumerate(ramp_fits):
parameters[1, :] = ramp_fit['average']['slope']
parameters[index, 1] = ramp_fit['average']['slope']

variances[0, :] = ramp_fit['average']['read_var']
variances[1, :] = ramp_fit['average']['poisson_var']
variances[index, 0] = ramp_fit['average']['read_var']
variances[index, 1] = ramp_fit['average']['poisson_var']

variances[2, :] = variances[0, :] + variances[1, :]
variances[:, 2] = (variances[:, 0] + variances[:, 1]).astype(np.float32)

if resultants.shape != orig_shape:
parameters = parameters[0]
Expand Down
14 changes: 9 additions & 5 deletions src/stcal/ramp_fitting/ols_cas22_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
"""
import numpy as np

__all__ = ['ma_table_to_tau', 'ma_table_to_tbar']
__all__ = [
'ma_table_to_read_pattern',
'ma_table_to_tau',
'ma_table_to_tbar',
'read_pattern_to_ma_table']


def matable_to_readpattern(ma_table):
def ma_table_to_read_pattern(ma_table):
"""Convert read patterns to multi-accum lists
Using Roman terminology, a "read pattern" is a list of resultants. Each element of this list
Expand All @@ -26,7 +30,7 @@ def matable_to_readpattern(ma_table):
[[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]]
The example above, using this function, should perform as follows:
>>> matable_to_readpattern([[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]])
>>> ma_table_to_read_pattern([[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]])
[[1], [2, 3], [4], [5, 6, 7, 8], [9, 10], [11]]
Parameters
Expand Down Expand Up @@ -123,7 +127,7 @@ def ma_table_to_tbar(ma_table, read_time):
return meantimes


def readpattern_to_matable(read_pattern):
def read_pattern_to_ma_table(read_pattern):
"""Convert read patterns to multi-accum lists
Using Roman terminology, a "read pattern" is a list of resultants. Each element of this list
Expand All @@ -144,7 +148,7 @@ def readpattern_to_matable(read_pattern):
[[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]]
The example above, using this function, should perform as follows:
>>> readpattern_to_matable([[1], [2, 3], [4], [5, 6, 7, 8], [9, 10], [11]])
>>> read_pattern_to_ma_table([[1], [2, 3], [4], [5, 6, 7, 8], [9, 10], [11]])
[[1, 1], [2, 2], [4, 1], [5, 4], [9, 2], [11, 1]]
Parameters
Expand Down
16 changes: 10 additions & 6 deletions tests/test_ramp_fitting_cas22.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@
ROMAN_READ_TIME = 3.04


def test_matable_to_readpattern():
def test_ma_table_to_read_pattern():
"""Test conversion from read pattern to multi-accum table"""
ma_table = [[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]]
expected = [[1], [2, 3], [4], [5, 6, 7, 8], [9, 10], [11]]

result = ols_cas22_util.matable_to_readpattern(ma_table)
result = ols_cas22_util.ma_table_to_read_pattern(ma_table)

assert result == expected


def test_readpattern_to_matable():
def test_read_pattern_to_ma_table():
"""Test conversion from read pattern to multi-accum table"""
pattern = [[1], [2, 3], [4], [5, 6, 7, 8], [9, 10], [11]]
expected = [[1, 1], [2, 2], [4, 1], [5, 4], [9,2], [11,1]]

result = ols_cas22_util.readpattern_to_matable(pattern)
result = ols_cas22_util.read_pattern_to_ma_table(pattern)

assert result == expected

Expand All @@ -38,14 +38,18 @@ def test_simulated_ramps():
ntrial = 100000
ma_table, flux, read_noise, resultants = simulate_many_ramps(ntrial=ntrial)

dq = np.zeros(resultants.shape, dtype=np.int32)
read_noise = np.ones(resultants.shape[1], dtype=np.float32) * read_noise

par, var = ramp.fit_ramps_casertano(
resultants, resultants * 0, read_noise, ROMAN_READ_TIME, ma_table=ma_table)
resultants, dq, read_noise, ROMAN_READ_TIME, ma_table=ma_table)

chi2dof_slope = np.sum((par[:, 1] - flux)**2 / var[:, 2]) / ntrial
assert np.abs(chi2dof_slope - 1) < 0.03

# now let's mark a bunch of the ramps as compromised.
bad = np.random.uniform(size=resultants.shape) > 0.7
dq = resultants * 0 + bad
dq += bad
par, var = ramp.fit_ramps_casertano(
resultants, dq, read_noise, ROMAN_READ_TIME, ma_table=ma_table)
# only use okay ramps
Expand Down

0 comments on commit e79bb89

Please sign in to comment.