Skip to content

Commit

Permalink
Add basic ArraySource node that accepts any funlib.persistence.Array
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 29, 2024
1 parent 470a238 commit 59fea58
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
1 change: 1 addition & 0 deletions gunpowder/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import

from .array_source import ArraySource
from .add_affinities import AddAffinities
from .astype import AsType
from .balance_labels import BalanceLabels
Expand Down
14 changes: 7 additions & 7 deletions gunpowder/nodes/array_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from funlib.persistence.arrays import Array as PersistenceArray
from gunpowder import Array, ArrayKey, Batch, BatchProvider, ArraySpec
from gunpowder.array import Array, ArrayKey
from gunpowder.array_spec import ArraySpec
from gunpowder.batch import Batch
from .batch_provider import BatchProvider


class ArraySource(BatchProvider):
Expand Down Expand Up @@ -33,20 +36,17 @@ def __init__(
self.array_spec = ArraySpec(
self.array.roi,
self.array.voxel_size,
self.interpolatable,
self.nonspatial,
interpolatable,
nonspatial,
self.array.dtype,
)

self.interpolatable = interpolatable
self.nonspatial = nonspatial

def setup(self):
self.provides(self.key, self.array_spec)

def provide(self, request):
outputs = Batch()
if self.nonspatial:
if self.array_spec.nonspatial:
outputs[self.key] = Array(self.array[:], self.array_spec.copy())
else:
out_spec = self.array_spec.copy()
Expand Down
29 changes: 29 additions & 0 deletions tests/cases/array_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from funlib.persistence import prepare_ds
from funlib.geometry import Roi
from gunpowder.nodes import ArraySource
from gunpowder import ArrayKey, build, BatchRequest, ArraySpec

import numpy as np


def test_array_source(tmpdir):
array = prepare_ds(
tmpdir / "data.zarr",
shape=(100, 102, 108),
offset=(100, 50, 0),
voxel_size=(1, 2, 3),
dtype="uint8",
)
array[:] = np.arange(100 * 102 * 108).reshape((100, 102, 108)) % 255

key = ArrayKey("TEST")

source = ArraySource(key=key, array=array)

with build(source):
request = BatchRequest()

roi = Roi((100, 100, 102), (30, 30, 30))
request[key] = ArraySpec(roi)

assert np.array_equal(source.request_batch(request)[key].data, array[roi])

0 comments on commit 59fea58

Please sign in to comment.