Skip to content

Commit

Permalink
added tests and solved issues in pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
eugenioLR committed Sep 2, 2024
1 parent df1b4ed commit 201195e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/seaborn_image/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,8 @@ def __init__(
)

if aspect == "auto":
aspect = 1e10

# Select minimum aspect ratio among all the images
for img in data:
aspect_aux = img.shape[1] / img.shape[0]
if aspect_aux < aspect:
aspect = aspect_aux
aspect_ratios = [img.shape[1] / img.shape[0] for img in data]
aspect = min(aspect_ratios)

elif not isinstance(data, np.ndarray):
raise ValueError("image data must be a list of images or a 3d or 4d array.")
Expand Down Expand Up @@ -433,7 +428,7 @@ def __init__(

if aspect == "auto":
# Select axis where width and height are
height_idx, width_idx = [i for i in range(data.ndim) if i != axis][:2]
height_idx, width_idx = [i for i in range(data.ndim) if i != (axis % data.ndim)][:2]

aspect = data.shape[width_idx] / data.shape[height_idx]

Expand Down Expand Up @@ -1192,10 +1187,10 @@ def __init__(
ncol = col_wrap
nrow = int(np.ceil(len(kwargs[f"{col}"]) / col_wrap))

# Calculate the base figure size
if aspect == "auto":
aspect = data.shape[1]/data.shape[0]


# Calculate the base figure size
figsize = (ncol * height * aspect, nrow * height)

fig = plt.figure(figsize=figsize)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,28 @@ def test_figure_size(self):
)
np.testing.assert_array_equal(g4.fig.get_size_inches(), (3 * 2 * 1.5, 2 * 2))
plt.close()

def test_auto_aspect(self):
imgsize0 = (10, 10)
g0 = isns.ImageGrid([np.zeros(imgsize0) for i in range(10)], aspect='auto')
assert np.isclose(imgsize0[1]/imgsize0[0], g0.aspect)
plt.close()

imgsize1 = (10, 5)
g1 = isns.ImageGrid([np.zeros(imgsize1) for i in range(10)], aspect='auto')
assert np.isclose(imgsize1[1]/imgsize1[0], g1.aspect)
plt.close()

imgsize2 = (5, 10)
g2 = isns.ImageGrid([np.zeros(imgsize2) for i in range(10)], aspect='auto')
assert np.isclose(imgsize2[1]/imgsize2[0], g2.aspect)
plt.close()

imglist = [np.zeros(imgsize0) for i in range(4)] + [np.zeros(imgsize1) for i in range(4)] + [np.zeros(imgsize2) for i in range(4)]
g3 = isns.ImageGrid(imglist, aspect='auto')
assert np.isclose(min([imgsize0[1]/imgsize0[0], imgsize1[1]/imgsize1[0], imgsize2[1]/imgsize2[0]]), g3.aspect)
plt.close()

def test_vmin_vmax(self):
g = isns.ImageGrid(cells, vmin=0.5, vmax=0.75)
for ax in g.axes.ravel():
Expand Down

0 comments on commit 201195e

Please sign in to comment.