diff --git a/fitsio/tests/test_image_compression.py b/fitsio/tests/test_image_compression.py index 6367c2a..23c8eac 100644 --- a/fitsio/tests/test_image_compression.py +++ b/fitsio/tests/test_image_compression.py @@ -230,71 +230,90 @@ def test_compress_preserve_zeros(): ] ) @pytest.mark.parametrize( - 'use_seed', + 'match_seed', [False, True], ) @pytest.mark.parametrize( 'use_fits_object', [False, True], ) -def test_compressed_seed(compress, use_seed, use_fits_object): +@pytest.mark.parametrize( + 'dtype', + ['f4', 'f8'], +) +def test_compressed_seed(compress, match_seed, use_fits_object, dtype): """ Test writing and reading a rice compressed image """ nrows = 5 ncols = 20 - dtypes = ['f4', 'f8'] qlevel = 16 seed = 1919 rng = np.random.RandomState(seed) - if use_seed: - dither_seed = 9881 + if match_seed: + # dither_seed = 9881 + dither_seed1 = 9881 + dither_seed2 = 9881 else: - dither_seed = None + # dither_seed = None + dither_seed1 = 3 + dither_seed2 = 4 with tempfile.TemporaryDirectory() as tmpdir: fname1 = os.path.join(tmpdir, 'test1.fits') fname2 = os.path.join(tmpdir, 'test2.fits') - for ext, dtype in enumerate(dtypes): - data = rng.normal(size=(nrows, ncols)) - if compress == 'plio': - data = data.clip(min=0) - data = data.astype(dtype) - - if use_fits_object: - with FITS(fname1, 'rw') as fits1: - fits1.write( - data, compress=compress, qlevel=qlevel, - dither_seed=dither_seed, - ) - rdata1 = fits1[-1].read() - - with FITS(fname2, 'rw') as fits2: - fits2.write( - data, compress=compress, qlevel=qlevel, - dither_seed=dither_seed, - ) - rdata2 = fits2[-1].read() - else: - write( - fname1, data, compress=compress, qlevel=qlevel, - dither_seed=dither_seed, + data = rng.normal(size=(nrows, ncols)) + if compress == 'plio': + data = data.clip(min=0) + data = data.astype(dtype) + + if use_fits_object: + with FITS(fname1, 'rw') as fits1: + fits1.write( + data, compress=compress, qlevel=qlevel, + # dither_seed=dither_seed, + dither_seed=dither_seed1, ) - rdata1 = read(fname1, ext=ext+1) + rdata1 = fits1[-1].read() - write( - fname2, data, compress=compress, qlevel=qlevel, - dither_seed=dither_seed, + with FITS(fname2, 'rw') as fits2: + fits2.write( + data, compress=compress, qlevel=qlevel, + # dither_seed=dither_seed, + dither_seed=dither_seed2, ) - rdata2 = read(fname2, ext=ext+1) - - mess = "%s compressed images ('%s')" % (compress, dtype) - - if use_seed: - assert np.all(rdata1 == rdata2), mess - else: - assert np.all(rdata1 != rdata2), mess + rdata2 = fits2[-1].read() + else: + write( + fname1, data, compress=compress, qlevel=qlevel, + # dither_seed=dither_seed, + dither_seed=dither_seed1, + ) + rdata1 = read(fname1) + + write( + fname2, data, compress=compress, qlevel=qlevel, + # dither_seed=dither_seed, + dither_seed=dither_seed2, + ) + rdata2 = read(fname2) + + mess = "%s compressed images ('%s')" % (compress, dtype) + + if match_seed: + assert np.all(rdata1 == rdata2), mess + else: + assert np.all(rdata1 != rdata2), mess + + +if __name__ == '__main__': + test_compressed_seed( + compress='rice', + match_seed=False, + use_fits_object=True, + dtype='f4', + )