Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions meeteval/wer/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def split_words(
keys=('words',),
word_level_timing_strategy=None,
segment_representation='word', # 'segment', 'word', 'speaker'
language=None
):
"""
Splits segments into words and copies all other entries.
Expand Down Expand Up @@ -77,8 +78,7 @@ def get_words(s):
words = s['words'] or ['']
timestamps = word_level_timing_strategy(
(s['start_time'], s['end_time']),
words
)
words, language)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to words, language=language)?

s['start_time'] = [s for s, _ in timestamps]
s['end_time'] = [s for _, s in timestamps]

Expand Down Expand Up @@ -263,6 +263,7 @@ def _preprocess_single(
name=None,
segment_index=False, # 'segment', 'word', False
segment_representation='word', # 'segment', 'word', 'speaker'
language=None
):
"""
>>> from paderbox.utils.pretty import pprint
Expand Down Expand Up @@ -428,7 +429,8 @@ def _preprocess_single(
words = split_words(
segments,
word_level_timing_strategy=word_level_timing_strategy,
segment_representation=segment_representation
segment_representation=segment_representation,
language=language
)

# Warn or raise an exception if the order of the words contradicts the
Expand Down Expand Up @@ -517,6 +519,7 @@ def preprocess(
hypothesis_pseudo_word_level_timing=None,
segment_representation='segment', # 'segment', 'word', 'speaker'
ensure_single_session=True,
language=None
):
"""
Preprocessing.
Expand All @@ -538,6 +541,7 @@ def preprocess(
collar=None, # collar is not applied to the reference
word_level_timing_strategy=reference_pseudo_word_level_timing,
segment_representation=segment_representation,
language=language
)
hypothesis, hypothesis_self_overlap = _preprocess_single(
hypothesis,
Expand All @@ -549,6 +553,7 @@ def preprocess(
collar=collar,
word_level_timing_strategy=hypothesis_pseudo_word_level_timing,
segment_representation=segment_representation,
language=language
)


Expand Down
54 changes: 45 additions & 9 deletions meeteval/wer/wer/time_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import typing
from dataclasses import dataclass, replace

import transphone
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use a lazy import? We want to keep the mandatory dependencies for the core code small.


from meeteval.io.pbjson import zip_strict
from meeteval.io.stm import STM
from meeteval.io.seglst import SegLST, seglst_map, asseglst, SegLstSegment
Expand Down Expand Up @@ -40,7 +42,7 @@ class Segment(TypedDict):


# pseudo-timestamp strategies
def equidistant_intervals(interval, words):
def equidistant_intervals(interval, words, *args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer (interval, words, language) as signature, or at least (interval, words, **kwargs).

"""Divides the interval into `count` equally sized intervals
"""
count = len(words)
Expand All @@ -57,7 +59,7 @@ def equidistant_intervals(interval, words):
]


def equidistant_points(interval, words):
def equidistant_points(interval, words, *args):
"""Places `count` points (intervals of size zero) in `interval` with equal distance"""
count = len(words)
if count == 0:
Expand All @@ -72,7 +74,7 @@ def equidistant_points(interval, words):
]


def character_based(interval, words):
def character_based(interval, words, *args):
"""Divides the interval into one interval per word where the size of the interval is
proportional to the word length in characters."""
if len(words) == 0:
Expand All @@ -93,7 +95,30 @@ def character_based(interval, words):
]


def character_based_points(interval, words):
def phoneme_based(interval, words, language):
"""Divides the interval into one interval per word where the size of the interval is
proportional to the number of phonemes in the word."""

g2p = transphone.read_tokenizer(language)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call transphone.read_tokenizer sounds, that it loads a model. Has it a caching?
If not, we should do a caching, since this function is called for every segment.

if len(words) == 0:
return []
elif len(words) == 1:
return [interval]
import numpy as np

word_lengths = np.asarray([len(g2p.tokenize(w)) for w in words])
end_points = np.cumsum(word_lengths)
total_num_characters = end_points[-1]
character_length = (interval[1] - interval[0]) / total_num_characters
return [
(
interval[0] + character_length * start,
interval[0] + character_length * end
)
for start, end in zip([0] + list(end_points[:-1]), end_points)
]

def character_based_points(interval, words, *args):
"""Places points in the center of the character-based intervals"""
intervals = character_based(interval, words)
intervals = [
Expand All @@ -102,13 +127,22 @@ def character_based_points(interval, words):
]
return intervals

def phoneme_based_points(interval, words, language):
"""Places points in the center of the phoneme-based intervals"""
intervals = phoneme_based(interval, words, language)
intervals = [
((interval[1] + interval[0]) / 2,) * 2
for interval in intervals
]
return intervals


def full_segment(interval, words):
def full_segment(interval, words, *args):
"""Outputs `interval` for each word"""
return [interval] * len(words)


def no_segmentation(interval, words):
def no_segmentation(interval, words, *args):
if len(words) != 1:
if len(words) > 1:
raise ValueError(
Expand All @@ -131,6 +165,8 @@ def no_segmentation(interval, words):
'equidistant_intervals': equidistant_intervals,
'equidistant_points': equidistant_points,
'full_segment': full_segment,
'phoneme_based': phoneme_based,
'phoneme_based_points': phoneme_based_points,
'character_based': character_based,
'character_based_points': character_based_points,
'none': no_segmentation,
Expand Down Expand Up @@ -585,8 +621,7 @@ def time_constrained_siso_word_error_rate(
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
reference_sort='segment',
hypothesis_sort='segment',
):
hypothesis_sort='segment'):
"""
Time-constrained word error rate for single-speaker transcripts.

Expand Down Expand Up @@ -669,6 +704,7 @@ def time_constrained_minimum_permutation_word_error_rate(
collar,
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
language=None,
reference_sort='segment',
hypothesis_sort='segment',
) -> CPErrorRate:
Expand Down Expand Up @@ -707,7 +743,7 @@ def time_constrained_minimum_permutation_word_error_rate(
collar=collar,
reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
segment_representation='word',
segment_representation='word', language=language
)

er = _minimum_permutation_word_error_rate(
Expand Down
39 changes: 39 additions & 0 deletions tests/test_time_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def test_time_constrained_sorting_options():
)
assert er.error_rate == 0

er = time_constrained_minimum_permutation_word_error_rate(
r1, r1, reference_sort='word',
reference_pseudo_word_level_timing='phoneme_based',
hypothesis_pseudo_word_level_timing='phoneme_based',
language="eng",
collar=0,
)
assert er.error_rate == 0

r1 = SegLST([
{'words': 'a b c d', 'start_time': 0, 'end_time': 4, 'speaker': 'A'},
{'words': 'e f g h', 'start_time': 2, 'end_time': 6, 'speaker': 'A'},
Expand All @@ -217,6 +226,16 @@ def test_time_constrained_sorting_options():
)
assert er.error_rate == 0.75

er = time_constrained_minimum_permutation_word_error_rate(
r1, r2, reference_sort='word',
reference_pseudo_word_level_timing='phoneme_based',
hypothesis_pseudo_word_level_timing='phoneme_based_points',
language="eng",
collar=0,
)
# reference will be: ['ʌ'], ['b', 'i'], ['s', 'i'], ['d', 'i'], ['i'], ['ɛ', 'f'], ['d͡ʒ', 'i'], ['e', 'j', 't͡ʃ']
assert er.error_rate == 0.625

er = time_constrained_minimum_permutation_word_error_rate(
r1, r2, reference_sort='segment',
collar=0,
Expand Down Expand Up @@ -255,6 +274,26 @@ def test_time_constrained_sorting_options():
)
assert er.error_rate == 1

# japanese testing with kanji char
# whitespace on characters
r1 = SegLST([
{'words': '\u4f11 \u65e5', 'start_time': 4, 'end_time': 8, 'speaker': 'A'}, # holiday
{'words': '\u4eca \u65e5', 'start_time': 0, 'end_time': 4, 'speaker': 'A'}, # today
])
r2 = SegLST([
{'words': '\u4f11 \u65e5 \u4eca \u65e5', 'start_time': 0, 'end_time': 8, 'speaker': 'A'},
])

er = time_constrained_minimum_permutation_word_error_rate(
r1, r2, reference_sort='word',
reference_pseudo_word_level_timing='phoneme_based',
hypothesis_pseudo_word_level_timing='phoneme_based_points',
language="jpn",
collar=5,
)
assert er.error_rate == 0.5



def test_examples_zero_self_overlap():
"""Tests that self-overlap is measured correctly (0) for the example files"""
Expand Down
Loading