Skip to content

Commit

Permalink
Better wer distrib
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Feb 14, 2025
1 parent 2ad3a9b commit 181b8d9
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions users/dorian_koch/jobs/ScliteToWerDistributionGraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Tuple
from typing import List, Tuple, Union
from sisyphus import Job, Task, tk
from sisyphus.delayed_ops import DelayedBase
from i6_core.util import uopen


Expand All @@ -10,11 +11,13 @@ def __init__(
*,
report_dir: tk.AbstractPath,
num_bins: int = 10,
plot_title: str = "WER distribution",
plot_title: Union[str, DelayedBase] = "WER distribution",
plot_metrics: bool = True,
):
self.report_dir = report_dir
self.num_bins = num_bins
self.plot_title = plot_title
self.plot_metrics = plot_metrics

self.out_file = self.output_path("vals.csv")
self.distrib_file = self.output_path("distrib.csv")
Expand All @@ -28,7 +31,6 @@ def tasks(self):

def run(self):
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
output_dir = self.report_dir.get_path()

values = []
Expand All @@ -53,37 +55,49 @@ def run(self):
print(f"Wrote to {self.out_file}")

bins = [0] * self.num_bins

avg = 0
total = 0
for (c,s,d,i) in values:
if s+d+c == 0:
print("Warning: empty sequence")
continue
wer = 100.0 * (s + d + i) / (s + d + c)
bin_idx = min(int(wer / 100.0 * self.num_bins), self.num_bins - 1)
bins[bin_idx] += 1
val = (s + d + c)
avg += wer * val
bins[bin_idx] += val
total += val

print("WER distribution:")
with uopen(self.distrib_file, "wt") as out:
out.write("bin_start,bin_end,count,relative_count\n")
out.write("bin_start,bin_end,count,relative_count_weighed_by_ref_length\n")
for i, count in enumerate(bins):
print(f"{i/self.num_bins:.4f}-{(i+1)/self.num_bins:.4f}: {count / len(values) * 100:.3f}%")
out.write(f"{i/self.num_bins:.4f},{(i+1)/self.num_bins:.4f},{count},{count / len(values):.6f}\n")
print(f"{i/self.num_bins:.4f}-{(i+1)/self.num_bins:.4f}: {count / total * 100:.3f}%")
out.write(f"{i/self.num_bins:.4f},{(i+1)/self.num_bins:.4f},{count},{count / total:.6f}\n")


plt.figure(figsize=(8, 8))
# show relative count
plt.bar(range(self.num_bins), [count / len(values) for count in bins])
plt.bar(range(self.num_bins), [count / total for count in bins], align="edge")

if self.plot_metrics:
# plot avg (this should be the wer score as reported by sclite)
plt.axvline(x=avg / total / 100 * self.num_bins, color="red", label=f"WER: {avg / total:.2f}")

plt.xlabel("WER")
plt.ylabel("fraction")
plt.ylim(0, 1)
plt.title(self.plot_title)
if isinstance(self.plot_title, DelayedBase):
plt.title(self.plot_title.get())
else:
plt.title(self.plot_title)
plt.xticks(range(0, self.num_bins, max(1, self.num_bins // 10)), [f"{i/self.num_bins:.2f}" for i in range(0, self.num_bins, max(1, self.num_bins // 10))])
plt.grid(axis="y")
plt.legend(loc="upper right")
plt.savefig(self.out_plot)
plt.autoscale(axis="y")
plt.savefig(self.out_plot_no_ylim)
new_ylim = max([count / len(values) for count in bins[1:]]) * 1.1
new_ylim = max([count / total for count in bins[1:]]) * 1.1
plt.ylim(0, new_ylim)
plt.savefig(self.out_plot_ylim_without_first_bin)
plt.ylim(0, 0.1)
Expand Down

0 comments on commit 181b8d9

Please sign in to comment.