Skip to content

Commit

Permalink
Merge pull request #213 from funkelab/dev-v1.4
Browse files Browse the repository at this point in the history
changelog:
features

add probability for applying augmentation nodes via an overrideable can_skip method.
errors now print in reversed order
improve CSVReader to use built in python csv reader
torch predict supports arrays as position arguments
add funlib.persistence.Array source
add ScanCallback
bugfixes:

pytorch train: move hooks to being added in start method. This caused problems when trying to run the model in some multithreaded use cases.
Deform Augment subsampling fixed
avoid np.sctypes["float"] to work on numpy >= 2.0


* remove duplicated for loop

* increment patch number

* ArraySpec docs

fix documentation to be more accurate around nonspatial arrays

* ArraySpec bug fix:

allow None roi/voxel size for spatial arrays

* Add probability option to aug nodes

* Revert can_skip to private method

* fix the deform augment test

no longer assumes a deformed label will still exist in an array

* better bounds on required packages

* ignore missing imports from packages that don't provide type hints

* fix typehint mistakes

* format pyproject.toml

* black format

* move register hooks to the start method

This is to get around local functions (i.e. the hooks) not being
pickle-able which we need for the "spawn" start function
(spawn is the default on windows and recent macs)

* fix typo

* support non-spatial arrays in ArraySource

* overhaul torch tests

* remove multiprocess set start method monkey patch

We want to test with both fork and spawn start methods, but this
seems to interfere with the torch tests

* only deploy docs on tagged commits to main

* minor black formatting and configuration changes

* properly skip torch tests if torch not installed

* black formatting

* avoid testing on python 3.7, instead use 3.11

numpy is no longer releasing updates for python 3.7, they are on 1.24
but the last release for 3.7 was 1.21.
I don't think we need to support it either, but we should test on 3.11

* add typed libraries to dev dependencies

* test subsampling in deform augment

test fails

* fix bugs associated with subsampling

* deform augment

fix bug with checking dims of graph_raster_voxel_size

* Add progress callback to Scan node

* pass torch train test

if using start method = "spawn" and the "start_subprocess" flag
for the predict node, we now pass our test.

* pass torch train test

if using the start method "spawn", and the "spawn_subprocess" flag for
the train node, we now pass our test

* remove extra error printing

* switch error printing order

Now prints the errors in reverse order of execution so the initial pipeline error is printed first

* black format docs and examples

* Squashed commit of the following:

commit 1686b949766b76960534ede1105751591fd91c9f
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Dec 19 08:43:11 2023 -0700

    black reformatting

commit 26d2c7cfff3f2702f56a5bb4249a0811f54b45ef
Author: Mohinta2892 <samiamohinta2892@gmail.com>
Date:   Thu Nov 2 19:09:15 2023 +0000

    Revert "black reformatted"

    This reverts commit 66dd69b.

    Only format changed files, since black does not consider formatting history

commit a273fd3813fc16b516c2438ad5af0c4ee3f0686b
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 17:12:26 2023 +0000

    black reformatted

commit bb37769eec33af5921386f283e2579055bb34e6d
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 16:40:32 2023 +0000

    add device arg

    Allow passing cuda device to Predict. Issue #188

commit a3b3588a1406d609ae95370cf2c5339872616011
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 16:39:09 2023 +0000

    add device arg

    allow passing cuda device to Train

* parameterize tests for cuda devices

currently failing a few of them, some are expected failures.

* Added support for reflect padding

Squashed commit of the following:

commit 0fb29c8
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 08:54:17 2024 -0800

    replace custom padding code with np.pad

commit c6928bd
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 08:54:06 2024 -0800

    simplify/expand padding test

    test padding on both sides

commit 3782525
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Dec 19 11:30:31 2023 -0700

    pass the fixed tests

commit a7027c6
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Dec 19 10:37:48 2023 -0700

    fix the test case

commit 531d81d
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Dec 19 10:06:44 2023 -0700

    update the pad tests

    parametrized the use of constant or reflect padding.

    Now avoids using the unittest framework

commit 443c666
Author: Manan Lalit <34229641+lmanan@users.noreply.github.com>
Date:   Fri Nov 3 00:09:33 2023 -0400

    Replace .ndim by len()

commit a7503d7
Author: lmanan <manan.lalit@gmail.com>
Date:   Thu Nov 2 11:52:27 2023 -0400

    Update pad.py to include reflective padding

* Fix bug in rasterize graph

we were using `graph.data.items()` to iterate over nodes instead of `graph.nodes`

Squashed commit of the following:

commit d027f5a260a1e2a9cf851efca85b7318434675d6
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 09:44:21 2024 -0800

    refactor rasterize_points test to use pytest

commit eadb0476d8475b55120486df6cf30f95b6df86f4
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 09:25:11 2024 -0800

    remove extra roi handling

    The node only needs to request the data it needs for its
    own operations.
    If you request a mask for a set of points that extend outside
    the bounds of your mask you will get an error

commit 29507f1f21d69cf76e34e7b0f05cd780100fd68b
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 09:22:21 2024 -0800

    remove type cast

    we do a bitwise during the `__rasterize` call which
    results fails if you change the dtype

commit 96e93e53ce0bc8240357259dab92f1ca64a08199
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Jan 2 09:21:16 2024 -0800

    remove matplotlib

commit eb2977a187a1cad95da54a515c84ce44d73b8315
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Dec 14 15:27:41 2023 +0000

    fix mask intersection with request

    outputs must match request rois when a mask is provided

commit 682189dac2ef6b94876bd30df813717da6530060
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Dec 14 15:25:12 2023 +0000

    Update rasterize_graph.py

commit e36dcf179ccd1aec6a5cafd31e7a9a858352faa1
Author: Mohinta2892 <samiamohinta2892@gmail.com>
Date:   Thu Nov 2 19:18:00 2023 +0000

    reformat rasterize_graph and rasterize_points

commit 42da2702e746f702d3d07144ab9fc1d4352b0c0d
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 14:23:01 2023 +0000

    Test for issue #193

    Test added to pass mask to `RasterizeGraph()` via `RasterizationSettings`.

commit b17cfad413f5ad7f48045a2167ec20d89674d939
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 14:19:42 2023 +0000

    fix for issue #193

    lines 224-226: replace graph.data.items() with graph.nodes
    lines 255-257: explicitly cast the boolean mask data to the original dtype of mask_array

* ruff: remove unused imports and fix small typos.

* Custom BatchRequestError handling in pipeline.request_batch

We can filter out some more of the excess error traceback that isn't helpful to the readers.

* black formatting

* mypy workflow use dev dependencies

* avoid testing on python 3.8, it doesn't support typing very well

* fix type hint for logdir in torch train

* switch order of decorators to avoid trying to determine if cuda is available if torch isn't installed

* check if torch is installed before checking if cuda is available

* update funlib.geometry version for mypy typing

* remove batch.id

replaced in tensorflow predict node debug statments with the request. This better indicates
the roi being predicted on.
replaced in snapshot node with an internal counter

* Provide the separator to the csv points source

* Use csv reader in csv points source

* Test csv points source with new dev dependencies

* Update required python to 3.9

* Black test cases

* Fix typos in pytest unordered dependency

* Remove pytest unordered dependency

* Black and ruff CSVPointsSource tests

* Correctly read and document ids in CsvPointsSource

* Automatically detect header in CSVPointsSource

* Test all CSVPointsSource functionality

* add support for args as inputs to predict.py

Its often not so straightforward to know the key word argument name for the forward function of your model. Especially if you use something like `torch.nn.Sequential`

* black reformat pad.py test

* remove excessive seed setting. I don't think this is necessary since as soon as the seeds are set, the rest of the tests are determanistic

* Pytorch Train: let users specify model inputs as args instead of kwargs

* PyTorch Train: add tests for using arg indexes for model inputs

* depend on overhauled funlib.persistence

* add funlib.persistence array source

* black formatting fix

* Add basic `ArraySource` node that accepts any `funlib.persistence.Array`

* add ArraySource to docs

* fix dtype checking for float types for numpy >= 2.0

* add documentation for gradients argument of torch `Train` node

* add typehint for dict

* black reformatting

* add support for python 3.12

* remove distutils

* black formatting

---------

Co-authored-by: sheridana <arlo@e11.bio>
Co-authored-by: Jan Funke <funkej@janelia.hhmi.org>
Co-authored-by: Caroline Malin-Mayor <malinmayorc@janelia.hhmi.org>
  • Loading branch information
4 people authored Aug 30, 2024
2 parents 185b1f0 + 9693615 commit 5b595e9
Show file tree
Hide file tree
Showing 33 changed files with 591 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
platform: [ubuntu-latest]

steps:
Expand Down
6 changes: 6 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ BatchFilter
Source Nodes
------------

ArraySource
^^^^^^^^^^^

.. autoclass:: ArraySource

ZarrSource
^^^^^^^^^^
.. autoclass:: ZarrSource
Expand Down Expand Up @@ -334,6 +339,7 @@ Iterative Processing Nodes
Scan
^^^^
.. autoclass:: Scan
.. autoclass:: ScanCallback

DaisyRequestBlocks
^^^^^^^^^^^^^^^^^^
Expand Down
10 changes: 0 additions & 10 deletions gunpowder/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@ class Batch(Freezable):
Contains all graphs that have been requested for this batch.
"""

__next_id = multiprocessing.Value("L")

@staticmethod
def get_next_id():
with Batch.__next_id.get_lock():
next_id = Batch.__next_id.value
Batch.__next_id.value += 1
return next_id

def __init__(self):
self.id = Batch.get_next_id()
self.profiling_stats = ProfilingStats()
self.arrays = {}
self.graphs = {}
Expand Down
15 changes: 12 additions & 3 deletions gunpowder/contrib/nodes/dvid_partner_annotation_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import distutils.util
import numpy as np
import logging
import requests
Expand All @@ -15,6 +14,16 @@
logger = logging.getLogger(__name__)


def strtobool(val):
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return 1
elif val in ("n", "no", "f", "false", "off", "0"):
return 0
else:
raise ValueError(f"Invalid truth value: {val}")


class DvidPartnerAnnoationSourceReadException(Exception):
pass

Expand Down Expand Up @@ -198,10 +207,10 @@ def __read_syn_points(self, roi):
props["agent"] = str(node["Prop"]["agent"])
if "flagged" in node["Prop"]:
str_value_flagged = str(node["Prop"]["flagged"])
props["flagged"] = bool(distutils.util.strtobool(str_value_flagged))
props["flagged"] = bool(strtobool(str_value_flagged))
if "multi" in node["Prop"]:
str_value_multi = str(node["Prop"]["multi"])
props["multi"] = bool(distutils.util.strtobool(str_value_multi))
props["multi"] = bool(strtobool(str_value_multi))

# create synPoint with information collected so far (partner_ids not completed yet)
if kind == "PreSyn":
Expand Down
3 changes: 2 additions & 1 deletion gunpowder/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import

from .array_source import ArraySource
from .add_affinities import AddAffinities
from .astype import AsType
from .balance_labels import BalanceLabels
Expand Down Expand Up @@ -34,7 +35,7 @@
from .reject import Reject
from .renumber_connected_components import RenumberConnectedComponents
from .resample import Resample
from .scan import Scan
from .scan import Scan, ScanCallback
from .shift_augment import ShiftAugment
from .simple_augment import SimpleAugment
from .snapshot import Snapshot
Expand Down
57 changes: 57 additions & 0 deletions gunpowder/nodes/array_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from funlib.persistence.arrays import Array as PersistenceArray
from gunpowder.array import Array, ArrayKey
from gunpowder.array_spec import ArraySpec
from gunpowder.batch import Batch
from .batch_provider import BatchProvider


class ArraySource(BatchProvider):
"""A `array <https://github.com/funkelab/funlib.persistence>`_ source.
Provides a source for any array that can fit into the funkelab
funlib.persistence.Array format. This class comes with assumptions about
the available metadata and convenient methods for indexing the data
with a :class:`Roi` in world units.
Args:
key (:class:`ArrayKey`):
The ArrayKey for accessing this array.
array (``Array``):
A `funlib.persistence.Array` object.
interpolatable (``bool``, optional):
Whether the array is interpolatable. If not given it is
guessed based on dtype.
"""

def __init__(
self,
key: ArrayKey,
array: PersistenceArray,
interpolatable: bool | None = None,
):
self.key = key
self.array = array
self.array_spec = ArraySpec(
self.array.roi,
self.array.voxel_size,
interpolatable,
False,
self.array.dtype,
)

def setup(self):
self.provides(self.key, self.array_spec)

def provide(self, request):
outputs = Batch()
out_spec = self.array_spec.copy()
out_spec.roi = request[self.key].roi
outputs[self.key] = Array(self.array[out_spec.roi], out_spec)
return outputs
10 changes: 9 additions & 1 deletion gunpowder/nodes/batch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def autoskip_enabled(self):
return self._autoskip_enabled

def provide(self, request):
skip = self.__can_skip(request)
skip = self.__can_skip(request) or self.skip_node(request)

timing_prepare = Timing(self, "prepare")
timing_prepare.start()
Expand Down Expand Up @@ -206,6 +206,14 @@ def __can_skip(self, request):

return True

def skip_node(self, request):
"""To be implemented in subclasses.
Skip a node if a condition is met. Can be useful if using a probability
to determine whether to use an augmentation, for example.
"""
pass

def setup(self):
"""To be implemented in subclasses.
Expand Down
96 changes: 59 additions & 37 deletions gunpowder/nodes/csv_points_source.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from typing import Union, Optional
import numpy as np
import logging
from gunpowder.batch import Batch
from gunpowder.coordinate import Coordinate
from gunpowder.nodes.batch_provider import BatchProvider
from gunpowder.graph import Node, Graph
from gunpowder.graph import Node, Graph, GraphKey
from gunpowder.graph_spec import GraphSpec
from gunpowder.profiling import Timing
from gunpowder.roi import Roi
import csv

logger = logging.getLogger(__name__)


class CsvPointsSource(BatchProvider):
"""Read a set of points from a comma-separated-values text file. Each line
in the file represents one point, e.g. z y x (id)
in the file represents one point, e.g. z y x (id). Note: this reads all
points into memory and finds the ones in the given roi by iterating
over all the points. For large datasets, this may be too slow.
Args:
Expand All @@ -25,6 +29,11 @@ class CsvPointsSource(BatchProvider):
The key of the points set to create.
spatial_cols (list[``int``]):
The columns of the csv that hold the coordinates of the points
(in the order that you want them to be used in training)
points_spec (:class:`GraphSpec`, optional):
An optional :class:`GraphSpec` to overwrite the points specs
Expand All @@ -37,28 +46,36 @@ class CsvPointsSource(BatchProvider):
from the CSV file. This is useful if the points refer to voxel
positions to convert them to world units.
ndims (``int``):
id_col (``int``, optional):
If ``ndims`` is None, all values in one line are considered as the
location of the point. If positive, only the first ``ndims`` are used.
If negative, all but the last ``-ndims`` are used.
The column of the csv that holds an id for each point. If not
provided, the index of the rows are used as the ids. When read
from file, ids are left as strings and not cast to anything.
id_dim (``int``):
delimiter (``str``, optional):
Each line may optionally contain an id for each point. This parameter
specifies its location, has to come after the position values.
Delimiter to pass to the csv reader. Defaults to ",".
"""

def __init__(
self, filename, points, points_spec=None, scale=None, ndims=None, id_dim=None
self,
filename: str,
points: GraphKey,
spatial_cols: list[int],
points_spec: Optional[GraphSpec] = None,
scale: Optional[Union[int, float, tuple, list, np.ndarray]] = None,
id_col: Optional[int] = None,
delimiter: str = ",",
):
self.filename = filename
self.points = points
self.points_spec = points_spec
self.scale = scale
self.ndims = ndims
self.id_dim = id_dim
self.data = None
self.spatial_cols = spatial_cols
self.id_dim = id_col
self.delimiter = delimiter
self.data: Optional[np.ndarray] = None
self.ids: Optional[list] = None

def setup(self):
self._parse_csv()
Expand All @@ -67,8 +84,8 @@ def setup(self):
self.provides(self.points, self.points_spec)
return

min_bb = Coordinate(np.floor(np.amin(self.data[:, : self.ndims], 0)))
max_bb = Coordinate(np.ceil(np.amax(self.data[:, : self.ndims], 0)) + 1)
min_bb = Coordinate(np.floor(np.amin(self.data, 0)))
max_bb = Coordinate(np.ceil(np.amax(self.data, 0)) + 1)

roi = Roi(min_bb, max_bb - min_bb)

Expand All @@ -84,7 +101,7 @@ def provide(self, request):
logger.debug("CSV points source got request for %s", request[self.points].roi)

point_filter = np.ones((self.data.shape[0],), dtype=bool)
for d in range(self.ndims):
for d in range(len(self.spatial_cols)):
point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d])
point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d])

Expand All @@ -100,30 +117,35 @@ def provide(self, request):
return batch

def _get_points(self, point_filter):
filtered = self.data[point_filter][:, : self.ndims]

if self.id_dim is not None:
ids = self.data[point_filter][:, self.id_dim]
else:
ids = np.arange(len(self.data))[point_filter]

filtered = self.data[point_filter]
ids = self.ids[point_filter]
return [Node(id=i, location=p) for i, p in zip(ids, filtered)]

def _parse_csv(self):
"""Read one point per line. If ``ndims`` is None, all values in one line
are considered as the location of the point. If positive, only the
first ``ndims`` are used. If negative, all but the last ``-ndims`` are
used.
"""Read one point per line, with spatial and id columns determined by
self.spatial_cols and self.id_col.
"""

with open(self.filename, "r") as f:
self.data = np.array(
[[float(t.strip(",")) for t in line.split()] for line in f],
dtype=np.float32,
)

if self.ndims is None:
self.ndims = self.data.shape[1]
data = []
ids = []
with open(self.filename, "r", newline="") as f:
has_header = csv.Sniffer().has_header(f.read(1024))
f.seek(0)
first_line = True
reader = csv.reader(f, delimiter=self.delimiter)
for line in reader:
if first_line and has_header:
first_line = False
continue
space = [float(line[c]) for c in self.spatial_cols]
data.append(space)
if self.id_dim is not None:
ids.append(line[self.id_dim])

self.data = np.array(data, dtype=np.float32)
if self.id_dim:
self.ids = np.array(ids)
else:
self.ids = np.arange(len(self.data))

if self.scale is not None:
self.data[:, : self.ndims] *= self.scale
self.data *= self.scale
12 changes: 12 additions & 0 deletions gunpowder/nodes/defect_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class DefectAugment(BatchFilter):
axis (``int``, optional):
Along which axis sections are cut.
p (``float``, optional):
Probability applying the augmentation. Default is 1.0 (always
apply). Should be a float value between 0 and 1. Lowering this value
could be useful for computational efficiency and increasing
augmentation space.
"""

def __init__(
Expand All @@ -80,6 +87,7 @@ def __init__(
artifacts_mask=None,
deformation_strength=20,
axis=0,
p=1.0,
):
self.intensities = intensities
self.prob_missing = prob_missing
Expand All @@ -92,6 +100,7 @@ def __init__(
self.artifacts_mask = artifacts_mask
self.deformation_strength = deformation_strength
self.axis = axis
self.p = p

def setup(self):
if self.artifact_source is not None:
Expand All @@ -101,6 +110,9 @@ def teardown(self):
if self.artifact_source is not None:
self.artifact_source.teardown()

def skip_node(self, request):
return random.random() > self.p

# send roi request to data-source upstream
def prepare(self, request):
deps = BatchRequest()
Expand Down
Loading

0 comments on commit 5b595e9

Please sign in to comment.