diff --git a/scintillometry/functions.py b/scintillometry/functions.py index 92f0ac21..c38c7963 100644 --- a/scintillometry/functions.py +++ b/scintillometry/functions.py @@ -115,3 +115,32 @@ def task(self, data): out[2] = c.real out[3] = c.imag return result + + +class Digitize(TaskBase): + """Digitize a stream to a given number of bits. + + Output values are between -(2**(bps-1)) and 2**(bps-1) -1. + For instance, between -8 and 7 for bps=4. + + Parameters + ---------- + ih : stream handle + Handle of a stream reader or another task. + nbits : int + Number of bits to digitize too. For complex data, the real + and imaginary components are digitized separately. + """ + def __init__(self, ih, bps, scale=1.): + super().__init__(ih) + self._low = - (1 << (bps-1)) + self._high = (1 << (bps-1)) - 1 + if self.complex_data: + real_dtype = np.zeros(1, self.dtype).real.dtype + self.task = lambda data: self._digitize( + data.view(real_dtype)).view(self.dtype) + else: + self.task = self._digitize + + def _digitize(self, data): + return np.clip(data.round(), self._low, self._high) diff --git a/scintillometry/tests/test_functions.py b/scintillometry/tests/test_functions.py index 2b563630..2cfb2195 100644 --- a/scintillometry/tests/test_functions.py +++ b/scintillometry/tests/test_functions.py @@ -6,8 +6,8 @@ import astropy.units as u from astropy.time import Time -from ..functions import Square, Power -from ..generators import EmptyStreamGenerator +from ..functions import Square, Power, Digitize +from ..generators import EmptyStreamGenerator, StreamGenerator from baseband import vdif, dada from baseband.data import SAMPLE_VDIF, SAMPLE_DADA @@ -167,3 +167,39 @@ def test_frequency_sideband_mismatch(self): frequency=frequency, sideband=bad_side) with pytest.raises(ValueError): Power(eh, polarization=polarization) + + +class TestDigitize: + def setup(self): + self.stream = StreamGenerator( + lambda stream: np.arange(-128, 128, 0.5), (256,), + sample_rate=1.*u.Hz, start_time=Time('2018-01-01'), + samples_per_frame=256) + + def test_basics(self): + fh = dada.open(SAMPLE_DADA) + ref_data = fh.read() + dig = Digitize(fh, bps=8) + data1 = dig.read() + assert np.all(data1 == ref_data) + + dig2 = Digitize(fh, bps=4) + data2 = dig2.read() + assert np.all(data2.real >= -16.) and np.all(data2.real <= 15.) + assert np.all(data2.imag >= -16.) and np.all(data2.imag <= 15.) + assert np.all(data2.real % 1. == 0.) + assert np.all(data2.imag % 1. == 0.) + expected = np.clip(ref_data.view(np.float32), + -8., 7.).view(np.complex64) + assert np.all(data2 == expected) + dig2.close() + + def test_stram(self): + ref_data = self.stream.read() + for bps in (1, 3, 5, 8): + r = 1 << (bps - 1) + expected = np.clip(ref_data, -r, r-1).round() + dig = Digitize(self.stream, bps=bps) + data = dig.read() + assert np.all(data == expected) + dig.close()