From 201195e3f776373e866c8f3dbe8e6cd46dc5f15e Mon Sep 17 00:00:00 2001 From: eugenioLR Date: Mon, 2 Sep 2024 12:52:38 +0200 Subject: [PATCH] added tests and solved issues in pull request --- src/seaborn_image/_grid.py | 15 +++++---------- tests/test_grid.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/seaborn_image/_grid.py b/src/seaborn_image/_grid.py index 7e019b4..cb6b9f5 100644 --- a/src/seaborn_image/_grid.py +++ b/src/seaborn_image/_grid.py @@ -367,13 +367,8 @@ def __init__( ) if aspect == "auto": - aspect = 1e10 - - # Select minimum aspect ratio among all the images - for img in data: - aspect_aux = img.shape[1] / img.shape[0] - if aspect_aux < aspect: - aspect = aspect_aux + aspect_ratios = [img.shape[1] / img.shape[0] for img in data] + aspect = min(aspect_ratios) elif not isinstance(data, np.ndarray): raise ValueError("image data must be a list of images or a 3d or 4d array.") @@ -433,7 +428,7 @@ def __init__( if aspect == "auto": # Select axis where width and height are - height_idx, width_idx = [i for i in range(data.ndim) if i != axis][:2] + height_idx, width_idx = [i for i in range(data.ndim) if i != (axis % data.ndim)][:2] aspect = data.shape[width_idx] / data.shape[height_idx] @@ -1192,10 +1187,10 @@ def __init__( ncol = col_wrap nrow = int(np.ceil(len(kwargs[f"{col}"]) / col_wrap)) - # Calculate the base figure size if aspect == "auto": aspect = data.shape[1]/data.shape[0] - + + # Calculate the base figure size figsize = (ncol * height * aspect, nrow * height) fig = plt.figure(figsize=figsize) diff --git a/tests/test_grid.py b/tests/test_grid.py index 47ed310..156a80b 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -871,7 +871,28 @@ def test_figure_size(self): ) np.testing.assert_array_equal(g4.fig.get_size_inches(), (3 * 2 * 1.5, 2 * 2)) plt.close() + + def test_auto_aspect(self): + imgsize0 = (10, 10) + g0 = isns.ImageGrid([np.zeros(imgsize0) for i in range(10)], aspect='auto') + assert np.isclose(imgsize0[1]/imgsize0[0], g0.aspect) + plt.close() + imgsize1 = (10, 5) + g1 = isns.ImageGrid([np.zeros(imgsize1) for i in range(10)], aspect='auto') + assert np.isclose(imgsize1[1]/imgsize1[0], g1.aspect) + plt.close() + + imgsize2 = (5, 10) + g2 = isns.ImageGrid([np.zeros(imgsize2) for i in range(10)], aspect='auto') + assert np.isclose(imgsize2[1]/imgsize2[0], g2.aspect) + plt.close() + + imglist = [np.zeros(imgsize0) for i in range(4)] + [np.zeros(imgsize1) for i in range(4)] + [np.zeros(imgsize2) for i in range(4)] + g3 = isns.ImageGrid(imglist, aspect='auto') + assert np.isclose(min([imgsize0[1]/imgsize0[0], imgsize1[1]/imgsize1[0], imgsize2[1]/imgsize2[0]]), g3.aspect) + plt.close() + def test_vmin_vmax(self): g = isns.ImageGrid(cells, vmin=0.5, vmax=0.75) for ax in g.axes.ravel():