Skip to content

Commit

Permalink
Merge pull request #104 from funkelab/bugfix_external_tracks
Browse files Browse the repository at this point in the history
Bugfix external tracks: allow different attribute names
  • Loading branch information
cmalinmayor authored Nov 4, 2024
2 parents 240db32 + 7147eb7 commit 377d603
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 27 deletions.
24 changes: 12 additions & 12 deletions src/motile_plugin/data_views/views/layers/track_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
from typing import TYPE_CHECKING

import napari
import networkx as nx
import numpy as np
from motile_toolbox.candidate_graph import NodeAttr

if TYPE_CHECKING:
from motile_plugin.data_model.solution_tracks import SolutionTracks
from motile_plugin.data_views.views_coordinator.tracks_viewer import TracksViewer


def update_napari_tracks(
graph: nx.DiGraph,
tracks: SolutionTracks,
):
"""Function to take a networkx graph with assigned track_ids and return the data needed to add to
a napari tracks layer.
Args:
graph (nx.DiGraph): graph that already has track_ids, position, and time assigned to the nodes.
tracks (SolutionTracks): tracks that have track_ids and have a tree structure
Returns:
data: array (N, D+1)
Expand All @@ -34,7 +33,8 @@ def update_napari_tracks(
parents, but only one child) in the case of track merging.
"""

ndim = len(graph.nodes[next(iter(graph.nodes))][NodeAttr.POS.value])
ndim = tracks.ndim - 1
graph = tracks.graph
napari_data = np.zeros((graph.number_of_nodes(), ndim + 2))
napari_edges = {}

Expand All @@ -51,16 +51,16 @@ def update_napari_tracks(

for index, node in enumerate(graph.nodes(data=True)):
node_id, data = node
location = graph.nodes[node_id][NodeAttr.POS.value]
location = tracks.get_position(node_id)
napari_data[index] = [
data[NodeAttr.TRACK_ID.value],
data[NodeAttr.TIME.value],
tracks.get_track_id(node_id),
tracks.get_time(node_id),
*location,
]

for parent, child in intertrack_edges:
parent_track_id = graph.nodes[parent][NodeAttr.TRACK_ID.value]
child_track_id = graph.nodes[child][NodeAttr.TRACK_ID.value]
parent_track_id = tracks.get_track_id(parent)
child_track_id = tracks.get_track_id(child)
if child_track_id in napari_edges:
napari_edges[child_track_id].append(parent_track_id)
else:
Expand All @@ -80,7 +80,7 @@ def __init__(
):
self.tracks_viewer = tracks_viewer
track_data, track_edges = update_napari_tracks(
self.tracks_viewer.tracks.graph,
self.tracks_viewer.tracks,
)

super().__init__(
Expand All @@ -99,7 +99,7 @@ def _refresh(self):
"""Refreshes the displayed tracks based on the graph in the current tracks_viewer.tracks"""

track_data, track_edges = update_napari_tracks(
self.tracks_viewer.tracks.graph,
self.tracks_viewer.tracks,
)

self.data = track_data
Expand Down
16 changes: 2 additions & 14 deletions src/motile_plugin/data_views/views/layers/track_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ def __init__(
):
self.tracks_viewer = tracks_viewer
self.nodes = list(tracks_viewer.tracks.graph.nodes)
self.node_index_dict = dict(
zip(
self.nodes,
[self.nodes.index(node) for node in self.nodes],
strict=False,
)
)
self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)}

points = self.tracks_viewer.tracks.get_positions(self.nodes, incl_time=True)
track_ids = [
Expand Down Expand Up @@ -104,13 +98,7 @@ def _refresh(self):
) # do not listen to new events until updates are complete
self.nodes = list(self.tracks_viewer.tracks.graph.nodes)

self.node_index_dict = dict(
zip(
self.nodes,
[self.nodes.index(node) for node in self.nodes],
strict=False,
)
)
self.node_index_dict = {node: idx for idx, node in enumerate(self.nodes)}

track_ids = [
self.tracks_viewer.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def view_external_tracks(self, tracks: SolutionTracks, name: str) -> None:
tracks (Tracks): A tracks object to view, created externally from the plugin
name (str): The name to display in napari layers
"""
tracks.segmentation = np.expand_dims(tracks.segmentation, axis=1)
# tracks.segmentation = np.expand_dims(tracks.segmentation, axis=1)
tracks.segmentation = relabel_segmentation(tracks.graph, tracks.segmentation)
self.update_tracks(tracks, name)

Expand Down

0 comments on commit 377d603

Please sign in to comment.