diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 43d34dd9..3102ac3e 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -224,7 +224,7 @@ def _output_cast(self, output): return output.to_host().astype(np_dtypes[self.dest_dtype]) case _: return output - + def save_output(self, function_name, output): if isinstance(output, tuple) or isinstance(output, list): for i in output: @@ -257,7 +257,7 @@ def __call__(self, function_name, inputs: list): else: output = self._run(function_name, inputs) if self.save_outputs: - self.save_output(function_name, output) + self.save_output(function_name, output) output = self._output_cast(output) return output