Skip to content

Commit

Permalink
Add 4D functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Mallory Wittwer committed Aug 22, 2024
1 parent 5a36449 commit 012a645
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 97 deletions.
119 changes: 71 additions & 48 deletions src/napari_label_focus/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,34 @@
import pandas as pd
import skimage.measure
from qtpy.QtWidgets import (
QFileDialog,
QGridLayout,
QHBoxLayout,
QPushButton,
QTableWidget,
QTableWidgetItem,
QWidget,
)


class Table(QWidget):
def __init__(self, layer: napari.layers.Layer = None, viewer: napari.Viewer = None):
def __init__(
self, layer: napari.layers.Layer = None, viewer: napari.Viewer = None
):
super().__init__()
self._layer = layer
self._labels_layer = layer
self._viewer = viewer
self._view = QTableWidget()
self._view.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers)
self._view.setColumnCount(2)
self._view.setRowCount(1)
self._view.setColumnWidth(0, 30)
self._view.setColumnWidth(1, 120)
self._view.setHorizontalHeaderItem(0, QTableWidgetItem('label'))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem('volume'))
self._view.setHorizontalHeaderItem(0, QTableWidgetItem("label"))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem("volume"))
self._view.clicked.connect(self._clicked_table)

save_button = QPushButton("Save as CSV")
save_button.clicked.connect(lambda _: self._save_csv())

self.setLayout(QGridLayout())
action_widget = QWidget()
action_widget.setLayout(QHBoxLayout())
action_widget.layout().addWidget(save_button)
self.layout().addWidget(action_widget)
self.layout().addWidget(self._view)
action_widget.layout().setSpacing(3)
Expand All @@ -46,35 +43,58 @@ def __init__(self, layer: napari.layers.Layer = None, viewer: napari.Viewer = No
def axes(self):
if self._viewer.dims.ndisplay == 3:
return

# 2D case
axes = list(self._viewer.dims.displayed)

# 3D case
if self._layer.data.ndim == 3:
axes.insert(0, list(set([0, 1, 2]) - set(list(self._viewer.dims.displayed)))[0])
if self._labels_layer.data.ndim == 3:
axes.insert(
0,
list(set([0, 1, 2]) - set(list(self._viewer.dims.displayed)))[
0
],
)

# 4D case (not used yet)
elif self._layer.data.ndim == 4:
elif self._labels_layer.data.ndim == 4:
xxx = set(self._viewer.dims.displayed)
to_add = list(set([0, 1, 2, 3]) - xxx)
axes = to_add + axes

return axes

def _clicked_table(self):
row = self._view.currentRow()
if self._layer is None:
if self._labels_layer is None:
return

selected_table_label = self.df["label"].values[row]

self.handle_selected_table_label_changed(selected_table_label)

def handle_selected_table_label_changed(self, selected_table_label):

if not selected_table_label in self.df['label'].unique():
print(f"Label {selected_table_label} is not present.")
return

self._layer.selected_label = self.df["label"].values[row]
self._labels_layer.selected_label = selected_table_label

x0 = int(self.df["bbox-0"].values[row])
x1 = int(self.df["bbox-3"].values[row])
y0 = int(self.df["bbox-1"].values[row])
y1 = int(self.df["bbox-4"].values[row])
z0 = int(self.df["bbox-2"].values[row])
z1 = int(self.df["bbox-5"].values[row])
sub_df = self.df[self.df['label'] == selected_table_label]

x0 = int(sub_df['bbox-0'].values[0])
x1 = int(sub_df['bbox-3'].values[0])
y0 = int(sub_df['bbox-1'].values[0])
y1 = int(sub_df['bbox-4'].values[0])
z0 = int(sub_df['bbox-2'].values[0])
z1 = int(sub_df['bbox-5'].values[0])
# x0 = int(self.df["bbox-0"].values[row])
# x1 = int(self.df["bbox-3"].values[row])
# y0 = int(self.df["bbox-1"].values[row])
# y1 = int(self.df["bbox-4"].values[row])
# z0 = int(self.df["bbox-2"].values[row])
# z1 = int(self.df["bbox-5"].values[row])

label_size = max(x1 - x0, y1 - y0, z1 - z0)

Expand All @@ -89,47 +109,48 @@ def _clicked_table(self):

if len(self.axes) == 2:
current_center[1] = centers[1:][self.axes][0]
current_center[2] = centers[1:][self.axes][1]
current_center[2] = centers[1:][self.axes][1]
elif len(self.axes) == 3:
current_center[1] = centers[self.axes[1]]
current_center[2] = centers[self.axes[2]]
# In 3D, also adjust the current step
current_step = np.array(self._viewer.dims.current_step)[self.axes]
current_step = np.array(self._viewer.dims.current_step)[
self.axes
]
current_step[self.axes[0]] = int(centers[self.axes[0]])
self._viewer.dims.current_step = tuple(current_step)

elif len(self.axes) == 4:
print("4D case not implemented yet.")
# TODO - This is very experimental (probably not working when layers are transposed)
current_center[1] = centers[self.axes[2]-1]
current_center[2] = centers[self.axes[3]-1]
current_step = np.array(self._viewer.dims.current_step)[
self.axes
]
current_step[self.axes[1]] = int(centers[self.axes[1]-1])
self._viewer.dims.current_step = tuple(current_step)

self._viewer.camera.center = tuple(current_center)

self._viewer.camera.zoom = max(3 - label_size * 0.005, 1.0)

def _save_csv(self):
if self._layer is None:
return

filename, _ = QFileDialog.getSaveFileName(
self, "Save as CSV", ".", "*.csv"
)

pd.DataFrame(self.df).to_csv(filename)

def updated_content_2D_or_3D(self, labels):
"""Compute volumes and update the table UI in the 2D and 3D cases."""
properties = skimage.measure.regionprops_table(labels, properties=["label", "area", "bbox"])
properties = skimage.measure.regionprops_table(
labels, properties=["label", "area", "bbox"]
)
self.df = pd.DataFrame.from_dict(properties)
self.df.rename(columns={"area": "volume"}, inplace=True)
self.df.sort_values(by="volume", ascending=False, inplace=True)

# Regenerate the table UI
self._view.clear()
self._view.setRowCount(len(self.df))
self._view.setHorizontalHeaderItem(0, QTableWidgetItem('label'))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem('volume'))
self._view.setHorizontalHeaderItem(0, QTableWidgetItem("label"))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem("volume"))

k = 0
for _, (lab, vol) in self.df[['label', 'volume']].iterrows():
for _, (lab, vol) in self.df[["label", "volume"]].iterrows():
self._view.setItem(k, 0, QTableWidgetItem(str(lab)))
self._view.setItem(k, 1, QTableWidgetItem(str(vol)))
k += 1
Expand All @@ -139,24 +160,26 @@ def handle_time_axis_changed(self, event, source_layer):
if (current_time != self.current_time) | (self.current_time is None):
# This gets called multiple times when moving forward in time. Why?
self.current_time = current_time
self.update_content(source_layer)
current_selected_label = self._labels_layer.selected_label
self.update_table_content(source_layer)
self.handle_selected_table_label_changed(current_selected_label)

def update_content(self, layer: napari.layers.Labels):
self._layer = layer
if self._layer is None:
def update_table_content(self, labels_layer: napari.layers.Labels):
self._labels_layer = labels_layer
if self._labels_layer is None:
self._view.clear()
self._view.setRowCount(1)
self._view.setColumnWidth(0, 30)
self._view.setColumnWidth(1, 120)
self._view.setHorizontalHeaderItem(0, QTableWidgetItem('label'))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem('volume'))
self._view.setHorizontalHeaderItem(0, QTableWidgetItem("label"))
self._view.setHorizontalHeaderItem(1, QTableWidgetItem("volume"))
return

labels = self._layer.data#.copy()
labels = self._labels_layer.data

if len(labels.shape) == 2:
labels = labels[None] # Add an extra dimension in the 2D case

elif len(labels.shape) == 4:
labels = labels[self._viewer.dims.current_step[0]]

Expand Down
89 changes: 40 additions & 49 deletions src/napari_label_focus/_widget.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,57 @@
import napari.layers
from qtpy.QtWidgets import QComboBox, QGridLayout, QWidget

import napari.layers

from qtpy.QtWidgets import QGridLayout, QWidget
from ._table import Table


class TableWidget(QWidget):
def __init__(self, napari_viewer):
super().__init__()
self.viewer = napari_viewer

self.labels_layer = None
self.selected_labels_layer = None

self.setLayout(QGridLayout())
self.cb = QComboBox()
self.cb.currentTextChanged.connect(self._on_cb_change)
self.layout().addWidget(self.cb, 0, 0)

self.table = Table(viewer=self.viewer)
self.layout().addWidget(self.table, 1, 0)
self.layout().addWidget(self.table, 0, 0)

self.viewer.layers.events.inserted.connect(
lambda e: e.value.events.name.connect(self._on_layer_change)
self.viewer.layers.selection.events.changed.connect(
self._on_layer_selection_changed
)
self.viewer.layers.events.inserted.connect(self._on_layer_change)
self.viewer.layers.events.removed.connect(self._on_layer_change)
self._on_layer_change(None)

def _on_layer_change(self, e):
self.cb.clear()
for x in self.viewer.layers:
if isinstance(x, napari.layers.Labels):
if len(x.data.shape) in [2, 3, 4]: # Only 2D-4D data are supported.
self.cb.addItem(x.name, x.data)

def _on_cb_change(self, selection: str):
if selection == '':
self.table.update_content(None)
return

selected_layer = self.viewer.layers[selection]
def _on_layer_selection_changed(self, event):
selected_layer = event.source.active
if not isinstance(selected_layer, napari.layers.Labels):
return

if self.labels_layer is not None:
# self.labels_layer.events.labels_update.disconnect(lambda _: self.table.update_content(self.labels_layer))
self.labels_layer.events.paint.disconnect(lambda _: self.table.update_content(self.labels_layer))
self.labels_layer.events.data.disconnect(lambda _: self.table.update_content(self.labels_layer))

# Updating live as pixels are drawn is too expensive (labels_update)
# selected_layer.events.labels_update.connect(lambda _: self.table.update_content(selected_layer))

# Not sure what this does - it's probably useful
selected_layer.events.data.connect(lambda _: self.table.update_content(selected_layer))

# Instead we update the table only when the mouse is up after drawing.
selected_layer.events.paint.connect(lambda _: self.table.update_content(selected_layer))

# Temporary
if self.selected_labels_layer is not None:
self.selected_labels_layer.events.paint.disconnect(
lambda _: self.table.update_table_content(
self.selected_labels_layer
)
)
self.selected_labels_layer.events.data.disconnect(
lambda _: self.table.update_table_content(
self.selected_labels_layer
)
)
if selected_layer.data.ndim == 4:
self.viewer.dims.events.current_step.disconnect(
lambda e: self.table.handle_time_axis_changed(
e, self.selected_labels_layer
)
)

selected_layer.events.data.connect(
lambda _: self.table.update_table_content(selected_layer)
)
selected_layer.events.paint.connect(
lambda _: self.table.update_table_content(selected_layer)
)
if selected_layer.data.ndim == 4:
self.viewer.dims.events.current_step.connect(lambda e: self.table.handle_time_axis_changed(e, selected_layer))

self.labels_layer = selected_layer

self.table.update_content(selected_layer)
self.viewer.dims.events.current_step.connect(
lambda e: self.table.handle_time_axis_changed(
e, selected_layer
)
)

self.selected_labels_layer = selected_layer
self.table.update_table_content(selected_layer)

0 comments on commit 012a645

Please sign in to comment.