-
Notifications
You must be signed in to change notification settings - Fork 18
Added pseudo alignment strategy based on phoneme duration #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,8 @@ | |
| import typing | ||
| from dataclasses import dataclass, replace | ||
|
|
||
| import transphone | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use a lazy import? We want to keep the mandatory dependencies for the core code small. |
||
|
|
||
| from meeteval.io.pbjson import zip_strict | ||
| from meeteval.io.stm import STM | ||
| from meeteval.io.seglst import SegLST, seglst_map, asseglst, SegLstSegment | ||
|
|
@@ -40,7 +42,7 @@ class Segment(TypedDict): | |
|
|
||
|
|
||
| # pseudo-timestamp strategies | ||
| def equidistant_intervals(interval, words): | ||
| def equidistant_intervals(interval, words, *args): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer |
||
| """Divides the interval into `count` equally sized intervals | ||
| """ | ||
| count = len(words) | ||
|
|
@@ -57,7 +59,7 @@ def equidistant_intervals(interval, words): | |
| ] | ||
|
|
||
|
|
||
| def equidistant_points(interval, words): | ||
| def equidistant_points(interval, words, *args): | ||
| """Places `count` points (intervals of size zero) in `interval` with equal distance""" | ||
| count = len(words) | ||
| if count == 0: | ||
|
|
@@ -72,7 +74,7 @@ def equidistant_points(interval, words): | |
| ] | ||
|
|
||
|
|
||
| def character_based(interval, words): | ||
| def character_based(interval, words, *args): | ||
| """Divides the interval into one interval per word where the size of the interval is | ||
| proportional to the word length in characters.""" | ||
| if len(words) == 0: | ||
|
|
@@ -93,7 +95,30 @@ def character_based(interval, words): | |
| ] | ||
|
|
||
|
|
||
| def character_based_points(interval, words): | ||
| def phoneme_based(interval, words, language): | ||
| """Divides the interval into one interval per word where the size of the interval is | ||
| proportional to the number of phonemes in the word.""" | ||
|
|
||
| g2p = transphone.read_tokenizer(language) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The call transphone.read_tokenizer sounds, that it loads a model. Has it a caching? |
||
| if len(words) == 0: | ||
| return [] | ||
| elif len(words) == 1: | ||
| return [interval] | ||
| import numpy as np | ||
|
|
||
| word_lengths = np.asarray([len(g2p.tokenize(w)) for w in words]) | ||
| end_points = np.cumsum(word_lengths) | ||
| total_num_characters = end_points[-1] | ||
| character_length = (interval[1] - interval[0]) / total_num_characters | ||
| return [ | ||
| ( | ||
| interval[0] + character_length * start, | ||
| interval[0] + character_length * end | ||
| ) | ||
| for start, end in zip([0] + list(end_points[:-1]), end_points) | ||
| ] | ||
|
|
||
| def character_based_points(interval, words, *args): | ||
| """Places points in the center of the character-based intervals""" | ||
| intervals = character_based(interval, words) | ||
| intervals = [ | ||
|
|
@@ -102,13 +127,22 @@ def character_based_points(interval, words): | |
| ] | ||
| return intervals | ||
|
|
||
| def phoneme_based_points(interval, words, language): | ||
| """Places points in the center of the phoneme-based intervals""" | ||
| intervals = phoneme_based(interval, words, language) | ||
| intervals = [ | ||
| ((interval[1] + interval[0]) / 2,) * 2 | ||
| for interval in intervals | ||
| ] | ||
| return intervals | ||
|
|
||
|
|
||
| def full_segment(interval, words): | ||
| def full_segment(interval, words, *args): | ||
| """Outputs `interval` for each word""" | ||
| return [interval] * len(words) | ||
|
|
||
|
|
||
| def no_segmentation(interval, words): | ||
| def no_segmentation(interval, words, *args): | ||
| if len(words) != 1: | ||
| if len(words) > 1: | ||
| raise ValueError( | ||
|
|
@@ -131,6 +165,8 @@ def no_segmentation(interval, words): | |
| 'equidistant_intervals': equidistant_intervals, | ||
| 'equidistant_points': equidistant_points, | ||
| 'full_segment': full_segment, | ||
| 'phoneme_based': phoneme_based, | ||
| 'phoneme_based_points': phoneme_based_points, | ||
| 'character_based': character_based, | ||
| 'character_based_points': character_based_points, | ||
| 'none': no_segmentation, | ||
|
|
@@ -585,8 +621,7 @@ def time_constrained_siso_word_error_rate( | |
| reference_pseudo_word_level_timing='character_based', | ||
| hypothesis_pseudo_word_level_timing='character_based_points', | ||
| reference_sort='segment', | ||
| hypothesis_sort='segment', | ||
| ): | ||
| hypothesis_sort='segment'): | ||
| """ | ||
| Time-constrained word error rate for single-speaker transcripts. | ||
|
|
||
|
|
@@ -669,6 +704,7 @@ def time_constrained_minimum_permutation_word_error_rate( | |
| collar, | ||
| reference_pseudo_word_level_timing='character_based', | ||
| hypothesis_pseudo_word_level_timing='character_based_points', | ||
| language=None, | ||
| reference_sort='segment', | ||
| hypothesis_sort='segment', | ||
| ) -> CPErrorRate: | ||
|
|
@@ -707,7 +743,7 @@ def time_constrained_minimum_permutation_word_error_rate( | |
| collar=collar, | ||
| reference_pseudo_word_level_timing=reference_pseudo_word_level_timing, | ||
| hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing, | ||
| segment_representation='word', | ||
| segment_representation='word', language=language | ||
| ) | ||
|
|
||
| er = _minimum_permutation_word_error_rate( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change this to
words, language=language)?