Skip to content

Commit 675cf43

Browse files
committed
overwrite the provide function in RandomLocation
no need for supporting skip
1 parent d02b696 commit 675cf43

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

gunpowder/nodes/random_location.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from gunpowder.array import Array
1313
from gunpowder.array_spec import ArraySpec
1414
from .batch_filter import BatchFilter
15+
from gunpowder.profiling import Timing
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -210,6 +211,33 @@ def prepare(self, request):
210211

211212
return request
212213

214+
def provide(self, request):
215+
216+
timing_prepare = Timing(self, "prepare")
217+
timing_prepare.start()
218+
219+
downstream_request = request.copy()
220+
221+
self.prepare(request)
222+
223+
self.remove_provided(request)
224+
225+
timing_prepare.stop()
226+
227+
batch = self.get_upstream_provider().request_batch(request)
228+
229+
timing_process = Timing(self, "process")
230+
timing_process.start()
231+
232+
self.process(batch, downstream_request)
233+
234+
timing_process.stop()
235+
236+
batch.profiling_stats.add(timing_prepare)
237+
batch.profiling_stats.add(timing_process)
238+
239+
return batch
240+
213241
def process(self, batch, request):
214242
if self.random_shift_key is not None:
215243
batch[self.random_shift_key] = Array(
@@ -429,13 +457,14 @@ def __select_random_location_with_points(
429457

430458
# count all points inside the shifted ROI
431459
points = self.__get_points_in_roi(request_points_roi.shift(random_shift))
432-
assert (
433-
point in points
434-
), "Requested batch to contain point %s, but got points " "%s" % (
435-
point,
436-
points,
460+
assert point in points, (
461+
"Requested batch to contain point %s, but got points "
462+
"%s"
463+
% (
464+
point,
465+
points,
466+
)
437467
)
438-
num_points = len(points)
439468

440469
return random_shift
441470

tests/cases/random_location.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ def test_output():
6060
a = ArrayKey("A")
6161
b = ArrayKey("B")
6262
random_shift_key = ArrayKey("RANDOM_SHIFT")
63-
source_a = ExampleSourceRandomLocation(a)
64-
source_b = ExampleSourceRandomLocation(b)
6563

6664
pipeline = (
67-
(source_a, source_b)
65+
(ExampleSourceRandomLocation(a), ExampleSourceRandomLocation(b))
6866
+ MergeProvider()
69-
+ CustomRandomLocation(a, random_store_key=random_shift_key)
67+
+ CustomRandomLocation(a, random_shift_key=random_shift_key)
7068
)
71-
pipeline_no_random = (source_a, source_b) + MergeProvider()
69+
pipeline_no_random = (
70+
ExampleSourceRandomLocation(a),
71+
ExampleSourceRandomLocation(b),
72+
) + MergeProvider()
7273

7374
with build(pipeline), build(pipeline_no_random):
7475
sums = set()
@@ -95,8 +96,7 @@ def test_output():
9596
),
9697
b: ArraySpec(
9798
roi=Roi(batch[random_shift_key].data, (20, 20, 20))
98-
),
99-
random_shift_key: ArraySpec(nonspatial=True),
99+
)
100100
}
101101
)
102102
)
@@ -106,8 +106,8 @@ def test_output():
106106
sums.add(batch[a].data.sum())
107107

108108
# Request a ROI with the same shape as the entire ROI
109-
full_roi_a = Roi((0, 0, 0), source_a.roi.shape)
110-
full_roi_b = Roi((0, 0, 0), source_b.roi.shape)
109+
full_roi_a = Roi((0, 0, 0), ExampleSourceRandomLocation(a).roi.shape)
110+
full_roi_b = Roi((0, 0, 0), ExampleSourceRandomLocation(b).roi.shape)
111111
batch = pipeline.request_batch(
112112
BatchRequest(
113113
{a: ArraySpec(roi=full_roi_a), b: ArraySpec(roi=full_roi_b)}

0 commit comments

Comments
 (0)