diff --git a/.travis.yml b/.travis.yml index a9cfc69..992fdaa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ matrix: branches: only: - master + - tests # install package and dependencies install: diff --git a/lenspack/starlet_l1norm.py b/lenspack/starlet_l1norm.py new file mode 100644 index 0000000..9efcb23 --- /dev/null +++ b/lenspack/starlet_l1norm.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +"""STARLET L1-NORM MODULE + +This module contains functions for computing the starlet l1norm +as defined in Eq. (1) of https://arxiv.org/pdf/2101.01542.pdf. + +""" + +import numpy as np +from astropy.stats import mad_std +from lenspack.image.transforms import starlet2d + + +def noise_coeff(image, nscales): + """Compute the noise coefficients :math:`\sigma^{e}_{j}` + to get the estimate of the noise at the scale j + following Starck and Murtagh (1998). + Parameters + ---------- + image : array_like + Two-dimensional input image. + nscales : int + Number of wavelet scales to compute. Should not exceed log2(N), where + N is the smaller of the two input dimensions. + Returns + ------- + coeff_j : numpy.ndarray + Values of the standard deviation of the noise at scale j + """ + noise_sigma = np.random.randn(image.shape[0], image.shape[0]) + noise_wavelet = starlet2d(noise_sigma, nscales) + coeff_j = np.array([np.std(scale) for scale in noise_wavelet]) + return coeff_j + + +def get_l1norm_noisy(image, noise, nscales, nbins): + """Compute the starlet :math:`\ell_1`-norm of a noisy image + following Eq. (1) of https://arxiv.org/abs/2101.01542. + Parameters + ---------- + image : array_like + Two-dimensional input noiseless image. + noise : array_like + Two-dimensional input of the noise to be added to image + nscales : int + Number of wavelet scales to compute. Should not exceed log2(N), where + N is the smaller of the two input dimensions. + nbins : int + Number of bins in S/N desired for the summary statistic + Returns + ------- + bins_snr, starlet_l1norm : tuple of 1D numpy arrays + Bin centers in S/N and Starlet :math:`\ell_1`-norm of the noisy image + """ + + # add noise to noiseless image + image_noisy = image + noise + # perform starlet decomposition + image_starlet = starlet2d(image_noisy, nscales) + # estimate of the noise + noise_estimate = mad_std(image_noisy) + coeff_j = noise_coeff(image, nscales) + + l1_coll = [] + bins_coll = [] + for image_j, std_co in zip(image_starlet, coeff_j): + + sigma_j = std_co * noise_estimate + + snr = image_j / sigma_j + thresholds_snr = np.linspace(np.min(snr), np.max(snr), nbins + 1) + bins_snr = 0.5 * (thresholds_snr[:-1] + thresholds_snr[1:]) + digitized = np.digitize(snr, thresholds_snr) + bin_l1_norm = [np.sum(np.abs(snr[digitized == i])) + for i in range(1, len(thresholds_snr))] + l1_coll.append(bin_l1_norm) + bins_coll.append(bins_snr) + return np.array(bins_coll), np.array(l1_coll) diff --git a/lenspack/tests/test_image.py b/lenspack/tests/test_image.py new file mode 100644 index 0000000..ab899f5 --- /dev/null +++ b/lenspack/tests/test_image.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +"""UNIT TESTS FOR IMAGE + +This module contains unit tests for the image module. + +""" + +from unittest import TestCase +import numpy as np +import numpy.testing as npt +from lenspack.image.transforms import starlet2d, dct2d, idct2d + + +class TransformsTestCase(TestCase): + + def setUp(self): + + self.nscales = 5 + self.npix = 64 + self.image = img = 10 * np.random.normal(size=(self.npix, self.npix)) + spike = np.zeros_like(self.image) + spike[self.npix // 2, self.npix // 2] = 1 + self.spike = spike + + def tearDown(self): + + self.nscales = None + self.npix = None + self.image = None + self.spike = None + + def test_starlet2d(self): + + # Test output shape of starlet transform + wt = starlet2d(self.image, self.nscales) + output_shape = (self.nscales + 1, self.npix, self.npix) + npt.assert_equal(output_shape, wt.shape, + err_msg="Incorrect starlet2d output shape.") + + # Test reconstruction + rec = np.sum(wt, axis=0) + npt.assert_allclose(rec, self.image, + err_msg="Incorrect starlet reconstruction.") + + # Test wavelet filter norms + wt_spike = starlet2d(self.spike, self.nscales) + norms = np.sqrt(np.sum(wt_spike[:-1]**2, axis=(1, 2))) + expected = [0.890796310279, 0.2006638510244, 0.0855075047534] + if len(norms > 2): + npt.assert_allclose(norms[:3], expected, + err_msg="Incorrect filter norms.") + + def test_dct2d(self): + + # Test reconstruction + dct = dct2d(self.image) + rec = idct2d(dct) + npt.assert_allclose(rec, self.image, + err_msg="Incorrect DCT reconstruction.") + + # Test exceptions + npt.assert_raises(Exception, dct2d, self.image[0]) + npt.assert_raises(Exception, dct2d, self.image, 'symmetric') + npt.assert_raises(Exception, idct2d, self.image[0]) + npt.assert_raises(Exception, idct2d, self.image, 'symmetric') diff --git a/lenspack/tests/test_stats.py b/lenspack/tests/test_stats.py new file mode 100644 index 0000000..807768d --- /dev/null +++ b/lenspack/tests/test_stats.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +"""UNIT TESTS FOR IMAGE + +This module contains unit tests for the iamge module. + +""" + +from unittest import TestCase +import numpy as np +import numpy.testing as npt +from scipy import stats +from lenspack.stats import mad, skew, kurt, mu_n, kappa_n, fdr, hc + + +class StatsTestCase(TestCase): + + def setUp(self): + + # [-5., -4., -3., ... 3., 4., 5.] + self.array = np.arange(11.) - 5 + + def tearDown(self): + + self.array = None + + def test_mad(self): + + # Test output value + npt.assert_equal(mad(self.array), 3.0, err_msg="Incorrect MAD value.") + + def test_skew(self): + + # Test output value and agreement with scipy + npt.assert_equal(skew(self.array), 0, err_msg="Incorrect skew value.") + npt.assert_equal(skew(self.array**2), 0.5661385170722978, + err_msg="Incorrect skew value.") + npt.assert_almost_equal(skew(self.array**2), stats.skew(self.array**2), + decimal=15, + err_msg="Does not match scipy.skew.") + + def test_kurt(self): + + # Test output value and agreement with scipy + npt.assert_almost_equal(kurt(self.array), -1.22, + decimal=15, + err_msg="Incorrect kurt value.") + npt.assert_almost_equal(kurt(self.array), stats.kurtosis(self.array), + decimal=15, + err_msg="Does not match scipy.kurtosis.") + + def test_mu_n(self): + + # Test output value + npt.assert_equal(mu_n(self.array, order=1), 0, + err_msg="Incorrect mu_n for order 1.") + npt.assert_equal(mu_n(self.array, order=2), 10, + err_msg="Incorrect mu_n for order 2.") + npt.assert_equal(mu_n(self.array, order=3), 0, + err_msg="Incorrect mu_n for order 3.") + npt.assert_equal(mu_n(self.array, order=4), 178, + err_msg="Incorrect mu_n for order 4.") + npt.assert_equal(mu_n(self.array, order=5), 0, + err_msg="Incorrect mu_n for order 5.") + npt.assert_equal(mu_n(self.array, order=6), 3730, + err_msg="Incorrect mu_n for order 6.") + + # Test agreement with scipy + npt.assert_equal(mu_n(self.array, order=1), + stats.moment(self.array, moment=1), + err_msg="Does not match scipy.moment for order 1.") + npt.assert_equal(mu_n(self.array, order=2), + stats.moment(self.array, moment=2), + err_msg="Does not match scipy.moment for order 2.") + npt.assert_equal(mu_n(self.array, order=3), + stats.moment(self.array, moment=3), + err_msg="Does not match scipy.moment for order 3.") + npt.assert_equal(mu_n(self.array, order=4), + stats.moment(self.array, moment=4), + err_msg="Does not match scipy.moment for order 4.") + npt.assert_equal(mu_n(self.array, order=5), + stats.moment(self.array, moment=5), + err_msg="Does not match scipy.moment for order 5.") + npt.assert_equal(mu_n(self.array, order=6), + stats.moment(self.array, moment=6), + err_msg="Does not match scipy.moment for order 6.") + + # Test exceptions + npt.assert_raises(Exception, mu_n, self.array, order=0) + + def test_kappa_n(self): + + # Test output value + npt.assert_equal(kappa_n(self.array, order=2), 10, + err_msg="Incorrect mu_n for order 2.") + npt.assert_equal(kappa_n(self.array, order=3), 0, + err_msg="Incorrect mu_n for order 3.") + npt.assert_equal(kappa_n(self.array, order=4), -122, + err_msg="Incorrect mu_n for order 4.") + npt.assert_equal(kappa_n(self.array, order=5), 0, + err_msg="Incorrect mu_n for order 5.") + npt.assert_equal(kappa_n(self.array, order=6), 7030, + err_msg="Incorrect mu_n for order 6.") + + # Test exceptions + npt.assert_raises(Exception, kappa_n, self.array, order=1)