Skip to content

Commit

Permalink
Add option to regularize linear fits
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Aug 22, 2023
1 parent a4ff591 commit 9aa7d0b
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def calc_width(filter_size, real_delta, nsamples):
lthresh = nsamples
return (uthresh, lthresh)

def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode, ridge_alpha=0.0,
filter_dims=1, skip_wgt=0.1, zero_residual_flags=True, **filter_kwargs):
'''
A filtering function that wraps up all functionality of high_pass_fourier_filter
Expand Down Expand Up @@ -352,6 +352,11 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
'dpss_matrix' method (see above)
'dayenu_clean', apply dayenu filter to data. Deconvolve
subtracted foregrounds with 'clean'.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater than zeros, ridge_alpha is used as
the regularization parameter in ridge regression. Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix).
zero_residual_flags : bool, optional.
If true, set flagged channels in the residual equal to zero.
Default is True.
Expand Down Expand Up @@ -479,6 +484,8 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
raise ValueError("data must be a 1D or 2D ndarray")
if not ndim_wgts == ndim_data:
raise ValueError("Number of dimensions in weights, %d does not equal number of dimensions in data, %d!"%(ndim_wgts, ndim_data))

assert ridge_alpha >= 0.0, "ridge_alpha must be greater than or equal to zero."
#The core code of this method will always assume 2d data
if ndim_data == 1:
data = np.asarray([data])
Expand Down Expand Up @@ -574,7 +581,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[1], method=mode[2], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
info['info_deconv']=info_deconv

elif mode[0] in ['dft', 'dpss']:
Expand All @@ -594,7 +601,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[0], method=mode[1], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
elif mode[0] == 'clean':
if zero_residual_flags is None:
zero_residual_flags = False
Expand Down Expand Up @@ -1561,7 +1568,7 @@ def delay_filter_leastsq(data, flags, sigma, nmax, add_noise=False,

def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
basis_options, suppression_factors=None, hash_decimal=10,
method='leastsq', basis='dft', cache=None):
method='leastsq', basis='dft', cache=None, ridge_alpha=0.0):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
y_model = A @ c
Expand Down Expand Up @@ -1689,7 +1696,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
x=x, hash_decimal=hash_decimal, label='covariance')
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, ridge_alpha=ridge_alpha)
if square_key in cache:
covmat = cache[square_key]
else:
Expand All @@ -1698,6 +1705,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,

if not fm_key in cache:
XTX = covmat - np.conj(amat[flags]).T @ amat[flags]
XTX.flat[::XTX.shape[0] + 1] += ridge_alpha # add regularization term

Xy = np.conj(amat[mask]).T @ y[mask]

Expand Down Expand Up @@ -1750,23 +1758,24 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
elif method == 'matrix':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis)
label='fitting matrix', basis=basis, ridge_alpha=ridge_alpha)
if basis.lower() == 'dft':
fm_key = fm_key + (basis_options['fundamental_period'], )
elif basis.lower() == 'dpss':
fm_key = fm_key + tuple(nterms)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key, ridge_alpha=ridge_alpha)
info['fitting_matrix'] = fmat
cn_out = fmat @ y

elif method == 'solve':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, alpha=ridge_alpha)
if fm_key in cache:
L = cache[fm_key]
else:
XTX = np.dot(np.conj(amat).T * w, amat)
XTX.flat[::XTX.shape[0] + 1] += ridge_alpha # add regularization term
L = linalg.lu_factor(XTX)
cache[fm_key] = L

Expand Down Expand Up @@ -1942,7 +1951,7 @@ def _clean_filter(x, data, wgts, filter_centers, filter_half_widths,
def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
basis_options, suppression_factors=None,
method='leastsq', basis='dft', cache=None,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5, ridge_alpha=0.0,
zero_residual_flags=True):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
Expand Down Expand Up @@ -2077,7 +2086,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[1],
suppression_factors=suppression_factors[1],
basis_options=basis_options[1], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_1'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2107,7 +2116,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[0],
suppression_factors=suppression_factors[0],
basis_options=basis_options[0], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_0'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2137,7 +2146,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
return model, residual, info


def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None):
def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None, ridge_alpha=0.0):
"""
Calculate the linear least squares solution matrix
from a design matrix, A and a weights matrix W
Expand All @@ -2156,6 +2165,9 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
fit_mat_key: optional hashable variable
optional key. If none is used, hash fit matrix against design and
weighting matrix.
alpha: float, optional
Regularization parameter. If non-zero, adds alpha * I to the
fitting matrix. Default is 0.0.
Returns
-----------
Expand Down Expand Up @@ -2185,6 +2197,8 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
xwmat = np.conj(design_matrix.T) @ weights
cmat = xwmat @ design_matrix

cmat.flat[::cmat.shape[0] + 1] += ridge_alpha

#should there be a conjugation!?!
if np.linalg.cond(cmat)>=1e9:
warn('Warning!!!!: Poorly conditioned matrix! Your linear inpainting IS WRONG!')
Expand Down

0 comments on commit 9aa7d0b

Please sign in to comment.