Skip to content

Commit

Permalink
prevent certain TFTasks from being repeated
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanLoh committed Feb 1, 2024
1 parent fcd3ac6 commit ade4174
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions nenupy/io/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ class TFTask:
"""

def __init__(self, name: str, func: Callable, args_to_update: List[str] = []):
def __init__(self, name: str, func: Callable, args_to_update: List[str] = [], repeatable: bool = False):
self.name = name
self.is_active = False
self._func = func
self.args_to_update = args_to_update
self.repeatable = repeatable
self._extra_params = {}

def __repr__(self) -> str:
Expand All @@ -118,7 +119,7 @@ def correct_bandpass(cls):
def wrapper_task(data, channels):
return utils.correct_bandpass(data=data, n_channels=channels)

return cls("Correct bandpass", wrapper_task, ["channels"])
return cls("Correct bandpass", wrapper_task, ["channels"], repeatable=False)

@classmethod
def remove_channels(cls):
Expand All @@ -132,7 +133,7 @@ def wrapper_task(data, channels, remove_channels):
)

return cls(
"Remove subband channels", wrapper_task, ["channels", "remove_channels"]
"Remove subband channels", wrapper_task, ["channels", "remove_channels"], repeatable=False
)

@classmethod
Expand Down Expand Up @@ -176,6 +177,7 @@ def wrapper_task(
"dreambeam_dt",
"dreambeam_parallactic",
],
repeatable=False
)

@classmethod
Expand All @@ -191,7 +193,7 @@ def apply_faraday(frequency_hz, data, rotation_measure):
rotation_measure=rotation_measure,
)

return cls("Correct faraday rotation", apply_faraday, ["rotation_measure"])
return cls("Correct faraday rotation", apply_faraday, ["rotation_measure"], repeatable=False)

@classmethod
def de_disperse(cls):
Expand Down Expand Up @@ -235,6 +237,7 @@ def wrapper_task(
"De-disperse",
wrapper_task,
["dt", "dispersion_measure", "ignore_volume_warning"],
repeatable=False
)

@classmethod
Expand All @@ -251,7 +254,7 @@ def rebin_time(time_unix, data, dt, rebin_dt):
new_dx=rebin_dt.to_value(u.s),
)

return cls("Rebin in time", rebin_time, ["dt", "rebin_dt"])
return cls("Rebin in time", rebin_time, ["dt", "rebin_dt"], repeatable=True)

@classmethod
def frequency_rebin(cls):
Expand All @@ -267,7 +270,7 @@ def rebin_freq(frequency_hz, data, df, rebin_df):
new_dx=rebin_df.to_value(u.Hz),
)

return cls("Rebin in frequency", rebin_freq, ["df", "rebin_df"])
return cls("Rebin in frequency", rebin_freq, ["df", "rebin_df"], repeatable=True)

@classmethod
def get_stokes(cls):
Expand All @@ -276,7 +279,7 @@ def compute_stokes(data, stokes):
return data
return utils.compute_stokes_parameters(data_array=data, stokes=stokes)

return cls("Compute Stokes parameters", compute_stokes, ["stokes"])
return cls("Compute Stokes parameters", compute_stokes, ["stokes"], repeatable=False)

def update(self, parameters: utils.TFPipelineParameters) -> None:
log.debug(f"Updating TFTask {self.name} parameters before running it.")
Expand Down Expand Up @@ -458,6 +461,9 @@ def insert(self, operation: TFTask, index: int) -> None:
"""
if operation.__class__.__name__ != TFTask.__name__:
raise TypeError(f"Tried to append {type(operation)} instead of {TFTask}.")
if self.contains(operation.name) and (not operation.repeatable):
log.warning(f"{operation} is already registered in the pipeline and is not repeatable.")
return
self.tasks.insert(index, operation)

def append(self, operation: TFTask) -> None:
Expand All @@ -475,6 +481,9 @@ def append(self, operation: TFTask) -> None:
"""
if operation.__class__.__name__ != TFTask.__name__:
raise TypeError(f"Tried to append {type(operation)} instead of {TFTask}.")
if self.contains(operation.name) and (not operation.repeatable):
log.warning(f"{operation} is already registered in the pipeline and is not repeatable.")
return
self.tasks.append(operation)

def remove(self, *args: Union[str, int]) -> None:
Expand Down

0 comments on commit ade4174

Please sign in to comment.