diff --git a/src/silx/gui/data/DataViews.py b/src/silx/gui/data/DataViews.py index 40a1edebda..a5e43ef5aa 100644 --- a/src/silx/gui/data/DataViews.py +++ b/src/silx/gui/data/DataViews.py @@ -28,6 +28,7 @@ import numpy import os +from silx.gui.data.NXdataWidgets import ArrayImagePlot import silx.io from silx.gui import qt, icons from silx.gui.data.TextFormatter import TextFormatter @@ -1760,27 +1761,15 @@ def setData(self, data): self._updateColormap(nxd) - # last two axes are Y & X - img_slicing = slice(-2, None) if not isRgba else slice(-3, -1) - y_axis, x_axis = nxd.axes[img_slicing] - y_label, x_label = nxd.axes_names[img_slicing] - y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing] - x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None - y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None - - self.getWidget().setImageData( + widget: ArrayImagePlot = self.getWidget() + widget.setImageData( [nxd.signal] + nxd.auxiliary_signals, - x_axis=x_axis, - y_axis=y_axis, + axes=nxd.axes, signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names, axes_names=nxd.axes_names, - xlabel=x_label, - ylabel=y_label, + axes_scales=nxd.plot_style.axes_scale_types, title=nxd.title, isRgba=isRgba, - xscale=x_scale, - yscale=y_scale, - keep_ratio=(x_units == y_units), ) def getDataPriority(self, data, info: DataInfo): diff --git a/src/silx/gui/data/NXdataWidgets.py b/src/silx/gui/data/NXdataWidgets.py index 7d7939a08a..3cf7ac3792 100644 --- a/src/silx/gui/data/NXdataWidgets.py +++ b/src/silx/gui/data/NXdataWidgets.py @@ -38,6 +38,7 @@ from silx.gui.colors import Colormap from silx.gui.data._SignalSelector import SignalSelector +from silx.io.nxdata._utils import get_attr_as_unicode from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration @@ -388,10 +389,7 @@ def __init__(self, parent=None): self.__signals = None self.__signals_names = None - self.__x_axis = None - self.__x_axis_name = None - self.__y_axis = None - self.__y_axis_name = None + self.__resetZoomNextTime = True self._plot = Plot2D(self) self._plot.setDefaultColormap( @@ -404,10 +402,9 @@ def __init__(self, parent=None): maskToolWidget = self._plot.getMaskToolsDockWidget().widget() maskToolWidget.setItemMaskUpdated(True) - # not closable self._axesSelector = NumpyAxesSelector(self) - self._axesSelector.setNamedAxesSelectorVisibility(False) self._axesSelector.selectionChanged.connect(self._updateImage) + self._axesSelector.selectedAxisChanged.connect(self._clearImage) self._signalSelector = SignalSelector(parent=self) self._signalSelector.selectionChanged.connect(self._signalChanges) @@ -454,47 +451,31 @@ def getPlot(self): def setImageData( self, signals, - x_axis=None, - y_axis=None, + axes=None, signals_names=None, axes_names=None, - xlabel=None, - ylabel=None, + axes_scales=None title=None, isRgba=False, - xscale=None, - yscale=None, - keep_ratio: bool = True, ): """ :param signals: list of n-D datasets, whose last 2 dimensions are used as the image's values, or list of 3D datasets interpreted as RGBA image. - :param x_axis: 1-D dataset used as the image's x coordinates. If - provided, its lengths must be equal to the length of the last - dimension of ``signal``. - :param y_axis: 1-D dataset used as the image's y. If provided, - its lengths must be equal to the length of the 2nd to last - dimension of ``signal``. :param signals_names: Names for each image, used as subtitle and legend. - :param xlabel: Label for X axis - :param ylabel: Label for Y axis :param title: Graph title :param isRgba: True if data is a 3D RGBA image - :param str xscale: Scale of X axis in (None, 'linear', 'log') - :param str yscale: Scale of Y axis in (None, 'linear', 'log') - :param keep_ratio: Toggle plot keep aspect ratio + :param str axes_scales: Scale of axes in (None, 'linear', 'log') """ self._axesSelector.selectionChanged.disconnect(self._updateImage) + self._axesSelector.selectedAxisChanged.disconnect(self._clearImage) self._signalSelector.selectionChanged.disconnect(self._signalChanges) self.__signals = signals self.__signals_names = signals_names - self.__axis_names = axes_names - self.__x_axis = x_axis - self.__x_axis_name = xlabel - self.__y_axis = y_axis - self.__y_axis_name = ylabel + self.__axes = axes + self.__axes_names = axes_names + self.__axes_scales = axes_scales self.__title = title self._axesSelector.clear() @@ -511,8 +492,8 @@ def setImageData( else: self._axesSelector.show() - if self.__axis_names: - self._axesSelector.setLabels(self.__axis_names) + if self.__axes_names: + self._axesSelector.setLabels(self.__axes_names) self._signalSelector.setSignalNames(signals_names) if len(signals) > 1: @@ -521,26 +502,38 @@ def setImageData( self._signalSelector.hide() self._signalSelector.setSignalIndex(0) - self._axis_scales = xscale, yscale self._axesSelector.selectionChanged.connect(self._updateImage) + self._axesSelector.selectedAxisChanged.connect(self._clearImage) self._signalSelector.selectionChanged.connect(self._signalChanges) self._updateImage() - self._plot.setKeepDataAspectRatio(keep_ratio) - self._plot.resetZoom() def _updateImage(self): axes_selection = self._axesSelector.selection() signal_index = self._signalSelector.getSignalIndex() + print(f"{axes_selection}, {signal_index}") legend = self.__signals_names[signal_index] images = [img[axes_selection] for img in self.__signals] image = images[signal_index] - x_axis = self.__x_axis - y_axis = self.__y_axis + axis_indices = self._axesSelector.getIndicesOfNamedAxes() + x_axis_index = axis_indices["X"] + y_axis_index = axis_indices["X"] + + if self.__axes: + x_axis = self.__axes[x_axis_index] + y_axis = self.__axes[y_axis_index] + x_units = get_attr_as_unicode(x_axis, "units") + y_units = get_attr_as_unicode(y_axis, "units") + self._plot.setKeepDataAspectRatio(x_units == y_units) + else: + x_axis = None + y_axis = None + self._plot.setKeepDataAspectRatio(False) + if x_axis is None and y_axis is None: xcalib = NoCalibration() @@ -601,7 +594,12 @@ def _updateImage(self): self._plot.addItem(imageItem) self._plot.setActiveImage(imageItem) else: - xaxisscale, yaxisscale = self._axis_scales + if self.__axes_scales: + xaxisscale = self.__axes_scales[x_axis_index] + yaxisscale = self.__axes_scales[y_axis_index] + else: + xaxisscale = None + yaxisscale = None if xaxisscale is not None: self._plot.getXAxis().setScale( @@ -621,16 +619,20 @@ def _updateImage(self): legend=legend, ) + if self.__resetZoomNextTime: + self._plot.resetZoom() + self.__resetZoomNextTime = False + if self.__title: title = self.__title if len(self.__signals_names) > 1: - # Append dataset name only when there is many datasets + # Append dataset name only when there are many datasets title += "\n" + self.__signals_names[signal_index] else: title = self.__signals_names[signal_index] self._plot.setGraphTitle(title) - self._plot.getXAxis().setLabel(self.__x_axis_name) - self._plot.getYAxis().setLabel(self.__y_axis_name) + self._plot.getXAxis().setLabel(self.__axes_names[x_axis_index]) + self._plot.getYAxis().setLabel(self.__axes_names[y_axis_index]) def clear(self): old = self._axesSelector.blockSignals(True) @@ -638,6 +640,10 @@ def clear(self): self._axesSelector.blockSignals(old) self._plot.clear() + def _clearImage(self): + self._plot.clear() + self.__resetZoomNextTime = True + class ArrayComplexImagePlot(qt.QWidget): """ diff --git a/src/silx/gui/data/NumpyAxesSelector.py b/src/silx/gui/data/NumpyAxesSelector.py index 453397e77e..79db5eae23 100644 --- a/src/silx/gui/data/NumpyAxesSelector.py +++ b/src/silx/gui/data/NumpyAxesSelector.py @@ -595,3 +595,11 @@ def setNamedAxesSelectorVisibility(self, visible): self.__namedAxesVisibility = visible for axis in self.__axis: axis.setNamedAxisSelectorVisibility(visible) + + def getIndicesOfNamedAxes(self) -> dict[str, int]: + result: dict[str, int] = {} + for i, axis in enumerate(self.__axis): + name = axis.axisName() + if name: + result[name] = i + return result