Skip to content

Commit

Permalink
fix tests for new numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
esheldon committed Aug 16, 2024
1 parent ce5f4d5 commit 3b04c9c
Showing 1 changed file with 61 additions and 42 deletions.
103 changes: 61 additions & 42 deletions fitsio/tests/test_image_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,71 +230,90 @@ def test_compress_preserve_zeros():
]
)
@pytest.mark.parametrize(
'use_seed',
'match_seed',
[False, True],
)
@pytest.mark.parametrize(
'use_fits_object',
[False, True],
)
def test_compressed_seed(compress, use_seed, use_fits_object):
@pytest.mark.parametrize(
'dtype',
['f4', 'f8'],
)
def test_compressed_seed(compress, match_seed, use_fits_object, dtype):
"""
Test writing and reading a rice compressed image
"""
nrows = 5
ncols = 20
dtypes = ['f4', 'f8']

qlevel = 16

seed = 1919
rng = np.random.RandomState(seed)

if use_seed:
dither_seed = 9881
if match_seed:
# dither_seed = 9881
dither_seed1 = 9881
dither_seed2 = 9881
else:
dither_seed = None
# dither_seed = None
dither_seed1 = 3
dither_seed2 = 4

with tempfile.TemporaryDirectory() as tmpdir:
fname1 = os.path.join(tmpdir, 'test1.fits')
fname2 = os.path.join(tmpdir, 'test2.fits')

for ext, dtype in enumerate(dtypes):
data = rng.normal(size=(nrows, ncols))
if compress == 'plio':
data = data.clip(min=0)
data = data.astype(dtype)

if use_fits_object:
with FITS(fname1, 'rw') as fits1:
fits1.write(
data, compress=compress, qlevel=qlevel,
dither_seed=dither_seed,
)
rdata1 = fits1[-1].read()

with FITS(fname2, 'rw') as fits2:
fits2.write(
data, compress=compress, qlevel=qlevel,
dither_seed=dither_seed,
)
rdata2 = fits2[-1].read()
else:
write(
fname1, data, compress=compress, qlevel=qlevel,
dither_seed=dither_seed,
data = rng.normal(size=(nrows, ncols))
if compress == 'plio':
data = data.clip(min=0)
data = data.astype(dtype)

if use_fits_object:
with FITS(fname1, 'rw') as fits1:
fits1.write(
data, compress=compress, qlevel=qlevel,
# dither_seed=dither_seed,
dither_seed=dither_seed1,
)
rdata1 = read(fname1, ext=ext+1)
rdata1 = fits1[-1].read()

write(
fname2, data, compress=compress, qlevel=qlevel,
dither_seed=dither_seed,
with FITS(fname2, 'rw') as fits2:
fits2.write(
data, compress=compress, qlevel=qlevel,
# dither_seed=dither_seed,
dither_seed=dither_seed2,
)
rdata2 = read(fname2, ext=ext+1)

mess = "%s compressed images ('%s')" % (compress, dtype)

if use_seed:
assert np.all(rdata1 == rdata2), mess
else:
assert np.all(rdata1 != rdata2), mess
rdata2 = fits2[-1].read()
else:
write(
fname1, data, compress=compress, qlevel=qlevel,
# dither_seed=dither_seed,
dither_seed=dither_seed1,
)
rdata1 = read(fname1)

write(
fname2, data, compress=compress, qlevel=qlevel,
# dither_seed=dither_seed,
dither_seed=dither_seed2,
)
rdata2 = read(fname2)

mess = "%s compressed images ('%s')" % (compress, dtype)

if match_seed:
assert np.all(rdata1 == rdata2), mess
else:
assert np.all(rdata1 != rdata2), mess


if __name__ == '__main__':
test_compressed_seed(
compress='rice',
match_seed=False,
use_fits_object=True,
dtype='f4',
)

0 comments on commit 3b04c9c

Please sign in to comment.