From 6b7fe5bceebdaffd309d3b6baa2c54e4f748c2a6 Mon Sep 17 00:00:00 2001 From: Alan Loh Date: Wed, 25 Oct 2023 14:34:52 +0200 Subject: [PATCH] TF : dreambeam, defaraday, dedisperse --- nenupy/__init__.py | 2 +- nenupy/astro/astro_tools.py | 5 + nenupy/astro/beam_correction.py | 45 ++--- nenupy/io/tf.py | 234 ++++++++++++-------------- nenupy/io/tf_utils.py | 284 +++++++++++++++++++++++++++++++- 5 files changed, 418 insertions(+), 152 deletions(-) diff --git a/nenupy/__init__.py b/nenupy/__init__.py index 45b8705..010b296 100644 --- a/nenupy/__init__.py +++ b/nenupy/__init__.py @@ -5,7 +5,7 @@ __copyright__ = "Copyright 2023, nenupy" __credits__ = ["Alan Loh"] __license__ = "MIT" -__version__ = "2.4.3" +__version__ = "2.4.4" __maintainer__ = "Alan Loh" __email__ = "alan.loh@obspm.fr" diff --git a/nenupy/astro/astro_tools.py b/nenupy/astro/astro_tools.py index 7f67c14..5c6bc67 100644 --- a/nenupy/astro/astro_tools.py +++ b/nenupy/astro/astro_tools.py @@ -1204,6 +1204,11 @@ def faraday_angle(frequency: u.Quantity, rotation_measure: u.Quantity, inverse: :rtype: :class:`~astropy.units.Quantity` """ + # Check that the unit is correct + try: + rotation_measure = rotation_measure.to(u.rad/u.m**2) + except: + raise wavelength = frequency.to(u.m, equivalencies=u.spectral()) angle = rotation_measure * wavelength**2 if inverse: diff --git a/nenupy/astro/beam_correction.py b/nenupy/astro/beam_correction.py index 8c3f94d..e7f6f74 100644 --- a/nenupy/astro/beam_correction.py +++ b/nenupy/astro/beam_correction.py @@ -20,13 +20,9 @@ try: from dreambeam.rime.scenarios import on_pointing_axis_tracking -except ImportError: - log.error( - "DreamBeam is not installed. " - "See installation instructions https://github.com/2baOrNot2ba/dreamBeam" - ) - raise - +except ModuleNotFoundError: + # This will raise an error eventually with an appropriate message + pass # ============================================================= # # ------------------ compute_jones_matrices ------------------- # @@ -39,18 +35,29 @@ def compute_jones_matrices( ) -> Tuple[Time, u.Quantity, JonesMatrix]: """ """ - log.info("Computing Jones matrices using DreamBeam...") - times, frequencies, Jn, _ = on_pointing_axis_tracking( - telescopename="NenuFAR", - stnid="NenuFAR", - band="LBA", - antmodel="Hamaker-NEC4_Charrier_v1r1", - obstimebeg=start_time.datetime, - obsdur=duration.datetime, - obstimestp=time_step.datetime, - pointingdir=(skycoord.ra.rad, skycoord.dec.rad, "J2000"), - do_parallactic_rot=parallactic - ) + log.info("\tComputing Jones matrices using DreamBeam...") + + if time_step.sec <= 1.: + raise ValueError("DreamBeam does not allow for time intervals lesser than 1 sec.") + + try: + times, frequencies, Jn, _ = on_pointing_axis_tracking( + telescopename="NenuFAR", + stnid="NenuFAR", + band="LBA", + antmodel="Hamaker-NEC4_Charrier_v1r1", + obstimebeg=start_time.datetime, + obsdur=duration.datetime, + obstimestp=time_step.datetime, + pointingdir=(skycoord.ra.rad, skycoord.dec.rad, "J2000"), + do_parallactic_rot=parallactic + ) + except NameError: + log.error( + "DreamBeam is not installed. " + "See installation instructions https://github.com/2baOrNot2ba/dreamBeam" + ) + raise # import numpy as np # times = start_time + TimeDelta(3600, format="sec")*np.arange(12) # frequencies = np.array([30e6, 50e6])*u.MHz diff --git a/nenupy/io/tf.py b/nenupy/io/tf.py index ba00cd4..b769feb 100644 --- a/nenupy/io/tf.py +++ b/nenupy/io/tf.py @@ -12,6 +12,7 @@ from dask.diagnostics import ProgressBar import astropy.units as u from astropy.time import Time +from astropy.coordinates import SkyCoord from typing import Union, Tuple, List import logging log = logging.getLogger(__name__) @@ -55,11 +56,11 @@ def __init__(self, self.frequency_range = [frequency_min.to_value(u.Hz), frequency_max.to_value(u.Hz)]*u.Hz self.beam = 0 self.dispersion_measure = None - # self.rotation_measure = None + self.rotation_measure = None self.rebin_dt = None self.rebin_df = None # self.jump_correction = False - # self.dreambeam_inputs = (None, None, None) + self.dreambeam = (None, None, None) self.correct_bandpass = True self.edge_channels_to_remove = 0 @@ -76,7 +77,7 @@ def beam(self, selected_beam: int) -> None: elif selected_beam not in self._available_beams: raise IndexError(f"Requested beam #{selected_beam} not found among available beam indices {self.available_beams}.") self._beam = selected_beam - log.info(f"Beam #{self._beam} selected.") + log.info(f"\tBeam #{self._beam} selected.") @property def time_range(self) -> Time: @@ -92,7 +93,7 @@ def time_range(self, selected_range: Time): if (selected_range[1] < self._time_min) or (selected_range[0] > self._time_max): log.warning("Requested time_range outside availaible data!") self._time_range = selected_range - log.info(f"Time range set: {selected_range[0].isot} to {selected_range[1].isot}") + log.info(f"\tTime range: {selected_range[0].isot} to {selected_range[1].isot}") @property def frequency_range(self) -> u.Quantity: @@ -108,7 +109,7 @@ def frequency_range(self, selected_range: u.Quantity): if (selected_range[1] < self._frequency_min) or (selected_range[0] > self._frequency_max): log.warning("Requested time_range outside availaible data!") self._frequency_range = selected_range - log.info(f"Frequency range set: {selected_range[0].to(u.MHz)} to {selected_range[1].to(u.MHz)}") + log.info(f"\tFrequency range: {selected_range[0].to(u.MHz)} to {selected_range[1].to(u.MHz)}") @property def edge_channels_to_remove(self) -> Union[int, Tuple[int, int]]: @@ -123,7 +124,7 @@ def edge_channels_to_remove(self, channels: Union[int, Tuple[int, int]]): elif not isinstance(channels, int): raise TypeError("Edge channels to remove muste be integers.") self._edge_channels_to_remove = channels - log.info("Edge channels to remove set: {channels}.") + log.info(f"\tEdge channels to remove set: {channels}.") @property def correct_bandpass(self) -> bool: @@ -133,14 +134,56 @@ def correct_bandpass(self, correct: bool): if not isinstance(correct, bool): raise TypeError("test") self._correct_bandpass = correct - log.info(f"Bandpass correction set: {correct}.") + log.info(f"\tBandpass correction: {correct}.") + + @property + def dreambeam(self) -> Tuple[float, SkyCoord, bool]: + return self._dreambeam + @dreambeam.setter + def dreambeam(self, db_inputs: Tuple[float, SkyCoord, bool]): + if len(db_inputs) != 3: + raise IndexError("dreambeam inputs should be a length 3 tuple.") + if db_inputs != (None,)*3: + if not isinstance(db_inputs[0], float): + raise TypeError("First element of dreambeam must be the time resolution in sec (float).") + if not isinstance(db_inputs[1], SkyCoord): + raise TypeError("Second element of dreambeam must be the tracked coordinates (astropy.SkyCoord).") + if not isinstance(db_inputs[2], bool): + raise TypeError("Third element of dreambeam must be the parallactic angle correction (bool).") + log.info(f"\tDreamBeam correction set (time_res, coord, parallactic)={db_inputs}") + self._dreambeam = db_inputs + + @property + def rotation_measure(self) -> u.Quantity: + return self._rotation_measure + @rotation_measure.setter + def rotation_measure(self, rm: u.Quantity): + if rm is None: + pass + elif isinstance(rm, u.Quantity): + log.info(f"\tRotation Measure set: {rm}") + else: + raise TypeError("RM should be an astropy.units.Quantity object.") + self._rotation_measure = rm + + @property + def dispersion_measure(self) -> u.Quantity: + return self._dispersion_measure + @dispersion_measure.setter + def dispersion_measure(self, dm: u.Quantity): + if dm is None: + pass + elif isinstance(dm, u.Quantity): + log.info(f"\tDispersion Measure set: {dm}") + else: + raise TypeError("DM should be an astropy.units.Quantity object.") + self._dispersion_measure = dm # ============================================================= # # -------------------------- Spectra -------------------------- # class Spectra: def __init__(self, filename: str): - log.info(f"\tReading {filename}...") self.filename = filename # Decode the main header and lazy load the data @@ -154,20 +197,9 @@ def __init__(self, filename: str): # Compute the boolean mask of bad blocks bad_block_mask = self._get_bad_data_mask(data) + # Compute the main data block descriptors (time / frequency / beam) self._block_start_unix = data["TIMESTAMP"][~bad_block_mask] + data["BLOCKSEQNUMBER"][~bad_block_mask] / SUBBAND_WIDTH.to_value(u.Hz) self._subband_start_hz = data["data"]["channel"][0, :] * SUBBAND_WIDTH.to_value(u.Hz) # Assumed constant over time - # Compute the frequency, time and beam axes - # self.frequency_hz = utils.compute_spectra_frequencies( - # n_channels=self.n_channels, - # n_subbands=self.n_subbands, - # frequency_step_hz=self.df.to_value(u.Hz), - # channel_indices=data["data"]["channel"][0, :] # Assumed constant over time - # ) - # self.time_unix = utils.compute_spectra_time( - # block_start_time_unix=data["TIMESTAMP"][~bad_block_mask] + data["BLOCKSEQNUMBER"][~bad_block_mask] / SUBBAND_WIDTH.to_value(u.Hz), - # ntime_per_block=self._n_time_per_block, - # time_step_s=self.dt.to_value(u.s) - # ) self.beam_indices_dict = utils.sort_beam_edges( beam_array=data["data"]["beam"][0], # Asummed same for all time step n_channels=self.n_channels, @@ -176,7 +208,7 @@ def __init__(self, filename: str): # Transform the data in Dask Array, once correctly reshaped self.data = self._assemble_to_tf(data=data, mask=bad_block_mask) - log.info("\tSetting up default configuration:") + log.info("Setting up default configuration:") self.configuration = _ProcessingConfiguration( time_min=self.time_min, time_max=self.time_max, @@ -187,6 +219,16 @@ def __init__(self, filename: str): # --------------------------------------------------------- # # --------------------- Getter/Setter --------------------- # + @property + def filename(self) -> str: + return self._filename + @filename.setter + def filename(self, name: str): + log.info(f"Reading {name}...") + if not name.endswith(".spectra"): + raise ValueError("A file whose extension is '.spectra' is expected.") + self._filename = name + @property def time_min(self) -> Time: return Time(self._block_start_unix[0], format="unix",precision=7) @@ -225,9 +267,10 @@ def info(self) -> None: f"df: {self.df.to(u.kHz)}", f"Available beam indices: {list(self.beam_indices_dict.keys())}" ]) - print(message) + print(message) def get(self, stokes: Union[str, List[str]] = "I"): + """ """ # Select the data in time and frequency (the beam selection is implicit on the frequency idexing) frequency_hz, time_unix, data = self._select_data() @@ -238,18 +281,51 @@ def get(self, stokes: Union[str, List[str]] = "I"): # Correct for the bandpass if self.configuration.correct_bandpass: - data = self._correct_bandpass(data=data, n_channels=self.n_channels) + data = utils.correct_bandpass(data=data, n_channels=self.n_channels) # Remove subband edge channels edge_chans = self.configuration.edge_channels_to_remove if edge_chans not in [0, (0, 0)]: - data = self._remove_edge_channels( + data = utils.crop_subband_edges( data=data, n_channels=self.n_channels, lower_edge_channels=edge_chans[0] if isinstance(edge_chans, tuple) else edge_chans, higher_edge_channels=edge_chans[1] if isinstance(edge_chans, tuple) else edge_chans, ) + # DreamBeam correction (beam gain + parallactic angle) + db_dt, db_coord, db_par = self.configuration.dreambeam + if not ((db_dt is None) or (db_coord is None) or (db_par is None)): + data = utils.apply_dreambeam_corrections( + time_unix=time_unix, + frequency_hz=frequency_hz, + data=data, + dt_sec=self.dt.to_value(u.s), + time_step_sec=db_dt, + n_channels=self.n_channels, + skycoord=db_coord, + parallactic=db_par + ) + + # De-faraday + if not (self.configuration.rotation_measure is None): + data = utils.de_faraday_data( + data=data, + frequency=frequency_hz*u.Hz, + rotation_measure=self.configuration.rotation_measure + ) + + # De-disperse array + if not (self.configuration.dispersion_measure is None): + tmp_chuncks = data.chunks + data = utils.de_disperse_array( + data=data.compute(), # forced to leave Dask + frequencies=frequency_hz*u.Hz, + time_step=self.dt, + dispersion_measure=self.configuration.dispersion_measure, + ) + data = da.from_array(data, chunks=tmp_chuncks) + # Rebin the data if self.configuration.rebin_dt is not None: log.info("Rebinning in time...") @@ -327,7 +403,7 @@ def _lazy_load_data(self) -> np.ndarray: with open(self.filename, "rb") as rf: tmp = np.memmap(rf, dtype="int8", mode="r") - log.info(f"{self.filename} has been correctly parsed.") + log.info(f"\t{self.filename} has been correctly parsed.") return tmp.view(np.dtype(global_struct)) @@ -335,7 +411,7 @@ def _lazy_load_data(self) -> np.ndarray: def _get_bad_data_mask(data: np.ndarray) -> np.ndarray: """ """ - log.info("\tChecking for missing data...") + log.info("Checking for missing data...") # Either the TIMESTAMP is set to 0, the first idx, or the SB number is negative # which indicates missing data. In all those cases we will ignore the associated data @@ -348,7 +424,7 @@ def _get_bad_data_mask(data: np.ndarray) -> np.ndarray: block_start_idx_mask[0] = False # Fake value, just to trick the mask bad_block_mask = block_timestamp_mask + block_start_idx_mask + block_nsubbands_mask - log.info(f"There are {np.sum(bad_block_mask)}/{block_timestamp_mask.size} blocks containing missing data and/or wrong time information.") + log.info(f"\tThere are {np.sum(bad_block_mask)}/{block_timestamp_mask.size} blocks containing missing data and/or wrong time information.") return bad_block_mask @@ -447,107 +523,3 @@ def _select_data(self) -> Tuple[np.ndarray, np.ndarray, da.Array]: log.debug(f"Data of shape {selected_data.shape} selected.") return frequency_hz.compute(), time_unix.compute(), selected_data - - @staticmethod - def _correct_bandpass(data: da.Array, n_channels: int) -> da.Array: - """ """ - - log.info("Correcting for bandpass...") - - # Compute the bandpass - bandpass = utils.get_bandpass(n_channels=n_channels) - - # Reshape the data array to isolate individual subbands - n_times, n_freqs, _, _ = data.shape - data = data.reshape( - ( - n_times, - int(n_freqs / n_channels), # subband - n_channels, # channels - 2, 2 - ) - ) - - # Multiply the channels by the bandpass to correct them - data *= bandpass[None, None, :, None, None] - - log.debug(f"\tEach subband corrected by the bandpass of size {bandpass.size}.") - - # Re-reshape the data into time, frequency, (2, 2) array - return data.reshape((n_times, n_freqs, 2, 2)) - - @staticmethod - def _remove_edge_channels(data: da.Array, n_channels: int, lower_edge_channels: int, higher_edge_channels: int) -> da.Array: - """ """ - - log.info("Removing edge channels...") - - # Reshape the data array to isolate individual subbands - n_times, n_freqs, _, _ = data.shape - data = data.reshape( - ( - n_times, - int(n_freqs / n_channels), # subband - n_channels, # channels - 2, 2 - ) - ) - - # Set to NaN edge channels - data[:, :, : lower_edge_channels, :, :] = np.nan # lower edge - data[:, :, n_channels - higher_edge_channels :, :] = np.nan # upper edge - data = data.reshape((n_times, n_freqs, 2, 2)) - - log.info( - f"\t{lower_edge_channels} lower and {higher_edge_channels} higher " - "band channels have been set to NaN at the subband edges." - ) - - return data - - def _time_frequency_rebin(self, data: da.Array, times: da.Array, freqs: da.Array) -> Tuple[da.Array, da.Array, da.Array]: - """ data: (time, frequency, ...) - .. versionadded:: 1.1.0 - """ - - ntimes_i, nfreqs_i, npols_i = data.shape - - if not (self.configuration.rebin_dt is None): - # Rebin in time - tbins = int(np.floor(self.configuration.rebin_dt / self.dt)) - log.info(f"Time-averaging {tbins} spectra, dt={tbins*self.dt}...") - ntimes = int(np.floor(ntimes_i / tbins)) - tleftover = ntimes_i % ntimes - log.info(f"Last {tleftover} spectra are left over for time-averaging.") - data = data[: -tleftover if tleftover != 0 else ntimes_i, :, :].reshape( - (ntimes, int((ntimes_i - tleftover) / ntimes), nfreqs_i, npols_i) - ) - times = times[: -tleftover if tleftover != 0 else ntimes_i].reshape( - (ntimes, int((ntimes_i - tleftover) / ntimes)) - ) - data = np.nanmean(data, axis=1) - times = np.nanmean(times, axis=1) - ntimes_i, nfreqs_i, npols_i = data.shape - log.info("Data are time-averaged.") - - if not (self.configuration.rebin_df is None): - # Rebin in frequency - fbins = int(np.floor(self.configuration.rebin_df / self.df)) - log.info(f"Frequency-averaging {fbins} channels: df={fbins*self.df}...") - nfreqs = int(np.floor(nfreqs_i / fbins)) - fleftover = nfreqs_i % nfreqs - log.info( - f"Last {fleftover} channels are left over for frequency-averaging." - ) - data = data[:, : -fleftover if fleftover != 0 else nfreqs_i, :].reshape( - (ntimes_i, nfreqs, int((nfreqs_i - fleftover) / nfreqs), npols_i) - ) - freqs = freqs[: -fleftover if fleftover != 0 else nfreqs_i].reshape( - (nfreqs, int((nfreqs_i - fleftover) / nfreqs)) - ) - data = np.nanmean(data, axis=2) - freqs = np.nanmean(freqs, axis=1) - log.info("Data are frequency-averaged.") - - return freqs, times, data - diff --git a/nenupy/io/tf_utils.py b/nenupy/io/tf_utils.py index 03b866c..48b9ee1 100644 --- a/nenupy/io/tf_utils.py +++ b/nenupy/io/tf_utils.py @@ -6,22 +6,130 @@ import os import dask.array as da import astropy.units as u -from astropy.time import Time +from astropy.coordinates import SkyCoord +from astropy.time import Time, TimeDelta from typing import Union, List, Tuple import logging log = logging.getLogger(__name__) +from nenupy.astro import dispersion_delay, faraday_angle +from nenupy.astro.beam_correction import compute_jones_matrices __all__ = [ "blocks_to_tf_data", "compute_spectra_frequencies", "compute_spectra_time", "compute_stokes_parameters", + "correct_bandpass", + "crop_subband_edges", + "de_disperse_array", + "de_faraday_data", + "get_bandpass", + "polarization_angle", "rebin_along_dimension", "sort_beam_edges", "spectra_data_to_matrix" ] +def apply_dreambeam_corrections( + time_unix: np.ndarray, + frequency_hz: np.ndarray, + data: np.ndarray, + dt_sec: float, + time_step_sec: float, + n_channels: int, + skycoord: SkyCoord, + parallactic: bool = True + ) -> np.ndarray: + """ """ + + log.info("Applying DreamBeam corrections...") + + # Basic checks to make sure the dimensions are correct + freq_size = frequency_hz.size + time_size = time_unix.size + if time_size != data.shape[0]: + raise ValueError("There is a problem in the time dimension!") + if (freq_size != data.shape[1]) or (freq_size % n_channels != 0): + raise ValueError("There is a problem in the frequency dimension!") + n_subbands = int(freq_size/n_channels) + + # Compute the number of time samples that will be corrected together + time_group_size = int(np.round(time_step_sec/dt_sec)) + log.debug(f"\tGroups of {time_group_size} time blocks will be corrected altogether ({dt_sec*time_group_size} sec resolution).") + n_time_groups = time_size // time_group_size + leftover_time_samples = time_size % time_group_size + + # Computing DreamBeam matrices + db_time, db_frequency, db_jones = compute_jones_matrices( + start_time=Time(time_unix[0], format="unix", precision=7), + time_step=TimeDelta(time_group_size * dt_sec, format="sec"), + duration=TimeDelta(time_unix[-1] - time_unix[0], format="sec"), + skycoord=skycoord, + parallactic=parallactic + ) + db_time = db_time.unix + db_frequency = db_frequency.to_value(u.Hz) + db_jones = np.swapaxes(db_jones, 0, 1) + + # Invert the matrices that will be used to correct the observed signals + # Jones matrices are at the subband resolution and an arbitrary time resolution + jones_matrices = np.linalg.inv(db_jones) + + # Reshape the data at the time and frequency resolutions + # Take into account leftover times + data = data[: time_size - leftover_time_samples, ...].reshape( + ( + n_time_groups, + time_group_size, + n_subbands, + n_channels, + 2, 2 + ) + ) + data_leftover = data[-leftover_time_samples:, ...].reshape( + ( + leftover_time_samples, + n_subbands, + n_channels, + 2, 2 + ) + ) + + # Compute the frequency indices to select the corresponding Jones matrices + subband_start_frequencies = frequency_hz.reshape((n_subbands, n_channels))[:, 0] + freq_start_idx = np.argmax(db_frequency >= subband_start_frequencies[0]) + freq_stop_idx = db_frequency.size - np.argmax(db_frequency[::-1] < subband_start_frequencies[-1]) + + # Do the same with the time + group_start_time = time_unix[: time_size - leftover_time_samples].reshape((n_time_groups, time_group_size))[:, 0] + time_start_idx = np.argmax(db_time >= group_start_time[0]) + time_stop_idx = db_time.size - np.argmax(db_time[::-1] < group_start_time[-1]) + + jones = jones_matrices[time_start_idx:time_stop_idx + 1, freq_start_idx:freq_stop_idx + 1, :, :][:, None, :, None, :, :] + jones_leftover = jones_matrices[-1, freq_start_idx:freq_stop_idx + 1, :, :][None, :, None, :, :] + + # Compute the Hermitian matrices + jones_transpose = np.swapaxes(jones, -2, -1) + jones_leftover_transpose = np.swapaxes(jones_leftover, -2, -1) + jones_hermitian = np.conjugate(jones_transpose) + jones_leftover_hermitian = np.conjugate(jones_leftover_transpose) + + # This would raise an indexerror if jones_values are at smaller t/f range than data + return np.concatenate( + ( + np.matmul( + jones, + np.matmul(data, jones_hermitian) + ).reshape((time_size - leftover_time_samples, freq_size, 2, 2)), + np.matmul( + jones_leftover, + np.matmul(data_leftover, jones_leftover_hermitian) + ).reshape((leftover_time_samples, freq_size, 2, 2)) + ), + axis=0 + ) + # ============================================================= # # --------------------- blocks_to_tf_data --------------------- # def blocks_to_tf_data(data: da.Array, n_block_times: int, n_channels: int) -> da.Array: @@ -114,6 +222,12 @@ def compute_stokes_parameters(data_array: np.ndarray, stokes: Union[List[str], s data_i = data_array[..., 0, 1].real * 2 elif stokes_i.upper() == "V": data_i = data_array[..., 0, 1].imag * 2 + elif stokes_i.upper() == "Q/I": + data_i = (data_array[..., 0, 0].real - data_array[..., 1, 1].real)/(data_array[..., 0, 0].real + data_array[..., 1, 1].real) + elif stokes_i.upper() == "U/I": + data_i = data_array[..., 0, 1].real * 2 / (data_array[..., 0, 0].real + data_array[..., 1, 1].real) + elif stokes_i.upper() == "V/I": + data_i = data_array[..., 0, 1].imag * 2 / (data_array[..., 0, 0].real + data_array[..., 1, 1].real) else: raise NotImplementedError(f"Stokes parameter {stokes_i} unknown.") @@ -127,6 +241,166 @@ def compute_stokes_parameters(data_array: np.ndarray, stokes: Union[List[str], s return result +# ============================================================= # +# --------------------- correct_bandpass ---------------------- # +def correct_bandpass(data: np.ndarray, n_channels: int) -> np.ndarray: + """ """ + + log.info("Correcting for bandpass...") + + # Compute the bandpass + bandpass = get_bandpass(n_channels=n_channels) + + # Reshape the data array to isolate individual subbands + n_times, n_freqs, _, _ = data.shape + if n_freqs % n_channels != 0: + raise ValueError("The frequency dimension of `data` doesn't match the argument `n_channels`.") + data = data.reshape( + ( + n_times, + int(n_freqs / n_channels), # subband + n_channels, # channels + 2, 2 + ) + ) + + # Multiply the channels by the bandpass to correct them + data *= bandpass[None, None, :, None, None] + + log.debug(f"\tEach subband corrected by the bandpass of size {bandpass.size}.") + + # Re-reshape the data into time, frequency, (2, 2) array + return data.reshape((n_times, n_freqs, 2, 2)) + +# ============================================================= # +# -------------------- crop_subband_edges --------------------- # +def crop_subband_edges(data: np.ndarray, n_channels: int, lower_edge_channels: int, higher_edge_channels: int) -> np.ndarray: + """ """ + + log.info("Removing edge channels...") + + if lower_edge_channels + higher_edge_channels >= n_channels: + raise ValueError(f"{lower_edge_channels + higher_edge_channels} channels to crop out of {n_channels} channels subbands.") + + # Reshape the data array to isolate individual subbands + n_times, n_freqs, _, _ = data.shape + if n_freqs % n_channels != 0: + raise ValueError("The frequency dimension of `data` doesn't match the argument `n_channels`.") + + data = data.reshape( + ( + n_times, + int(n_freqs / n_channels), # subband + n_channels, # channels + 2, 2 + ) + ) + + # Set to NaN edge channels + data[:, :, : lower_edge_channels, :, :] = np.nan # lower edge + data[:, :, n_channels - higher_edge_channels :, :] = np.nan # upper edge + data = data.reshape((n_times, n_freqs, 2, 2)) + + log.info( + f"\t{lower_edge_channels} lower and {higher_edge_channels} higher " + "band channels have been set to NaN at the subband edges." + ) + + return data + +# ============================================================= # +# --------------------- de_disperse_array --------------------- # +def de_disperse_array( + data: np.ndarray, + frequencies: u.Quantity, + time_step: u.Quantity, + dispersion_measure: u.Quantity, +) -> np.ndarray: + """De-disperse in time an array ``data`` whose first two + dimensions are time and frequency respectively. The array + must be regularly sampled in time. The de-dispersion is made + relatively to the highest frequency. De-dedispersed array + is filled with ``NaN`` in time-frequency places where the + shifted values were. + + :param data: + Data array to de-disperse. + :type data: + :class:`~numpy.ndarray` + :param frequencies: + 1D array of frequencies corresponding to the second + dimension of ``data``. + :type frequencies: + :class:`~astropy.units.Quantity` + :param time_step: + Time step between two spectra. + :type time_step: + :class:`~astropy.units.Quantity` + :param dispersion_measure: + Dispersion Measure (in pc/cm3). + :type dispersion_measure: + :class:`~astropy.units.Quantity` + """ + + log.info("De-dispersing data...") + + if data.ndim < 2: + raise Exception( + f"Input data is {data.shape}. >2D array is required " + "(time, frequency, ...)." + ) + if data.shape[1] != frequencies.size: + raise ValueError( + f"The size of frequencies ({frequencies.size}) does " + f"not match dimension 1 of data ({data.shape[1]})." + ) + + # Compute the relative delays + delays = dispersion_delay( + frequency=frequencies, dispersion_measure=dispersion_measure + ) + delays -= dispersion_delay( + frequency=frequencies.max(), dispersion_measure=dispersion_measure + ) + + # Convert the delays into indices + cell_delays = np.round((delays / time_step).decompose().to_value()).astype(int) + + # Shift the array in time + for i in range(frequencies.size): + data[:, i, ...] = np.roll(data[:, i, ...], -cell_delays[i], 0) + # # Mask right edge of dynspec + data[-cell_delays[i] :, i, ...] = np.nan + + return data + +# ============================================================= # +# ---------------------- de_faraday_data ---------------------- # +def de_faraday_data(data: np.ndarray, frequency: u.Quantity, rotation_measure: u.Quantity) -> np.ndarray: + """ """ + + log.info("Correcting for Faraday rotation...") + + # Check the dimensions + if (data.ndim != 4) or (data.shape[1:] != (frequency.size, 2, 2)): + raise Exception("Wrong data dimensions!") + + # Computing the Faraday angles compared to infinite frequency + log.info(f"\tComputing {frequency.size} Faraday rotation angles at the RM={rotation_measure}...") + rotation_angle = faraday_angle( + frequency=frequency, + rotation_measure=rotation_measure, + inverse=True + ).to_value(u.rad) + + log.info("\tApplying Faraday rotation Jones matrices...") + cosa = np.cos(rotation_angle) + sina = np.sin(rotation_angle) + jones = np.transpose(np.array([[cosa, -sina], [sina, cosa]]), (2, 0, 1)) + jones_transpose = np.transpose(jones, (0, 2, 1)) + + return np.matmul(np.matmul(jones, data), jones_transpose) + # ============================================================= # # ----------------------- get_bandpass ------------------------ # def get_bandpass(n_channels: int) -> np.ndarray: @@ -151,6 +425,14 @@ def get_bandpass(n_channels: int) -> np.ndarray: return g**2.0 +# ============================================================= # +# -------------------- polarization_angle --------------------- # +def polarization_angle(stokes_u: np.ndarray, stokes_q: np.ndarray) -> np.ndarray: + """ """ + return 0.5 * np.arctan(stokes_u / stokes_q) + +# ============================================================= # +# ------------------- rebin_along_dimension ------------------- # def rebin_along_dimension(data: np.ndarray, axis_array: np.ndarray, axis: int, dx: float, new_dx: float) -> Tuple[np.ndarray, np.ndarray]: """ """