diff --git a/src/seaborn_image/_grid.py b/src/seaborn_image/_grid.py index d4fb8b1..cb6b9f5 100644 --- a/src/seaborn_image/_grid.py +++ b/src/seaborn_image/_grid.py @@ -43,8 +43,8 @@ class ImageGrid: Number of columns to display. Defaults to None. height : int or float, optional Size of the individual images. Defaults to 3. - aspect : int or float, optional - Aspect ratio of individual images. Defaults to 1. + aspect : int, float or 'auto', optional + Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'. cmap : str or `matplotlib.colors.Colormap` or list, optional Image colormap. If input data is a list of images, `cmap` can be a list of colormaps. Defaults to None. @@ -290,7 +290,8 @@ class ImageGrid: ... [pol, pl, retina], ... map_func=[gaussian, median, hessian], ... dx=[15, 100, None], - ... units="nm") + ... units="nm", + ... aspect=1) Change colorbar orientation @@ -321,7 +322,7 @@ def __init__( map_func_kw=None, col_wrap=None, height=3, - aspect=1, + aspect="auto", cmap=None, robust=False, perc=(2, 98), @@ -365,6 +366,10 @@ def __init__( len(map_func) if len(map_func) >= len(data) else len(data) ) + if aspect == "auto": + aspect_ratios = [img.shape[1] / img.shape[0] for img in data] + aspect = min(aspect_ratios) + elif not isinstance(data, np.ndarray): raise ValueError("image data must be a list of images or a 3d or 4d array.") @@ -385,6 +390,9 @@ def __init__( # no of columns should now be len of map_func list col_wrap = len(map_func) if col_wrap is None else col_wrap + if aspect == "auto": + aspect = data.shape[1] / data.shape[0] + elif data.ndim in [3, 4]: if data.ndim == 4 and data.shape[-1] not in [1, 3, 4]: raise ValueError( @@ -418,6 +426,12 @@ def __init__( _nimages = len(slices) + if aspect == "auto": + # Select axis where width and height are + height_idx, width_idx = [i for i in range(data.ndim) if i != (axis % data.ndim)][:2] + + aspect = data.shape[width_idx] / data.shape[height_idx] + # ---- 3D or 4D image with an individual map_func ---- map_func_type = self._check_map_func(map_func, map_func_kw) # raise a ValueError if a list of map_func is provided for 3d image @@ -765,7 +779,7 @@ def rgbplot( *, col_wrap=3, height=3, - aspect=1, + aspect="auto", cmap=None, alpha=None, origin=None, @@ -794,8 +808,8 @@ def rgbplot( Number of columns to display. Defaults to 3. height : int or float, optional Size of the individual images. Defaults to 3. - aspect : int or float, optional - Aspect ratio of individual images. Defaults to 1. + aspect : int, float or 'auto', optional + Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'. cmap : str or `matplotlib.colors.Colormap` or list, optional Image colormap or a list of colormaps. Defaults to None. alpha : float or array-like, optional @@ -975,8 +989,8 @@ class ParamGrid(object): is not None and `row` is None. Defaults to None. height : int or float, optional Size of the individual images. Defaults to 3. - aspect : int or float, optional - Aspect ratio of individual images. Defaults to 1. + aspect : int, float or 'auto', optional + Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'. cmap : str or `matplotlib.colors.Colormap`, optional Image colormap. Defaults to None. alpha : float or array-like, optional @@ -1113,7 +1127,7 @@ def __init__( col=None, col_wrap=None, height=3, - aspect=1, + aspect="auto", cmap=None, alpha=None, origin=None, @@ -1173,6 +1187,9 @@ def __init__( ncol = col_wrap nrow = int(np.ceil(len(kwargs[f"{col}"]) / col_wrap)) + if aspect == "auto": + aspect = data.shape[1]/data.shape[0] + # Calculate the base figure size figsize = (ncol * height * aspect, nrow * height) diff --git a/tests/test_grid.py b/tests/test_grid.py index 47ed310..156a80b 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -871,7 +871,28 @@ def test_figure_size(self): ) np.testing.assert_array_equal(g4.fig.get_size_inches(), (3 * 2 * 1.5, 2 * 2)) plt.close() + + def test_auto_aspect(self): + imgsize0 = (10, 10) + g0 = isns.ImageGrid([np.zeros(imgsize0) for i in range(10)], aspect='auto') + assert np.isclose(imgsize0[1]/imgsize0[0], g0.aspect) + plt.close() + imgsize1 = (10, 5) + g1 = isns.ImageGrid([np.zeros(imgsize1) for i in range(10)], aspect='auto') + assert np.isclose(imgsize1[1]/imgsize1[0], g1.aspect) + plt.close() + + imgsize2 = (5, 10) + g2 = isns.ImageGrid([np.zeros(imgsize2) for i in range(10)], aspect='auto') + assert np.isclose(imgsize2[1]/imgsize2[0], g2.aspect) + plt.close() + + imglist = [np.zeros(imgsize0) for i in range(4)] + [np.zeros(imgsize1) for i in range(4)] + [np.zeros(imgsize2) for i in range(4)] + g3 = isns.ImageGrid(imglist, aspect='auto') + assert np.isclose(min([imgsize0[1]/imgsize0[0], imgsize1[1]/imgsize1[0], imgsize2[1]/imgsize2[0]]), g3.aspect) + plt.close() + def test_vmin_vmax(self): g = isns.ImageGrid(cells, vmin=0.5, vmax=0.75) for ax in g.axes.ravel():