Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions src/silx/gui/data/DataViews.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy
import os

from silx.gui.data.NXdataWidgets import ArrayImagePlot
import silx.io
from silx.gui import qt, icons
from silx.gui.data.TextFormatter import TextFormatter
Expand Down Expand Up @@ -1760,27 +1761,15 @@ def setData(self, data):

self._updateColormap(nxd)

# last two axes are Y & X
img_slicing = slice(-2, None) if not isRgba else slice(-3, -1)
y_axis, x_axis = nxd.axes[img_slicing]
y_label, x_label = nxd.axes_names[img_slicing]
y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing]
x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None

self.getWidget().setImageData(
widget: ArrayImagePlot = self.getWidget()
widget.setImageData(
[nxd.signal] + nxd.auxiliary_signals,
x_axis=x_axis,
y_axis=y_axis,
axes=nxd.axes,
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
axes_names=nxd.axes_names,
xlabel=x_label,
ylabel=y_label,
axes_scales=nxd.plot_style.axes_scale_types,
title=nxd.title,
isRgba=isRgba,
xscale=x_scale,
yscale=y_scale,
keep_ratio=(x_units == y_units),
)

def getDataPriority(self, data, info: DataInfo):
Expand Down
86 changes: 46 additions & 40 deletions src/silx/gui/data/NXdataWidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from silx.gui.colors import Colormap
from silx.gui.data._SignalSelector import SignalSelector

from silx.io.nxdata._utils import get_attr_as_unicode
from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration


Expand Down Expand Up @@ -388,10 +389,7 @@ def __init__(self, parent=None):

self.__signals = None
self.__signals_names = None
self.__x_axis = None
self.__x_axis_name = None
self.__y_axis = None
self.__y_axis_name = None
self.__resetZoomNextTime = True

self._plot = Plot2D(self)
self._plot.setDefaultColormap(
Expand All @@ -404,10 +402,9 @@ def __init__(self, parent=None):
maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
maskToolWidget.setItemMaskUpdated(True)

# not closable
self._axesSelector = NumpyAxesSelector(self)
self._axesSelector.setNamedAxesSelectorVisibility(False)
self._axesSelector.selectionChanged.connect(self._updateImage)
self._axesSelector.selectedAxisChanged.connect(self._clearImage)

self._signalSelector = SignalSelector(parent=self)
self._signalSelector.selectionChanged.connect(self._signalChanges)
Expand Down Expand Up @@ -454,47 +451,31 @@ def getPlot(self):
def setImageData(
self,
signals,
x_axis=None,
y_axis=None,
axes=None,
signals_names=None,
axes_names=None,
xlabel=None,
ylabel=None,
axes_scales=None
title=None,
isRgba=False,
xscale=None,
yscale=None,
keep_ratio: bool = True,
):
"""

:param signals: list of n-D datasets, whose last 2 dimensions are used as the
image's values, or list of 3D datasets interpreted as RGBA image.
:param x_axis: 1-D dataset used as the image's x coordinates. If
provided, its lengths must be equal to the length of the last
dimension of ``signal``.
:param y_axis: 1-D dataset used as the image's y. If provided,
its lengths must be equal to the length of the 2nd to last
dimension of ``signal``.
:param signals_names: Names for each image, used as subtitle and legend.
:param xlabel: Label for X axis
:param ylabel: Label for Y axis
:param title: Graph title
:param isRgba: True if data is a 3D RGBA image
:param str xscale: Scale of X axis in (None, 'linear', 'log')
:param str yscale: Scale of Y axis in (None, 'linear', 'log')
:param keep_ratio: Toggle plot keep aspect ratio
:param str axes_scales: Scale of axes in (None, 'linear', 'log')
"""
self._axesSelector.selectionChanged.disconnect(self._updateImage)
self._axesSelector.selectedAxisChanged.disconnect(self._clearImage)
self._signalSelector.selectionChanged.disconnect(self._signalChanges)

self.__signals = signals
self.__signals_names = signals_names
self.__axis_names = axes_names
self.__x_axis = x_axis
self.__x_axis_name = xlabel
self.__y_axis = y_axis
self.__y_axis_name = ylabel
self.__axes = axes
self.__axes_names = axes_names
self.__axes_scales = axes_scales
self.__title = title

self._axesSelector.clear()
Expand All @@ -511,8 +492,8 @@ def setImageData(
else:
self._axesSelector.show()

if self.__axis_names:
self._axesSelector.setLabels(self.__axis_names)
if self.__axes_names:
self._axesSelector.setLabels(self.__axes_names)

self._signalSelector.setSignalNames(signals_names)
if len(signals) > 1:
Expand All @@ -521,26 +502,38 @@ def setImageData(
self._signalSelector.hide()
self._signalSelector.setSignalIndex(0)

self._axis_scales = xscale, yscale

self._axesSelector.selectionChanged.connect(self._updateImage)
self._axesSelector.selectedAxisChanged.connect(self._clearImage)
self._signalSelector.selectionChanged.connect(self._signalChanges)

self._updateImage()
self._plot.setKeepDataAspectRatio(keep_ratio)
self._plot.resetZoom()

def _updateImage(self):
axes_selection = self._axesSelector.selection()
signal_index = self._signalSelector.getSignalIndex()
print(f"{axes_selection}, {signal_index}")

legend = self.__signals_names[signal_index]

images = [img[axes_selection] for img in self.__signals]
image = images[signal_index]

x_axis = self.__x_axis
y_axis = self.__y_axis
axis_indices = self._axesSelector.getIndicesOfNamedAxes()
x_axis_index = axis_indices["X"]
y_axis_index = axis_indices["X"]

if self.__axes:
x_axis = self.__axes[x_axis_index]
y_axis = self.__axes[y_axis_index]
x_units = get_attr_as_unicode(x_axis, "units")
y_units = get_attr_as_unicode(y_axis, "units")
self._plot.setKeepDataAspectRatio(x_units == y_units)
else:
x_axis = None
y_axis = None
self._plot.setKeepDataAspectRatio(False)


if x_axis is None and y_axis is None:
xcalib = NoCalibration()
Expand Down Expand Up @@ -601,7 +594,12 @@ def _updateImage(self):
self._plot.addItem(imageItem)
self._plot.setActiveImage(imageItem)
else:
xaxisscale, yaxisscale = self._axis_scales
if self.__axes_scales:
xaxisscale = self.__axes_scales[x_axis_index]
yaxisscale = self.__axes_scales[y_axis_index]
else:
xaxisscale = None
yaxisscale = None

if xaxisscale is not None:
self._plot.getXAxis().setScale(
Expand All @@ -621,23 +619,31 @@ def _updateImage(self):
legend=legend,
)

if self.__resetZoomNextTime:
self._plot.resetZoom()
self.__resetZoomNextTime = False

if self.__title:
title = self.__title
if len(self.__signals_names) > 1:
# Append dataset name only when there is many datasets
# Append dataset name only when there are many datasets
title += "\n" + self.__signals_names[signal_index]
else:
title = self.__signals_names[signal_index]
self._plot.setGraphTitle(title)
self._plot.getXAxis().setLabel(self.__x_axis_name)
self._plot.getYAxis().setLabel(self.__y_axis_name)
self._plot.getXAxis().setLabel(self.__axes_names[x_axis_index])
self._plot.getYAxis().setLabel(self.__axes_names[y_axis_index])

def clear(self):
old = self._axesSelector.blockSignals(True)
self._axesSelector.clear()
self._axesSelector.blockSignals(old)
self._plot.clear()

def _clearImage(self):
self._plot.clear()
self.__resetZoomNextTime = True


class ArrayComplexImagePlot(qt.QWidget):
"""
Expand Down
8 changes: 8 additions & 0 deletions src/silx/gui/data/NumpyAxesSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,11 @@ def setNamedAxesSelectorVisibility(self, visible):
self.__namedAxesVisibility = visible
for axis in self.__axis:
axis.setNamedAxisSelectorVisibility(visible)

def getIndicesOfNamedAxes(self) -> dict[str, int]:
result: dict[str, int] = {}
for i, axis in enumerate(self.__axis):
name = axis.axisName()
if name:
result[name] = i
return result
Loading