diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py index e6f05507..ad4573e1 100644 --- a/gunpowder/nodes/array_source.py +++ b/gunpowder/nodes/array_source.py @@ -22,6 +22,12 @@ class ArraySource(BatchProvider): array (``Array``): A `funlib.persistence.Array` object. + + interpolatable (``bool``, optional): + + Whether the array is interpolatable. If not given it is + guessed based on dtype. + """ def __init__( @@ -29,7 +35,6 @@ def __init__( key: ArrayKey, array: PersistenceArray, interpolatable: bool | None = None, - nonspatial: bool = False, ): self.key = key self.array = array @@ -37,7 +42,7 @@ def __init__( self.array.roi, self.array.voxel_size, interpolatable, - nonspatial, + False, self.array.dtype, ) @@ -46,10 +51,7 @@ def setup(self): def provide(self, request): outputs = Batch() - if self.array_spec.nonspatial: - outputs[self.key] = Array(self.array[:], self.array_spec.copy()) - else: - out_spec = self.array_spec.copy() - out_spec.roi = request[self.key].roi - outputs[self.key] = Array(self.array[out_spec.roi], out_spec) + out_spec = self.array_spec.copy() + out_spec.roi = request[self.key].roi + outputs[self.key] = Array(self.array[out_spec.roi], out_spec) return outputs diff --git a/gunpowder/nodes/dvid_source.py b/gunpowder/nodes/dvid_source.py index d285a502..312dd59e 100644 --- a/gunpowder/nodes/dvid_source.py +++ b/gunpowder/nodes/dvid_source.py @@ -182,11 +182,8 @@ def __get_spec(self, array_key): spec.dtype = data_dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/hdf5like_source_base.py b/gunpowder/nodes/hdf5like_source_base.py index d7c63149..f5a8e58b 100644 --- a/gunpowder/nodes/hdf5like_source_base.py +++ b/gunpowder/nodes/hdf5like_source_base.py @@ -174,11 +174,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s " diff --git a/gunpowder/nodes/klb_source.py b/gunpowder/nodes/klb_source.py index 53eca5c4..d4a55049 100644 --- a/gunpowder/nodes/klb_source.py +++ b/gunpowder/nodes/klb_source.py @@ -155,11 +155,8 @@ def __read_spec(self, headers): spec.dtype = dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 2f1c15fc..82831fa3 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -206,11 +206,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s "