Skip to content

Commit

Permalink
update the pad tests
Browse files Browse the repository at this point in the history
parametrized the use of constant or reflect padding.

Now avoids using the unittest framework
  • Loading branch information
pattonw committed Dec 19, 2023
1 parent 443c666 commit 531d81d
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions tests/cases/pad.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,56 @@
from .provider_test import ProviderTest
from .helper_sources import ArraySource, GraphSource
from gunpowder import (
BatchProvider,
BatchRequest,
Batch,
ArrayKeys,
ArraySpec,
Roi,
Coordinate,
Graph,
GraphKey,
GraphKeys,
GraphSpec,
Array,
ArrayKey,
Pad,
build,
MergeProvider,
)
import numpy as np


class ExampleSourcePad(BatchProvider):
def setup(self):
self.provides(
ArrayKeys.TEST_LABELS,
ArraySpec(roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)),
)

self.provides(
GraphKeys.TEST_GRAPH, GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180)))
)
import pytest
import numpy as np

def provide(self, request):
batch = Batch()

roi_array = request[ArrayKeys.TEST_LABELS].roi
roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size
@pytest.mark.parametrize("mode", ["constant", "reflect"])
def test_output(mode):
array_key = ArrayKey("TEST_ARRAY")
graph_key = GraphKey("TEST_GRAPH")

data = np.zeros(roi_voxel.shape, dtype=np.uint32)
data[:, ::2] = 100
array_spec = ArraySpec(
roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)
)
roi_voxel = array_spec.roi / array_spec.voxel_size
data = np.zeros(roi_voxel.shape, dtype=np.uint32)
data[:, ::2] = 100
array = Array(data, spec=array_spec)

spec = self.spec[ArrayKeys.TEST_LABELS].copy()
spec.roi = roi_array
batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec)
graph_spec = GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180)))
graph = Graph([], [], graph_spec)

return batch
source = (
ArraySource(array_key, array),
GraphSource(graph_key, graph),
) + MergeProvider()

pipeline = (
source
+ Pad(array_key, Coordinate((20, 20, 20)), value=1, mode=mode)
+ Pad(graph_key, Coordinate((10, 10, 10)), mode=mode)
)

class TestPad(ProviderTest):
def test_output(self):
graph = GraphKey("TEST_GRAPH")
labels = ArrayKey("TEST_LABELS")
with build(pipeline):
assert pipeline.spec[array_key].roi == Roi((180, 0, 0), (1840, 220, 220))
assert pipeline.spec[graph_key].roi == Roi((190, 10, 10), (1820, 200, 200))

pipeline = (
ExampleSourcePad()
+ Pad(labels, Coordinate((20, 20, 20)), value=1)
+ Pad(graph, Coordinate((10, 10, 10)))
batch = pipeline.request_batch(
BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))})
)

with build(pipeline):
self.assertTrue(
pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220, 220))
)
self.assertTrue(
pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200, 200))
)

batch = pipeline.request_batch(
BatchRequest({labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))})
)

self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10)
assert np.sum(batch.arrays[array_key].data) == 1 * 10 * 10

0 comments on commit 531d81d

Please sign in to comment.