Skip to content

Commit d0e28a5

Browse files
committed
better plot scales
1 parent b6fc3b3 commit d0e28a5

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,6 @@ class PlotResults2DJob(Job):
616616
Plot results
617617
"""
618618

619-
__sis_version__ = 1
620-
621619
def __init__(self, *, x_axis_name: str, y_axis_name: str, results: Dict[Tuple[float, float], tk.Path]):
622620
self.x_axis_name = x_axis_name
623621
self.y_axis_name = y_axis_name
@@ -631,23 +629,37 @@ def tasks(self):
631629
def run(self):
632630
from ast import literal_eval
633631
import matplotlib.pyplot as plt
632+
import matplotlib.ticker as ticker
634633

635634
xs = sorted(set(x for x, _ in self.results.keys()))
636635
ys = sorted(set(y for _, y in self.results.keys()))
637636
results = {k: literal_eval(open(v).read()) for k, v in self.results.items()}
638637
first_res = results[next(iter(results.keys()))]
639638
assert isinstance(first_res, dict)
640639

640+
plt.figure(figsize=(8, 8 * len(first_res)))
641+
641642
for key_idx, key in enumerate(first_res.keys()):
642643
zs = np.zeros((len(ys), len(xs)))
643644
for y_idx, y in enumerate(ys):
644645
for x_idx, x in enumerate(xs):
645646
zs[y_idx, x_idx] = results[(x, y)][key]
646647

647-
plt.subplot(len(first_res), 1, 1 + key_idx)
648-
plt.contourf(xs, ys, zs)
648+
best = np.min(zs.flatten())
649+
worst_limit = best * 1.3
650+
651+
ax = plt.subplot(len(first_res), 1, 1 + key_idx)
652+
plt.contourf(xs, ys, zs, levels=np.geomspace(best, worst_limit, 30))
653+
654+
ax.set_title(f"{key}")
655+
ax.set_ylabel(self.y_axis_name)
656+
ax.set_xlabel(self.x_axis_name)
657+
ax.xaxis.set_major_locator(ticker.AutoLocator())
658+
ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())
659+
ax.yaxis.set_major_locator(ticker.AutoLocator())
660+
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
649661

650-
plt.axis("scaled")
651-
plt.colorbar()
662+
cbar = plt.colorbar()
663+
cbar.set_label("WER [%]")
652664

653665
plt.savefig(self.out_plot.get_path())

0 commit comments

Comments
 (0)