diff --git a/lenspack/tests/test_image.py b/lenspack/tests/test_image.py new file mode 100644 index 0000000..ab899f5 --- /dev/null +++ b/lenspack/tests/test_image.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +"""UNIT TESTS FOR IMAGE + +This module contains unit tests for the image module. + +""" + +from unittest import TestCase +import numpy as np +import numpy.testing as npt +from lenspack.image.transforms import starlet2d, dct2d, idct2d + + +class TransformsTestCase(TestCase): + + def setUp(self): + + self.nscales = 5 + self.npix = 64 + self.image = img = 10 * np.random.normal(size=(self.npix, self.npix)) + spike = np.zeros_like(self.image) + spike[self.npix // 2, self.npix // 2] = 1 + self.spike = spike + + def tearDown(self): + + self.nscales = None + self.npix = None + self.image = None + self.spike = None + + def test_starlet2d(self): + + # Test output shape of starlet transform + wt = starlet2d(self.image, self.nscales) + output_shape = (self.nscales + 1, self.npix, self.npix) + npt.assert_equal(output_shape, wt.shape, + err_msg="Incorrect starlet2d output shape.") + + # Test reconstruction + rec = np.sum(wt, axis=0) + npt.assert_allclose(rec, self.image, + err_msg="Incorrect starlet reconstruction.") + + # Test wavelet filter norms + wt_spike = starlet2d(self.spike, self.nscales) + norms = np.sqrt(np.sum(wt_spike[:-1]**2, axis=(1, 2))) + expected = [0.890796310279, 0.2006638510244, 0.0855075047534] + if len(norms > 2): + npt.assert_allclose(norms[:3], expected, + err_msg="Incorrect filter norms.") + + def test_dct2d(self): + + # Test reconstruction + dct = dct2d(self.image) + rec = idct2d(dct) + npt.assert_allclose(rec, self.image, + err_msg="Incorrect DCT reconstruction.") + + # Test exceptions + npt.assert_raises(Exception, dct2d, self.image[0]) + npt.assert_raises(Exception, dct2d, self.image, 'symmetric') + npt.assert_raises(Exception, idct2d, self.image[0]) + npt.assert_raises(Exception, idct2d, self.image, 'symmetric') diff --git a/lenspack/tests/test_stats.py b/lenspack/tests/test_stats.py new file mode 100644 index 0000000..807768d --- /dev/null +++ b/lenspack/tests/test_stats.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +"""UNIT TESTS FOR IMAGE + +This module contains unit tests for the iamge module. + +""" + +from unittest import TestCase +import numpy as np +import numpy.testing as npt +from scipy import stats +from lenspack.stats import mad, skew, kurt, mu_n, kappa_n, fdr, hc + + +class StatsTestCase(TestCase): + + def setUp(self): + + # [-5., -4., -3., ... 3., 4., 5.] + self.array = np.arange(11.) - 5 + + def tearDown(self): + + self.array = None + + def test_mad(self): + + # Test output value + npt.assert_equal(mad(self.array), 3.0, err_msg="Incorrect MAD value.") + + def test_skew(self): + + # Test output value and agreement with scipy + npt.assert_equal(skew(self.array), 0, err_msg="Incorrect skew value.") + npt.assert_equal(skew(self.array**2), 0.5661385170722978, + err_msg="Incorrect skew value.") + npt.assert_almost_equal(skew(self.array**2), stats.skew(self.array**2), + decimal=15, + err_msg="Does not match scipy.skew.") + + def test_kurt(self): + + # Test output value and agreement with scipy + npt.assert_almost_equal(kurt(self.array), -1.22, + decimal=15, + err_msg="Incorrect kurt value.") + npt.assert_almost_equal(kurt(self.array), stats.kurtosis(self.array), + decimal=15, + err_msg="Does not match scipy.kurtosis.") + + def test_mu_n(self): + + # Test output value + npt.assert_equal(mu_n(self.array, order=1), 0, + err_msg="Incorrect mu_n for order 1.") + npt.assert_equal(mu_n(self.array, order=2), 10, + err_msg="Incorrect mu_n for order 2.") + npt.assert_equal(mu_n(self.array, order=3), 0, + err_msg="Incorrect mu_n for order 3.") + npt.assert_equal(mu_n(self.array, order=4), 178, + err_msg="Incorrect mu_n for order 4.") + npt.assert_equal(mu_n(self.array, order=5), 0, + err_msg="Incorrect mu_n for order 5.") + npt.assert_equal(mu_n(self.array, order=6), 3730, + err_msg="Incorrect mu_n for order 6.") + + # Test agreement with scipy + npt.assert_equal(mu_n(self.array, order=1), + stats.moment(self.array, moment=1), + err_msg="Does not match scipy.moment for order 1.") + npt.assert_equal(mu_n(self.array, order=2), + stats.moment(self.array, moment=2), + err_msg="Does not match scipy.moment for order 2.") + npt.assert_equal(mu_n(self.array, order=3), + stats.moment(self.array, moment=3), + err_msg="Does not match scipy.moment for order 3.") + npt.assert_equal(mu_n(self.array, order=4), + stats.moment(self.array, moment=4), + err_msg="Does not match scipy.moment for order 4.") + npt.assert_equal(mu_n(self.array, order=5), + stats.moment(self.array, moment=5), + err_msg="Does not match scipy.moment for order 5.") + npt.assert_equal(mu_n(self.array, order=6), + stats.moment(self.array, moment=6), + err_msg="Does not match scipy.moment for order 6.") + + # Test exceptions + npt.assert_raises(Exception, mu_n, self.array, order=0) + + def test_kappa_n(self): + + # Test output value + npt.assert_equal(kappa_n(self.array, order=2), 10, + err_msg="Incorrect mu_n for order 2.") + npt.assert_equal(kappa_n(self.array, order=3), 0, + err_msg="Incorrect mu_n for order 3.") + npt.assert_equal(kappa_n(self.array, order=4), -122, + err_msg="Incorrect mu_n for order 4.") + npt.assert_equal(kappa_n(self.array, order=5), 0, + err_msg="Incorrect mu_n for order 5.") + npt.assert_equal(kappa_n(self.array, order=6), 7030, + err_msg="Incorrect mu_n for order 6.") + + # Test exceptions + npt.assert_raises(Exception, kappa_n, self.array, order=1)