diff --git a/nenupy/io/tf.py b/nenupy/io/tf.py index 364099e..d58b89d 100644 --- a/nenupy/io/tf.py +++ b/nenupy/io/tf.py @@ -95,11 +95,12 @@ class TFTask: """ - def __init__(self, name: str, func: Callable, args_to_update: List[str] = []): + def __init__(self, name: str, func: Callable, args_to_update: List[str] = [], repeatable: bool = False): self.name = name self.is_active = False self._func = func self.args_to_update = args_to_update + self.repeatable = repeatable self._extra_params = {} def __repr__(self) -> str: @@ -118,7 +119,7 @@ def correct_bandpass(cls): def wrapper_task(data, channels): return utils.correct_bandpass(data=data, n_channels=channels) - return cls("Correct bandpass", wrapper_task, ["channels"]) + return cls("Correct bandpass", wrapper_task, ["channels"], repeatable=False) @classmethod def remove_channels(cls): @@ -132,7 +133,7 @@ def wrapper_task(data, channels, remove_channels): ) return cls( - "Remove subband channels", wrapper_task, ["channels", "remove_channels"] + "Remove subband channels", wrapper_task, ["channels", "remove_channels"], repeatable=False ) @classmethod @@ -176,6 +177,7 @@ def wrapper_task( "dreambeam_dt", "dreambeam_parallactic", ], + repeatable=False ) @classmethod @@ -191,7 +193,7 @@ def apply_faraday(frequency_hz, data, rotation_measure): rotation_measure=rotation_measure, ) - return cls("Correct faraday rotation", apply_faraday, ["rotation_measure"]) + return cls("Correct faraday rotation", apply_faraday, ["rotation_measure"], repeatable=False) @classmethod def de_disperse(cls): @@ -235,6 +237,7 @@ def wrapper_task( "De-disperse", wrapper_task, ["dt", "dispersion_measure", "ignore_volume_warning"], + repeatable=False ) @classmethod @@ -251,7 +254,7 @@ def rebin_time(time_unix, data, dt, rebin_dt): new_dx=rebin_dt.to_value(u.s), ) - return cls("Rebin in time", rebin_time, ["dt", "rebin_dt"]) + return cls("Rebin in time", rebin_time, ["dt", "rebin_dt"], repeatable=True) @classmethod def frequency_rebin(cls): @@ -267,7 +270,7 @@ def rebin_freq(frequency_hz, data, df, rebin_df): new_dx=rebin_df.to_value(u.Hz), ) - return cls("Rebin in frequency", rebin_freq, ["df", "rebin_df"]) + return cls("Rebin in frequency", rebin_freq, ["df", "rebin_df"], repeatable=True) @classmethod def get_stokes(cls): @@ -276,7 +279,7 @@ def compute_stokes(data, stokes): return data return utils.compute_stokes_parameters(data_array=data, stokes=stokes) - return cls("Compute Stokes parameters", compute_stokes, ["stokes"]) + return cls("Compute Stokes parameters", compute_stokes, ["stokes"], repeatable=False) def update(self, parameters: utils.TFPipelineParameters) -> None: log.debug(f"Updating TFTask {self.name} parameters before running it.") @@ -458,6 +461,9 @@ def insert(self, operation: TFTask, index: int) -> None: """ if operation.__class__.__name__ != TFTask.__name__: raise TypeError(f"Tried to append {type(operation)} instead of {TFTask}.") + if self.contains(operation.name) and (not operation.repeatable): + log.warning(f"{operation} is already registered in the pipeline and is not repeatable.") + return self.tasks.insert(index, operation) def append(self, operation: TFTask) -> None: @@ -475,6 +481,9 @@ def append(self, operation: TFTask) -> None: """ if operation.__class__.__name__ != TFTask.__name__: raise TypeError(f"Tried to append {type(operation)} instead of {TFTask}.") + if self.contains(operation.name) and (not operation.repeatable): + log.warning(f"{operation} is already registered in the pipeline and is not repeatable.") + return self.tasks.append(operation) def remove(self, *args: Union[str, int]) -> None: