Skip to content

Commit

Permalink
fix dtype checking for float types for numpy >= 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 29, 2024
1 parent aac278b commit 3a8be9b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 28 deletions.
18 changes: 10 additions & 8 deletions gunpowder/nodes/array_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@ class ArraySource(BatchProvider):
array (``Array``):
A `funlib.persistence.Array` object.
interpolatable (``bool``, optional):
Whether the array is interpolatable. If not given it is
guessed based on dtype.
"""

def __init__(
self,
key: ArrayKey,
array: PersistenceArray,
interpolatable: bool | None = None,
nonspatial: bool = False,
):
self.key = key
self.array = array
self.array_spec = ArraySpec(
self.array.roi,
self.array.voxel_size,
interpolatable,
nonspatial,
False,
self.array.dtype,
)

Expand All @@ -46,10 +51,7 @@ def setup(self):

def provide(self, request):
outputs = Batch()
if self.array_spec.nonspatial:
outputs[self.key] = Array(self.array[:], self.array_spec.copy())
else:
out_spec = self.array_spec.copy()
out_spec.roi = request[self.key].roi
outputs[self.key] = Array(self.array[out_spec.roi], out_spec)
out_spec = self.array_spec.copy()
out_spec.roi = request[self.key].roi
outputs[self.key] = Array(self.array[out_spec.roi], out_spec)
return outputs
7 changes: 2 additions & 5 deletions gunpowder/nodes/dvid_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,8 @@ def __get_spec(self, array_key):
spec.dtype = data_dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or (
spec.dtype == np.uint8
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s. "
Expand Down
7 changes: 2 additions & 5 deletions gunpowder/nodes/hdf5like_source_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,8 @@ def __read_spec(self, array_key, data_file, ds_name):
spec.dtype = dataset.dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or (
spec.dtype == np.uint8
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s "
Expand Down
7 changes: 2 additions & 5 deletions gunpowder/nodes/klb_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,8 @@ def __read_spec(self, headers):
spec.dtype = dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or (
spec.dtype == np.uint8
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s. "
Expand Down
7 changes: 2 additions & 5 deletions gunpowder/nodes/zarr_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,8 @@ def __read_spec(self, array_key, data_file, ds_name):
spec.dtype = dataset.dtype

if spec.interpolatable is None:
spec.interpolatable = spec.dtype in (
np.sctypes["float"]
+ [
np.uint8, # assuming this is not used for labels
]
spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or (
spec.dtype == np.uint8
)
logger.warning(
"WARNING: You didn't set 'interpolatable' for %s "
Expand Down

0 comments on commit 3a8be9b

Please sign in to comment.