From 71bcb184145e5b222aae141b71075341ecf5835f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Thu, 15 Aug 2024 17:12:36 +0530 Subject: [PATCH] Fix ImportError (#385) * Change BrokenBarHCollection to PolyCollection * Remove unused import * np.Inf -> np.inf * Use `ann.set_clip_path` * Set expected legend location * Run black --- mir_eval/display.py | 8 ++------ mir_eval/separation.py | 2 +- tests/test_display.py | 8 ++++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mir_eval/display.py b/mir_eval/display.py index f5d5e1ef..0dee99c4 100644 --- a/mir_eval/display.py +++ b/mir_eval/display.py @@ -12,7 +12,6 @@ from matplotlib.ticker import FuncFormatter, MultipleLocator from matplotlib.ticker import Formatter from matplotlib.colors import LinearSegmentedColormap, LogNorm, ColorConverter -from matplotlib.collections import BrokenBarHCollection from matplotlib.transforms import Bbox, TransformedBbox from .melody import freq_to_voicing @@ -184,18 +183,15 @@ def segments( seg_map[lab].pop("label", None) if text: - bbox = Bbox.from_extents(ival[0], base, ival[1], height) - tbbox = TransformedBbox(bbox, transform) ann = ax.annotate( lab, xy=(ival[0], height), xycoords=transform, xytext=(8, -10), textcoords="offset points", - clip_path=rect, - clip_box=tbbox, **text_kw ) + ann.set_clip_path(rect) return ax @@ -264,7 +260,7 @@ def labeled_intervals( **kwargs Additional keyword arguments to pass to - `matplotlib.collection.BrokenBarHCollection`. + `matplotlib.collection.PolyCollection`. Returns ------- diff --git a/mir_eval/separation.py b/mir_eval/separation.py index 0bb0704e..96570eaf 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -837,7 +837,7 @@ def _safe_db(num, den): be 0. """ if den == 0: - return np.Inf + return np.inf return 10 * np.log10(num / den) diff --git a/tests/test_display.py b/tests/test_display.py index d2c254b0..943a6674 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -151,7 +151,7 @@ def test_display_labeled_intervals_compare_noextend(): est_int, est_labels, extend_labels=False, alpha=0.5, label="Estimate" ) - plt.legend() + plt.legend(loc="upper right") return plt.gcf() @@ -178,7 +178,7 @@ def test_display_labeled_intervals_compare_common(): est_int, est_labels, label_set=label_set, alpha=0.5, label="Estimate" ) - plt.legend() + plt.legend(loc="upper right") return plt.gcf() @@ -344,7 +344,7 @@ def test_display_piano_roll(): est_t, est_p, label="Estimate", alpha=0.5, facecolor="r" ) - plt.legend() + plt.legend(loc="upper right") return plt.gcf() @@ -367,7 +367,7 @@ def test_display_piano_roll_midi(): est_t, midi=est_midi, label="Estimate", alpha=0.5, facecolor="r" ) - plt.legend() + plt.legend(loc="upper right") return plt.gcf()