Skip to content

Commit 88a2dec

Browse files
committed
first pass at masking interface:
- add mask/unmask and apply_mask methods to TreeNeuron, MeshNeuron and Dotprops - add is_masked property for all neurons - add `navis.NeuronMask` class - add __length__ to all neurons - dotprops: clear `_tree` with temporary attributes
1 parent 6a501ae commit 88a2dec

File tree

8 files changed

+1261
-316
lines changed

8 files changed

+1261
-316
lines changed

docs/api.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ learn more!
4747
``TreeNeurons``, ``MeshNeurons``, ``VoxelNeurons`` and ``Dotprops`` are neuron
4848
classes. ``NeuronLists`` are containers thereof.
4949

50-
| Class | Description |
51-
|------|------|
52-
| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. |
53-
| [`navis.MeshNeuron`][] | Meshes with vertices and faces. |
54-
| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). |
55-
| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. |
56-
| [`navis.NeuronList`][] | Containers for neurons. |
50+
| Class | Description |
51+
|-------------------------|---------------------------------------------------------|
52+
| [`navis.TreeNeuron`][] | Skeleton representation of a neuron. |
53+
| [`navis.MeshNeuron`][] | Meshes with vertices and faces. |
54+
| [`navis.VoxelNeuron`][] | 3D images (e.g. from confocal stacks). |
55+
| [`navis.Dotprops`][] | Point cloud + vector representations, used for NBLAST. |
56+
| [`navis.NeuronList`][] | Containers for neurons. |
5757

5858
### General Neuron methods
5959

@@ -89,6 +89,7 @@ to all neurons:
8989
| `Neuron.type` | {{ autosummary("navis.BaseNeuron.type") }} |
9090
| `Neuron.soma` | {{ autosummary("navis.BaseNeuron.soma") }} |
9191
| `Neuron.bbox` | {{ autosummary("navis.BaseNeuron.bbox") }} |
92+
| `Neuron.is_masked` | {{ autosummary("navis.BaseNeuron.is_masked") }} |
9293

9394
!!! note
9495

@@ -119,6 +120,8 @@ this neuron type. Note that most of them are simply short-hands for the other
119120
| [`TreeNeuron.reroot()`][navis.TreeNeuron.reroot] | {{ autosummary("navis.TreeNeuron.reroot") }} |
120121
| [`TreeNeuron.resample()`][navis.TreeNeuron.resample] | {{ autosummary("navis.TreeNeuron.resample") }} |
121122
| [`TreeNeuron.snap()`][navis.TreeNeuron.snap] | {{ autosummary("navis.TreeNeuron.snap") }} |
123+
| [`TreeNeuron.mask()`][navis.TreeNeuron.mask] | {{ autosummary("navis.TreeNeuron.mask") }} |
124+
| [`TreeNeuron.unmask()`][navis.TreeNeuron.unmask] | {{ autosummary("navis.TreeNeuron.unmask") }} |
122125

123126
In addition, a [`navis.TreeNeuron`][] has a range of different properties:
124127

@@ -146,7 +149,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties:
146149
| [`TreeNeuron.vertices`][navis.TreeNeuron.vertices] | {{ autosummary("navis.TreeNeuron.vertices") }} |
147150
| [`TreeNeuron.volume`][navis.TreeNeuron.volume] | {{ autosummary("navis.TreeNeuron.volume") }} |
148151

149-
150152
#### Skeleton utility functions
151153

152154
| Function | Description |
@@ -158,7 +160,6 @@ In addition, a [`navis.TreeNeuron`][] has a range of different properties:
158160
| [`navis.graph.skeleton_adjacency_matrix()`][navis.graph.skeleton_adjacency_matrix] | {{ autosummary("navis.graph.skeleton_adjacency_matrix") }} |
159161

160162

161-
162163
### Mesh neurons
163164

164165
Properties specific to [`navis.MeshNeuron`][]:
@@ -178,6 +179,8 @@ Methods specific to [`navis.MeshNeuron`][]:
178179
| [`MeshNeuron.skeletonize()`][navis.MeshNeuron.skeletonize] | {{ autosummary("navis.MeshNeuron.skeletonize") }} |
179180
| [`MeshNeuron.snap()`][navis.MeshNeuron.snap] | {{ autosummary("navis.MeshNeuron.snap") }} |
180181
| [`MeshNeuron.validate()`][navis.MeshNeuron.validate] | {{ autosummary("navis.MeshNeuron.validate") }} |
182+
| [`MeshNeuron.mask()`][navis.MeshNeuron.mask] | {{ autosummary("navis.MeshNeuron.mask") }} |
183+
| [`MeshNeuron.unmask()`][navis.MeshNeuron.unmask] | {{ autosummary("navis.MeshNeuron.unmask") }} |
181184

182185

183186
### Voxel neurons
@@ -215,6 +218,8 @@ These are methods and properties specific to [Dotprops][navis.Dotprops]:
215218
| [`Dotprops.alpha`][navis.Dotprops.alpha] | {{ autosummary("navis.Dotprops.alpha") }} |
216219
| [`Dotprops.to_skeleton()`][navis.Dotprops.to_skeleton] | {{ autosummary("navis.Dotprops.to_skeleton") }} |
217220
| [`Dotprops.snap()`][navis.Dotprops.snap] | {{ autosummary("navis.Dotprops.snap") }} |
221+
| [`Dotprops.mask()`][navis.Dotprops.mask] | {{ autosummary("navis.Dotprops.mask") }} |
222+
| [`Dotprops.unmask()`][navis.Dotprops.unmask] | {{ autosummary("navis.Dotprops.unmask") }} |
218223

219224
### Converting between types
220225

navis/core/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,22 @@
1818
from .dotprop import Dotprops
1919
from .voxel import VoxelNeuron
2020
from .neuronlist import NeuronList
21+
from .masking import NeuronMask
2122
from .core_utils import make_dotprops, to_neuron_space, NeuronProcessor
2223

2324
from typing import Union
2425

2526
NeuronObject = Union[NeuronList, TreeNeuron, BaseNeuron, MeshNeuron]
2627

27-
__all__ = ['Volume', 'Neuron', 'BaseNeuron', 'TreeNeuron', 'MeshNeuron',
28-
'Dotprops', 'VoxelNeuron', 'NeuronList', 'make_dotprops']
28+
__all__ = [
29+
"Volume",
30+
"Neuron",
31+
"BaseNeuron",
32+
"TreeNeuron",
33+
"MeshNeuron",
34+
"NeuronMask",
35+
"Dotprops",
36+
"VoxelNeuron",
37+
"NeuronList",
38+
"make_dotprops",
39+
]

navis/core/base.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646

4747

4848
def Neuron(
49-
x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"], **metadata
49+
x: Union[nx.DiGraph, str, pd.DataFrame, "TreeNeuron", "MeshNeuron"],
50+
**metadata, # noqa: F821
5051
):
5152
"""Constructor for Neuron objects. Depending on the input, either a
5253
`TreeNeuron` or a `MeshNeuron` is returned.
@@ -195,6 +196,9 @@ class BaseNeuron(UnitObject):
195196
#: Core data table(s) used to calculate hash
196197
CORE_DATA = []
197198

199+
#: Property used to calculate length of neuron
200+
_LENGTH_DATA = None
201+
198202
def __init__(self, **kwargs):
199203
# Set a random ID -> may be replaced later
200204
self.id = uuid.uuid4()
@@ -303,6 +307,14 @@ def __isub__(self, other):
303307
"""Subtraction with assignment (-=)."""
304308
return self.__sub__(other, copy=False)
305309

310+
def __len__(self):
311+
if self._LENGTH_DATA is None:
312+
return None
313+
# Deal with potential empty neurons
314+
if not hasattr(self, self._LENGTH_DATA):
315+
return 0
316+
return len(getattr(self, self._LENGTH_DATA))
317+
306318
def _repr_html_(self):
307319
frame = self.summary().to_frame()
308320
frame.columns = [""]
@@ -654,6 +666,7 @@ def copy(self, deepcopy=False) -> "BaseNeuron":
654666

655667
def summary(self, add_props=None) -> pd.Series:
656668
"""Get a summary of this neuron."""
669+
657670
# Do not remove the list -> otherwise we might change the original!
658671
props = list(self.SUMMARY_PROPS)
659672

@@ -721,6 +734,87 @@ def plot3d(self, **kwargs):
721734

722735
return plot3d(core.NeuronList(self, make_copy=False), **kwargs)
723736

737+
@property
738+
def is_masked(self):
739+
"""Test if neuron is masked.
740+
741+
See Also
742+
--------
743+
[`navis.BaseNeuron.mask`][]
744+
Mask neuron.
745+
[`navis.BaseNeuron.unmask`][]
746+
Remove mask from neuron.
747+
[`navis.NeuronMask`][]
748+
Context manager for masking neurons.
749+
"""
750+
return hasattr(self, "_masked_data")
751+
752+
def mask(self, mask):
753+
"""Mask neuron."""
754+
raise NotImplementedError(
755+
f"Masking not implemented for neuron of type {type(self)}."
756+
)
757+
758+
def unmask(self):
759+
"""Unmask neuron.
760+
761+
Returns the neuron to its original state before masking.
762+
763+
Returns
764+
-------
765+
self
766+
767+
See Also
768+
--------
769+
[`Neuron.is_masked`][navis.BaseNeuron.is_masked]
770+
Check if neuron. is masked.
771+
[`Neuron.mask`][navis.BaseNeuron.unmask]
772+
Mask neuron.
773+
[`navis.NeuronMask`][]
774+
Context manager for masking neurons.
775+
776+
"""
777+
if not self.is_masked:
778+
raise ValueError("Neuron is not masked.")
779+
780+
for k, v in self._masked_data.items():
781+
if hasattr(self, k):
782+
setattr(self, k, v)
783+
784+
delattr(self, "_mask")
785+
delattr(self, "_masked_data")
786+
self._clear_temp_attr()
787+
788+
return self
789+
790+
def apply_mask(self, inplace=False):
791+
"""Apply mask to neuron.
792+
793+
This will effectively make the mask permanent.
794+
795+
Parameters
796+
----------
797+
inplace : bool
798+
If True will apply mask in-place. If False
799+
will return a copy and the original neuron
800+
will remain masked.
801+
802+
Returns
803+
-------
804+
Neuron
805+
Neuron with mask applied.
806+
807+
"""
808+
if not self.is_masked:
809+
raise ValueError("Neuron is not masked.")
810+
811+
n = self if inplace else self.copy()
812+
813+
delattr(n, "_mask")
814+
delattr(n, "_masked_data")
815+
816+
return n
817+
724818
def map_units(
725819
self,
726820
units: Union[pint.Unit, str],

navis/core/dotprop.py

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,14 @@ class Dotprops(BaseNeuron):
105105
EQ_ATTRIBUTES = ['name', 'n_points', 'k']
106106

107107
#: Temporary attributes that need clearing when neuron data changes
108-
TEMP_ATTR = ['_memory_usage']
108+
TEMP_ATTR = ['_memory_usage', "_tree"]
109109

110110
#: Core data table(s) used to calculate hash
111111
_CORE_DATA = ['points', 'vect']
112112

113+
#: Property used to calculate length of neuron
114+
_LENGTH_DATA = 'points'
115+
113116
def __init__(self,
114117
points: np.ndarray,
115118
k: int,
@@ -230,9 +233,6 @@ def __getstate__(self):
230233

231234
return state
232235

233-
def __len__(self):
234-
return len(self.points)
235-
236236
@property
237237
def alpha(self):
238238
"""Alpha value for tangent vectors (optional)."""
@@ -539,6 +539,137 @@ def drop_fluff(self, epsilon, keep_size: int = None, n_largest: int = None, inpl
539539
if not inplace:
540540
return x
541541

542+
def mask(self, mask, copy=True):
543+
"""Mask neuron with given mask.
544+
545+
This is always done in-place!
546+
547+
Parameters
548+
----------
549+
mask : np.ndarray
550+
Mask to apply. Can be:
551+
- 1D array with boolean values
552+
- callable that accepts a neuron and returns a mask
553+
- string with property name
554+
555+
Returns
556+
-------
557+
self
558+
The masked neuron.
559+
560+
See Also
561+
--------
562+
[`Dotprops.unmask`][navis.Dotprops.unmask]
563+
Remove mask from neuron.
564+
[`Dotprops.is_masked`][navis.Dotprops.is_masked]
565+
Check if neuron is masked.
566+
[`navis.NeuronMask`][]
567+
Context manager for masking neurons.
568+
569+
"""
570+
if self.is_masked:
571+
raise ValueError(
572+
"Neuron already masked. Layering multiple masks is currently not supported, please unmask first."
573+
)
574+
575+
if callable(mask):
576+
mask = mask(self)
577+
elif isinstance(mask, str):
578+
mask = getattr(self, mask)
579+
580+
mask = np.asarray(mask)
581+
582+
if mask.dtype != bool:
583+
raise ValueError("Mask must be boolean array.")
584+
elif mask.shape[0] != len(self):
585+
raise ValueError("Mask must have same length as points.")
586+
587+
self._mask = mask
588+
self._masked_data = {}
589+
self._masked_data['_points'] = self.points
590+
591+
# Drop soma if masked out
592+
if self.soma is not None:
593+
if isinstance(self.soma, (list, np.ndarray)):
594+
soma_left = self.soma[mask[self.soma]]
595+
self._masked_data['_soma'] = self.soma
596+
597+
if any(soma_left):
598+
self.soma = soma_left
599+
else:
600+
self.soma = None
601+
elif not mask[self.soma]:
602+
self._masked_data['_soma'] = self.soma
603+
self.soma = None
604+
605+
# N.B. we're directly setting `._nodes`` to avoid overhead from checks
606+
for att in ("_points", "_vect", "_alpha"):
607+
if hasattr(self, att):
608+
self._masked_data[att] = getattr(self, att)
609+
setattr(self, att, getattr(self, att)[mask])
610+
611+
if copy:
612+
setattr(self, att, getattr(self, att).copy())
613+
614+
if hasattr(self, "_connectors") and "point_ix" in self._connectors.columns:
615+
self._masked_data['connectors'] = self.connectors
616+
self._connectors = self._connectors.loc[
617+
self.connectors.point_ix.isin(np.arange(len(mask))[mask])
618+
]
619+
if copy:
620+
self._connectors = self._connectors.copy()
621+
622+
self._clear_temp_attr()
623+
624+
return self
625+
626+
def unmask(self, reset=True):
627+
"""Unmask neuron.
628+
629+
Returns the neuron to its original state before masking.
630+
631+
Parameters
632+
----------
633+
reset : bool
634+
Whether to reset the neuron to its original state before masking.
635+
If False, edits made to the neuron after masking will be kept.
636+
637+
Returns
638+
-------
639+
self
640+
641+
See Also
642+
--------
643+
[`Dotprops.is_masked`][navis.Dotprops.is_masked]
644+
Check if neuron is masked.
645+
[`Dotprops.mask`][navis.Dotprops.mask]
646+
Mask neuron.
647+
[`navis.NeuronMask`][]
648+
Context manager for masking neurons.
649+
650+
"""
651+
if not self.is_masked:
652+
raise ValueError("Neuron is not masked.")
653+
654+
if reset:
655+
# Unmask and reset to original state
656+
super().unmask()
657+
return self
658+
659+
mask = self._mask
660+
for k, v in self._masked_data.items():
661+
# Combine with current data
662+
if hasattr(self, k):
663+
v = np.concatenate((v[~mask], getattr(self, k)), axis=0)
664+
setattr(self, k, v)
665+
666+
del self._mask
667+
del self._masked_data
668+
669+
self._clear_temp_attr()
670+
671+
return self
672+
542673
def recalculate_tangents(self, k: int, inplace=False):
543674
"""Recalculate tangent vectors and alpha with a new `k`.
544675

0 commit comments

Comments
 (0)