Skip to content

Commit

Permalink
Merge pull request #4 from austinpeel/tests
Browse files Browse the repository at this point in the history
Add new tests
  • Loading branch information
austinpeel authored Jun 15, 2021
2 parents 9c4eea6 + 4e01f3b commit 7eb1c6f
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 0 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ matrix:
branches:
only:
- master
- tests

# install package and dependencies
install:
Expand Down
79 changes: 79 additions & 0 deletions lenspack/starlet_l1norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-

"""STARLET L1-NORM MODULE
This module contains functions for computing the starlet l1norm
as defined in Eq. (1) of https://arxiv.org/pdf/2101.01542.pdf.
"""

import numpy as np
from astropy.stats import mad_std
from lenspack.image.transforms import starlet2d


def noise_coeff(image, nscales):
"""Compute the noise coefficients :math:`\sigma^{e}_{j}`
to get the estimate of the noise at the scale j
following Starck and Murtagh (1998).
Parameters
----------
image : array_like
Two-dimensional input image.
nscales : int
Number of wavelet scales to compute. Should not exceed log2(N), where
N is the smaller of the two input dimensions.
Returns
-------
coeff_j : numpy.ndarray
Values of the standard deviation of the noise at scale j
"""
noise_sigma = np.random.randn(image.shape[0], image.shape[0])
noise_wavelet = starlet2d(noise_sigma, nscales)
coeff_j = np.array([np.std(scale) for scale in noise_wavelet])
return coeff_j


def get_l1norm_noisy(image, noise, nscales, nbins):
"""Compute the starlet :math:`\ell_1`-norm of a noisy image
following Eq. (1) of https://arxiv.org/abs/2101.01542.
Parameters
----------
image : array_like
Two-dimensional input noiseless image.
noise : array_like
Two-dimensional input of the noise to be added to image
nscales : int
Number of wavelet scales to compute. Should not exceed log2(N), where
N is the smaller of the two input dimensions.
nbins : int
Number of bins in S/N desired for the summary statistic
Returns
-------
bins_snr, starlet_l1norm : tuple of 1D numpy arrays
Bin centers in S/N and Starlet :math:`\ell_1`-norm of the noisy image
"""

# add noise to noiseless image
image_noisy = image + noise
# perform starlet decomposition
image_starlet = starlet2d(image_noisy, nscales)
# estimate of the noise
noise_estimate = mad_std(image_noisy)
coeff_j = noise_coeff(image, nscales)

l1_coll = []
bins_coll = []
for image_j, std_co in zip(image_starlet, coeff_j):

sigma_j = std_co * noise_estimate

snr = image_j / sigma_j
thresholds_snr = np.linspace(np.min(snr), np.max(snr), nbins + 1)
bins_snr = 0.5 * (thresholds_snr[:-1] + thresholds_snr[1:])
digitized = np.digitize(snr, thresholds_snr)
bin_l1_norm = [np.sum(np.abs(snr[digitized == i]))
for i in range(1, len(thresholds_snr))]
l1_coll.append(bin_l1_norm)
bins_coll.append(bins_snr)
return np.array(bins_coll), np.array(l1_coll)
66 changes: 66 additions & 0 deletions lenspack/tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-

"""UNIT TESTS FOR IMAGE
This module contains unit tests for the image module.
"""

from unittest import TestCase
import numpy as np
import numpy.testing as npt
from lenspack.image.transforms import starlet2d, dct2d, idct2d


class TransformsTestCase(TestCase):

def setUp(self):

self.nscales = 5
self.npix = 64
self.image = img = 10 * np.random.normal(size=(self.npix, self.npix))
spike = np.zeros_like(self.image)
spike[self.npix // 2, self.npix // 2] = 1
self.spike = spike

def tearDown(self):

self.nscales = None
self.npix = None
self.image = None
self.spike = None

def test_starlet2d(self):

# Test output shape of starlet transform
wt = starlet2d(self.image, self.nscales)
output_shape = (self.nscales + 1, self.npix, self.npix)
npt.assert_equal(output_shape, wt.shape,
err_msg="Incorrect starlet2d output shape.")

# Test reconstruction
rec = np.sum(wt, axis=0)
npt.assert_allclose(rec, self.image,
err_msg="Incorrect starlet reconstruction.")

# Test wavelet filter norms
wt_spike = starlet2d(self.spike, self.nscales)
norms = np.sqrt(np.sum(wt_spike[:-1]**2, axis=(1, 2)))
expected = [0.890796310279, 0.2006638510244, 0.0855075047534]
if len(norms > 2):
npt.assert_allclose(norms[:3], expected,
err_msg="Incorrect filter norms.")

def test_dct2d(self):

# Test reconstruction
dct = dct2d(self.image)
rec = idct2d(dct)
npt.assert_allclose(rec, self.image,
err_msg="Incorrect DCT reconstruction.")

# Test exceptions
npt.assert_raises(Exception, dct2d, self.image[0])
npt.assert_raises(Exception, dct2d, self.image, 'symmetric')
npt.assert_raises(Exception, idct2d, self.image[0])
npt.assert_raises(Exception, idct2d, self.image, 'symmetric')
106 changes: 106 additions & 0 deletions lenspack/tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-

"""UNIT TESTS FOR IMAGE
This module contains unit tests for the iamge module.
"""

from unittest import TestCase
import numpy as np
import numpy.testing as npt
from scipy import stats
from lenspack.stats import mad, skew, kurt, mu_n, kappa_n, fdr, hc


class StatsTestCase(TestCase):

def setUp(self):

# [-5., -4., -3., ... 3., 4., 5.]
self.array = np.arange(11.) - 5

def tearDown(self):

self.array = None

def test_mad(self):

# Test output value
npt.assert_equal(mad(self.array), 3.0, err_msg="Incorrect MAD value.")

def test_skew(self):

# Test output value and agreement with scipy
npt.assert_equal(skew(self.array), 0, err_msg="Incorrect skew value.")
npt.assert_equal(skew(self.array**2), 0.5661385170722978,
err_msg="Incorrect skew value.")
npt.assert_almost_equal(skew(self.array**2), stats.skew(self.array**2),
decimal=15,
err_msg="Does not match scipy.skew.")

def test_kurt(self):

# Test output value and agreement with scipy
npt.assert_almost_equal(kurt(self.array), -1.22,
decimal=15,
err_msg="Incorrect kurt value.")
npt.assert_almost_equal(kurt(self.array), stats.kurtosis(self.array),
decimal=15,
err_msg="Does not match scipy.kurtosis.")

def test_mu_n(self):

# Test output value
npt.assert_equal(mu_n(self.array, order=1), 0,
err_msg="Incorrect mu_n for order 1.")
npt.assert_equal(mu_n(self.array, order=2), 10,
err_msg="Incorrect mu_n for order 2.")
npt.assert_equal(mu_n(self.array, order=3), 0,
err_msg="Incorrect mu_n for order 3.")
npt.assert_equal(mu_n(self.array, order=4), 178,
err_msg="Incorrect mu_n for order 4.")
npt.assert_equal(mu_n(self.array, order=5), 0,
err_msg="Incorrect mu_n for order 5.")
npt.assert_equal(mu_n(self.array, order=6), 3730,
err_msg="Incorrect mu_n for order 6.")

# Test agreement with scipy
npt.assert_equal(mu_n(self.array, order=1),
stats.moment(self.array, moment=1),
err_msg="Does not match scipy.moment for order 1.")
npt.assert_equal(mu_n(self.array, order=2),
stats.moment(self.array, moment=2),
err_msg="Does not match scipy.moment for order 2.")
npt.assert_equal(mu_n(self.array, order=3),
stats.moment(self.array, moment=3),
err_msg="Does not match scipy.moment for order 3.")
npt.assert_equal(mu_n(self.array, order=4),
stats.moment(self.array, moment=4),
err_msg="Does not match scipy.moment for order 4.")
npt.assert_equal(mu_n(self.array, order=5),
stats.moment(self.array, moment=5),
err_msg="Does not match scipy.moment for order 5.")
npt.assert_equal(mu_n(self.array, order=6),
stats.moment(self.array, moment=6),
err_msg="Does not match scipy.moment for order 6.")

# Test exceptions
npt.assert_raises(Exception, mu_n, self.array, order=0)

def test_kappa_n(self):

# Test output value
npt.assert_equal(kappa_n(self.array, order=2), 10,
err_msg="Incorrect mu_n for order 2.")
npt.assert_equal(kappa_n(self.array, order=3), 0,
err_msg="Incorrect mu_n for order 3.")
npt.assert_equal(kappa_n(self.array, order=4), -122,
err_msg="Incorrect mu_n for order 4.")
npt.assert_equal(kappa_n(self.array, order=5), 0,
err_msg="Incorrect mu_n for order 5.")
npt.assert_equal(kappa_n(self.array, order=6), 7030,
err_msg="Incorrect mu_n for order 6.")

# Test exceptions
npt.assert_raises(Exception, kappa_n, self.array, order=1)

0 comments on commit 7eb1c6f

Please sign in to comment.