Skip to content

Commit

Permalink
Merge pull request #24 from austinpeel/master
Browse files Browse the repository at this point in the history
Add new tests
  • Loading branch information
austinpeel authored Jun 15, 2021
2 parents 8af72da + fa75daf commit cef8d9a
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 0 deletions.
66 changes: 66 additions & 0 deletions lenspack/tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-

"""UNIT TESTS FOR IMAGE
This module contains unit tests for the image module.
"""

from unittest import TestCase
import numpy as np
import numpy.testing as npt
from lenspack.image.transforms import starlet2d, dct2d, idct2d


class TransformsTestCase(TestCase):

def setUp(self):

self.nscales = 5
self.npix = 64
self.image = img = 10 * np.random.normal(size=(self.npix, self.npix))
spike = np.zeros_like(self.image)
spike[self.npix // 2, self.npix // 2] = 1
self.spike = spike

def tearDown(self):

self.nscales = None
self.npix = None
self.image = None
self.spike = None

def test_starlet2d(self):

# Test output shape of starlet transform
wt = starlet2d(self.image, self.nscales)
output_shape = (self.nscales + 1, self.npix, self.npix)
npt.assert_equal(output_shape, wt.shape,
err_msg="Incorrect starlet2d output shape.")

# Test reconstruction
rec = np.sum(wt, axis=0)
npt.assert_allclose(rec, self.image,
err_msg="Incorrect starlet reconstruction.")

# Test wavelet filter norms
wt_spike = starlet2d(self.spike, self.nscales)
norms = np.sqrt(np.sum(wt_spike[:-1]**2, axis=(1, 2)))
expected = [0.890796310279, 0.2006638510244, 0.0855075047534]
if len(norms > 2):
npt.assert_allclose(norms[:3], expected,
err_msg="Incorrect filter norms.")

def test_dct2d(self):

# Test reconstruction
dct = dct2d(self.image)
rec = idct2d(dct)
npt.assert_allclose(rec, self.image,
err_msg="Incorrect DCT reconstruction.")

# Test exceptions
npt.assert_raises(Exception, dct2d, self.image[0])
npt.assert_raises(Exception, dct2d, self.image, 'symmetric')
npt.assert_raises(Exception, idct2d, self.image[0])
npt.assert_raises(Exception, idct2d, self.image, 'symmetric')
106 changes: 106 additions & 0 deletions lenspack/tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-

"""UNIT TESTS FOR IMAGE
This module contains unit tests for the iamge module.
"""

from unittest import TestCase
import numpy as np
import numpy.testing as npt
from scipy import stats
from lenspack.stats import mad, skew, kurt, mu_n, kappa_n, fdr, hc


class StatsTestCase(TestCase):

def setUp(self):

# [-5., -4., -3., ... 3., 4., 5.]
self.array = np.arange(11.) - 5

def tearDown(self):

self.array = None

def test_mad(self):

# Test output value
npt.assert_equal(mad(self.array), 3.0, err_msg="Incorrect MAD value.")

def test_skew(self):

# Test output value and agreement with scipy
npt.assert_equal(skew(self.array), 0, err_msg="Incorrect skew value.")
npt.assert_equal(skew(self.array**2), 0.5661385170722978,
err_msg="Incorrect skew value.")
npt.assert_almost_equal(skew(self.array**2), stats.skew(self.array**2),
decimal=15,
err_msg="Does not match scipy.skew.")

def test_kurt(self):

# Test output value and agreement with scipy
npt.assert_almost_equal(kurt(self.array), -1.22,
decimal=15,
err_msg="Incorrect kurt value.")
npt.assert_almost_equal(kurt(self.array), stats.kurtosis(self.array),
decimal=15,
err_msg="Does not match scipy.kurtosis.")

def test_mu_n(self):

# Test output value
npt.assert_equal(mu_n(self.array, order=1), 0,
err_msg="Incorrect mu_n for order 1.")
npt.assert_equal(mu_n(self.array, order=2), 10,
err_msg="Incorrect mu_n for order 2.")
npt.assert_equal(mu_n(self.array, order=3), 0,
err_msg="Incorrect mu_n for order 3.")
npt.assert_equal(mu_n(self.array, order=4), 178,
err_msg="Incorrect mu_n for order 4.")
npt.assert_equal(mu_n(self.array, order=5), 0,
err_msg="Incorrect mu_n for order 5.")
npt.assert_equal(mu_n(self.array, order=6), 3730,
err_msg="Incorrect mu_n for order 6.")

# Test agreement with scipy
npt.assert_equal(mu_n(self.array, order=1),
stats.moment(self.array, moment=1),
err_msg="Does not match scipy.moment for order 1.")
npt.assert_equal(mu_n(self.array, order=2),
stats.moment(self.array, moment=2),
err_msg="Does not match scipy.moment for order 2.")
npt.assert_equal(mu_n(self.array, order=3),
stats.moment(self.array, moment=3),
err_msg="Does not match scipy.moment for order 3.")
npt.assert_equal(mu_n(self.array, order=4),
stats.moment(self.array, moment=4),
err_msg="Does not match scipy.moment for order 4.")
npt.assert_equal(mu_n(self.array, order=5),
stats.moment(self.array, moment=5),
err_msg="Does not match scipy.moment for order 5.")
npt.assert_equal(mu_n(self.array, order=6),
stats.moment(self.array, moment=6),
err_msg="Does not match scipy.moment for order 6.")

# Test exceptions
npt.assert_raises(Exception, mu_n, self.array, order=0)

def test_kappa_n(self):

# Test output value
npt.assert_equal(kappa_n(self.array, order=2), 10,
err_msg="Incorrect mu_n for order 2.")
npt.assert_equal(kappa_n(self.array, order=3), 0,
err_msg="Incorrect mu_n for order 3.")
npt.assert_equal(kappa_n(self.array, order=4), -122,
err_msg="Incorrect mu_n for order 4.")
npt.assert_equal(kappa_n(self.array, order=5), 0,
err_msg="Incorrect mu_n for order 5.")
npt.assert_equal(kappa_n(self.array, order=6), 7030,
err_msg="Incorrect mu_n for order 6.")

# Test exceptions
npt.assert_raises(Exception, kappa_n, self.array, order=1)

0 comments on commit cef8d9a

Please sign in to comment.