From 15e634f11a0146129fb0db9f422ef3fe7b3cc03d Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Fri, 25 Oct 2024 17:41:25 +0200 Subject: [PATCH 1/2] Use SolutionTracks when relabeling the segmentation by track --- .../data_views/views/layers/track_graph.py | 24 +++++++++---------- .../views_coordinator/tracks_viewer.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/motile_plugin/data_views/views/layers/track_graph.py b/src/motile_plugin/data_views/views/layers/track_graph.py index fd1c810..adcd409 100644 --- a/src/motile_plugin/data_views/views/layers/track_graph.py +++ b/src/motile_plugin/data_views/views/layers/track_graph.py @@ -4,22 +4,21 @@ from typing import TYPE_CHECKING import napari -import networkx as nx import numpy as np -from motile_toolbox.candidate_graph import NodeAttr if TYPE_CHECKING: + from motile_plugin.data_model.solution_tracks import SolutionTracks from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer def update_napari_tracks( - graph: nx.DiGraph, + tracks: SolutionTracks, ): """Function to take a networkx graph with assigned track_ids and return the data needed to add to a napari tracks layer. Args: - graph (nx.DiGraph): graph that already has track_ids, position, and time assigned to the nodes. + tracks (SolutionTracks): tracks that have track_ids and have a tree structure Returns: data: array (N, D+1) @@ -34,7 +33,8 @@ def update_napari_tracks( parents, but only one child) in the case of track merging. """ - ndim = len(graph.nodes[next(iter(graph.nodes))][NodeAttr.POS.value]) + ndim = tracks.ndim - 1 + graph = tracks.graph napari_data = np.zeros((graph.number_of_nodes(), ndim + 2)) napari_edges = {} @@ -51,16 +51,16 @@ def update_napari_tracks( for index, node in enumerate(graph.nodes(data=True)): node_id, data = node - location = graph.nodes[node_id][NodeAttr.POS.value] + location = tracks.get_position(node_id) napari_data[index] = [ - data[NodeAttr.TRACK_ID.value], - data[NodeAttr.TIME.value], + tracks.get_track_id(node_id), + tracks.get_time(node_id), *location, ] for parent, child in intertrack_edges: - parent_track_id = graph.nodes[parent][NodeAttr.TRACK_ID.value] - child_track_id = graph.nodes[child][NodeAttr.TRACK_ID.value] + parent_track_id = tracks.get_track_id(parent) + child_track_id = tracks.get_track_id(child) if child_track_id in napari_edges: napari_edges[child_track_id].append(parent_track_id) else: @@ -80,7 +80,7 @@ def __init__( ): self.tracks_viewer = tracks_viewer track_data, track_edges = update_napari_tracks( - self.tracks_viewer.tracks.graph, + self.tracks_viewer.tracks, ) super().__init__( @@ -99,7 +99,7 @@ def _refresh(self): """Refreshes the displayed tracks based on the graph in the current tracks_viewer.tracks""" track_data, track_edges = update_napari_tracks( - self.tracks_viewer.tracks.graph, + self.tracks_viewer.tracks, ) self.data = track_data diff --git a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py index 3f26fcd..6a3bada 100644 --- a/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py +++ b/src/motile_plugin/data_views/views_coordinator/tracks_viewer.py @@ -210,7 +210,7 @@ def view_external_tracks(self, tracks: SolutionTracks, name: str) -> None: tracks (Tracks): A tracks object to view, created externally from the plugin name (str): The name to display in napari layers """ - tracks.segmentation = np.expand_dims(tracks.segmentation, axis=1) + # tracks.segmentation = np.expand_dims(tracks.segmentation, axis=1) tracks.segmentation = relabel_segmentation(tracks.graph, tracks.segmentation) self.update_tracks(tracks, name) From 7147eb7d50cf58b91834fbbf8e00c92bcf160f40 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 4 Nov 2024 14:59:28 -0500 Subject: [PATCH 2/2] Speed up points layer inititalization --- .../data_views/views/layers/track_points.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/motile_plugin/data_views/views/layers/track_points.py b/src/motile_plugin/data_views/views/layers/track_points.py index 85f2f76..39e9100 100644 --- a/src/motile_plugin/data_views/views/layers/track_points.py +++ b/src/motile_plugin/data_views/views/layers/track_points.py @@ -28,13 +28,7 @@ def __init__( ): self.tracks_viewer = tracks_viewer self.nodes = list(tracks_viewer.tracks.graph.nodes) - self.node_index_dict = dict( - zip( - self.nodes, - [self.nodes.index(node) for node in self.nodes], - strict=False, - ) - ) + self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)} points = self.tracks_viewer.tracks.get_positions(self.nodes, incl_time=True) track_ids = [ @@ -104,13 +98,7 @@ def _refresh(self): ) # do not listen to new events until updates are complete self.nodes = list(self.tracks_viewer.tracks.graph.nodes) - self.node_index_dict = dict( - zip( - self.nodes, - [self.nodes.index(node) for node in self.nodes], - strict=False, - ) - ) + self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)} track_ids = [ self.tracks_viewer.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value]