Skip to content

Add probability option to aug nodes #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion gunpowder/nodes/batch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def autoskip_enabled(self):
return self._autoskip_enabled

def provide(self, request):
skip = self.__can_skip(request)
skip = self.__can_skip(request) or self.skip_node(request)

timing_prepare = Timing(self, "prepare")
timing_prepare.start()
Expand Down Expand Up @@ -206,6 +206,14 @@ def __can_skip(self, request):

return True

def skip_node(self, request):
"""To be implemented in subclasses.

Skip a node if a condition is met. Can be useful if using a probability
to determine whether to use an augmentation, for example.
"""
pass

def setup(self):
"""To be implemented in subclasses.

Expand Down
12 changes: 12 additions & 0 deletions gunpowder/nodes/defect_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ class DefectAugment(BatchFilter):
axis (``int``, optional):

Along which axis sections are cut.

p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -82,6 +89,7 @@ def __init__(
artifacts_mask=None,
deformation_strength=20,
axis=0,
p=1.0,
):
self.intensities = intensities
self.prob_missing = prob_missing
Expand All @@ -94,6 +102,7 @@ def __init__(
self.artifacts_mask = artifacts_mask
self.deformation_strength = deformation_strength
self.axis = axis
self.p = p

def setup(self):
if self.artifact_source is not None:
Expand All @@ -103,6 +112,9 @@ def teardown(self):
if self.artifact_source is not None:
self.artifact_source.teardown()

def skip_node(self, request):
return random.random() > self.p

# send roi request to data-source upstream
def prepare(self, request):
deps = BatchRequest()
Expand Down
14 changes: 13 additions & 1 deletion gunpowder/nodes/deform_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ class DeformAugment(BatchFilter):

Whether or not to compute the elastic transform node wise for nodes
that were lossed during the fast elastic transform process.


p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -95,6 +103,7 @@ def __init__(
recompute_missing_points=True,
transform_key: ArrayKey = None,
graph_raster_voxel_size: Coordinate = None,
p: float = 1.0,
):
self.control_point_spacing = Coordinate(control_point_spacing)
self.jitter_sigma = Coordinate(jitter_sigma)
Expand All @@ -107,6 +116,7 @@ def __init__(
self.recompute_missing_points = recompute_missing_points
self.transform_key = transform_key
self.graph_raster_voxel_size = Coordinate(graph_raster_voxel_size)
self.p = p
assert (
self.control_point_spacing.dims
== self.jitter_sigma.dims
Expand All @@ -128,8 +138,10 @@ def setup(self):

self.provides(self.transform_key, spec)

def prepare(self, request):
def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
# get the total ROI of all requests
total_roi = request.get_total_roi()
logger.debug("total ROI is %s" % total_roi)
Expand Down
13 changes: 12 additions & 1 deletion gunpowder/nodes/elastic_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class ElasticAugment(BatchFilter):

Whether or not to compute the elastic transform node wise for nodes
that were lossed during the fast elastic transform process.

p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -103,6 +110,7 @@ def __init__(
spatial_dims=3,
use_fast_points_transform=False,
recompute_missing_points=True,
p=1.0,
):
warnings.warn(
"ElasticAugment is deprecated, please use the DeformAugment",
Expand All @@ -122,9 +130,12 @@ def __init__(
self.spatial_dims = spatial_dims
self.use_fast_points_transform = use_fast_points_transform
self.recompute_missing_points = recompute_missing_points
self.p = p

def prepare(self, request):
def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
# get the voxel size
self.voxel_size = self.__get_common_voxel_size(request)

Expand Down
13 changes: 13 additions & 0 deletions gunpowder/nodes/intensity_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import random

from gunpowder.batch_request import BatchRequest

Expand Down Expand Up @@ -34,6 +35,13 @@ class IntensityAugment(BatchFilter):

Set to False if modified values should not be clipped to [0, 1]
Disables range check!

p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -45,6 +53,7 @@ def __init__(
shift_max,
z_section_wise=False,
clip=True,
p=1.0,
):
self.array = array
self.scale_min = scale_min
Expand All @@ -53,11 +62,15 @@ def __init__(
self.shift_max = shift_max
self.z_section_wise = z_section_wise
self.clip = clip
self.p = p

def setup(self):
self.enable_autoskip()
self.updates(self.array, self.spec[self.array])

def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
deps = BatchRequest()
deps[self.array] = request[self.array].copy()
Expand Down
16 changes: 13 additions & 3 deletions gunpowder/nodes/noise_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import random
import skimage

from gunpowder.batch_request import BatchRequest
Expand All @@ -24,18 +25,29 @@ class NoiseAugment(BatchFilter):

Whether to preserve the image range (either [-1, 1] or [0, 1]) by clipping values in the end, see
scikit-image documentation

p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(self, array, mode="gaussian", clip=True, **kwargs):
def __init__(self, array, mode="gaussian", clip=True, p=1.0, **kwargs):
self.array = array
self.mode = mode
self.clip = clip
self.p = p
self.kwargs = kwargs

def setup(self):
self.enable_autoskip()
self.updates(self.array, self.spec[self.array])

def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
deps = BatchRequest()
deps[self.array] = request[self.array].copy()
Expand All @@ -57,13 +69,11 @@ def process(self, batch, request):
seed = request.random_seed

try:

raw.data = skimage.util.random_noise(
raw.data, mode=self.mode, rng=seed, clip=self.clip, **self.kwargs
).astype(raw.data.dtype)

except ValueError:

# legacy version of skimage random_noise
raw.data = skimage.util.random_noise(
raw.data, mode=self.mode, seed=seed, clip=self.clip, **self.kwargs
Expand Down
7 changes: 5 additions & 2 deletions gunpowder/nodes/shift_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@


class ShiftAugment(BatchFilter):
def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0):
def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0, p=1.0):
self.prob_slip = prob_slip
self.prob_shift = prob_shift
self.sigma = sigma
self.shift_axis = shift_axis
self.p = p

self.ndim = None
self.shift_sigmas = None
self.shift_array = None
self.lcm_voxel_size = None

def prepare(self, request):
def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
self.ndim = request.get_total_roi().dims
assert self.shift_axis in range(self.ndim)

Expand Down
13 changes: 12 additions & 1 deletion gunpowder/nodes/simple_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class SimpleAugment(BatchFilter):
and attempt to weight them appropriately. A weight of 0 means
this axis will never be transposed, a weight of 1 means this axis
will always be transposed.

p (``float``, optional):

Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -55,6 +62,7 @@ def __init__(
transpose_only=None,
mirror_probs=None,
transpose_probs=None,
p=1.0,
):
self.mirror_only = mirror_only
self.mirror_probs = mirror_probs
Expand All @@ -63,6 +71,7 @@ def __init__(
self.mirror_mask = None
self.dims = None
self.transpose_dims = None
self.p = p

def setup(self):
self.dims = self.spec.get_total_roi().dims
Expand Down Expand Up @@ -105,8 +114,10 @@ def setup(self):
if valid:
self.permutation_dict[k] = v

def prepare(self, request):
def skip_node(self, request):
return random.random() > self.p

def prepare(self, request):
self.mirror = [
random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0
for d in range(self.dims)
Expand Down
53 changes: 53 additions & 0 deletions tests/cases/batch_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from .helper_sources import ArraySource
from gunpowder import (
ArrayKey,
build,
Array,
ArraySpec,
Roi,
Coordinate,
BatchRequest,
BatchFilter,
)

import numpy as np
import random


class DummyNode(BatchFilter):
def __init__(self, array, p=1.0):
self.array = array
self.p = p

def skip_node(self, request):
return random.random() > self.p

def process(self, batch, request):
batch[self.array].data = batch[self.array].data + 1


def test_skip():
raw_key = ArrayKey("RAW")
array = Array(
np.ones((10, 10)),
ArraySpec(Roi((0, 0), (10, 10)), Coordinate(1, 1)),
)
source = ArraySource(raw_key, array)

request_1 = BatchRequest(random_seed=1)
request_2 = BatchRequest(random_seed=2)

request_1.add(raw_key, Coordinate(10, 10))
request_2.add(raw_key, Coordinate(10, 10))

pipeline = source + DummyNode(raw_key, p=0.5)

with build(pipeline):
batch_1 = pipeline.request_batch(request_1)
batch_2 = pipeline.request_batch(request_2)

x_1 = batch_1.arrays[raw_key].data
x_2 = batch_2.arrays[raw_key].data

assert x_1.max() == 2
assert x_2.max() == 1