Skip to content

Commit

Permalink
Better log
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanLoh committed Oct 23, 2023
1 parent 00e3084 commit 885a739
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 27 deletions.
70 changes: 49 additions & 21 deletions nenupy/io/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

SUBBAND_WIDTH = 195312.5 * u.Hz


# ============================================================= #
# ----------------- _ProcessingConfiguration ------------------ #
class _ProcessingConfiguration:
Expand All @@ -46,32 +45,35 @@ def __init__(self,
frequency_max: u.Quantity,
available_beams: np.ndarray
):
self.available_beams = available_beams
self.time_min = time_min
self.time_max = time_max
self.frequency_min = frequency_min
self.frequency_max = frequency_max
self._available_beams = available_beams
self._time_min = time_min
self._time_max = time_max
self._frequency_min = frequency_min
self._frequency_max = frequency_max

self.time_range = Time([time_min.isot, time_max.isot], precision=max(time_min.precision, time_max.precision))
self.frequency_range = [frequency_min.to_value(u.Hz), frequency_max.to_value(u.Hz)]*u.Hz
self.beam = 0
self.dispersion_measure = None
self.rotation_measure = None
# self.rotation_measure = None
self.rebin_dt = None
self.rebin_df = None
self.jump_correction = False
self.dreambeam_inputs = (None, None, None)
# self.jump_correction = False
# self.dreambeam_inputs = (None, None, None)
self.correct_bandpass = True
self.edge_channels_to_remove = 0

def __repr__(self) -> str:
return "\n".join([f"{attr}: {getattr(self, attr)}" for attr in dir(self) if not attr.startswith("_")])

@property
def beam(self) -> int:
return self._beam
@beam.setter
def beam(self, selected_beam: int) -> None:
if not isinstance(selected_beam, int):
raise TypeError("Selected beam is expected as an integer value.")
elif selected_beam not in self.available_beams:
elif selected_beam not in self._available_beams:
raise IndexError(f"Requested beam #{selected_beam} not found among available beam indices {self.available_beams}.")
self._beam = selected_beam
log.info(f"Beam #{self._beam} selected.")
Expand All @@ -87,7 +89,7 @@ def time_range(self, selected_range: Time):
raise ValueError("time_range should be a length-2 Time array.")
if selected_range[0] >= selected_range[1]:
raise ValueError("time_range start >= stop.")
if (selected_range[1] < self.time_min) or (selected_range[0] > self.time_max):
if (selected_range[1] < self._time_min) or (selected_range[0] > self._time_max):
log.warning("Requested time_range outside availaible data!")
self._time_range = selected_range
log.info(f"Time range set: {selected_range[0].isot} to {selected_range[1].isot}")
Expand All @@ -103,6 +105,8 @@ def frequency_range(self, selected_range: u.Quantity):
raise ValueError("frequency_range should be a length-2 Quantity array.")
if selected_range[0] >= selected_range[1]:
raise ValueError("frequency_range min >= max.")
if (selected_range[1] < self._frequency_min) or (selected_range[0] > self._frequency_max):
log.warning("Requested time_range outside availaible data!")
self._frequency_range = selected_range
log.info(f"Frequency range set: {selected_range[0].to(u.MHz)} to {selected_range[1].to(u.MHz)}")

Expand All @@ -115,10 +119,21 @@ def edge_channels_to_remove(self, channels: Union[int, Tuple[int, int]]):
if not len(channels) == 2:
raise IndexError("If a `tuple` is given to the edge_channels_to_remove argument, it must be of length 2: (lower_edge_channels_to_remove, higher_edge_channels_to_remove).")
elif not np.all([isinstance(chan, int) for chan in channels]):
raise TypeError("Edge channels to remove muste be integers.")
raise TypeError("Edge channels to remove must be integers.")
elif not isinstance(channels, int):
raise TypeError("Edge channels to remove muste be integers.")
self._edge_channels_to_remove = channels
log.info("Edge channels to remove set: {channels}.")

@property
def correct_bandpass(self) -> bool:
return self._correct_bandpass
@correct_bandpass.setter
def correct_bandpass(self, correct: bool):
if not isinstance(correct, bool):
raise TypeError("test")
self._correct_bandpass = correct
log.info(f"Bandpass correction set: {correct}.")

# ============================================================= #
# -------------------------- Spectra -------------------------- #
Expand Down Expand Up @@ -201,7 +216,7 @@ def frequency_max(self) -> u.Quantity:
def info(self) -> None:
""" Display informations about the file. """
message = "\n".join([
f"{self.filename}",
f"filename: {self.filename}",
f"time_min: {self.time_min.isot}",
f"time_max: {self.time_max.isot}",
f"dt: {self.dt.to(u.ms)}",
Expand All @@ -223,13 +238,11 @@ def get(self, stokes: Union[str, List[str]] = "I"):

# Correct for the bandpass
if self.configuration.correct_bandpass:
log.info("Correcting for bandpass.")
data = self._correct_bandpass(data=data, n_channels=self.n_channels)

# Remove subband edge channels
edge_chans = self.configuration.edge_channels_to_remove
if edge_chans not in [0, (0, 0)]:
log.info("\tRemoving edge channels...")
data = self._remove_edge_channels(
data=data,
n_channels=self.n_channels,
Expand All @@ -239,7 +252,7 @@ def get(self, stokes: Union[str, List[str]] = "I"):

# Rebin the data
if self.configuration.rebin_dt is not None:
log.info("\tRebinning in time...")
log.info("Rebinning in time...")
time_unix, data = utils.rebin_along_dimension(
data=data,
axis_array=time_unix,
Expand All @@ -248,21 +261,23 @@ def get(self, stokes: Union[str, List[str]] = "I"):
new_dx=self.configuration.rebin_dt.to_value(u.s)
)
if self.configuration.rebin_df is not None:
log.info("\tRebinning in frequency...")
log.info("Rebinning in frequency...")
frequency_hz, data = utils.rebin_along_dimension(
data=data,
axis_array=frequency_hz,
axis=1,
dx=self.df.to_value(u.Hz),
new_dx=self.configuration.rebin_df.to_value(u.Hz)
)
log.info(f"Shape of rebinned data: {data.shape}.")

# Compute the selected Stokes parameters
data = utils.compute_stokes_parameters(data_array=data, stokes=stokes)

log.info("Computing the data...")
with ProgressBar():
data = data.compute()
log.info(f"\tData of shape (time, frequency, stokes) = {data.shape} produced.")

return SData(
data=data,
Expand Down Expand Up @@ -362,23 +377,27 @@ def _assemble_to_tf(self, data: np.ndarray, mask: np.ndarray) -> da.Array:
def _select_data(self) -> Tuple[np.ndarray, np.ndarray, da.Array]:
""" """

log.info("\tComputing the time selection...")
log.info("Computing the time selection...")
tmin, tmax = self.configuration.time_range.unix

# Find out which block indices are at the edges of the desired time range
block_idx_min = int(np.argmin(np.abs(np.ceil(self._block_start_unix - tmin))))# n_blocks - np.argmax(((self._block_start_unix - tmin) <= 0)[::-1]) - 1
block_idx_max = int(np.argmin(np.abs(np.ceil(self._block_start_unix - tmax))))# n_blocks - np.argmax(((self._block_start_unix - tmax) <= 0)[::-1]) - 1
log.info(f"\tClosest time block from requested range are #{block_idx_min} and #{block_idx_max}.")

# Get the closest time index within each of the two bounding blocks
dt_sec = self.dt.to_value(u.s)
# Compute the time index within the block and bound it between 0 and the number of spectra in each block
time_idx_min_in_block = int(np.round((tmin - self._block_start_unix[block_idx_min])/dt_sec))
time_idx_min_in_block = max(0, min(self._n_time_per_block - 1, time_idx_min_in_block)) # bound the value between in between channels indices
time_idx_min = block_idx_min * self._n_time_per_block + time_idx_min_in_block
# Do the same for the higher edge of the desired time range
time_idx_max_in_block = int(np.round((tmax - self._block_start_unix[block_idx_max])/dt_sec))
time_idx_max_in_block = self._n_time_per_block - 1 if time_idx_max_in_block >= self._n_time_per_block else time_idx_max_in_block
time_idx_max_in_block = max(0, min(self._n_time_per_block - 1, time_idx_max_in_block))
time_idx_max = block_idx_max * self._n_time_per_block + time_idx_max_in_block
log.info(f"\t{time_idx_max - time_idx_min + 1} time samples selected.")

# Raise warnings if the time selection results in no data selected
if time_idx_min == time_idx_max:
if (time_idx_min > 0) and (time_idx_min < self._block_start_unix.size * self._n_time_per_block - 1):
log.warning("Desired time selection encompasses missing data.")
Expand All @@ -403,14 +422,15 @@ def _select_data(self) -> Tuple[np.ndarray, np.ndarray, da.Array]:
# Cut down the first and last time blocks
time_unix = time_unix[time_idx_min_in_block:time_unix.size - (self._n_time_per_block - time_idx_max_in_block) + 1]

log.info("\tComputing the frequency selection...")
log.info("Computing the frequency selection...")
fmin, fmax = self.configuration.frequency_range.to_value(u.Hz)
beam_idx_start, beam_idx_stop = self.beam_indices_dict[str(self.configuration.beam)]

# Find out the subband edges covering the selected frequency range
subbands_in_beam = self._subband_start_hz[int(beam_idx_start/self.n_channels):int((beam_idx_stop + 1)/self.n_channels)]
sb_idx_min = int(np.argmin(np.abs(np.ceil(subbands_in_beam - fmin))))
sb_idx_max = int(np.argmin(np.abs(np.ceil(subbands_in_beam - fmax))))
log.info(f"\tClosest beamlet indices from requested range are #{sb_idx_min} and #{sb_idx_max}.")

# Select frequencies at the subband granularity at minimum
# Later, we want to correct for bandpass, edge channels and so on...
Expand All @@ -421,15 +441,19 @@ def _select_data(self) -> Tuple[np.ndarray, np.ndarray, da.Array]:
n_channels=self.n_channels,
frequency_step_hz=self.df.to_value(u.Hz)
)
log.info(f"\t{frequency_idx_max - frequency_idx_min} frequency samples selected.")

selected_data = self.data[:, beam_idx_start:beam_idx_stop + 1, ...][time_idx_min:time_idx_max + 1, frequency_idx_min:frequency_idx_max, ...]
log.debug(f"Data of shape {selected_data.shape} selected.")

return frequency_hz.compute(), time_unix.compute(), selected_data

@staticmethod
def _correct_bandpass(data: da.Array, n_channels: int) -> da.Array:
""" """

log.info("Correcting for bandpass...")

# Compute the bandpass
bandpass = utils.get_bandpass(n_channels=n_channels)

Expand All @@ -447,13 +471,17 @@ def _correct_bandpass(data: da.Array, n_channels: int) -> da.Array:
# Multiply the channels by the bandpass to correct them
data *= bandpass[None, None, :, None, None]

log.debug(f"\tEach subband corrected by the bandpass of size {bandpass.size}.")

# Re-reshape the data into time, frequency, (2, 2) array
return data.reshape((n_times, n_freqs, 2, 2))

@staticmethod
def _remove_edge_channels(data: da.Array, n_channels: int, lower_edge_channels: int, higher_edge_channels: int) -> da.Array:
""" """

log.info("Removing edge channels...")

# Reshape the data array to isolate individual subbands
n_times, n_freqs, _, _ = data.shape
data = data.reshape(
Expand All @@ -471,7 +499,7 @@ def _remove_edge_channels(data: da.Array, n_channels: int, lower_edge_channels:
data = data.reshape((n_times, n_freqs, 2, 2))

log.info(
f"{lower_edge_channels} lower and {higher_edge_channels} higher "
f"\t{lower_edge_channels} lower and {higher_edge_channels} higher "
"band channels have been set to NaN at the subband edges."
)

Expand Down
22 changes: 16 additions & 6 deletions nenupy/io/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ def blocks_to_tf_data(data: da.Array, n_block_times: int, n_channels: int) -> da
def compute_spectra_frequencies(subband_start_hz: np.ndarray, n_channels: int, frequency_step_hz: float) -> da.Array:
""" """

log.info("\tComputing the frequency ramp...")

# Construct the frequency array
frequencies = da.tile(np.arange(n_channels) - n_channels / 2, subband_start_hz.size)
frequencies = frequencies.reshape((subband_start_hz.size, n_channels))
frequencies *= frequency_step_hz
frequencies += subband_start_hz[:, None]
frequencies = frequencies.ravel()

log.info(f"Frequency axis computed (size={frequencies.size}).")
log.debug(f"\tFrequency axis computed (size={frequencies.size}).")

return frequencies

Expand All @@ -79,8 +77,6 @@ def compute_spectra_frequencies(subband_start_hz: np.ndarray, n_channels: int, f
def compute_spectra_time(block_start_time_unix: np.ndarray, ntime_per_block: int, time_step_s: float) -> da.Array:
""" """

log.info("\tComputing the time ramp...")

# Construct the elapsed time per block (1D array)
time_seconds_per_block = da.arange(ntime_per_block, dtype="float64") * time_step_s

Expand All @@ -90,7 +86,7 @@ def compute_spectra_time(block_start_time_unix: np.ndarray, ntime_per_block: int
# Return the flatten array
unix_time = unix_time.ravel()

log.info(f"Time axis computed (size={unix_time.size}).")
log.debug(f"\tTime axis computed (size={unix_time.size}).")

return unix_time

Expand All @@ -99,9 +95,15 @@ def compute_spectra_time(block_start_time_unix: np.ndarray, ntime_per_block: int
def compute_stokes_parameters(data_array: np.ndarray, stokes: Union[List[str], str]) -> np.ndarray:
""" data_array: >2 D, last 2 dimensions are ((XX, XY), (YX, YY))
"""

log.info("Computing Stokes parameters...")

# Assert that the last dimensions are shaped like a cross correlation electric field matrix
if data_array.shape[-2:] != (2, 2):
raise Exception("The data_array last 2 dimensions are not of shape (2, 2).")

result = None

for stokes_i in stokes:
# Compute the correct Stokes value
if stokes_i.upper() == "I":
Expand All @@ -114,11 +116,15 @@ def compute_stokes_parameters(data_array: np.ndarray, stokes: Union[List[str], s
data_i = data_array[..., 0, 1].imag * 2
else:
raise NotImplementedError(f"Stokes parameter {stokes_i} unknown.")

log.info(f"\tStokes {stokes_i} computed.")

# Stack everything
if result is None:
result = np.expand_dims(data_i, axis=-1)
else:
result = np.concatenate([result, data_i[..., None]], axis=-1)

return result

# ============================================================= #
Expand Down Expand Up @@ -163,6 +169,8 @@ def rebin_along_dimension(data: np.ndarray, axis_array: np.ndarray, axis: int, d

d_shape = data.shape

log.info(f"\tdx: {dx} | new_dx: {new_dx} -> rebin factor: {bin_size}.")

# Reshape the data and the axis to ease the averaging
data = data[tuple([slice(None) if i != axis else slice(None, initial_size - leftovers) for i in range(len(d_shape))])].reshape(
d_shape[:axis] + (final_size, int((initial_size - leftovers) / final_size)) + d_shape[axis + 1:]
Expand All @@ -175,6 +183,8 @@ def rebin_along_dimension(data: np.ndarray, axis_array: np.ndarray, axis: int, d
data = np.nanmean(data, axis=axis + 1)
axis_array = np.nanmean(axis_array, axis=1)

log.info(f"\tData rebinned, last {leftovers} samples were not considered.")

return axis_array, data

# ============================================================= #
Expand Down

0 comments on commit 885a739

Please sign in to comment.