diff --git a/users/dorian_koch/jobs/ScliteToWerDistributionGraph.py b/users/dorian_koch/jobs/ScliteToWerDistributionGraph.py index a04120c0b..2ccd5b3d4 100644 --- a/users/dorian_koch/jobs/ScliteToWerDistributionGraph.py +++ b/users/dorian_koch/jobs/ScliteToWerDistributionGraph.py @@ -1,5 +1,6 @@ -from typing import List, Tuple +from typing import List, Tuple, Union from sisyphus import Job, Task, tk +from sisyphus.delayed_ops import DelayedBase from i6_core.util import uopen @@ -10,11 +11,13 @@ def __init__( *, report_dir: tk.AbstractPath, num_bins: int = 10, - plot_title: str = "WER distribution", + plot_title: Union[str, DelayedBase] = "WER distribution", + plot_metrics: bool = True, ): self.report_dir = report_dir self.num_bins = num_bins self.plot_title = plot_title + self.plot_metrics = plot_metrics self.out_file = self.output_path("vals.csv") self.distrib_file = self.output_path("distrib.csv") @@ -28,7 +31,6 @@ def tasks(self): def run(self): import matplotlib.pyplot as plt - import matplotlib.ticker as ticker output_dir = self.report_dir.get_path() values = [] @@ -53,37 +55,49 @@ def run(self): print(f"Wrote to {self.out_file}") bins = [0] * self.num_bins - + avg = 0 + total = 0 for (c,s,d,i) in values: if s+d+c == 0: print("Warning: empty sequence") continue wer = 100.0 * (s + d + i) / (s + d + c) bin_idx = min(int(wer / 100.0 * self.num_bins), self.num_bins - 1) - bins[bin_idx] += 1 + val = (s + d + c) + avg += wer * val + bins[bin_idx] += val + total += val print("WER distribution:") with uopen(self.distrib_file, "wt") as out: - out.write("bin_start,bin_end,count,relative_count\n") + out.write("bin_start,bin_end,count,relative_count_weighed_by_ref_length\n") for i, count in enumerate(bins): - print(f"{i/self.num_bins:.4f}-{(i+1)/self.num_bins:.4f}: {count / len(values) * 100:.3f}%") - out.write(f"{i/self.num_bins:.4f},{(i+1)/self.num_bins:.4f},{count},{count / len(values):.6f}\n") + print(f"{i/self.num_bins:.4f}-{(i+1)/self.num_bins:.4f}: {count / total * 100:.3f}%") + out.write(f"{i/self.num_bins:.4f},{(i+1)/self.num_bins:.4f},{count},{count / total:.6f}\n") plt.figure(figsize=(8, 8)) # show relative count - plt.bar(range(self.num_bins), [count / len(values) for count in bins]) + plt.bar(range(self.num_bins), [count / total for count in bins], align="edge") + + if self.plot_metrics: + # plot avg (this should be the wer score as reported by sclite) + plt.axvline(x=avg / total / 100 * self.num_bins, color="red", label=f"WER: {avg / total:.2f}") plt.xlabel("WER") plt.ylabel("fraction") plt.ylim(0, 1) - plt.title(self.plot_title) + if isinstance(self.plot_title, DelayedBase): + plt.title(self.plot_title.get()) + else: + plt.title(self.plot_title) plt.xticks(range(0, self.num_bins, max(1, self.num_bins // 10)), [f"{i/self.num_bins:.2f}" for i in range(0, self.num_bins, max(1, self.num_bins // 10))]) plt.grid(axis="y") + plt.legend(loc="upper right") plt.savefig(self.out_plot) plt.autoscale(axis="y") plt.savefig(self.out_plot_no_ylim) - new_ylim = max([count / len(values) for count in bins[1:]]) * 1.1 + new_ylim = max([count / total for count in bins[1:]]) * 1.1 plt.ylim(0, new_ylim) plt.savefig(self.out_plot_ylim_without_first_bin) plt.ylim(0, 0.1)