From d0521a7dffe900cdb5ac1ab5b3aafd45ba7557cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 7 Mar 2024 11:16:22 -0500 Subject: [PATCH 01/69] Extending Lhotse dataloading to text/multimodal data (#1295) * WIP: extending lhotse to support non-Cut/audio data for sampling and dataloading * Tests showing how to use lhotse for text-only + audio-only dataloading * Add missing TextExample class * some fixes * num tokens attr for TextPairExample * fix property * Fix unit tests * Docs --- docs/datasets.rst | 70 ++++++++++ lhotse/custom.py | 130 +++++++++++++++++ lhotse/cut/data.py | 108 +------------- lhotse/cut/set.py | 18 ++- lhotse/cut/text.py | 38 +++++ lhotse/dataset/sampling/__init__.py | 10 ++ lhotse/dataset/sampling/base.py | 133 +++++++++++++++++- lhotse/dataset/sampling/dynamic.py | 54 +++++-- lhotse/dataset/sampling/dynamic_bucketing.py | 96 ++++++++----- lhotse/lazy.py | 73 ++++++++-- lhotse/supervision.py | 51 +------ test/dataset/sampling/test_text_sampling.py | 140 +++++++++++++++++++ test/test_lazy.py | 40 +++++- 13 files changed, 744 insertions(+), 217 deletions(-) create mode 100644 lhotse/custom.py create mode 100644 lhotse/cut/text.py create mode 100644 test/dataset/sampling/test_text_sampling.py diff --git a/docs/datasets.rst b/docs/datasets.rst index efcc6a47d..03df15ea1 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -141,6 +141,76 @@ However, many functions and classes in Lhotse accept either a random seed or an .. note:: The lazy seed resolution is done by calling :func:`lhotse.dataset.dataloading.resolve_seed`. +Customizing sampling constraints +-------------------------------- + +Since version 1.22.0, Lhotse provides a mechanism to customize how samplers measure the "length" of each example +for the purpose of determining dynamic batch size. To leverage this option, use the keyword argument ``constraint`` +in :class:`~lhotse.dataset.sampling.DynamicCutSampler` or :class:`~lhotse.dataset.sampling.DynamicBucketingSampler`. +The sampling criteria are defined by implementing a subclass of :class:`~lhotse.dataset.sampling.base.SamplingConstraint`: + +.. autoclass:: lhotse.dataset.sampling.base.SamplingConstraint + :members: + +The default constraint is :class:`~lhotse.dataset.sampling.base.TimeConstraint` which is created from +``max_duration``, ``max_cuts``, and ``quadratic_duration`` args passed to samplers constructor. + +Sampling non-audio data +*********************** + +Because :class:`~lhotse.dataset.sampling.base.SamplingConstraint` defines the method ``measure_length``, +it's possible to use a different attribute than duration (or a different formula) for computing the effective batch size. +This enables re-using Lhotse's sampling algorithms for other data than speech, and passing around other objects than :class:`~lhotse.cut.Cut`. + +To showcase this, we added an experimental support for text-only dataloading. We introduced a few classes specifically for this purpose: + +.. autoclass:: lhotse.cut.text.TextExample + :members: + +.. autoclass:: lhotse.cut.text.TextPairExample + :members: + +.. autoclass:: lhotse.lazy.LazyTxtIterator + :members: + +.. autoclass:: lhotse.dataset.sampling.base.TokenConstraint + :members: + +A minimal example of how to perform text-only dataloading is available below (note that any of these classes may be replaced by your own implementation if that is more suitable to your work):: + + import torch + import numpy as np + from lhotse import CutSet + from lhotse.lazy import LazyTxtIterator + from lhotse.cut.text import TextPairExample + from lhotse.dataset import DynamicBucketingSampler, TokenConstraint + from lhotse.dataset.collation import collate_vectors + + examples = CutSet(LazyTxtIterator("data.txt")) + + def tokenize(example): + # tokenize as individual bytes; BPE or another technique may be used here instead + example.tokens = np.frombuffer(example.text.encode("utf-8"), np.int8) + return example + + examples = examples.map(tokenize, apply_fn=None) + + sampler = DynamicBucketingSampler(examples, constraint=TokenConstraint(max_tokens=1024, quadratic_length=128), num_buckets=2) + + class ExampleTextDataset(torch.utils.data.Dataset): + def __getitem__(self, examples: CutSet): + tokens = [ex.tokens for ex in examples] + token_lens = torch.tensor([len(t) for t in tokens]) + tokens = collate_vectors(tokens, padding_value=-1) + return tokens, token_lens + + dloader = torch.utils.data.DataLoader(ExampleTextDataset(), sampler=sampler, batch_size=None) + + for batch in dloader: + print(batch) + +.. note:: Support for this kind of dataloading is experimental in Lhotse. If you run into any rough edges, please let us know. + Dataset's list -------------- diff --git a/lhotse/custom.py b/lhotse/custom.py new file mode 100644 index 000000000..132385a12 --- /dev/null +++ b/lhotse/custom.py @@ -0,0 +1,130 @@ +from functools import partial +from typing import Any, Dict, Optional + +import numpy as np + +from lhotse import Recording +from lhotse.utils import ifnone + + +class CustomFieldMixin: + """ + :class:`CustomFieldMixin` is intended for classes such as Cut or SupervisionSegment + that support holding custom, user-defined fields. + + .. caution:: Due to the way inheritance and dataclasses work before Python 3.10, + it is necessary to re-define ``custom`` attribute in dataclasses that + inherit from this mixin. + """ + + def __init__(self, custom: Optional[Dict[str, Any]]) -> None: + self.custom: Optional[Dict[str, Any]] = custom + + def __setattr__(self, key: str, value: Any) -> None: + """ + This magic function is called when the user tries to set an attribute. + We use it as syntactic sugar to store custom attributes in ``self.custom`` + field, so that they can be (de)serialized later. + Setting a ``None`` value will remove the attribute from ``custom``. + """ + if key in self.__dataclass_fields__: + super().__setattr__(key, value) + else: + custom = ifnone(self.custom, {}) + if value is None: + custom.pop(key, None) + else: + custom[key] = value + if custom: + self.custom = custom + + def __getattr__(self, name: str) -> Any: + """ + This magic function is called when the user tries to access an attribute + of :class:`.MonoCut` that doesn't exist. It is used for accessing the custom + attributes of cuts. + + We use it to look up the ``custom`` field: when it's None or empty, + we'll just raise AttributeError as usual. + If ``item`` is found in ``custom``, we'll return ``custom[item]``. + If ``item`` starts with "load_", we'll assume the name of the relevant + attribute comes after that, and that value of that field is of type + :class:`~lhotse.array.Array` or :class:`~lhotse.array.TemporalArray`. + We'll return its ``load`` method to call by the user. + + Example of attaching and reading an alignment as TemporalArray:: + + >>> cut = MonoCut('cut1', start=0, duration=4, channel=0) + >>> cut.alignment = TemporalArray(...) + >>> ali = cut.load_alignment() + + """ + custom = self.custom + if custom is None: + raise AttributeError(f"No such attribute: {name}") + if name in custom: + # Somebody accesses raw [Temporal]Array manifest + # or wrote a custom piece of metadata into MonoCut. + return self.custom[name] + elif name.startswith("load_"): + # Return the method for loading [Temporal]Arrays, + # to be invoked by the user. + attr_name = name[5:] + return partial(self.load_custom, attr_name) + raise AttributeError(f"No such attribute: {name}") + + def __delattr__(self, key: str) -> None: + """Used to support ``del cut.custom_attr`` syntax.""" + if key in self.__dataclass_fields__: + super().__delattr__(key) + if self.custom is None or key not in self.custom: + raise AttributeError(f"No such member: '{key}'") + del self.custom[key] + + def load_custom(self, name: str) -> np.ndarray: + """ + Load custom data as numpy array. The custom data is expected to have + been stored in cuts ``custom`` field as an :class:`~lhotse.array.Array` or + :class:`~lhotse.array.TemporalArray` manifest. + + .. note:: It works with Array manifests stored via attribute assignments, + e.g.: ``cut.my_custom_data = Array(...)``. + + :param name: name of the custom attribute. + :return: a numpy array with the data. + """ + from lhotse.array import Array, TemporalArray + + value = self.custom.get(name) + if isinstance(value, Array): + # Array does not support slicing. + return value.load() + elif isinstance(value, TemporalArray): + # TemporalArray supports slicing. + return value.load(start=self.start, duration=self.duration) + elif isinstance(value, Recording): + # Recording supports slicing. Note: we will not slice the channels + # as cut.channels referes to cut.recording and not the custom field. + return value.load_audio(offset=self.start, duration=self.duration) + else: + raise ValueError( + f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) " + f"defined, and its value has to be a manifest of type Array or TemporalArray." + ) + + def has_custom(self, name: str) -> bool: + """ + Check if the Cut has a custom attribute with name ``name``. + + :param name: name of the custom attribute. + :return: a boolean. + """ + if self.custom is None: + return False + return name in self.custom + + def drop_custom(self, name: str): + if self.custom is None or name not in self.custom: + return None + del self.custom[name] + return self diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index 43401b46c..438805eb3 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -2,7 +2,6 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field from decimal import ROUND_DOWN -from functools import partial from math import isclose from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -12,6 +11,7 @@ from lhotse.audio import Recording, VideoInfo from lhotse.augmentation import AugmentFn +from lhotse.custom import CustomFieldMixin from lhotse.cut.base import Cut from lhotse.features import FeatureExtractor, Features from lhotse.features.io import FeaturesWriter @@ -25,7 +25,6 @@ compute_num_frames, compute_num_samples, fastcopy, - ifnone, measure_overlap, overlaps, overspans, @@ -36,7 +35,7 @@ @dataclass -class DataCut(Cut, metaclass=ABCMeta): +class DataCut(Cut, CustomFieldMixin, metaclass=ABCMeta): """ :class:`~lhotse.cut.DataCut` is a base class for cuts that point to actual audio data. It can be either a :class:`~lhotse.cut.MonoCut` or a :class:`~lhotse.cut.MultiCut`. @@ -71,109 +70,6 @@ class DataCut(Cut, metaclass=ABCMeta): # Store anything else the user might want. custom: Optional[Dict[str, Any]] = None - def __setattr__(self, key: str, value: Any) -> None: - """ - This magic function is called when the user tries to set an attribute. - We use it as syntactic sugar to store custom attributes in ``self.custom`` - field, so that they can be (de)serialized later. - Setting a ``None`` value will remove the attribute from ``custom``. - """ - if key in self.__dataclass_fields__: - super().__setattr__(key, value) - else: - custom = ifnone(self.custom, {}) - if value is None: - custom.pop(key, None) - else: - custom[key] = value - if custom: - self.custom = custom - - def __getattr__(self, name: str) -> Any: - """ - This magic function is called when the user tries to access an attribute - of :class:`.MonoCut` that doesn't exist. It is used for accessing the custom - attributes of cuts. - - We use it to look up the ``custom`` field: when it's None or empty, - we'll just raise AttributeError as usual. - If ``item`` is found in ``custom``, we'll return ``custom[item]``. - If ``item`` starts with "load_", we'll assume the name of the relevant - attribute comes after that, and that value of that field is of type - :class:`~lhotse.array.Array` or :class:`~lhotse.array.TemporalArray`. - We'll return its ``load`` method to call by the user. - - Example of attaching and reading an alignment as TemporalArray:: - - >>> cut = MonoCut('cut1', start=0, duration=4, channel=0) - >>> cut.alignment = TemporalArray(...) - >>> ali = cut.load_alignment() - - """ - custom = self.custom - if custom is None: - raise AttributeError(f"No such attribute: {name}") - if name in custom: - # Somebody accesses raw [Temporal]Array manifest - # or wrote a custom piece of metadata into MonoCut. - return self.custom[name] - elif name.startswith("load_"): - # Return the method for loading [Temporal]Arrays, - # to be invoked by the user. - attr_name = name[5:] - return partial(self.load_custom, attr_name) - raise AttributeError(f"No such attribute: {name}") - - def __delattr__(self, key: str) -> None: - """Used to support ``del cut.custom_attr`` syntax.""" - if key in self.__dataclass_fields__: - super().__delattr__(key) - if self.custom is None or key not in self.custom: - raise AttributeError(f"No such member: '{key}'") - del self.custom[key] - - def load_custom(self, name: str) -> np.ndarray: - """ - Load custom data as numpy array. The custom data is expected to have - been stored in cuts ``custom`` field as an :class:`~lhotse.array.Array` or - :class:`~lhotse.array.TemporalArray` manifest. - - .. note:: It works with Array manifests stored via attribute assignments, - e.g.: ``cut.my_custom_data = Array(...)``. - - :param name: name of the custom attribute. - :return: a numpy array with the data. - """ - from lhotse.array import Array, TemporalArray - - value = self.custom.get(name) - if isinstance(value, Array): - # Array does not support slicing. - return value.load() - elif isinstance(value, TemporalArray): - # TemporalArray supports slicing. - return value.load(start=self.start, duration=self.duration) - elif isinstance(value, Recording): - # Recording supports slicing. Note: we will not slice the channels - # as cut.channels referes to cut.recording and not the custom field. - return value.load_audio(offset=self.start, duration=self.duration) - else: - raise ValueError( - f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) " - f"defined, and its value has to be a manifest of type Array or TemporalArray." - ) - - def has_custom(self, name: str) -> bool: - """ - Check if the Cut has a custom attribute with name ``name``. - - :param name: name of the custom attribute. - :return: a boolean. - """ - if self.custom is None: - return False - return name in self.custom - @property def recording_id(self) -> str: return self.recording.id if self.has_recording else self.features.recording_id diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index d4df42e2e..ef082900e 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -45,11 +45,13 @@ from lhotse.lazy import ( AlgorithmMixin, Dillable, + LazyFilter, LazyFlattener, LazyIteratorChain, LazyManifestIterator, LazyMapper, LazySlicer, + T, ) from lhotse.serialization import Serializable from lhotse.supervision import SupervisionSegment, SupervisionSet @@ -73,6 +75,10 @@ FW = TypeVar("FW", bound=FeaturesWriter) +def is_cut(example) -> bool: + return isinstance(example, Cut) + + class CutSet(Serializable, AlgorithmMixin): """ :class:`~lhotse.cut.CutSet` represents a collection of cuts. @@ -937,6 +943,16 @@ def subset( # Restore the requested cut_ids order. return cuts.sort_like(cut_ids) + def map( + self, + transform_fn: Callable[[T], T], + apply_fn: Optional[Callable[[T], bool]] = is_cut, + ) -> "CutSet": + ans = CutSet(LazyMapper(self.data, fn=transform_fn, apply_fn=apply_fn)) + if self.is_lazy: + return ans + return ans.to_eager() + def filter_supervisions( self, predicate: Callable[[SupervisionSegment], bool] ) -> "CutSet": @@ -3471,7 +3487,7 @@ def __iter__(self): for cut in self.source: # Check whether we're going to mix something into the current cut # or pass it through unchanged. - if rng.uniform(0.0, 1.0) > self.mix_prob: + if not is_cut(cut) or rng.uniform(0.0, 1.0) > self.mix_prob: yield cut continue to_mix = next(mix_in_cuts) diff --git a/lhotse/cut/text.py b/lhotse/cut/text.py new file mode 100644 index 000000000..5fa55e337 --- /dev/null +++ b/lhotse/cut/text.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import numpy as np + +from lhotse.custom import CustomFieldMixin + + +@dataclass +class TextExample(CustomFieldMixin): + """ + Represents a single text example. Useful e.g. for language modeling. + """ + + text: str + tokens: Optional[np.ndarray] = None + custom: Optional[Dict[str, Any]] = None + + @property + def num_tokens(self) -> Optional[int]: + if self.tokens is None: + return None + return len(self.tokens) + + +@dataclass +class TextPairExample(CustomFieldMixin): + """ + Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks. + """ + + source: TextExample + target: TextExample + custom: Optional[Dict[str, Any]] = None + + @property + def num_tokens(self) -> Optional[int]: + return self.source.num_tokens diff --git a/lhotse/dataset/sampling/__init__.py b/lhotse/dataset/sampling/__init__.py index 31cdc1bfd..2644d7fe7 100644 --- a/lhotse/dataset/sampling/__init__.py +++ b/lhotse/dataset/sampling/__init__.py @@ -1,3 +1,9 @@ +from .base import ( + SamplingConstraint, + SamplingDiagnostics, + TimeConstraint, + TokenConstraint, +) from .bucketing import BucketingSampler from .cut_pairs import CutPairsSampler from .dynamic import DynamicCutSampler @@ -9,6 +15,10 @@ from .zip import ZipSampler __all__ = [ + "TokenConstraint", + "TimeConstraint", + "SamplingDiagnostics", + "SamplingConstraint", "BucketingSampler", "CutPairsSampler", "DynamicCutSampler", diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index e839f1c97..ca43b36fc 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -1,5 +1,7 @@ +import copy import os import warnings +from abc import ABCMeta, abstractmethod from copy import deepcopy from dataclasses import asdict, dataclass from math import isclose @@ -10,6 +12,7 @@ from torch.utils.data import Sampler from lhotse.cut import Cut, CutSet +from lhotse.cut.text import TextExample from lhotse.lazy import Dillable from lhotse.manipulation import combine from lhotse.utils import Seconds, ifnone, is_none_or_gt @@ -366,8 +369,51 @@ def attach_dataloading_info(cuts: CutSet, rank: int, world_size: int) -> None: cut.dataloading_info = info +class SamplingConstraint(metaclass=ABCMeta): + """ + Defines the interface for sampling constraints. A sampling constraint + keeps track of the sampled examples and lets the sampler know when it + should yield a mini-batch. + """ + + @abstractmethod + def add(self, example: Any) -> None: + """ + Update the sampling constraint with the information about the sampled example + (e.g. current batch size, total duration). + """ + pass + + @abstractmethod + def exceeded(self) -> bool: + """Inform if the sampling constraint has been exceeded.""" + pass + + @abstractmethod + def close_to_exceeding(self) -> bool: + """Inform if we're going to exceed the sampling constraint after adding one more example.""" + pass + + @abstractmethod + def reset(self) -> None: + """Resets the internal state (called after yielding a mini-batch).""" + pass + + @abstractmethod + def measure_length(self, example: Any) -> float: + """ + Returns the "size" of an example, used to create bucket distribution for bucketing samplers + (e.g., for audio it may be duration; for text it may be number of tokens; etc.). + """ + pass + + def copy(self) -> "SamplingConstraint": + """Return a shallow copy of this constraint.""" + return copy.copy(self) + + @dataclass -class TimeConstraint: +class TimeConstraint(SamplingConstraint): """ Represents a time-based constraint for sampler classes. It is defined as maximum total batch duration (in seconds) and/or the total number of cuts. @@ -402,13 +448,13 @@ def is_active(self) -> bool: """Is it an actual constraint, or a dummy one (i.e. never exceeded).""" return self.max_duration is not None or self.max_cuts is not None - def add(self, cut: Cut) -> None: + def add(self, example: Cut) -> None: """ Increment the internal counter for the time constraint, selecting the right property from the input ``cut`` object. """ if self.max_duration is not None: - duration = self._maybe_apply_quadratic_correction(cut.duration) + duration = self._maybe_apply_quadratic_correction(example.duration) self.current += duration self.longest_seen = max(self.longest_seen, duration) self.num_cuts += 1 @@ -454,6 +500,9 @@ def reset(self) -> None: self.num_cuts = 0 self.longest_seen = 0 + def measure_length(self, example: Cut) -> float: + return example.duration + def state_dict(self) -> Dict[str, Any]: return asdict(self) @@ -499,6 +548,84 @@ def __eq__(self, other: "TimeConstraint") -> bool: ) +@dataclass +class TokenConstraint(SamplingConstraint): + """ + Represents a token-based constraint for sampler classes that sample text data. + It is defined as maximum total number of tokens in a mini-batch and/or max batch size. + + Similarly to :class:`TimeConstraint`, we support ``quadratic_length`` for quadratic + token penalty when sampling longer texts. + """ + + max_tokens: int = None + max_examples: int = None + current: int = 0 + num_examples: int = 0 + longest_seen: int = 0 + quadratic_length: Optional[int] = None + + def __post_init__(self) -> None: + assert is_none_or_gt(self.max_tokens, 0) + assert is_none_or_gt(self.max_examples, 0) + assert is_none_or_gt(self.quadratic_length, 0) + + def add(self, example: TextExample) -> None: + """ + Increment the internal token counter for the constraint, + selecting the right property from the input object. + """ + if self.max_tokens is not None: + size = self._maybe_apply_quadratic_correction(self.measure_length(example)) + self.current += size + self.longest_seen = max(self.longest_seen, size) + self.num_examples += 1 + + def _maybe_apply_quadratic_correction(self, size: int) -> int: + if self.quadratic_length is None: + return size + # For the quadratic complexity case, we add a term that accounts for + # extra memory occupied by the model. The 1/quadratic_length term causes + # the effective length to be doubled when it's equal to quadratic_length. + return size + (size**2) / self.quadratic_length + + def exceeded(self) -> bool: + """Is the constraint exceeded or not.""" + if self.max_examples is not None and self.num_examples > self.max_examples: + return True + if self.max_tokens is None: + return False + effective_duration = self.num_examples * self.longest_seen + return effective_duration > self.max_tokens + + def close_to_exceeding(self) -> bool: + """ + Check if the batch is close to satisfying the constraints. + We define "closeness" as: if we added one more cut that has + duration/num_frames/num_samples equal to the longest seen cut + in the current batch, then the batch would have exceeded the constraints. + """ + if self.max_examples is not None and self.num_examples >= self.max_examples: + return True + + if self.max_tokens is not None: + effective_size = (self.num_examples + 1) * self.longest_seen + return effective_size > self.max_tokens + return False + + def reset(self) -> None: + """ + Reset the internal counter (to be used after a batch was created, + to start collecting a new one). + """ + self.current = 0 + self.num_examples = 0 + self.longest_seen = 0 + + def measure_length(self, example: TextExample) -> float: + return example.num_tokens + + @dataclass class EpochDiagnostics: epoch: int = 0 diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index f0cff289a..a42726d4a 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -20,6 +20,7 @@ from lhotse.dataset.sampling.base import ( CutSampler, EpochDiagnostics, + SamplingConstraint, SamplingDiagnostics, TimeConstraint, ) @@ -69,9 +70,10 @@ class DynamicCutSampler(CutSampler): def __init__( self, - *cuts: Iterable[Cut], + *cuts: Iterable, max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, + constraint: Optional[SamplingConstraint] = None, shuffle: bool = False, drop_last: bool = False, consistent_ids: bool = True, @@ -88,6 +90,12 @@ def __init__( Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. + :param constraint: Provide a :class:`~lhotse.dataset.sampling.base.SamplingConstraint` object + defining how the sampler decides when a mini-batch is complete. It also affects which + attribute of the input examples decides the "size" of the example (by default it's ``.duration``). + Before this parameter was introduced, Lhotse samplers used + :class:`~lhotse.dataset.sampling.base.TimeConstraint` implicitly. + Introduced in Lhotse v1.22.0. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: @@ -121,14 +129,12 @@ def __init__( self.cuts = cuts self.max_duration = max_duration self.max_cuts = max_cuts + self.constraint = constraint self.shuffle = shuffle self.consistent_ids = consistent_ids self.shuffle_buffer_size = shuffle_buffer_size self.quadratic_duration = quadratic_duration self.rng = None - assert any( - v is not None for v in (self.max_duration, self.max_cuts) - ), "At least one of max_duration or max_cuts has to be set." if strict is not None: warnings.warn( @@ -138,6 +144,9 @@ def __init__( ) def state_dict(self) -> Dict[str, Any]: + assert ( + self.constraint is None + ), "state_dict() is not supported with samplers that use a custom constraint." sd = super().state_dict() sd.update( { @@ -212,6 +221,7 @@ def __iter__(self) -> "DynamicCutSampler": self.cuts_iter, max_duration=self.max_duration, max_cuts=self.max_cuts, + constraint=self.constraint, drop_last=self.drop_last, quadratic_duration=self.quadratic_duration, diagnostics=self.diagnostics, @@ -251,6 +261,7 @@ def __init__( datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_duration: Seconds = None, max_cuts: Optional[int] = None, + constraint: Optional[SamplingConstraint] = None, drop_last: bool = False, quadratic_duration: Optional[Seconds] = None, diagnostics: Optional[SamplingDiagnostics] = None, @@ -259,11 +270,15 @@ def __init__( self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) - self.time_constraint = TimeConstraint( - max_duration=max_duration, - max_cuts=max_cuts, - quadratic_duration=quadratic_duration, - ) + check_constraint(constraint, max_duration, max_cuts) + if constraint is not None: + self.constraint = constraint + else: + self.constraint = TimeConstraint( + max_duration=max_duration, + max_cuts=max_cuts, + quadratic_duration=quadratic_duration, + ) def __iter__(self) -> Generator[Union[CutSet, Tuple[CutSet]], None, None]: self.cuts_iter = iter(self.datapipe) @@ -289,7 +304,7 @@ def detuplify( else: return CutSet.from_cuts(cuts) - self.time_constraint.reset() + self.constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. @@ -301,7 +316,7 @@ def detuplify( # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and ( - not self.drop_last or self.time_constraint.close_to_exceeding() + not self.drop_last or self.constraint.close_to_exceeding() ): # We have a partial batch and we can return it. return detuplify(cuts) @@ -316,16 +331,16 @@ def detuplify( # Track the duration/frames/etc. constraints. cuts.append(next_cut_or_tpl) - self.time_constraint.add( + self.constraint.add( next_cut_or_tpl[0] if isinstance(next_cut_or_tpl, tuple) else next_cut_or_tpl ) # Did we exceed the max_frames and max_cuts constraints? - if self.time_constraint.close_to_exceeding(): + if self.constraint.close_to_exceeding(): # Yes. Finish sampling this batch. - if self.time_constraint.exceeded() and len(cuts) == 1: + if self.constraint.exceeded() and len(cuts) == 1: warnings.warn( "We have exceeded the max_duration constraint during sampling but have only 1 cut. " "This is likely because max_duration was set to a very low value ~10s, " @@ -366,3 +381,14 @@ def __iter__(self) -> Iterable: yield item else: self.diagnostics.discard(item) + + +def check_constraint(constraint: Optional, max_duration: Optional, max_cuts: Optional): + if constraint is not None: + assert ( + max_duration is None and max_cuts is None + ), "Cannot specify both constraint= and max_duration=/max_cuts=" + else: + assert ( + max_duration is not None or max_cuts is not None + ), "At least one of max_duration= or max_cuts= has to be defined (or provide constraint=)." diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index caf205664..7dcbfd8d8 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -25,10 +25,11 @@ from lhotse.dataset.sampling.base import ( CutSampler, EpochDiagnostics, + SamplingConstraint, SamplingDiagnostics, TimeConstraint, ) -from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter +from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter, check_constraint from lhotse.utils import ifnone @@ -75,9 +76,10 @@ class DynamicBucketingSampler(CutSampler): def __init__( self, - *cuts: Iterable[Cut], - max_duration: Seconds, + *cuts: Iterable, + max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, + constraint: Optional[SamplingConstraint] = None, num_buckets: Optional[int] = 10, shuffle: bool = False, drop_last: bool = False, @@ -139,15 +141,14 @@ def __init__( self.cuts = cuts self.max_duration = max_duration self.max_cuts = max_cuts + self.constraint = constraint self.shuffle = shuffle self.consistent_ids = consistent_ids self.num_cuts_for_bins_estimate = num_cuts_for_bins_estimate self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration self.rng = None - assert any( - v is not None for v in (self.max_duration, self.max_cuts) - ), "At least one of max_duration or max_cuts has to be set." + check_constraint(constraint, max_duration, max_cuts) if strict is not None: warnings.warn( @@ -171,12 +172,22 @@ def __init__( ), "Duration bins must be sorted ascendingly." self.duration_bins = duration_bins else: + if constraint is None: + constraint = TimeConstraint( + max_duration=self.max_duration, + max_cuts=self.max_cuts, + quadratic_duration=self.quadratic_duration, + ) self.duration_bins = estimate_duration_buckets( islice(self.cuts[0], num_cuts_for_bins_estimate), num_buckets=num_buckets, + constraint=constraint, ) def state_dict(self) -> Dict[str, Any]: + assert ( + self.constraint is None + ), "state_dict() is not supported with samplers that use a custom constraint." sd = super().state_dict() sd.update( { @@ -246,6 +257,7 @@ def __iter__(self) -> "DynamicBucketingSampler": duration_bins=self.duration_bins, max_duration=self.max_duration, max_cuts=self.max_cuts, + constraint=self.constraint, drop_last=self.drop_last, buffer_size=self.buffer_size, quadratic_duration=self.quadratic_duration, @@ -281,7 +293,11 @@ def num_cuts(self) -> Optional[int]: return None -def estimate_duration_buckets(cuts: Iterable[Cut], num_buckets: int) -> List[Seconds]: +def estimate_duration_buckets( + cuts: Iterable[Cut], + num_buckets: int, + constraint: Optional[SamplingConstraint] = None, +) -> List[float]: """ Given an iterable of cuts and a desired number of buckets, select duration values that should start each bucket. @@ -293,25 +309,30 @@ def estimate_duration_buckets(cuts: Iterable[Cut], num_buckets: int) -> List[Sec :param cuts: an iterable of :class:`lhotse.cut.Cut`. :param num_buckets: desired number of buckets. + :param constraint: object with ``.measure_length()`` method that's used to determine + the size of each sample. If ``None``, we'll use ``TimeConstraint``. :return: a list of boundary duration values (floats). """ assert num_buckets > 1 - durs = np.array([c.duration for c in cuts]) - durs.sort() - assert num_buckets <= durs.shape[0], ( + if constraint is None: + constraint = TimeConstraint() + + sizes = np.array([constraint.measure_length(c) for c in cuts]) + sizes.sort() + assert num_buckets <= sizes.shape[0], ( f"The number of buckets ({num_buckets}) must be smaller than " - f"or equal to the number of cuts ({durs.shape[0]})." + f"or equal to the number of cuts ({sizes.shape[0]})." ) - bucket_duration = durs.sum() / num_buckets + size_per_bucket = sizes.sum() / num_buckets bins = [] tot = 0.0 - for dur in durs: - if tot > bucket_duration: - bins.append(dur) + for size in sizes: + if tot > size_per_bucket: + bins.append(size) tot = 0.0 - tot += dur + tot += size return bins @@ -321,8 +342,9 @@ def __init__( self, cuts: Iterable[Union[Cut, Tuple[Cut]]], duration_bins: List[Seconds], - max_duration: float, + max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, + constraint: Optional[SamplingConstraint] = None, drop_last: bool = False, buffer_size: int = 10000, quadratic_duration: Optional[Seconds] = None, @@ -334,6 +356,7 @@ def __init__( self.duration_bins = duration_bins self.max_duration = max_duration self.max_cuts = max_cuts + self.constraint = constraint self.drop_last = drop_last self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration @@ -347,16 +370,27 @@ def __init__( f"Argument list for 'duration_bins' is expected to be in " f"sorted order (got: {duration_bins})." ) + check_constraint(constraint, max_duration, max_cuts) + + if self.constraint is None: + self.constraint = TimeConstraint( + max_duration=self.max_duration, + max_cuts=self.max_cuts, + quadratic_duration=self.quadratic_duration, + ) # A heuristic diagnostic first, for finding the right settings. - mean_duration = np.mean(duration_bins) - expected_buffer_duration = buffer_size * mean_duration - expected_bucket_duration = expected_buffer_duration / (len(duration_bins) + 1) - if expected_bucket_duration < max_duration: - warnings.warn( - f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy " - f"a 'max_duration' of {max_duration} (given our best guess)." + if max_duration is not None: + mean_duration = np.mean(duration_bins) + expected_buffer_duration = buffer_size * mean_duration + expected_bucket_duration = expected_buffer_duration / ( + len(duration_bins) + 1 ) + if expected_bucket_duration < max_duration: + warnings.warn( + f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy " + f"a 'max_duration' of {max_duration} (given our best guess)." + ) # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [ @@ -370,11 +404,7 @@ def __iter__(self) -> Generator[CutSet, None, None]: # Init: determine which buckets are "ready" def is_ready(bucket: Deque[Cut]): - tot = TimeConstraint( - max_duration=self.max_duration, - max_cuts=self.max_cuts, - quadratic_duration=self.quadratic_duration, - ) + tot = self.constraint.copy() for c in bucket: tot.add(c[0] if isinstance(c, tuple) else c) if tot.close_to_exceeding(): @@ -408,9 +438,7 @@ def is_ready(bucket: Deque[Cut]): # Sample one batch from that bucket and yield it to the caller. batcher = DurationBatcher( maybe_shuffled, - max_duration=self.max_duration, - max_cuts=self.max_cuts, - quadratic_duration=self.quadratic_duration, + constraint=self.constraint.copy(), diagnostics=self.diagnostics, ) batch = next(iter(batcher)) @@ -441,8 +469,8 @@ def _collect_cuts_in_buckets(self, n_cuts: int): try: for _ in range(n_cuts): cuts = next(self.cuts_iter) - duration = ( - cuts[0].duration if isinstance(cuts, tuple) else cuts.duration + duration = self.constraint.measure_length( + cuts[0] if isinstance(cuts, tuple) else cuts ) bucket_idx = bisect_right(self.duration_bins, duration) self.buckets[bucket_idx].append(cuts) diff --git a/lhotse/lazy.py b/lhotse/lazy.py index 549784d1e..ed01aa7d1 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -56,11 +56,10 @@ def map(self, transform_fn: Callable[[T], T]): :return: a new ``CutSet`` with transformed cuts. """ cls = type(self) - + ans = cls(LazyMapper(self.data, fn=transform_fn)) if self.is_lazy: - return cls(LazyMapper(self.data, fn=transform_fn)) - - return cls.from_items(transform_fn(item) for item in self) + return ans + return ans.to_eager() @classmethod def mux( @@ -241,6 +240,38 @@ def dill_enabled(value: bool): set_dill_enabled(previous) +class LazyTxtIterator: + """ + LazyTxtIterator is a thin wrapper over builtin ``open`` function to + iterate over lines in a (possibly compressed) text file. + It can also provide the number of lines via __len__ via fast newlines counting. + """ + + def __init__(self, path: Pathlike, as_text_example: bool = True) -> None: + self.path = path + self.as_text_example = as_text_example + self._len = None + + def __iter__(self): + from lhotse.cut.text import TextExample + + tot = 0 + with open_best(self.path, "r") as f: + for line in f: + line = line.strip() + if self.as_text_example: + line = TextExample(line) + yield line + tot += 1 + if self._len is None: + self._len = tot + + def __len__(self) -> int: + if self._len is None: + self._len = count_newlines_fast(self.path) + return self._len + + class LazyJsonlIterator: """ LazyJsonlIterator provides the ability to read JSON lines as Python dicts. @@ -622,16 +653,32 @@ class LazyMapper(Dillable): A wrapper over an iterable that enables lazy function evaluation on each item. It works like Python's `map` built-in by applying a callable ``fn`` to each element ``x`` and yielding the result of ``fn(x)`` further. + + New in Lhotse v1.22.0: ``apply_fn`` can be provided to decide whether ``fn`` should be applied + to a given example or not (in which case it will return it as-is, i.e., it does not filter). """ - def __init__(self, iterator: Iterable, fn: Callable[[Any], Any]) -> None: + def __init__( + self, + iterator: Iterable, + fn: Callable[[Any], Any], + apply_fn: Optional[Callable[[Any], bool]] = None, + ) -> None: self.iterator = iterator self.fn = fn + self.apply_fn = apply_fn assert callable(self.fn), f"LazyMapper: 'fn' arg must be callable (got {fn})." + if self.apply_fn is not None: + assert callable( + self.apply_fn + ), f"LazyMapper: 'apply_fn' arg must be callable (got {fn})." if ( - isinstance(self.fn, types.LambdaType) - and self.fn.__name__ == "" - and not is_module_available("dill") + (isinstance(self.fn, types.LambdaType) and self.fn.__name__ == "") + or ( + isinstance(self.apply_fn, types.LambdaType) + and self.apply_fn.__name__ == "" + ) + and not is_dill_enabled() ): warnings.warn( "A lambda was passed to LazyMapper: it may prevent you from forking this process. " @@ -640,7 +687,15 @@ def __init__(self, iterator: Iterable, fn: Callable[[Any], Any]) -> None: ) def __iter__(self): - return map(self.fn, self.iterator) + if self.apply_fn is None: + yield from map(self.fn, self.iterator) + else: + for item in self.iterator: + if self.apply_fn(item): + ans = self.fn(item) + else: + ans = item + yield ans def __len__(self) -> int: return len(self.iterator) diff --git a/lhotse/supervision.py b/lhotse/supervision.py index a905cfcab..5208593ab 100644 --- a/lhotse/supervision.py +++ b/lhotse/supervision.py @@ -16,6 +16,7 @@ from tqdm import tqdm +from lhotse.custom import CustomFieldMixin from lhotse.lazy import AlgorithmMixin from lhotse.serialization import Serializable from lhotse.utils import ( @@ -117,7 +118,7 @@ def transform(self, transform_fn: Callable[[str], str]) -> "AlignmentItem": @dataclass -class SupervisionSegment: +class SupervisionSegment(CustomFieldMixin): """ :class:`~lhotse.supervsion.SupervisionSegment` represents a time interval (segment) annotated with some supervision labels and/or metadata, such as the transcription, the speaker identity, the language, etc. @@ -452,54 +453,6 @@ def from_dict(data: dict) -> "SupervisionSegment": return SupervisionSegment(**data) - def __setattr__(self, key: str, value: Any) -> None: - """ - This magic function is called when the user tries to set an attribute. - We use it as syntactic sugar to store custom attributes in ``self.custom`` - field, so that they can be (de)serialized later. - """ - if key in self.__dataclass_fields__: - super().__setattr__(key, value) - else: - custom = ifnone(self.custom, {}) - if value is None: - custom.pop(key, None) - else: - custom[key] = value - if custom: - self.custom = custom - - def __getattr__(self, name: str) -> Any: - """ - This magic function is called when the user tries to access an attribute - of :class:`.SupervisionSegment` that doesn't exist. - It is used as syntactic sugar for accessing the custom supervision attributes. - - We use it to look up the ``custom`` field: when it's None or empty, - we'll just raise AttributeError as usual. - If ``item`` is found in ``custom``, we'll return ``self.custom[item]``. - - Example of adding custom metadata and retrieving it as an attribute:: - - >>> sup = SupervisionSegment('utt1', recording_id='rec1', start=0, - ... duration=1, channel=0, text='Yummy.') - >>> sup.gps_coordinates = "34.1021097,-79.1553182" - >>> coordinates = sup.gps_coordinates - - """ - try: - return self.custom[name] - except: - raise AttributeError(f"No such attribute: {name}") - - def __delattr__(self, key: str) -> None: - """Used to support ``del supervision.custom_attr`` syntax.""" - if key in self.__dataclass_fields__: - super().__delattr__(key) - if self.custom is None or key not in self.custom: - raise AttributeError(f"No such member: '{key}'") - del self.custom[key] - class SupervisionSet(Serializable, AlgorithmMixin): """ diff --git a/test/dataset/sampling/test_text_sampling.py b/test/dataset/sampling/test_text_sampling.py new file mode 100644 index 000000000..7d1749f06 --- /dev/null +++ b/test/dataset/sampling/test_text_sampling.py @@ -0,0 +1,140 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch + +from lhotse import CutSet, Fbank, compute_num_frames +from lhotse.cut import Cut +from lhotse.cut.text import TextExample +from lhotse.dataset import DynamicBucketingSampler, DynamicCutSampler +from lhotse.dataset.collation import collate_audio +from lhotse.dataset.sampling.base import TokenConstraint +from lhotse.testing.dummies import DummyManifest + + +@pytest.fixture +def text_source(): + def get_text_source(): + while True: + for item in ("hello world", "example text", "this is my text data"): + # for this example, "bytes are all you need", could be BPE, etc. + yield TextExample(item, np.frombuffer(item.encode("utf-8"), np.int8)) + + return get_text_source() + + +def test_text_dynamic_cut_sampler_static_batch_size(text_source): + sampler = DynamicCutSampler( + text_source, constraint=TokenConstraint(max_examples=16) + ) + batch = next(iter(sampler)) + assert len(batch) == 16 + assert isinstance(batch[0], TextExample) + assert isinstance(batch[0].text, str) + + +def test_text_dynamic_cut_sampler_dynamic_batch_size(text_source): + sampler = DynamicCutSampler(text_source, constraint=TokenConstraint(max_tokens=256)) + batch = next(iter(sampler)) + assert isinstance(batch[0], TextExample) + assert isinstance(batch[0].text, str) + assert len(batch) == 12 + + +def test_text_dynamic_bucketing_sampler(text_source): + sampler = DynamicBucketingSampler( + text_source, + num_buckets=2, + constraint=TokenConstraint(max_tokens=256, quadratic_length=128), + ) + batch = next(iter(sampler)) + assert isinstance(batch[0], TextExample) + assert isinstance(batch[0].text, str) + assert len(batch) == 11 + + +class TextDataset(torch.utils.data.Dataset): + def __getitem__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]: + from lhotse.dataset.collation import collate_vectors + + tokens = collate_vectors( + [item.tokens.astype(np.int32) for item in cuts], padding_value=-1 + ) + token_lens = torch.LongTensor([item.tokens.shape[0] for item in cuts]) + return tokens, token_lens + + +def test_text_dataloader_with_dynamic_bucketing_sampler(text_source): + sampler = DynamicBucketingSampler( + text_source, + num_buckets=2, + constraint=TokenConstraint(max_tokens=256, quadratic_length=128), + ) + dloader = torch.utils.data.DataLoader( + TextDataset(), sampler=sampler, batch_size=None + ) + batch = next(iter(dloader)) + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape == (11, 20) # (batch_size, seq_len) + assert isinstance(batch[1], torch.Tensor) + assert batch[1].shape == (11,) + + +class MixedAudioTextDataset(torch.utils.data.Dataset): + def __init__(self): + self.text_dataset = TextDataset() + + def __getitem__( + self, cuts: CutSet + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + text_cuts = cuts.filter(lambda c: isinstance(c, TextExample)).to_eager() + if text_cuts: + tokens, token_lens = self.text_dataset[text_cuts] + else: + tokens, token_lens = None, None + + audio_cuts = cuts.filter(lambda c: isinstance(c, Cut)).to_eager() + if audio_cuts: + audio, audio_lens = collate_audio(audio_cuts) + else: + audio, audio_lens = None, None + + return tokens, token_lens, audio, audio_lens + + +def _assign_num_tokens_to_cut(cut, frame_shift=0.01): + cut.num_tokens = compute_num_frames( + cut.duration, frame_shift=frame_shift, sampling_rate=cut.sampling_rate + ) + return cut + + +@pytest.fixture +def audio_source(): + return ( + DummyManifest(CutSet, begin_id=0, end_id=10, with_data=True) + .map(_assign_num_tokens_to_cut) + .repeat() + ) + + +def test_audio_and_text_dataloader_with_dynamic_sampler(text_source, audio_source): + mixed = CutSet.mux(text_source, audio_source, weights=[0.7, 0.3]) + sampler = DynamicCutSampler( + mixed, + constraint=TokenConstraint(max_tokens=1024, quadratic_length=128), + ) + dloader = torch.utils.data.DataLoader( + MixedAudioTextDataset(), sampler=sampler, batch_size=None + ) + batch = next(iter(dloader)) + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape == (3, 20) # (batch_size, seq_len) + assert isinstance(batch[1], torch.Tensor) + assert batch[1].shape == (3,) + assert isinstance(batch[2], torch.Tensor) + assert batch[2].shape == (2, 16000) # (batch_size, seq_len) + assert isinstance(batch[3], torch.Tensor) + assert batch[3].shape == (2,) diff --git a/test/test_lazy.py b/test/test_lazy.py index 636612caf..336227310 100644 --- a/test/test_lazy.py +++ b/test/test_lazy.py @@ -5,11 +5,13 @@ """ import random from concurrent.futures import ProcessPoolExecutor +from pathlib import Path import pytest from lhotse import CutSet, FeatureSet, RecordingSet, SupervisionSet, combine -from lhotse.lazy import LazyJsonlIterator +from lhotse.cut.text import TextExample +from lhotse.lazy import LazyJsonlIterator, LazyTxtIterator from lhotse.testing.dummies import DummyManifest, as_lazy from lhotse.testing.fixtures import with_dill_enabled from lhotse.utils import fastcopy, is_module_available @@ -265,3 +267,39 @@ def test_lazy_jsonl_iterator_caches_len(): assert it._len is not None assert it._len == expected_len assert len(it) == expected_len + + +def test_lazy_txt_iterator(tmp_path: Path): + txt = tmp_path / "test.txt" + txt.write_text("a\nb\nc\n") + + it = LazyTxtIterator(txt) + + # Supports len + assert len(it) == 3 + + # Can be iterated, strips newlines + texts = [t for t in it] + assert texts == [TextExample("a"), TextExample("b"), TextExample("c")] + + # Can be iterated again + texts = [t for t in it] + assert texts == [TextExample("a"), TextExample("b"), TextExample("c")] + + +def test_lazy_txt_iterator_raw_text(tmp_path: Path): + txt = tmp_path / "test.txt" + txt.write_text("a\nb\nc\n") + + it = LazyTxtIterator(txt, as_text_example=False) + + # Supports len + assert len(it) == 3 + + # Can be iterated, strips newlines + texts = [t for t in it] + assert texts == ["a", "b", "c"] + + # Can be iterated again + texts = [t for t in it] + assert texts == ["a", "b", "c"] From 77543c3405df7505e80ceae546ca971744d04c84 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 7 Mar 2024 23:18:05 +0700 Subject: [PATCH 02/69] Add new recipe: speechio (#1297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add new recipe: speechio Co-authored-by: Piotr Żelasko --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/speechio.py | 22 +++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/speechio.py | 138 +++++++++++++++++++++++++++ 5 files changed, 164 insertions(+) create mode 100644 lhotse/bin/modes/recipes/speechio.py create mode 100644 lhotse/recipes/speechio.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 6909affbc..9735df37f 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -167,6 +167,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_rir_noise` * - Speech Commands - :func:`lhotse.recipes.prepare_speechcommands` + * - SpeechIO + - :func:`lhotse.recipes.prepare_speechio` * - SPGISpeech - :func:`lhotse.recipes.prepare_spgispeech` * - Switchboard diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index 054cd606d..8d14d5b0a 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -64,6 +64,7 @@ from .rir_noise import * from .slu import * from .speechcommands import * +from .speechio import * from .spgispeech import * from .stcmds import * from .switchboard import * diff --git a/lhotse/bin/modes/recipes/speechio.py b/lhotse/bin/modes/recipes/speechio.py new file mode 100644 index 000000000..aa7706fcd --- /dev/null +++ b/lhotse/bin/modes/recipes/speechio.py @@ -0,0 +1,22 @@ +from typing import Dict, List, Optional, Tuple, Union + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.speechio import prepare_speechio +from lhotse.utils import Pathlike + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +def speechio( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, +): + """SpeechIO data preparation. See https://github.com/SpeechColab/Leaderboard""" + prepare_speechio( + corpus_dir=corpus_dir, + output_dir=output_dir, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 03f568893..272ef594d 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -65,6 +65,7 @@ from .rir_noise import download_rir_noise, prepare_rir_noise from .slu import prepare_slu from .speechcommands import download_speechcommands, prepare_speechcommands +from .speechio import prepare_speechio from .spgispeech import download_spgispeech, prepare_spgispeech from .stcmds import download_stcmds, prepare_stcmds from .switchboard import prepare_switchboard diff --git a/lhotse/recipes/speechio.py b/lhotse/recipes/speechio.py new file mode 100644 index 000000000..d53d9abce --- /dev/null +++ b/lhotse/recipes/speechio.py @@ -0,0 +1,138 @@ +""" +The SpeechIO Chinese data is a collection of test sets covering wide range of speech recognition tasks & scenarios. + +Participants can obtain the datasets at https://github.com/SpeechColab/Leaderboard - please download the datasets manually. +""" + +import logging +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from tqdm.auto import tqdm + +from lhotse.audio import AudioSource, Recording, RecordingSet +from lhotse.qa import fix_manifests, validate_recordings_and_supervisions +from lhotse.recipes.utils import manifests_exist +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, is_module_available + +SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. + + +def _parse_one_subset( + corpus_dir: Pathlike, +) -> Optional[Tuple[Recording, SupervisionSegment]]: + recordings = [] + segments = [] + + if not is_module_available("pandas"): + raise ValueError("To prepare speechio data, please 'pip install pandas' first.") + import pandas as pd + + df = pd.read_csv(f"{str(corpus_dir)}/metadata.tsv", sep="\t") + + recording_ids = df["ID"].tolist() + texts = df["TEXT"].tolist() + wav_paths = df["AUDIO"].tolist() + + for idx, audio_path in enumerate(wav_paths): + audio_path = str(corpus_dir / audio_path) + if not os.path.exists(audio_path): + logging.warning(f"Audio file {audio_path} does not exist - skipping.") + continue + recording = Recording.from_file(audio_path) + recordings.append(recording) + recording_id = recording_ids[idx] + text = texts[idx] + speaker = recording_id.split("_")[0] + + segment = SupervisionSegment( + id=f"{corpus_dir}-{recording_id}", + recording_id=recording_id, + start=0, + duration=recording.duration, + channel=0, + language="Chinese", + speaker=speaker, + text=text, + ) + segments.append(segment) + + return recordings, segments + + +def _prepare_subset( + subset: str, + corpus_dir: Pathlike, +) -> Tuple[RecordingSet, SupervisionSet]: + """ + Returns the RecodingSet and SupervisionSet given a dataset part. + :param subset: str, the name of the subset. + :param corpus_dir: Pathlike, the path of the data dir. + :return: the RecodingSet and SupervisionSet for train and valid. + """ + corpus_dir = Path(corpus_dir) + part_path = corpus_dir / subset + + recording_set, supervision_set = _parse_one_subset(part_path) + recording_set = RecordingSet.from_recordings(recording_set) + supervision_set = SupervisionSet.from_segments(supervision_set) + + # Fix manifests + recording_set, supervision_set = fix_manifests(recording_set, supervision_set) + validate_recordings_and_supervisions(recording_set, supervision_set) + + return recording_set, supervision_set + + +def prepare_speechio( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions + :param corpus_dir: Path to the SpeechIO dataset. + :param output_dir: Pathlike, the path where to write the manifests. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'recordings' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + logging.info("Preparing SpeechIO...") + + subsets = [] + for i in range(SPEECHIO_TESTSET_INDEX + 1): + idx = f"{i}".zfill(2) + subsets.append(f"SPEECHIO_ASR_ZH000{idx}") + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifests = defaultdict(dict) + + for part in tqdm(subsets, desc="Dataset parts"): + logging.info(f"Processing SpeechIO subset: {part}") + if manifests_exist( + part=part, + output_dir=output_dir, + prefix=f"speechio", + suffix="jsonl.gz", + ): + logging.info(f"SpeechIO subset: {part} already prepared - skipping.") + continue + + recording_set, supervision_set = _prepare_subset(part, corpus_dir) + + if output_dir is not None: + supervision_set.to_file( + output_dir / f"speechio_supervisions_{part}.jsonl.gz" + ) + recording_set.to_file(output_dir / f"speechio_recordings_{part}.jsonl.gz") + + manifests[part] = {"recordings": recording_set, "supervisions": supervision_set} + + return manifests From 0a9a532860cd7b2aa7203266f9be165e05f0c19f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Mar 2024 00:18:49 +0800 Subject: [PATCH 03/69] tedlium2 recipe (#1296) * init commit * updated imports * Update corpus.rst --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/tedlium2.py | 60 ++++++++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/tedlium2.py | 170 +++++++++++++++++++++++++++ 5 files changed, 234 insertions(+) create mode 100644 lhotse/bin/modes/recipes/tedlium2.py create mode 100644 lhotse/recipes/tedlium2.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 9735df37f..bd5284b56 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -173,6 +173,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_spgispeech` * - Switchboard - :func:`lhotse.recipes.prepare_switchboard` + * - TED-LIUM v2 + - :func:`lhotse.recipes.prepare_tedlium2` * - TED-LIUM v3 - :func:`lhotse.recipes.prepare_tedlium` * - TIMIT diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index 8d14d5b0a..a208a25c6 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -71,6 +71,7 @@ from .tal_asr import * from .tal_csasr import * from .tedlium import * +from .tedlium2 import * from .thchs_30 import * from .this_american_life import * from .timit import * diff --git a/lhotse/bin/modes/recipes/tedlium2.py b/lhotse/bin/modes/recipes/tedlium2.py new file mode 100644 index 000000000..a50e8c266 --- /dev/null +++ b/lhotse/bin/modes/recipes/tedlium2.py @@ -0,0 +1,60 @@ +from typing import List + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.tedlium2 import TEDLIUM_PARTS, download_tedlium2, prepare_tedlium2 +from lhotse.utils import Pathlike + + +@prepare.command() +@click.argument( + "tedlium_dir", type=click.Path(exists=True, dir_okay=True, file_okay=False) +) +@click.argument("output_dir", type=click.Path()) +@click.option( + "--parts", + "-p", + type=click.Choice(TEDLIUM_PARTS), + multiple=True, + default=list(TEDLIUM_PARTS), + help=f"Which parts of TED-LIUM v2 to prepare (by default all, i.e., {TEDLIUM_PARTS}).", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +@click.option( + "--normalize-text", + type=click.Choice(["none", "upper", "kaldi"], case_sensitive=False), + default="none", + help="Type of text normalization to apply (no normalization, by default). " + "Selecting `kaldi` will remove tokens and join suffixes.", +) +def tedlium2( + tedlium_dir: Pathlike, + output_dir: Pathlike, + parts: List[str], + num_jobs: int, + normalize_text: str, +): + """ + TED-LIUM v2 recording and supervision manifest preparation. + """ + prepare_tedlium2( + tedlium_root=tedlium_dir, + output_dir=output_dir, + dataset_parts=parts, + num_jobs=num_jobs, + normalize_text=normalize_text, + ) + + +@download.command() +@click.argument("target_dir", type=click.Path()) +def tedlium(target_dir: Pathlike): + """TED-LIUM v2 download (approx. 35GB).""" + download_tedlium2(target_dir) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 272ef594d..3b4c34ac9 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -70,6 +70,7 @@ from .stcmds import download_stcmds, prepare_stcmds from .switchboard import prepare_switchboard from .tedlium import download_tedlium, prepare_tedlium +from .tedlium2 import download_tedlium2, prepare_tedlium2 from .thchs_30 import download_thchs_30, prepare_thchs_30 from .this_american_life import download_this_american_life, prepare_this_american_life from .timit import download_timit, prepare_timit diff --git a/lhotse/recipes/tedlium2.py b/lhotse/recipes/tedlium2.py new file mode 100644 index 000000000..81c06f749 --- /dev/null +++ b/lhotse/recipes/tedlium2.py @@ -0,0 +1,170 @@ +""" +The following are the original TED-LIUM 2 README contents. + +This is theTED-LIUM corpus release 2, English speech recognition training corpus from TED talks, created by Laboratoire d’Informatique de l’Université du Maine (LIUM) (mirrored here) +licensed under Creative Commons BY-NC-ND 3.0 (http://creativecommons.org/licenses/by-nc-nd/3.0/deed.en). + +All talks and text are property of TED Conferences LLC. + +--- + +The TED-LIUM corpus was made from audio talks and their transcriptions available on the TED website. We have prepared and filtered these data in order to train acoustic models to participate to the International Workshop on Spoken Language Translation 2011 (the LIUM English/French SLT system reached the first rank in the SLT task). + +More details are given in this paper: + +A. Rousseau, P. Deléglise, and Y. Estève, "Enhancing the TED-LIUM Corpus with Selected Data for Language Modeling and More TED Talks", +in Proceedings of the Ninth International Conference on Language Resources and Evaluation (LREC’14), May 2014. + + +Please cite this reference if you use these data in your research work. + +--- + +Contents: + +- 1495 audio talks in NIST sphere format (SPH) +- 1495 transcripts in STM format +- Dictionary with pronunciation (159848 entries) +- Selected monolingual data for language modeling from WMT12 publicly available corpora + + +SPH format info: + +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Bit Rate : 256k +Sample Encoding : 16-bit Signed Integer PCM + +""" + +import logging +import shutil +import tarfile +from concurrent.futures.thread import ThreadPoolExecutor +from functools import partial +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +from lhotse import ( + RecordingSet, + SupervisionSegment, + SupervisionSet, + validate_recordings_and_supervisions, +) +from lhotse.qa import fix_manifests +from lhotse.recipes.utils import normalize_text_tedlium +from lhotse.utils import Pathlike, resumable_download, safe_extract + +TEDLIUM_PARTS = ("train", "dev", "test") + + +def download_tedlium2( + target_dir: Pathlike = ".", force_download: Optional[bool] = False +) -> Path: + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + tar_path = target_dir / "TEDLIUM_release2.tar.gz" + corpus_dir = target_dir / "TEDLIUM_release2" + completed_detector = corpus_dir / ".completed" + if completed_detector.is_file(): + logging.info(f"Skipping {tar_path.name} because {completed_detector} exists.") + return corpus_dir + resumable_download( + "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz", + filename=tar_path, + force_download=force_download, + ) + shutil.rmtree(corpus_dir, ignore_errors=True) + with tarfile.open(tar_path) as tar: + safe_extract(tar, path=target_dir) + completed_detector.touch() + return corpus_dir + + +def prepare_tedlium2( + tedlium_root: Pathlike, + output_dir: Optional[Pathlike] = None, + dataset_parts: Union[str, Sequence[str]] = TEDLIUM_PARTS, + num_jobs: int = 1, + normalize_text: str = "none", +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Prepare manifests for the TED-LIUM v2 corpus. + + The manifests are created in a dict with three splits: train, dev and test. + Each split contains a RecordingSet and SupervisionSet in a dict under keys 'recordings' and 'supervisions'. + + :param tedlium_root: Path to the unpacked TED-LIUM data. + :param output_dir: Path where the manifests should be written. + :param dataset_parts: Which parts of the dataset to prepare. + By default, all parts are prepared. + :param num_jobs: Number of parallel jobs to use. + :return: A dict with standard corpus splits containing the manifests. + """ + tedlium_root = Path(tedlium_root) + output_dir = Path(output_dir) if output_dir is not None else None + corpus = {} + + dataset_parts = [dataset_parts] if isinstance(dataset_parts, str) else dataset_parts + + with ThreadPoolExecutor(num_jobs) as ex: + for split in dataset_parts: + logging.info(f"Processing {split} split...") + root = tedlium_root / split + recordings = RecordingSet.from_dir( + root / "sph", pattern="*.sph", num_jobs=num_jobs + ) + stms = list((root / "stm").glob("*.stm")) + assert len(stms) == len(recordings), ( + f"Mismatch: found {len(recordings)} " + f"sphere files and {len(stms)} STM files. " + f"You might be missing some parts of TEDLIUM..." + ) + futures = [] + _parse_stm_worker = partial(_parse_stm_file, normalize_text=normalize_text) + for stm in stms: + futures.append(ex.submit(_parse_stm_worker, stm)) + + segments = [] + for future in futures: + segments.extend(future.result()) + + supervisions = SupervisionSet.from_segments(segments) + recordings, supervisions = fix_manifests(recordings, supervisions) + + corpus[split] = {"recordings": recordings, "supervisions": supervisions} + validate_recordings_and_supervisions(**corpus[split]) + + if output_dir is not None: + recordings.to_file(output_dir / f"tedlium2_recordings_{split}.jsonl.gz") + supervisions.to_file( + output_dir / f"tedlium2_supervisions_{split}.jsonl.gz" + ) + + return corpus + + +def _parse_stm_file(stm: str, normalize_text: str = "none") -> SupervisionSegment: + """Helper function to parse a single STM file.""" + segments = [] + with stm.open() as f: + for idx, l in enumerate(f): + rec_id, _, _, start, end, _, *words = l.split() + start, end = float(start), float(end) + text = " ".join(words).replace("{NOISE}", "[NOISE]") + if text == "ignore_time_segment_in_scoring": + continue + segments.append( + SupervisionSegment( + id=f"{rec_id}-{idx}", + recording_id=rec_id, + start=start, + duration=round(end - start, ndigits=8), + channel=0, + text=normalize_text_tedlium(text, normalize_text), + language="English", + speaker=rec_id, + ) + ) + return segments From 7cc8fb4f2a674b15874bf3e08129a7cc8e2434ad Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 7 Mar 2024 23:20:33 +0700 Subject: [PATCH 04/69] fix whisper for multi-channel data (#1289) * fix whisper for multi-channel data * update warning for multi-channel data --- lhotse/features/whisper_fbank.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) mode change 100644 => 100755 lhotse/features/whisper_fbank.py diff --git a/lhotse/features/whisper_fbank.py b/lhotse/features/whisper_fbank.py old mode 100644 new mode 100755 index df69071ff..c42f79621 --- a/lhotse/features/whisper_fbank.py +++ b/lhotse/features/whisper_fbank.py @@ -55,7 +55,16 @@ def log_mel_spectrogram( if device is not None: audio = audio.to(device) - audio = audio.squeeze(0) + + if len(audio.shape) == 2: + if audio.shape[0] > 1: + raise ValueError("Whisper Fbank works only with single-channel recordings.") + else: + audio = audio[0] + assert ( + len(audio.shape) == 1 + ), f"Whisper Fbank works only with single-channel recordings (shape: {audio.shape})" + window = torch.hann_window(n_fft).to(audio.device) stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 From b34e805156679ce97a2e564dda1188c0e84c6221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 7 Mar 2024 13:17:44 -0500 Subject: [PATCH 05/69] Channel selection for multi-channel custom recording fields (#1299) * Channel selection for multi-channel custom recording fields * fix * fix exporting multicuts to shar --- lhotse/custom.py | 21 ++++++++-- lhotse/cut/multi.py | 35 ++++++++++++++++- lhotse/cut/set.py | 73 ++++++++++++++++++++++------------- lhotse/shar/writers/shar.py | 23 +++++++++++ lhotse/testing/dummies.py | 3 ++ test/cut/test_custom_attrs.py | 40 +++++++++++++++++++ test/cut/test_cut_truncate.py | 9 ++++- 7 files changed, 171 insertions(+), 33 deletions(-) diff --git a/lhotse/custom.py b/lhotse/custom.py index 132385a12..0feeb5fc9 100644 --- a/lhotse/custom.py +++ b/lhotse/custom.py @@ -4,7 +4,7 @@ import numpy as np from lhotse import Recording -from lhotse.utils import ifnone +from lhotse.utils import fastcopy, ifnone class CustomFieldMixin: @@ -81,6 +81,14 @@ def __delattr__(self, key: str) -> None: raise AttributeError(f"No such member: '{key}'") del self.custom[key] + def with_custom(self, name: str, value: Any): + """Return a copy of this object with an extra custom field assigned to it.""" + cpy = fastcopy( + self, custom=self.custom.copy() if self.custom is not None else {} + ) + cpy.custom[name] = value + return cpy + def load_custom(self, name: str) -> np.ndarray: """ Load custom data as numpy array. The custom data is expected to have @@ -103,9 +111,14 @@ def load_custom(self, name: str) -> np.ndarray: # TemporalArray supports slicing. return value.load(start=self.start, duration=self.duration) elif isinstance(value, Recording): - # Recording supports slicing. Note: we will not slice the channels - # as cut.channels referes to cut.recording and not the custom field. - return value.load_audio(offset=self.start, duration=self.duration) + # Recording supports slicing. + # Note: cut.channels referes to cut.recording and not the custom field. + # We have to use a special channel selector field instead; e.g.: + # if this is "target_recording", we'll look for "target_recording_channel_selector" + channels = self.custom.get(f"{name}_channel_selector") + return value.load_audio( + channels=channels, offset=self.start, duration=self.duration + ) else: raise ValueError( f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) " diff --git a/lhotse/cut/multi.py b/lhotse/cut/multi.py index 45557b811..439f9d217 100644 --- a/lhotse/cut/multi.py +++ b/lhotse/cut/multi.py @@ -4,7 +4,7 @@ from functools import partial, reduce from itertools import groupby from operator import add -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -365,6 +365,39 @@ def merge_supervisions( return fastcopy(self, supervisions=msups) + def with_channels(self, channels: Union[List[int], int]) -> DataCut: + """ + Select specified channels from this cut. + Supports extending to other channels available in the underlying :class:`Recording`. + If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`, + otherwise we'll return a :class:`~lhotse.cut.MultiCut'. + """ + mono = isinstance(channels, int) or len(channels) == 1 + assert set([channels] if mono else channels).issubset( + set(self.recording.channel_ids) + ), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}" + + if mono: + from .mono import MonoCut + + if isinstance(channels, Sequence): + (channels,) = channels + return MonoCut( + id=f"{self.id}-{channels}", + recording=self.recording, + start=self.start, + duration=self.duration, + channel=channels, + supervisions=[ + fastcopy(s, channel=channels) + for s in self.supervisions + if is_equal_or_contains(s.channel, channels) + ], + custom=self.custom, + ) + + return fastcopy(self, channel=channels) + @staticmethod def from_mono(*cuts: DataCut) -> "MultiCut": """ diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index ef082900e..652b8942d 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -1394,34 +1394,21 @@ def truncate( :param rng: optional random number generator to be used with a 'random' ``offset_type``. :return: a new CutSet instance with truncated cuts. """ - truncated_cuts = [] - for cut in self: - if cut.duration <= max_duration: - truncated_cuts.append(cut) - continue - - def compute_offset(): - if offset_type == "start": - return 0.0 - last_offset = cut.duration - max_duration - if offset_type == "end": - return last_offset - if offset_type == "random": - if rng is None: - return random.uniform(0.0, last_offset) - else: - return rng.uniform(0.0, last_offset) - raise ValueError(f"Unknown 'offset_type' option: {offset_type}") - - truncated_cuts.append( - cut.truncate( - offset=compute_offset(), - duration=max_duration, - keep_excessive_supervisions=keep_excessive_supervisions, - preserve_id=preserve_id, - ) + assert offset_type in ( + "start", + "end", + "random", + ), f"Unknown offset type: '{offset_type}'" + return self.map( + partial( + _truncate_single, + max_duration=max_duration, + offset_type=offset_type, + keep_excessive_supervisions=keep_excessive_supervisions, + preserve_id=preserve_id, + rng=rng, ) - return CutSet(truncated_cuts) + ) def extend_by( self, @@ -3368,6 +3355,38 @@ def _drop_supervisions(cut, *args, **kwargs): return cut.drop_supervisions(*args, **kwargs) +def _truncate_single( + cut: Cut, + max_duration: Seconds, + offset_type: str, + keep_excessive_supervisions: bool = True, + preserve_id: bool = False, + rng: Optional[random.Random] = None, +) -> Cut: + if cut.duration <= max_duration: + return cut + + def compute_offset(): + if offset_type == "start": + return 0.0 + last_offset = cut.duration - max_duration + if offset_type == "end": + return last_offset + if offset_type == "random": + if rng is None: + return random.uniform(0.0, last_offset) + else: + return rng.uniform(0.0, last_offset) + raise ValueError(f"Unknown 'offset_type' option: {offset_type}") + + return cut.truncate( + offset=compute_offset(), + duration=max_duration, + keep_excessive_supervisions=keep_excessive_supervisions, + preserve_id=preserve_id, + ) + + def _export_to_shar_single( cuts: CutSet, output_dir: Pathlike, diff --git a/lhotse/shar/writers/shar.py b/lhotse/shar/writers/shar.py index 5ec5ac050..229a073b6 100644 --- a/lhotse/shar/writers/shar.py +++ b/lhotse/shar/writers/shar.py @@ -128,6 +128,12 @@ def write(self, cut: Cut) -> None: if cut.has_recording: data = cut.load_audio() recording = to_shar_placeholder(cut.recording, cut) + cut_channels = _aslist(cut.channel) + if recording.channel_ids != cut_channels: + # If recording is multi-channel but the cut refers to a subset of them, + # we have to update the recording manifest accordingly + recording.sources[0].channels = cut_channels + recording.channel_ids = cut_channels self.writers["recording"].write( cut.id, data, cut.sampling_rate, manifest=recording ) @@ -171,13 +177,24 @@ def write(self, cut: Cut) -> None: else: data = cut.load_custom(key) placeholder_obj = to_shar_placeholder(val, cut) + channel_selector_key = f"{key}_channel_selector" kwargs = {} if isinstance(val, Recording): kwargs["sampling_rate"] = val.sampling_rate + if cut.has_custom(channel_selector_key): + # override custom recording channels since the audio was loaded via cut + # and used the channel selector + placeholder_obj.sources[0].channels = cut.custom[ + channel_selector_key + ] + placeholder_obj.channel_ids = cut.custom[ + channel_selector_key + ] self.writers[key].write( cut.id, data, manifest=placeholder_obj, **kwargs ) cut = fastcopy(cut, custom=cut.custom.copy()) + cut.custom.pop(channel_selector_key, None) # no longer needed setattr(cut, key, placeholder_obj) else: self.writers[key].write_placeholder(cut.id) @@ -224,3 +241,9 @@ def _create_cuts_output_url(base_output_url: str, shard_suffix: str) -> str: base_output_url = base_output_url.replace("pipe:", "pipe:gzip -c | ") return f"{base_output_url}/cuts{shard_suffix}.jsonl.gz" + + +def _aslist(x): + if isinstance(x, list): + return x + return [x] diff --git a/lhotse/testing/dummies.py b/lhotse/testing/dummies.py index 7963b74d3..aec6a7581 100644 --- a/lhotse/testing/dummies.py +++ b/lhotse/testing/dummies.py @@ -99,6 +99,9 @@ def dummy_audio_source( data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples)) if len(channels) > 1: data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1) + # ensure each channel has different data for channel selection testing + mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)]) + data = data * mults binary_data = BytesIO() soundfile.write( binary_data, diff --git a/test/cut/test_custom_attrs.py b/test/cut/test_custom_attrs.py index 8706fd6f6..693dfa4eb 100644 --- a/test/cut/test_custom_attrs.py +++ b/test/cut/test_custom_attrs.py @@ -18,6 +18,7 @@ from lhotse.serialization import deserialize_item from lhotse.testing.dummies import ( dummy_cut, + dummy_multi_channel_recording, dummy_multi_cut, dummy_recording, dummy_supervision, @@ -401,3 +402,42 @@ def test_del_attr_mono_cut(cut): with pytest.raises(AttributeError): del cut.extra_metadata assert "extra_metadata" not in cut.custom + + +def test_multi_cut_custom_multi_recording_channel_selector(): + cut = dummy_multi_cut(0, channel=[0, 1, 2, 3], with_data=True) + cut.target_recording = dummy_multi_channel_recording( + 1, channel_ids=[0, 1, 2, 3], with_data=True + ) + + # All input channels + ref_audio = cut.load_audio() + assert ref_audio.shape == (4, 16000) + + # Input channel selection + two_channel_in = cut.with_channels([0, 1]) + audio = two_channel_in.load_audio() + assert audio.shape == (2, 16000) + np.testing.assert_allclose(ref_audio[:2], audio) + + # Input channel selection, different channels + two_channel_in = cut.with_channels([0, 3]) + audio = two_channel_in.load_audio() + assert audio.shape == (2, 16000) + np.testing.assert_allclose(ref_audio[::3], audio) + + # All output channels + ref_tgt_audio = cut.load_target_recording() + assert ref_tgt_audio.shape == (4, 16000) + + # Output channel selection + two_channel_out = cut.with_custom("target_recording_channel_selector", [0, 1]) + audio = two_channel_out.load_target_recording() + assert audio.shape == (2, 16000) + np.testing.assert_allclose(ref_tgt_audio[:2], audio) + + # Output channel selection, different channels + two_channel_out = cut.with_custom("target_recording_channel_selector", [0, 3]) + audio = two_channel_out.load_target_recording() + assert audio.shape == (2, 16000) + np.testing.assert_allclose(ref_tgt_audio[::3], audio) diff --git a/test/cut/test_cut_truncate.py b/test/cut/test_cut_truncate.py index e9651866b..9b17e6cef 100644 --- a/test/cut/test_cut_truncate.py +++ b/test/cut/test_cut_truncate.py @@ -7,7 +7,7 @@ from lhotse.cut import CutSet, MixedCut, MixTrack, MonoCut, PaddingCut from lhotse.features import Features from lhotse.supervision import SupervisionSegment, SupervisionSet -from lhotse.testing.dummies import DummyManifest, dummy_cut, dummy_recording +from lhotse.testing.dummies import DummyManifest, as_lazy, dummy_cut, dummy_recording from lhotse.testing.random import deterministic_rng @@ -238,6 +238,13 @@ def test_truncate_mixed_cut_gap_or_padding(gapped_mixed_cut, offset): assert audio is not None +def test_truncate_cut_set_lazy_result(cut_set): + with as_lazy(cut_set, ".jsonl") as lazy_cuts: + truncated_cut_set = lazy_cuts.truncate(max_duration=5, offset_type="start") + assert truncated_cut_set.is_lazy + assert all(c.duration == pytest.approx(5.0) for c in truncated_cut_set) + + def test_truncate_cut_set_offset_start(cut_set): truncated_cut_set = cut_set.truncate(max_duration=5, offset_type="start") cut1, cut2 = truncated_cut_set From 52b626bae710a049b7466c341784fe24c4e964d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 7 Mar 2024 14:10:24 -0500 Subject: [PATCH 06/69] Xfail flaky SileroVAD tests (#1300) --- test/workflows/test_activity_detection.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/workflows/test_activity_detection.py b/test/workflows/test_activity_detection.py index cc0768cb6..6f0b0f444 100644 --- a/test/workflows/test_activity_detection.py +++ b/test/workflows/test_activity_detection.py @@ -24,6 +24,9 @@ def _check_torch_version(greater_than: str): return False +@pytest.mark.xfail( + reason="This test fails too often in the CI on downloading the SileroVAD model." +) def test_silero_vad_init(): if not _check_torch_version("1.12"): pytest.skip("torch >= 1.12 is required for this test") @@ -39,6 +42,9 @@ def test_silero_vad_init(): assert activity[0].start + activity[0].duration < recording.duration +@pytest.mark.xfail( + reason="This test fails too often in the CI on downloading the SileroVAD model." +) def test_silero_vad_in_parallel(): if not _check_torch_version("1.12"): pytest.skip("torch >= 1.12 is required for this test") @@ -66,6 +72,9 @@ def temporary_directory(): yield temp_dir +@pytest.mark.xfail( + reason="This test fails too often in the CI on downloading the SileroVAD model." +) def test_silero_vad_workflow_simple(temporary_directory: str): if not _check_torch_version("1.12"): pytest.skip("torch >= 1.12 is required for this test") From d26d4763e405e3c861363d6b74fab46b9c599cbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 7 Mar 2024 14:29:04 -0500 Subject: [PATCH 07/69] Bump dev version to 1.23.0 (#1301) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 57807d6d0..a6c2798a4 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.22.0 +1.23.0 From 067392511142b8da7bbae81dc7518ac0c32a64b6 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 9 Mar 2024 04:13:20 +0800 Subject: [PATCH 08/69] MDCC recipe (#1302) * init commit * updated * files renamed * files renamed * Update mdcc.py * misc. fix * misc. updates * lowercase the function names to align with other recipes --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/mdcc.py | 51 ++++++++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/mdcc.py | 138 +++++++++++++++++++++++++++ 5 files changed, 193 insertions(+) create mode 100644 lhotse/bin/modes/recipes/mdcc.py create mode 100644 lhotse/recipes/mdcc.py diff --git a/docs/corpus.rst b/docs/corpus.rst index bd5284b56..70393cc38 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -145,6 +145,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_librittsr` * - LJ Speech - :func:`lhotse.recipes.prepare_ljspeech` + * - MDCC + - :func:`lhotse.recipes.prepare_mdcc` * - Medical - :func:`lhotse.recipes.prepare_medical` * - MiniLibriMix diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index a208a25c6..3dbc31c14 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -52,6 +52,7 @@ from .libritts import * from .ljspeech import * from .magicdata import * +from .mdcc import * from .medical import * from .mgb2 import * from .mls import * diff --git a/lhotse/bin/modes/recipes/mdcc.py b/lhotse/bin/modes/recipes/mdcc.py new file mode 100644 index 000000000..2302ae7bf --- /dev/null +++ b/lhotse/bin/modes/recipes/mdcc.py @@ -0,0 +1,51 @@ +from typing import Optional, Sequence + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.mdcc import download_mdcc, prepare_mdcc +from lhotse.utils import Pathlike + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-p", + "--dataset-parts", + type=str, + default=["all"], + multiple=True, + help="List of dataset parts to prepare. To prepare multiple parts, pass each with `-p` " + "Example: `-p train -p valid`", +) +def MDCC( + corpus_dir: Pathlike, + dataset_parts: Sequence[str], + output_dir: Optional[Pathlike] = None, +): + """MDCC data preparation.""" + prepare_mdcc( + corpus_dir=corpus_dir, + dataset_parts=dataset_parts, + output_dir=output_dir, + ) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +@click.option( + "--force-download", + is_flag=True, + default=False, + help="if True, it will download the MDCC data even if it is already present.", +) +def MDCC( + target_dir: Pathlike, + force_download: Optional[bool] = False, +): + """MDCC download.""" + download_mdcc( + target_dir=target_dir, + force_download=force_download, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 3b4c34ac9..583585f46 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -54,6 +54,7 @@ ) from .ljspeech import download_ljspeech, prepare_ljspeech from .magicdata import download_magicdata, prepare_magicdata +from .mdcc import download_mdcc, prepare_mdcc from .medical import download_medical, prepare_medical from .mgb2 import prepare_mgb2 from .mls import prepare_mls diff --git a/lhotse/recipes/mdcc.py b/lhotse/recipes/mdcc.py new file mode 100644 index 000000000..a8d16e36f --- /dev/null +++ b/lhotse/recipes/mdcc.py @@ -0,0 +1,138 @@ +""" +Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with +transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy, +politics, education, culture, lifestyle and family domains, covering a wide range of topics. + +Manuscript can be found at: https://arxiv.org/abs/2201.02419 +""" + +import logging +import zipfile +from pathlib import Path +from typing import Dict, Sequence, Union + +from tqdm.auto import tqdm + +from lhotse import validate_recordings_and_supervisions +from lhotse.audio import Recording, RecordingSet +from lhotse.qa import fix_manifests +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, is_module_available + +MDCC_URL = "https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view" + +MDCC_PARTS = ["train", "valid", "test"] + + +def download_mdcc(target_dir: Pathlike, force_download: bool = False) -> Path: + """ + Downloads the MDCC data from the Google Drive and extracts it. + :param target_dir: the directory where MDCC data will be saved. + :param force_download: if True, it will download the MDCC data even if it is already present. + :return: the path to downloaded and extracted directory with data. + """ + if not is_module_available("gdown"): + raise ValueError("Please run 'pip install gdown' to download MDCC.") + + import gdown + + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + corpus_dir = target_dir / "dataset" + corpus_zip = corpus_dir.with_suffix(".zip") + + if not force_download and corpus_zip.exists(): + logging.info(f"{corpus_zip} already exists. Skipping download.") + else: + logging.info(f"Running: gdown --fuzzy {MDCC_URL}") + gdown.download(MDCC_URL, str(corpus_zip), fuzzy=True, quiet=False) + + # Extract the zipped file + if not corpus_dir.exists() or force_download: + logging.info(f"Extracting {corpus_zip} to {target_dir}") + with zipfile.ZipFile(corpus_zip) as zf: + zf.extractall(path=target_dir) + + return corpus_dir + + +def prepare_mdcc( + corpus_dir: Pathlike, + dataset_parts: Union[str, Sequence[str]] = "all", + output_dir: Pathlike = None, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Create RecordingSet and SupervisionSet manifests for MDCC from a raw corpus distribution. + + :param corpus_dir: Pathlike, the path to the extracted corpus. + :param output_dir: Pathlike, the path where to write the manifests. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + audio_dir = corpus_dir / "audio" + assert (audio_dir).is_dir(), f"Missing {audio_dir} in {corpus_dir}." + manifests = {} + + if dataset_parts == "all" or dataset_parts[0] == "all": + dataset_parts = MDCC_PARTS + elif isinstance(dataset_parts, str): + assert dataset_parts in MDCC_PARTS, f"Unknown dataset part: {dataset_parts}" + dataset_parts = [dataset_parts] + + for part in dataset_parts: + recordings = [] + supervisions = [] + + metadata = corpus_dir / f"cnt_asr_{part}_metadata.csv" + assert (metadata).is_file(), f"Missing {part} metadata in {corpus_dir}." + + # read cvs file in an ugly way as there are no more than 80k lines + # and i don't want to depend on pandas + with open(metadata, "r") as f: + lines = f.readlines() + + # remove the header + lines = lines[1:] + + for line in tqdm(lines, desc=f"Processing {part} metadata"): + # audio_path, text_path, gender, duration + audio_path, text_path, gender, _ = line.strip().split(",") + audio_path = audio_dir / Path(audio_path).name + text_path = corpus_dir / text_path + + recording_id = make_recording_id(Path(audio_path)) + recording = Recording.from_file(audio_path, recording_id=recording_id) + recordings.append(recording) + + supervision_segment = SupervisionSegment( + id=recording_id, + recording_id=recording_id, + start=0.0, + duration=recording.duration, + channel=0, + text=text_path.read_text().strip(), + gender=gender, + language="yue", + ) + supervisions.append(supervision_segment) + + recordings = RecordingSet.from_recordings(recordings) + supervisions = SupervisionSet.from_segments(supervisions) + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions( + recordings=recordings, supervisions=supervisions + ) + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + recordings.to_file(output_dir / f"mdcc_recordings_{part}.jsonl.gz") + supervisions.to_file(output_dir / f"mdcc_supervisions_{part}.jsonl.gz") + + manifests[part] = {"recordings": recordings, "supervisions": supervisions} + + return manifests + + +def make_recording_id(path: Path) -> str: + return f"mdcc_{path.stem}" From d3106cf374df5dc33b16a5fcdd2ecb87c0cd0359 Mon Sep 17 00:00:00 2001 From: Feiteng Date: Mon, 11 Mar 2024 20:05:18 +0800 Subject: [PATCH 09/69] Fix _get_strided_batch device (#1303) --- lhotse/features/kaldi/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/features/kaldi/layers.py b/lhotse/features/kaldi/layers.py index 4d974b804..e2b271556 100644 --- a/lhotse/features/kaldi/layers.py +++ b/lhotse/features/kaldi/layers.py @@ -776,7 +776,7 @@ def _get_strided_batch( if npad_right >= 0: pad_right = torch.flip(waveform[:, -npad_right:], (1,)) else: - pad_right = torch.zeros(0, dtype=waveform.dtype) + pad_right = torch.zeros(0, dtype=waveform.dtype, device=waveform.device) waveform = torch.cat((pad_left, waveform, pad_right), dim=1) strides = ( From bdb692f1e63cfe8e88aef763271f4402d109d862 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:07:52 +0800 Subject: [PATCH 10/69] Fix typo in README.md (#1308) --- lhotse/workflows/activity_detection/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/workflows/activity_detection/README.md b/lhotse/workflows/activity_detection/README.md index 42113d82e..f3b938c46 100644 --- a/lhotse/workflows/activity_detection/README.md +++ b/lhotse/workflows/activity_detection/README.md @@ -51,7 +51,7 @@ The Activity Detection module provides tools for detecting activity in audio rec .../librispeech_recordings_train-clean-5.jsonl.gz ``` -## Trubleshooting +## Troubleshooting If you encounter the following errors while running the activity detection. From dd605c17959e3a6556fdf883e34dad89ccab3b5d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 27 Mar 2024 21:43:35 +0800 Subject: [PATCH 11/69] Updated text_norm for `aishell` recipe (#1305) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update aishell.py Co-authored-by: Piotr Żelasko --- lhotse/recipes/aishell.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lhotse/recipes/aishell.py b/lhotse/recipes/aishell.py index d73a85474..62a7abd66 100644 --- a/lhotse/recipes/aishell.py +++ b/lhotse/recipes/aishell.py @@ -135,7 +135,9 @@ def prepare_aishell( channel=0, language="Chinese", speaker=speaker, - text=text.strip(), + text=text.strip().replace(" ", ""), + # here we remove the space between words in the text + # in advance. ) supervisions.append(segment) From 1c2a1b53fb0bcaafc6cbc884ea542f752e16d839 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:48:27 +0800 Subject: [PATCH 12/69] Add dataset for audio tagging (#1241) * add audio tagging dataset * minor fix * add test for audio tagging --- lhotse/dataset/__init__.py | 1 + lhotse/dataset/audio_tagging.py | 137 +++++++++++++++++++++++++++++ test/dataset/test_audio_tagging.py | 24 +++++ 3 files changed, 162 insertions(+) create mode 100644 lhotse/dataset/audio_tagging.py create mode 100644 test/dataset/test_audio_tagging.py diff --git a/lhotse/dataset/__init__.py b/lhotse/dataset/__init__.py index 0c2e856af..a6688e6ee 100644 --- a/lhotse/dataset/__init__.py +++ b/lhotse/dataset/__init__.py @@ -1,4 +1,5 @@ from . import cut_transforms, input_strategies, sampling, signal_transforms +from .audio_tagging import AudioTaggingDataset from .cut_transforms import * from .dataloading import make_worker_init_fn from .diarization import DiarizationDataset diff --git a/lhotse/dataset/audio_tagging.py b/lhotse/dataset/audio_tagging.py new file mode 100644 index 000000000..0ca44a687 --- /dev/null +++ b/lhotse/dataset/audio_tagging.py @@ -0,0 +1,137 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class AudioTaggingDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the audio tagging task. + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + # For audio event, which can be mapped to a multi-hot tensor + 'audio_event': string separated by semicolon + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + Audio tagging IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_frames and max_cuts. + """ + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "supervisions": default_collate( + [ + { + "audio_event": supervision.audio_event, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + return batch diff --git a/test/dataset/test_audio_tagging.py b/test/dataset/test_audio_tagging.py new file mode 100644 index 000000000..24ca416cb --- /dev/null +++ b/test/dataset/test_audio_tagging.py @@ -0,0 +1,24 @@ +import pytest + +from lhotse.cut import CutSet +from lhotse.dataset import AudioTaggingDataset + + +@pytest.fixture +def dummy_cut_set(): + cuts = CutSet.from_json("test/fixtures/libri/cuts.json") + + def _add_audio_event(c): + c.supervisions[0].audio_event = "Speech; Whisper" + return c + + cuts = cuts.map(_add_audio_event) + return cuts + + +def test_audio_tagging_dataset(dummy_cut_set): + dataset = AudioTaggingDataset() + out = dataset[dummy_cut_set] + supervisions = out["supervisions"] + assert "audio_event" in supervisions + print("Pass the test") From 393a72ae8c6d2bc1c74c207b83b40e93634bf11c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 3 Apr 2024 10:37:29 -0400 Subject: [PATCH 13/69] Enhance `CutSet.mix()` randomness and data utilization (#1315) * Enhance randomness of CutSet.mix * Use as much noise signal as possible from the first drawn noise in LazyCutMixer --- lhotse/cut/set.py | 60 +++++++++++++++++++++++++++++------- test/cut/test_cut_set_mix.py | 17 ++++++++++ 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 652b8942d..5afcb2b48 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -3464,6 +3464,9 @@ class LazyCutMixer(Dillable): :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. + :param stateful: when True, each time this object is iterated we will shuffle the noise cuts + using a different random seed. This is useful when you often re-start the iteration and + don't want to keep seeing the same noise examples. Enabled by default. """ def __init__( @@ -3477,6 +3480,7 @@ def __init__( mix_prob: float = 1.0, seed: Union[int, Literal["trng", "randomized"]] = 42, random_mix_offset: bool = False, + stateful: bool = True, ) -> None: self.source = cuts self.mix_in_cuts = mix_in_cuts @@ -3487,6 +3491,8 @@ def __init__( self.mix_prob = mix_prob self.seed = seed self.random_mix_offset = random_mix_offset + self.stateful = stateful + self.num_times_iterated = 0 assert 0.0 <= self.mix_prob <= 1.0 assert self.duration is None or self.duration > 0 @@ -3500,28 +3506,49 @@ def __init__( def __iter__(self): from lhotse.dataset.dataloading import resolve_seed - rng = random.Random(resolve_seed(self.seed)) - mix_in_cuts = iter(self.mix_in_cuts.repeat().shuffle(rng=rng, buffer_size=100)) + rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated) + if self.stateful: + self.num_times_iterated += 1 + if self.mix_in_cuts.is_lazy: + # If the noise input is lazy, we'll shuffle it approximately. + # We set the shuffling buffer size to 2000 because that's the size of MUSAN, + # so even if the user forgets to convert MUSAN to an eager manifest, they will + # get roughly the same quality of noise randomness. + # Note: we can't just call .to_eager() as the noise CutSet can technically be + # very large, or even hold data in-memory in case of webdataset/Lhotse Shar sources. + def noise_gen(): + yield from self.mix_in_cuts.repeat().shuffle(rng=rng, buffer_size=2000) + + else: + # Eager nose cuts are just fully reshuffled in a different order on each noise "epoch". + def noise_gen(): + # + while True: + yield from self.mix_in_cuts.shuffle(rng=rng) + + mix_in_cuts = iter(noise_gen()) for cut in self.source: # Check whether we're going to mix something into the current cut # or pass it through unchanged. if not is_cut(cut) or rng.uniform(0.0, 1.0) > self.mix_prob: yield cut continue - to_mix = next(mix_in_cuts) # Determine the SNR - either it's specified or we need to sample one. cut_snr = ( rng.uniform(*self.snr) if isinstance(self.snr, (list, tuple)) else self.snr ) - if self.random_mix_offset and to_mix.duration > cut.duration: - to_mix = to_mix.truncate( - offset=rng.uniform(0, to_mix.duration - cut.duration), - duration=cut.duration, - ) + # Note: we subtract 0.05s (50ms) from the target duration to avoid edge cases + # where we mix in some noise cut that effectively has 0 frames of features. + target_mixed_duration = round( + self.duration if self.duration is not None else cut.duration - 0.05, + ndigits=8, + ) # Actual mixing + to_mix = next(mix_in_cuts) + to_mix = self._maybe_truncate_cut(to_mix, target_mixed_duration, rng) mixed = cut.mix(other=to_mix, snr=cut_snr, preserve_id=self.preserve_id) # Did the user specify a duration? # If yes, we will ensure that shorter cuts have more noise mixed in @@ -3532,10 +3559,9 @@ def __iter__(self): # Keep sampling until we mixed in a "duration" amount of noise. # Note: we subtract 0.05s (50ms) from the target duration to avoid edge cases # where we mix in some noise cut that effectively has 0 frames of features. - while mixed_in_duration < ( - self.duration if self.duration is not None else cut.duration - 0.05 - ): + while mixed_in_duration < target_mixed_duration: to_mix = next(mix_in_cuts) + to_mix = self._maybe_truncate_cut(to_mix, target_mixed_duration, rng) # Keep the SNR constant for each cut from "self". mixed = mixed.mix( other=to_mix, @@ -3550,12 +3576,24 @@ def __iter__(self): mixed_in_duration + to_mix.duration, ndigits=8 ) # We truncate the mixed to either the original duration or the requested duration. + # Note: we don't use 'target_mixed_duration' here because it may have subtracted + # a tiny bit of actual target duration to avoid errors related to edge effects. mixed = mixed.truncate( duration=self.duration if self.duration is not None else cut.duration, preserve_id=self.preserve_id is not None, ) yield mixed + def _maybe_truncate_cut( + self, cut: Cut, target_duration: Seconds, rng: random.Random + ) -> Cut: + if self.random_mix_offset and cut.duration > target_duration: + cut = cut.truncate( + offset=rng.uniform(0, cut.duration - target_duration), + duration=target_duration, + ) + return cut + def __len__(self) -> int: return len(self.source) diff --git a/test/cut/test_cut_set_mix.py b/test/cut/test_cut_set_mix.py index 03019052c..7cc1334a9 100644 --- a/test/cut/test_cut_set_mix.py +++ b/test/cut/test_cut_set_mix.py @@ -2,6 +2,7 @@ import pytest from lhotse.cut import CutSet, MixedCut +from lhotse.testing.dummies import DummyManifest, as_lazy from lhotse.testing.fixtures import random_cut_set @@ -59,3 +60,19 @@ def test_cut_set_mixing_with_random_mix_offset(): offset_mix = speech_cuts.mix(noise_cuts, random_mix_offset=True) for a, b in zip(normal_mix, offset_mix): assert not np.array_equal(a.load_audio(), b.load_audio()) + + +def test_cut_set_mixing_less_noise_cuts_than_speech_cuts_eager_noise_cutset(): + speech_cuts = DummyManifest(CutSet, begin_id=0, end_id=2) + noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=101) + mixed_cuts = speech_cuts.mix(noise_cuts) + for c in mixed_cuts: + assert isinstance(c, MixedCut) + + +def test_cut_set_mixing_less_noise_cuts_than_speech_cuts_lazy_noise_cutset(): + speech_cuts = DummyManifest(CutSet, begin_id=0, end_id=10) + noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=101).repeat(2) + mixed_cuts = speech_cuts.mix(noise_cuts) + for c in mixed_cuts: + assert isinstance(c, MixedCut) From 1b68036d20e5a674c45c01d7719ddacc2d6742ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 3 Apr 2024 14:05:38 -0400 Subject: [PATCH 14/69] Fix randomness in CutMix transform (#1316) * Fix randomness in CutMix transform * nitpicks --- lhotse/cut/set.py | 13 ++++++++----- lhotse/dataset/cut_transforms/mix.py | 22 +++++++++++++++++++--- test/dataset/test_cut_transforms.py | 11 +++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 5afcb2b48..4f0872f3d 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -1670,7 +1670,7 @@ def mix( snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20, preserve_id: Optional[str] = None, mix_prob: float = 1.0, - seed: Union[int, Literal["trng", "randomized"]] = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, ) -> "CutSet": """ @@ -1699,7 +1699,7 @@ def mix( Values lower than 1.0 mean that some cuts in the output will be unchanged. :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results - on each iteration. + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -3460,7 +3460,7 @@ class LazyCutMixer(Dillable): Values lower than 1.0 mean that some cuts in the output will be unchanged. :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results - on each iteration. + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -3478,7 +3478,7 @@ def __init__( snr: Optional[Union[Decibels, Sequence[Decibels]]] = 20, preserve_id: Optional[str] = None, mix_prob: float = 1.0, - seed: Union[int, Literal["trng", "randomized"]] = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, stateful: bool = True, ) -> None: @@ -3506,7 +3506,10 @@ def __init__( def __iter__(self): from lhotse.dataset.dataloading import resolve_seed - rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated) + if isinstance(self.seed, random.Random): + rng = self.seed + else: + rng = random.Random(resolve_seed(self.seed) + self.num_times_iterated) if self.stateful: self.num_times_iterated += 1 diff --git a/lhotse/dataset/cut_transforms/mix.py b/lhotse/dataset/cut_transforms/mix.py index 44c4c6264..2e244281f 100644 --- a/lhotse/dataset/cut_transforms/mix.py +++ b/lhotse/dataset/cut_transforms/mix.py @@ -1,7 +1,9 @@ +import random import warnings -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union from lhotse import CutSet +from lhotse.dataset.dataloading import resolve_seed from lhotse.utils import Decibels @@ -18,7 +20,7 @@ def __init__( p: float = 0.5, pad_to_longest: bool = True, preserve_id: bool = False, - seed: int = 42, + seed: Union[int, Literal["trng", "randomized"], random.Random] = 42, random_mix_offset: bool = False, ) -> None: """ @@ -34,6 +36,9 @@ def __init__( to match the duration of the longest Cut in a batch. :param preserve_id: When ``True``, preserves the IDs the cuts had before augmentation. Otherwise, new random IDs are generated for the augmented cuts (default). + :param seed: an optional int or "trng". Random seed for choosing the cuts to mix and the SNR. + If "trng" is provided, we'll use the ``secrets`` module for non-deterministic results + on each iteration. You can also directly pass a ``random.Random`` instance here. :param random_mix_offset: an optional bool. When ``True`` and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut. @@ -48,6 +53,7 @@ def __init__( self.pad_to_longest = pad_to_longest self.preserve_id = preserve_id self.seed = seed + self.rng = None self.random_mix_offset = random_mix_offset def __call__(self, cuts: CutSet) -> CutSet: @@ -56,6 +62,8 @@ def __call__(self, cuts: CutSet) -> CutSet: if len(self.cuts) == 0: return cuts + self._lazy_rng_init() + maybe_max_duration = ( max(c.duration for c in cuts) if self.pad_to_longest else None ) @@ -65,6 +73,14 @@ def __call__(self, cuts: CutSet) -> CutSet: snr=self.snr, mix_prob=self.p, preserve_id="left" if self.preserve_id else None, - seed=self.seed, + seed=self.rng, random_mix_offset=self.random_mix_offset, ).to_eager() + + def _lazy_rng_init(self): + if self.rng is not None: + return + if isinstance(self.seed, random.Random): + self.rng = self.seed + else: + self.rng = random.Random(resolve_seed(self.seed)) diff --git a/test/dataset/test_cut_transforms.py b/test/dataset/test_cut_transforms.py index 9e7a90955..1c24763a4 100644 --- a/test/dataset/test_cut_transforms.py +++ b/test/dataset/test_cut_transforms.py @@ -118,6 +118,17 @@ def test_cutmix(preserve_id: bool): ) +def test_cut_mix_is_stateful(): + speech_cuts = DummyManifest(CutSet, begin_id=0, end_id=10) + noise_cuts = DummyManifest(CutSet, begin_id=100, end_id=102) + + # called twice on the same input, expecting different results + tnfm = CutMix(noise_cuts, snr=None, p=1.0, seed=0, preserve_id=True) + out1 = tnfm(speech_cuts) + out2 = tnfm(speech_cuts) + assert list(out1) != list(out2) + + def test_cutmix_random_mix_offset(): speech_cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json").resample(16000) noise_cuts = CutSet.from_json("test/fixtures/libri/cuts.json") From 82afe2a37cff290da9c8e4b73b45269d4f1335d4 Mon Sep 17 00:00:00 2001 From: Omid Sadjadi Date: Wed, 17 Apr 2024 09:24:53 -0400 Subject: [PATCH 15/69] select a random sub-region of the noise based on the delta duration (#1317) * select a random sub-region of the noise based on the delta duration we select a random sub-region of the noise based on the delta duration needed to reach the target duration * fix failing unit tests in cut mix * apply tolerance in while loop condition unconditionally --- lhotse/cut/set.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 4f0872f3d..58900306b 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -3562,9 +3562,11 @@ def noise_gen(): # Keep sampling until we mixed in a "duration" amount of noise. # Note: we subtract 0.05s (50ms) from the target duration to avoid edge cases # where we mix in some noise cut that effectively has 0 frames of features. - while mixed_in_duration < target_mixed_duration: + while mixed_in_duration < target_mixed_duration - 0.05: to_mix = next(mix_in_cuts) - to_mix = self._maybe_truncate_cut(to_mix, target_mixed_duration, rng) + to_mix = self._maybe_truncate_cut( + to_mix, target_mixed_duration - mixed_in_duration, rng + ) # Keep the SNR constant for each cut from "self". mixed = mixed.mix( other=to_mix, From bfce956aa4523ab80419f3e6394333e8951c505a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Apr 2024 10:53:40 -0400 Subject: [PATCH 16/69] Fix export of features/array to shar (#1323) * Fix export of features/array to shar * Add unit test and fix the changes --- lhotse/shar/utils.py | 33 +++++++++++++++++----------- test/shar/test_write.py | 48 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/lhotse/shar/utils.py b/lhotse/shar/utils.py index c4554cfd8..253ece9ed 100644 --- a/lhotse/shar/utils.py +++ b/lhotse/shar/utils.py @@ -11,14 +11,6 @@ def to_shar_placeholder(manifest: Manifest, cut: Optional[Cut] = None) -> Manifest: if isinstance(manifest, Recording): - kwargs = ( - {} - if cut is None - else dict( - duration=cut.duration, - num_samples=compute_num_samples(cut.duration, manifest.sampling_rate), - ) - ) return fastcopy( manifest, # Creates a single AudioSource out of multiple ones. @@ -27,18 +19,35 @@ def to_shar_placeholder(manifest: Manifest, cut: Optional[Cut] = None) -> Manife ], # Removes the transform metadata because they were already executed. transforms=None, - **kwargs, + duration=cut.duration if cut is not None else manifest.duration, + num_samples=compute_num_samples(cut.duration, manifest.sampling_rate) + if cut is not None + else manifest.num_samples, ) - # TODO: modify Features/TemporalArray's start/duration/num_frames if needed to match the Cut (in case we read subset of array) - elif isinstance(manifest, (Array, Features)): + elif isinstance(manifest, Array): return fastcopy(manifest, storage_type="shar", storage_path="", storage_key="") + elif isinstance(manifest, Features): + return fastcopy( + manifest, + start=0, + duration=cut.duration if cut is not None else manifest.duration, + storage_type="shar", + storage_path="", + storage_key="", + ) elif isinstance(manifest, TemporalArray): return fastcopy( manifest, + start=0, array=fastcopy( - manifest.array, storage_type="shar", storage_path="", storage_key="" + manifest.array, + storage_type="shar", + storage_path="", + storage_key="", ), ) + else: + raise RuntimeError(f"Unexpected manifest type: {type(manifest)}") def fill_shar_placeholder( diff --git a/test/shar/test_write.py b/test/shar/test_write.py index a70b665c7..cee35de1c 100644 --- a/test/shar/test_write.py +++ b/test/shar/test_write.py @@ -9,7 +9,7 @@ from lhotse.audio.backend import audio_backend, check_torchaudio_version_gt from lhotse.lazy import LazyJsonlIterator from lhotse.shar import AudioTarWriter, SharWriter, TarIterator, TarWriter -from lhotse.testing.dummies import DummyManifest +from lhotse.testing.dummies import DummyManifest, dummy_cut def test_tar_writer(tmp_path: Path): @@ -687,3 +687,49 @@ def test_shar_writer_pipe(tmp_path: Path): assert cut.custom_recording.sources[0].type == "shar" with pytest.raises(RuntimeError): cut.load_custom_recording() + + +def test_shar_writer_truncates_temporal_array_and_features(tmp_path: Path): + # Basic data and sanity check of shapes. + cut = dummy_cut(0, with_data=True) + for k in "custom_embedding custom_features custom_recording".split(): + cut = cut.drop_custom(k) + ref_audio = cut.load_audio() + ref_feats = cut.load_features() + ref_indxs = cut.load_custom_indexes() + assert ref_audio.shape == (1, 16000) + assert ref_feats.shape == (100, 23) + assert ref_indxs.shape == (100,) + + # Truncated cut before writing to Shar and sanity check of shapes and content. + cut = cut.truncate(offset=0.2, duration=0.6) + trunc_audio = cut.load_audio() + trunc_feats = cut.load_features() + trunc_indxs = cut.load_custom_indexes() + assert trunc_audio.shape == (1, 9600) + np.testing.assert_array_equal(trunc_audio, ref_audio[:, 3200:-3200]) + assert trunc_feats.shape == (60, 23) + np.testing.assert_array_equal(trunc_feats, ref_feats[20:-20, :]) + assert trunc_indxs.shape == (60,) + np.testing.assert_array_equal(trunc_indxs, ref_indxs[20:-20]) + + # System under test. + with SharWriter( + tmp_path, + fields={"recording": "wav", "features": "numpy", "custom_indexes": "numpy"}, + shard_size=None, + ) as writer: + writer.write(cut) + + # Truncated cut restored from Shar and sanity check of shapes and content. + sharcuts = CutSet.from_shar(in_dir=writer.output_dir) + cut = sharcuts[0] + trunc_audio = cut.load_audio() + trunc_feats = cut.load_features() + trunc_indxs = cut.load_custom_indexes() + assert trunc_audio.shape == (1, 9600) + np.testing.assert_array_equal(trunc_audio, ref_audio[:, 3200:-3200]) + assert trunc_feats.shape == (60, 23) + np.testing.assert_array_equal(trunc_feats, ref_feats[20:-20, :]) + assert trunc_indxs.shape == (60,) + np.testing.assert_array_equal(trunc_indxs, ref_indxs[20:-20]) From 9bf1b8f757441a9dfaa1d73b5bfc574667c2411d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Apr 2024 10:53:56 -0400 Subject: [PATCH 17/69] Fix `trim_to_supervision_groups` (#1322) Fix trim_to_supervision_groups --- lhotse/cut/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/cut/base.py b/lhotse/cut/base.py index 1f7c046a0..7ab95d991 100644 --- a/lhotse/cut/base.py +++ b/lhotse/cut/base.py @@ -677,7 +677,7 @@ def trim_to_supervision_groups( from .set import CutSet if not self.supervisions: - return self + return CutSet([self]) supervisions = sorted(self.supervisions, key=lambda s: s.start) supervision_group = [supervisions[0]] cur_end = supervisions[0].end From ad668890f9bc7ccf31ace59a7b47f40c8ab73749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 23 Apr 2024 10:54:20 -0400 Subject: [PATCH 18/69] Allow skipping missing files in AMI download (#1318) --- lhotse/recipes/ami.py | 24 ++++++++++-------------- lhotse/utils.py | 9 ++++++++- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/lhotse/recipes/ami.py b/lhotse/recipes/ami.py index 8f3f080f8..f18b7a247 100644 --- a/lhotse/recipes/ami.py +++ b/lhotse/recipes/ami.py @@ -198,7 +198,10 @@ def download_audio( wav_dir.mkdir(parents=True, exist_ok=True) wav_path = wav_dir / wav_name resumable_download( - wav_url, filename=wav_path, force_download=force_download + wav_url, + filename=wav_path, + force_download=force_download, + missing_ok=True, ) elif mic == "mdm": for array in MDM_ARRAYS: @@ -208,19 +211,12 @@ def download_audio( wav_dir = target_dir / "wav_db" / item / "audio" wav_dir.mkdir(parents=True, exist_ok=True) wav_path = wav_dir / wav_name - try: - resumable_download( - wav_url, filename=wav_path, force_download=force_download - ) - except urllib.error.HTTPError as err: - if err.code == 404: - logging.warning( - f"{wav_url} does not exist. Skipping this file." - ) - if os.path.exists(wav_path) and os.path.isfile(wav_path): - os.remove(wav_path) - else: - raise err + resumable_download( + wav_url, + filename=wav_path, + force_download=force_download, + missing_ok=True, + ) elif mic == "mdm8-bf": wav_name = f"{item}_MDM8.wav" wav_url = f"{url}/AMICorpusMirror/amicorpus/beamformed/{item}/{wav_name}" diff --git a/lhotse/utils.py b/lhotse/utils.py index f303100d3..23ed1fd45 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -454,6 +454,7 @@ def resumable_download( filename: Pathlike, force_download: bool = False, completed_file_size: Optional[int] = None, + missing_ok: bool = False, ) -> None: # Check if the file exists and get its size file_exists = os.path.exists(filename) @@ -518,7 +519,13 @@ def _download(rq, size): except urllib.error.HTTPError as e: # "Request Range Not Satisfiable" means the requested range # starts after the file ends OR that the server does not support range requests. - if e.code == 416: + if e.code == 404 and missing_ok: + logging.warning( + f"{url} does not exist (error 404). Skipping this file." + ) + if Path(filename).is_file(): + os.remove(filename) + elif e.code == 416: content_range = e.headers.get("Content-Range", None) if content_range is None: # sometimes, the server actually supports range requests From ed5797c1d40cfa305bbdd7950917a6ffd170d5de Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 23 Apr 2024 22:55:39 +0800 Subject: [PATCH 19/69] Add Chinese TTS dataset `baker`. (#1304) --- lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/baker_zh.py | 25 ++++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/baker_zh.py | 113 +++++++++++++++++++++++++++ 4 files changed, 140 insertions(+) create mode 100644 lhotse/bin/modes/recipes/baker_zh.py create mode 100644 lhotse/recipes/baker_zh.py diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index 3dbc31c14..b5fe10981 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -10,6 +10,7 @@ from .atcosim import * from .audio_mnist import * from .babel import * +from .baker_zh import * from .bengaliai_speech import * from .broadcast_news import * from .but_reverb_db import * diff --git a/lhotse/bin/modes/recipes/baker_zh.py b/lhotse/bin/modes/recipes/baker_zh.py new file mode 100644 index 000000000..bcdd26718 --- /dev/null +++ b/lhotse/bin/modes/recipes/baker_zh.py @@ -0,0 +1,25 @@ +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.baker_zh import download_baker_zh, prepare_baker_zh +from lhotse.utils import Pathlike + +__all__ = ["baker_zh"] + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path(), default=".") +def baker_zh(target_dir: Pathlike): + """bazker_zh download.""" + download_baker_zh(target_dir) + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +def baker_zh( + corpus_dir: Pathlike, + output_dir: Pathlike, +): + """bazker_zh data preparation.""" + prepare_baker_zh(corpus_dir, output_dir=output_dir) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 583585f46..99bde7d97 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -7,6 +7,7 @@ from .aspire import prepare_aspire from .atcosim import download_atcosim, prepare_atcosim from .babel import prepare_single_babel_language +from .baker_zh import download_baker_zh, prepare_baker_zh from .bengaliai_speech import prepare_bengaliai_speech from .broadcast_news import prepare_broadcast_news from .but_reverb_db import download_but_reverb_db, prepare_but_reverb_db diff --git a/lhotse/recipes/baker_zh.py b/lhotse/recipes/baker_zh.py new file mode 100644 index 000000000..04172a689 --- /dev/null +++ b/lhotse/recipes/baker_zh.py @@ -0,0 +1,113 @@ +""" +See https://en.data-baker.com/datasets/freeDatasets/ + +It is a Chinese TTS dataset, containing 12 hours of data. +""" + +import logging +import re +import shutil +import tarfile +from pathlib import Path +from typing import Dict, Optional, Union + +from lhotse import fix_manifests, validate_recordings_and_supervisions +from lhotse.audio import Recording, RecordingSet +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, resumable_download, safe_extract + + +def download_baker_zh( + target_dir: Pathlike = ".", force_download: Optional[bool] = False +) -> Path: + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + dataset_name = "BZNSYP" + tar_path = target_dir / f"{dataset_name}.tar.bz2" + corpus_dir = target_dir / dataset_name + completed_detector = corpus_dir / ".completed" + if completed_detector.is_file(): + logging.info(f"Skipping {dataset_name} because {completed_detector} exists.") + return corpus_dir + resumable_download( + f"https://huggingface.co/openspeech/BZNSYP/resolve/main/{dataset_name}.tar.bz2", + filename=tar_path, + force_download=force_download, + ) + shutil.rmtree(corpus_dir, ignore_errors=True) + with tarfile.open(tar_path) as tar: + safe_extract(tar, path=target_dir) + completed_detector.touch() + + return corpus_dir + + +def prepare_baker_zh( + corpus_dir: Pathlike, output_dir: Optional[Pathlike] = None +) -> Dict[str, Union[RecordingSet, SupervisionSet]]: + """ + Returns the manifests which consist of the Recordings and Supervisions + + :param corpus_dir: Pathlike, the path of the data dir. + :param output_dir: Pathlike, the path where to write the manifests. + :return: The RecordingSet and SupervisionSet with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # The corpus_dir contains three sub directories + # PhoneLabeling ProsodyLabeling Wave + + # Generate a mapping: utt_id -> (audio_path, audio_info, text) + labeling_file = corpus_dir / "ProsodyLabeling" / "000001-010000.txt" + if not labeling_file.is_file(): + raise ValueError(f"{labeling_file} does not exist") + + recordings = [] + supervisions = [] + logging.info("Started preparing. It may take 30 seconds") + pattern = re.compile("#[12345]") + with open(labeling_file) as f: + try: + while True: + first = next(f).strip() + pinyin = next(f).strip() + recording_id, original_text = first.split(None, maxsplit=1) + normalized_text = re.sub(pattern, "", original_text) + audio_path = corpus_dir / "Wave" / f"{recording_id}.wav" + + if not audio_path.is_file(): + logging.warning(f"No such file: {audio_path}") + continue + recording = Recording.from_file(audio_path) + + segment = SupervisionSegment( + id=recording_id, + recording_id=recording_id, + start=0.0, + duration=recording.duration, + channel=0, + language="Chinese", + gender="female", + text=original_text, + custom={"pinyin": pinyin, "normalized_text": normalized_text}, + ) + recordings.append(recording) + supervisions.append(segment) + except StopIteration: + pass + + recording_set = RecordingSet.from_recordings(recordings) + supervision_set = SupervisionSet.from_segments(supervisions) + + recording_set, supervision_set = fix_manifests(recording_set, supervision_set) + validate_recordings_and_supervisions(recording_set, supervision_set) + + if output_dir is not None: + supervision_set.to_file(output_dir / "baker_zh_supervisions_all.jsonl.gz") + recording_set.to_file(output_dir / "baker_zh_recordings_all.jsonl.gz") + + return {"recordings": recording_set, "supervisions": supervision_set} From b2dce7813c3e5236f1b8f48908196cf66e0ecf72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Do=C3=B1a?= <23705091+daniel-dona@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:33:40 +0200 Subject: [PATCH 20/69] In CommonVoice corpus, use .tsv headers to parse and not column index (#1328) * Fix for cv corpus * Fix for cv corpus x2 * Debug serialization problem * Debug serialization problem * Undo * Handle quote polution in CV dataset --- lhotse/recipes/commonvoice.py | 40 +++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/lhotse/recipes/commonvoice.py b/lhotse/recipes/commonvoice.py index fd84329da..46a08cd91 100644 --- a/lhotse/recipes/commonvoice.py +++ b/lhotse/recipes/commonvoice.py @@ -9,6 +9,7 @@ How does it work? We are crowdsourcing an open-source dataset of voices. Donate your voice, validate the accuracy of other people's clips, make the dataset better for everyone. """ +import csv import logging import math import numbers @@ -149,14 +150,13 @@ def _parse_utterance( language: str, audio_info: str, ) -> Optional[Tuple[Recording, SupervisionSegment]]: - audio_info = audio_info.split("\t", -1) - audio_path = lang_path / "clips" / audio_info[1] + audio_path = lang_path / "clips" / audio_info["path"] if not audio_path.is_file(): logging.info(f"No such file: {audio_path}") return None - recording_id = Path(audio_info[1]).stem + recording_id = Path(audio_info["path"]).stem recording = Recording.from_file(path=audio_path, recording_id=recording_id) segment = SupervisionSegment( @@ -166,12 +166,13 @@ def _parse_utterance( duration=recording.duration, channel=0, language=language, - speaker=audio_info[0], - text=audio_info[2].strip(), - gender=audio_info[6], + speaker=audio_info["client_id"], + text=audio_info["sentence"].strip(), + gender=audio_info["gender"], custom={ - "age": audio_info[5], - "accents": audio_info[7], + "age": audio_info["age"], + "accents": audio_info["accents"], + "variant": audio_info["variant"], }, ) return recording, segment @@ -207,19 +208,22 @@ def _prepare_part( futures = [] recordings = [] supervisions = [] + audio_infos = [] - with open(tsv_path) as f: - audio_infos = iter(f.readlines()) + with open(tsv_path, "r") as f: - for audio_info in tqdm(audio_infos, desc="Distributing tasks"): - futures.append( - ex.submit( - _parse_utterance, - lang_path, - lang, - audio_info, + # Note: using QUOTE_NONE as CV dataset contains unbalanced quotes, cleanup needed later + audio_infos = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE) + + for audio_info in tqdm(audio_infos, desc="Distributing tasks"): + futures.append( + ex.submit( + _parse_utterance, + lang_path, + lang, + audio_info, + ) ) - ) for future in tqdm(futures, desc="Processing"): result = future.result() From 4f014b13202c724d484e0471343053a261487b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 30 Apr 2024 14:46:36 -0400 Subject: [PATCH 21/69] Bump dev version to 1.24.0 (#1329) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index a6c2798a4..53cc1a6f9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.23.0 +1.24.0 From ddde5bd15a769962ba7ec872fd9212bc89b08b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Do=C3=B1a?= <23705091+daniel-dona@users.noreply.github.com> Date: Fri, 10 May 2024 01:49:21 +0200 Subject: [PATCH 22/69] Missing 'subset' parameter (#1336) --- lhotse/bin/modes/recipes/voxpopuli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/bin/modes/recipes/voxpopuli.py b/lhotse/bin/modes/recipes/voxpopuli.py index af9e3fb8e..0653da14d 100644 --- a/lhotse/bin/modes/recipes/voxpopuli.py +++ b/lhotse/bin/modes/recipes/voxpopuli.py @@ -81,4 +81,4 @@ def voxpopuli( ) def voxpopuli(target_dir: Pathlike, subset: str): """voxpopuli download.""" - download_voxpopuli(target_dir) + download_voxpopuli(target_dir, subset) From d02a168ebb763f053f91c5713ecff39b1a815e4c Mon Sep 17 00:00:00 2001 From: Kee Koo Date: Wed, 15 May 2024 22:13:43 +0800 Subject: [PATCH 23/69] Fix describe on cuts (#1340) Update describe.py Fix #1339 --- lhotse/cut/describe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/cut/describe.py b/lhotse/cut/describe.py index 97c0eb079..6be3a1855 100644 --- a/lhotse/cut/describe.py +++ b/lhotse/cut/describe.py @@ -175,7 +175,7 @@ def time_as_str_(seconds: Seconds) -> str: if self.sup_custom: print("SUPERVISION custom fields:") for key, val in self.sup_custom.most_common(): - cut_stats.append(f"- {key} (in {val} cuts)") + print(f"- {key} (in {val} cuts)") total_speech = np.array(self.speech_durations).sum() total_speaking_time = np.array(self.speaking_time_durations).sum() From 4db7f5f6491cf2b8f797e55b21b5d120ed529cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 16 May 2024 12:05:13 -0400 Subject: [PATCH 24/69] `reverb_rir`: support Cut input and in memory data (#1332) * Support Cut input in reverb_rir as well as in memory data * unit test fixes --- lhotse/audio/recording.py | 44 +++++++++++++++---------- lhotse/augmentation/rir.py | 39 +++++++++++++++++----- lhotse/cut/base.py | 1 - lhotse/cut/data.py | 11 +++++++ lhotse/cut/mono.py | 52 +++++++++++++++++++++++++++-- lhotse/cut/multi.py | 13 ++++---- test/cut/test_cut_augmentation.py | 54 +++++++++++++++++++++++++++++-- 7 files changed, 177 insertions(+), 37 deletions(-) diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index 0c3ea2ebc..971166995 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -2,7 +2,7 @@ from io import BytesIO from math import ceil, isclose from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -126,7 +126,7 @@ class Recording: num_samples: int duration: Seconds channel_ids: Optional[List[int]] = None - transforms: Optional[List[Dict]] = None + transforms: Optional[List[Union[AudioTransform, Dict]]] = None def __post_init__(self): if self.channel_ids is None: @@ -334,7 +334,10 @@ def _aslist(x): ) def to_dict(self) -> dict: - return asdict_nonull(self) + d = asdict_nonull(self) + if self.transforms is not None: + d["transforms"] = [t.to_dict() for t in self.transforms] + return d def to_cut(self): """ @@ -395,7 +398,8 @@ def load_audio( ) transforms = [ - AudioTransform.from_dict(params) for params in self.transforms or [] + tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm) + for tnfm in self.transforms or [] ] # Do a "backward pass" over data augmentation transforms to get the @@ -488,10 +492,15 @@ def load_video( ) for t in ifnone(self.transforms, ()): - assert t["name"] not in ( - "Speed", - "Tempo", - ), "Recording.load_video() does not support speed/tempo perturbation." + if isinstance(t, dict): + assert t["name"] not in ( + "Speed", + "Tempo", + ), "Recording.load_video() does not support speed/tempo perturbation." + else: + assert not isinstance( + t, (Speed, Tempo) + ), "Recording.load_video() does not support speed/tempo perturbation." if not with_audio: video, _ = self._video_source.load_video( @@ -519,7 +528,8 @@ def load_video( ) transforms = [ - AudioTransform.from_dict(params) for params in self.transforms or [] + tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm) + for tnfm in self.transforms or [] ] # Do a "backward pass" over data augmentation transforms to get the @@ -659,7 +669,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Speed(factor=factor).to_dict()) + transforms.append(Speed(factor=factor)) new_num_samples = perturb_num_samples(self.num_samples, factor) new_duration = new_num_samples / self.sampling_rate return fastcopy( @@ -684,7 +694,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Tempo(factor=factor).to_dict()) + transforms.append(Tempo(factor=factor)) new_num_samples = perturb_num_samples(self.num_samples, factor) new_duration = new_num_samples / self.sampling_rate return fastcopy( @@ -705,7 +715,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Volume(factor=factor).to_dict()) + transforms.append(Volume(factor=factor)) return fastcopy( self, id=f"{self.id}_vp{factor}" if affix_id else self.id, @@ -722,7 +732,7 @@ def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recordin :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(LoudnessNormalization(target=target).to_dict()) + transforms.append(LoudnessNormalization(target=target)) return fastcopy( self, id=f"{self.id}_ln{target}" if affix_id else self.id, @@ -738,7 +748,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(DereverbWPE().to_dict()) + transforms.append(DereverbWPE()) return fastcopy( self, id=f"{self.id}_wpe" if affix_id else self.id, @@ -751,7 +761,7 @@ def reverb_rir( normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: Optional[List[int]] = None, + rir_channels: Optional[Sequence[int]] = None, room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> "Recording": @@ -812,7 +822,7 @@ def reverb_rir( early_only=early_only, rir_channels=rir_channels if rir_channels is not None else [0], rir_generator=rir_generator, - ).to_dict() + ) ) return fastcopy( self, @@ -835,7 +845,7 @@ def resample(self, sampling_rate: int) -> "Recording": Resample( source_sampling_rate=self.sampling_rate, target_sampling_rate=sampling_rate, - ).to_dict() + ) ) new_num_samples = compute_num_samples( diff --git a/lhotse/augmentation/rir.py b/lhotse/augmentation/rir.py index 09f5dc38c..bf2e8a1a5 100644 --- a/lhotse/augmentation/rir.py +++ b/lhotse/augmentation/rir.py @@ -35,10 +35,14 @@ class ReverbWithImpulseResponse(AudioTransform): def __post_init__(self): if isinstance(self.rir, dict): - from lhotse import Recording + from lhotse.serialization import deserialize_item - # Pass a shallow copy of the RIR dict since `from_dict()` pops the `sources` key. - self.rir = Recording.from_dict(self.rir.copy()) + # Pass a shallow copy of the RIR dict since deserialization is destructive + # If RIR is a Cut, we have to perform one extra copy (hacky but better than deepcopy). + rir = self.rir.copy() + if "recording" in self.rir: + rir["recording"] = rir["recording"].copy() + self.rir = deserialize_item(rir) assert ( self.rir is not None or self.rir_generator is not None @@ -52,6 +56,23 @@ def __post_init__(self): if self.rir_generator is not None and isinstance(self.rir_generator, dict): self.rir_generator = FastRandomRIRGenerator(**self.rir_generator) + def to_dict(self) -> dict: + from lhotse import Recording + from lhotse.cut import Cut + + return { + "name": type(self).__name__, + "kwargs": { + "rir": self.rir.to_dict() + if isinstance(self.rir, (Recording, Cut)) + else self.rir, + "normalize_output": self.normalize_output, + "early_only": self.early_only, + "rir_channels": list(self.rir_channels), + "rir_generator": self.rir_generator, + }, + } + def __call__( self, samples: np.ndarray, @@ -92,11 +113,13 @@ def __call__( if self.rir is None: rir_ = self.rir_generator(nsource=1) else: - rir_ = ( - self.rir.load_audio(channels=self.rir_channels) - if not self.early_only - else self.rir.load_audio(channels=self.rir_channels, duration=0.05) - ) + from lhotse import Recording + + rir = self.rir.to_cut() if isinstance(self.rir, Recording) else self.rir + rir = rir.with_channels(self.rir_channels) + if self.early_only: + rir = rir.truncate(duration=0.05) + rir_ = rir.load_audio() D_rir, N_rir = rir_.shape N_out = N_in # Enforce shift output diff --git a/lhotse/cut/base.py b/lhotse/cut/base.py index 7ab95d991..fa1400a4a 100644 --- a/lhotse/cut/base.py +++ b/lhotse/cut/base.py @@ -22,7 +22,6 @@ compute_start_duration_for_extended_cut, fastcopy, ifnone, - is_torchaudio_available, overlaps, to_hashable, ) diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index 438805eb3..7229353af 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -22,6 +22,7 @@ Seconds, TimeSpan, add_durations, + asdict_nonull, compute_num_frames, compute_num_samples, fastcopy, @@ -70,6 +71,16 @@ class DataCut(Cut, CustomFieldMixin, metaclass=ABCMeta): # Store anything else the user might want. custom: Optional[Dict[str, Any]] = None + def to_dict(self) -> dict: + d = asdict_nonull(self) + if self.has_recording: + d["recording"] = self.recording.to_dict() + if self.custom is not None: + for k, v in self.custom.items(): + if isinstance(v, Recording): + d["custom"][k] = v.to_dict() + return {**d, "type": type(self).__name__} + @property def recording_id(self) -> str: return self.recording.id if self.has_recording else self.features.recording_id diff --git a/lhotse/cut/mono.py b/lhotse/cut/mono.py index 1cecfb2c0..75bbe23f7 100644 --- a/lhotse/cut/mono.py +++ b/lhotse/cut/mono.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import partial, reduce from operator import add -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -16,6 +16,7 @@ add_durations, fastcopy, hash_str_to_int, + is_equal_or_contains, merge_items_with_delimiter, overlaps, rich_exception_info, @@ -102,13 +103,58 @@ def load_video( ) return None + def with_channels(self, channels: Union[List[int], int]) -> DataCut: + """ + Select specified channels from this cut. + Supports extending to other channels available in the underlying :class:`Recording`. + If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`, + otherwise we'll return a :class:`~lhotse.cut.MultiCut`. + """ + channel_is_int = isinstance(channels, int) + assert set([channels] if channel_is_int else channels).issubset( + set(self.recording.channel_ids) + ), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}" + + mono = channel_is_int or len(channels) == 1 + if mono: + if not channel_is_int: + (channels,) = channels + return MonoCut( + id=f"{self.id}-{channels}", + recording=self.recording, + start=self.start, + duration=self.duration, + channel=channels, + supervisions=[ + fastcopy(s, channel=channels) + for s in self.supervisions + if is_equal_or_contains(s.channel, channels) + ], + custom=self.custom, + ) + else: + from lhotse import MultiCut + + return MultiCut( + id=f"{self.id}-{len(channels)}chan", + start=self.start, + duration=self.duration, + channel=channels, + supervisions=[ + s + for s in self.supervisions + if is_equal_or_contains(channels, s.channel) + ], + custom=self.custom, + ) + def reverb_rir( self, - rir_recording: Optional["Recording"] = None, + rir_recording: Optional[Union[Recording, DataCut]] = None, normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: List[int] = [0], + rir_channels: Sequence[int] = (0,), room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> DataCut: diff --git a/lhotse/cut/multi.py b/lhotse/cut/multi.py index 439f9d217..3544b7c4d 100644 --- a/lhotse/cut/multi.py +++ b/lhotse/cut/multi.py @@ -157,11 +157,11 @@ def load_video( def reverb_rir( self, - rir_recording: Optional["Recording"] = None, + rir_recording: Optional[Union[Recording, DataCut]] = None, normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: List[int] = [0], + rir_channels: Sequence[int] = (0,), room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> "MultiCut": @@ -370,17 +370,18 @@ def with_channels(self, channels: Union[List[int], int]) -> DataCut: Select specified channels from this cut. Supports extending to other channels available in the underlying :class:`Recording`. If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`, - otherwise we'll return a :class:`~lhotse.cut.MultiCut'. + otherwise we'll return a :class:`~lhotse.cut.MultiCut`. """ - mono = isinstance(channels, int) or len(channels) == 1 - assert set([channels] if mono else channels).issubset( + channel_is_int = isinstance(channels, int) + assert set([channels] if channel_is_int else channels).issubset( set(self.recording.channel_ids) ), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}" + mono = channel_is_int or len(channels) == 1 if mono: from .mono import MonoCut - if isinstance(channels, Sequence): + if not channel_is_int: (channels,) = channels return MonoCut( id=f"{self.id}-{channels}", diff --git a/test/cut/test_cut_augmentation.py b/test/cut/test_cut_augmentation.py index 133321ce5..9219cebbc 100644 --- a/test/cut/test_cut_augmentation.py +++ b/test/cut/test_cut_augmentation.py @@ -1,3 +1,6 @@ +import os +from tempfile import NamedTemporaryFile + import numpy as np import pytest import torch @@ -6,7 +9,7 @@ from lhotse.audio import RecordingSet from lhotse.cut import PaddingCut from lhotse.testing.dummies import dummy_cut, dummy_multi_cut -from lhotse.utils import fastcopy, is_module_available +from lhotse.utils import fastcopy, is_module_available, nullcontext @pytest.fixture @@ -652,9 +655,14 @@ def test_cut_normalize_loudness(libri_cut_set, target, mix_first): assert loudness == pytest.approx(target, abs=0.5) -def test_cut_reverb_rir(libri_cut_with_supervision, libri_recording_rvb, rir): +@pytest.mark.parametrize("in_memory", [True, False]) +def test_cut_reverb_rir( + libri_cut_with_supervision, libri_recording_rvb, rir, in_memory +): cut = libri_cut_with_supervision + if in_memory: + rir = rir.move_to_memory() cut_rvb = cut.reverb_rir(rir) assert cut_rvb.start == cut.start assert cut_rvb.duration == cut.duration @@ -676,6 +684,48 @@ def test_cut_reverb_rir(libri_cut_with_supervision, libri_recording_rvb, rir): np.testing.assert_array_almost_equal(cut_rvb.load_audio(), rvb_audio_from_fixture) +@pytest.mark.parametrize("with_serialization", [True, False]) +def test_cut_reverb_rir_input_is_cut( + libri_cut_with_supervision, libri_recording_rvb, rir, with_serialization +): + + cut = libri_cut_with_supervision + rir = rir.to_cut() + + with ( + NamedTemporaryFile(suffix=".jsonl", mode="w") + if with_serialization + else nullcontext() + ) as f: + if with_serialization: + CutSet([rir]).to_file(f.name) + f.flush() + os.fsync(f.fileno()) + rir = CutSet.from_file(f.name)[0] + + cut_rvb = cut.reverb_rir(rir) + assert cut_rvb.start == cut.start + assert cut_rvb.duration == cut.duration + assert cut_rvb.end == cut.end + assert cut_rvb.num_samples == cut.num_samples + + assert cut_rvb.recording.duration == cut.recording.duration + assert cut_rvb.recording.num_samples == cut.recording.num_samples + + assert cut_rvb.supervisions[0].start == cut.supervisions[0].start + assert cut_rvb.supervisions[0].duration == cut.supervisions[0].duration + assert cut_rvb.supervisions[0].end == cut.supervisions[0].end + + assert cut_rvb.load_audio().shape == cut.load_audio().shape + assert cut_rvb.recording.load_audio().shape == cut.recording.load_audio().shape + + rvb_audio_from_fixture = libri_recording_rvb.load_audio() + + np.testing.assert_array_almost_equal( + cut_rvb.load_audio(), rvb_audio_from_fixture + ) + + def test_cut_reverb_rir_assert_sampling_rate(libri_cut_with_supervision, rir): cut = libri_cut_with_supervision rir_new = rir.resample(8000) From bbb3fccd3022b3e2cee5f510f452c51a45439723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 16 May 2024 15:56:46 -0400 Subject: [PATCH 25/69] Use libsndfile in recording chunk dataset (#1335) --- lhotse/dataset/unsupervised.py | 80 +++++----------------------------- 1 file changed, 10 insertions(+), 70 deletions(-) diff --git a/lhotse/dataset/unsupervised.py b/lhotse/dataset/unsupervised.py index 2ac215d1d..e3d91239e 100644 --- a/lhotse/dataset/unsupervised.py +++ b/lhotse/dataset/unsupervised.py @@ -1,7 +1,7 @@ import math -import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional +import numpy as np import torch from torch.utils.data import IterableDataset @@ -165,12 +165,6 @@ def __init__( self.validate() def validate(self) -> None: - if not torchaudio_supports_ffmpeg(): - raise RuntimeError( - "Using FFMPEG streamer backend for reading is supported only " - "with PyTorch 1.12+ and torchaudio 0.12+" - ) - for r in self.recordings: assert ( len(r.sources) == 1 @@ -183,25 +177,21 @@ def validate(self) -> None: ), f"We currently only support single-channel audio in this dataset (got {r.num_channels} channels in recording {r.id})." def __iter__(self): - import torchaudio + import soundfile as sf for r in self.recordings[self.start : self.end]: chunk_size = compute_num_samples(self.chunk_size, r.sampling_rate) - chunk_shift = compute_num_samples(self.chunk_shift, r.sampling_rate) - - streamer = torchaudio.io.StreamReader(src=r.sources[0].source) - assert streamer.num_src_streams == 1, ( - "Lhotse doesn't support files with more than one FFMPEG source stream yet " - "(not to be confused with multi-channel)." + chunk_overlap = compute_num_samples( + self.chunk_size - self.chunk_shift, r.sampling_rate ) - streamer.add_basic_audio_stream(frames_per_chunk=chunk_size) begin_time = 0 end_time = self.chunk_size - buffer = ShiftingBuffer(chunk_size=chunk_size, chunk_shift=chunk_shift) - for (incoming_audio,) in streamer.stream(): - buffer.push(incoming_audio.squeeze()) - for chunk in buffer.get_chunks(): + with sf.SoundFile(r.sources[0].source, "rb") as stream: + for chunk in stream.blocks( + chunk_size, overlap=chunk_overlap, dtype=np.float32 + ): + chunk = torch.as_tensor(chunk) yield { "recording_id": r.id, "begin_time": torch.as_tensor(begin_time, dtype=torch.float32), @@ -210,56 +200,6 @@ def __iter__(self): } begin_time += self.chunk_shift end_time = begin_time + self.chunk_size - remainder = buffer.flush() - if remainder.shape[0] > 0: - yield { - "recording_id": r.id, - "begin_time": torch.as_tensor(begin_time, dtype=torch.float32), - "end_time": torch.as_tensor(end_time, dtype=torch.float32), - "audio": remainder, - } - - -class ShiftingBuffer: - """ - Utility for iterating over streaming audio chunks that supports chunk_shift < chunk_size. - It is useful when running model predictions on overlapping chunks of audio data. - """ - - def __init__(self, chunk_size: int, chunk_shift: int): - self.buf = torch.empty(1, 0) - self.chunk_size = chunk_size - self.chunk_shift = chunk_shift - - def push(self, audio: torch.Tensor) -> None: - """Add new chunk of audio to the buffer. Expects shape (num_samples, ).""" - self.buf = torch.cat([self.buf, audio.unsqueeze(0)], dim=1) - - def get_chunks(self) -> torch.Tensor: - """ - Retrieve chunks accumulated so far, adjusted for chunk_shift. - For chunk_shift < chunk_size, there will typically be more chunks - returned from this function than were pushed into the buffer because - of overlap. - The returned shape is (num_chunks, chunk_size). - """ - out, self.buf = _get_strided_batch_streaming( - self.buf, - window_shift=self.chunk_shift, - window_length=self.chunk_size, - snip_edges=True, - ) - return out.squeeze(0) - - def flush(self) -> torch.Tensor: - """ - Flush out the remainder chunk from the buffer. - Typically it will be shorter than chunk_size. - The returned shape is (remainder_size, ). - """ - out = self.buf.squeeze(0) - self.buf = torch.empty(1, 0) - return out def audio_chunk_collate(batch: List[Dict]): From 26c3911556f1426caac283f97e6bc2e116095d16 Mon Sep 17 00:00:00 2001 From: Marc Harkonen Date: Wed, 29 May 2024 19:31:39 +0800 Subject: [PATCH 26/69] Fix librispeech manifest caching (#1343) fix librispeech manifest caching Now passes the correct prefix to `manifests_exist`. --- lhotse/recipes/librispeech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/recipes/librispeech.py b/lhotse/recipes/librispeech.py index faa48f2a9..ceedd8285 100644 --- a/lhotse/recipes/librispeech.py +++ b/lhotse/recipes/librispeech.py @@ -166,7 +166,7 @@ def prepare_librispeech( with ThreadPoolExecutor(num_jobs) as ex: for part in tqdm(dataset_parts, desc="Dataset parts"): logging.info(f"Processing LibriSpeech subset: {part}") - if manifests_exist(part=part, output_dir=output_dir): + if manifests_exist(part=part, output_dir=output_dir, prefix="librispeech"): logging.info(f"LibriSpeech subset: {part} already prepared - skipping.") continue recordings = [] From c7785201dca9ab8284dce18135e6d2b0fd5bad7a Mon Sep 17 00:00:00 2001 From: Triplecq Date: Wed, 29 May 2024 16:43:01 -0400 Subject: [PATCH 27/69] Add the ReazonSpeech recipe (#1330) * Add stub ReazonSpeech recipe I created this recipe by copying "aishell4" recipe, and stripping the most of the contents. Signed-off-by: Fujimoto Seiji * Format the script with black to meet style guidelines * Add ReazonSpeech to the dataset table * Add a download method and refactor the prepare function * Fix the TypeError when download the subset * Format to follow the code style * Change to local import * Format to follow the code style --------- Signed-off-by: Fujimoto Seiji Co-authored-by: Fujimoto Seiji Co-authored-by: Chen Co-authored-by: root --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/reazonspeech.py | 52 +++++ lhotse/recipes/__init__.py | 3 + lhotse/recipes/reazonspeech.py | 243 +++++++++++++++++++++++ 5 files changed, 301 insertions(+) create mode 100644 lhotse/bin/modes/recipes/reazonspeech.py create mode 100644 lhotse/recipes/reazonspeech.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 70393cc38..9f037380b 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -165,6 +165,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_nsc` * - People's Speech - :func:`lhotse.recipes.prepare_peoples_speech` + * - ReazonSpeech + - :func:`lhotse.recipes.prepare_reazonspeech` * - RIRs and Noises Corpus (OpenSLR 28) - :func:`lhotse.recipes.prepare_rir_noise` * - Speech Commands diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index b5fe10981..20f9b5c9a 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -63,6 +63,7 @@ from .nsc import * from .peoples_speech import * from .primewords import * +from .reazonspeech import * from .rir_noise import * from .slu import * from .speechcommands import * diff --git a/lhotse/bin/modes/recipes/reazonspeech.py b/lhotse/bin/modes/recipes/reazonspeech.py new file mode 100644 index 000000000..125098c89 --- /dev/null +++ b/lhotse/bin/modes/recipes/reazonspeech.py @@ -0,0 +1,52 @@ +import logging +from typing import List + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.reazonspeech import ( + REAZONSPEECH, + download_reazonspeech, + prepare_reazonspeech, +) +from lhotse.utils import Pathlike + +__all__ = ["reazonspeech"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +def reazonspeech( + corpus_dir: Pathlike, + output_dir: Pathlike, + num_jobs: int, +): + """ReazonSpeech ASR data preparation.""" + logging.basicConfig(level=logging.INFO) + prepare_reazonspeech(corpus_dir, output_dir=output_dir, num_jobs=num_jobs) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +@click.option( + "--subset", + type=click.Choice(("auto",) + REAZONSPEECH), + multiple=True, + default=["auto"], + help="List of dataset parts to prepare (default: small-v1). To prepare multiple parts, pass each with `--subset` " + "Example: `--subset all", +) +def reazonspeech(target_dir: Pathlike, subset: List[str]): + """ReazonSpeech download.""" + logging.basicConfig(level=logging.INFO) + if "auto" in subset: + subset = "auto" + download_reazonspeech(target_dir, dataset_parts=subset) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 99bde7d97..5e59b613c 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -64,6 +64,7 @@ from .musan import download_musan, prepare_musan from .nsc import prepare_nsc from .peoples_speech import prepare_peoples_speech +from .reazonspeech import download_reazonspeech, prepare_reazonspeech from .rir_noise import download_rir_noise, prepare_rir_noise from .slu import prepare_slu from .speechcommands import download_speechcommands, prepare_speechcommands @@ -180,6 +181,8 @@ "prepare_musan", "prepare_nsc", "prepare_peoples_speech", + "download_reazonspeech", + "prepare_reazonspeech", "download_rir_noise", "prepare_rir_noise", "prepare_slu", diff --git a/lhotse/recipes/reazonspeech.py b/lhotse/recipes/reazonspeech.py new file mode 100644 index 000000000..bebae77ff --- /dev/null +++ b/lhotse/recipes/reazonspeech.py @@ -0,0 +1,243 @@ +""" +ReazonSpeech is an open-source dataset that contains a diverse set of natural Japanese speech, +collected from terrestrial television streams. It contains more than 35,000 hours of audio. + +The dataset is available on Hugging Face. For more details, please visit: + +Dataset: https://huggingface.co/datasets/reazon-research/reazonspeech +Paper: https://research.reazon.jp/_static/reazonspeech_nlp2023.pdf +""" + +import json +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +from tqdm.auto import tqdm + +from lhotse import CutSet, fix_manifests, validate_recordings_and_supervisions +from lhotse.audio import Recording, RecordingSet +from lhotse.parallel import parallel_map +from lhotse.recipes.utils import manifests_exist, read_manifests_if_cached +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, is_module_available + +REAZONSPEECH = ( + "tiny", + "small", + "medium", + "large", + "all", + "small-v1", + "medium-v1", + "all-v1", +) + +PUNCTUATIONS = {ord(x): "" for x in "、。「」『』,,?!!!?!?"} +ZENKAKU = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +HANKAKU = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +ZEN2HAN = str.maketrans(ZENKAKU, HANKAKU) + + +def normalize(s): + """ + Convert full-width characters to half-width, and remove punctuations. + :param s: str, input string. + :return: str, normalized string. + """ + if is_module_available("num2words"): + import num2words + else: + raise ImportError( + "To process the ReazonSpeech corpus, please install optional dependency: pip install num2words" + ) + s = s.translate(PUNCTUATIONS).translate(ZEN2HAN) + conv = lambda m: num2words.num2words(m.group(0), lang="ja") + return re.sub(r"\d+\.?\d*", conv, s) + + +def write_to_json(data, filename): + """ + Writes data to a JSON file. + :param data: The data to write. + :param filename: The name of the file to write to. + """ + + with open(filename, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + +def download_reazonspeech( + target_dir: Pathlike = ".", + dataset_parts: Optional[Union[str, Sequence[str]]] = "auto", +) -> Path: + """ + Download the ReazonSpeech dataset. + :param target_dir: Pathlike, the path of the dir to storage the dataset. + :param dataset_parts: the parts of the dataset to download (e.g. small, medium, or large). + :return: the path to downloaded data and the JSON file. + """ + if is_module_available("datasets"): + from datasets import load_dataset + else: + raise ImportError( + "To process the ReazonSpeech corpus, please install optional dependency: pip install datasets" + ) + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + corpus_dir = target_dir / "ReazonSpeech" + + if dataset_parts == "auto": + dataset_parts = ("small-v1",) + elif isinstance(dataset_parts, str): + dataset_parts = [dataset_parts] + + for part in dataset_parts: + logging.info(f"Downloading ReazonSpeech part: {part}") + ds = load_dataset( + "reazon-research/reazonspeech", + part, + trust_remote_code=True, + cache_dir=corpus_dir, + )["train"] + + # Prepare data for JSON export + data_for_json = [] + idx = 0 + for item in ds: + # Calculate the duration of the audio file + audio_array = item["audio"]["array"] + sampling_rate = item["audio"]["sampling_rate"] + duration = len(audio_array) / float(sampling_rate) + + # Create a dictionary for the current record + record = { + "id": str(idx), + "audio_filepath": item["audio"]["path"], + "text": normalize(item["transcription"]), + "duration": duration, + } + + # Append the record to the list + data_for_json.append(record) + idx += 1 + + # Write data to a JSON file + write_to_json(data_for_json, corpus_dir / "dataset.json") + + return corpus_dir + + +def prepare_reazonspeech( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike], + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + :param corpus_dir: Pathlike, the path of the data dir. + :param output_dir: Pathlike, the path where to write the manifests. + :param num_jobs: int, number of parallel threads used for 'parse_utterance' calls. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + # Split the dataset into train, dev, and test + with open(corpus_dir / "dataset.json", "r", encoding="utf-8") as file: + full = json.load(file) + dev = full[:1000] + test = full[1000:1100] + train = full[1100:] + + write_to_json(train, corpus_dir / "train.json") + write_to_json(dev, corpus_dir / "dev.json") + write_to_json(test, corpus_dir / "test.json") + + parts = ("train", "dev", "test") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Maybe some manifests already exist: we can read them and save a bit of preparation time. + manifests = read_manifests_if_cached( + dataset_parts=parts, + output_dir=output_dir, + prefix="reazonspeech", + suffix="jsonl.gz", + lazy=True, + ) + + for part in parts: + logging.info(f"Processing ReazonSpeech subset: {part}") + if manifests_exist( + part=part, output_dir=output_dir, prefix="reazonspeech", suffix="jsonl.gz" + ): + logging.info(f"ReazonSpeech subset: {part} already prepared - skipping.") + continue + + filename = corpus_dir / f"{part}.json" + with open(filename, "r", encoding="utf-8") as file: + items = json.load(file) + + with RecordingSet.open_writer( + output_dir / f"reazonspeech_recordings_{part}.jsonl.gz" + ) as rec_writer, SupervisionSet.open_writer( + output_dir / f"reazonspeech_supervisions_{part}.jsonl.gz" + ) as sup_writer, CutSet.open_writer( + output_dir / f"reazonspeech_cuts_{part}.jsonl.gz" + ) as cut_writer: + for recording, segment in tqdm( + parallel_map( + parse_utterance, + items, + num_jobs=num_jobs, + ), + desc="Processing reazonspeech JSON entries", + ): + # Fix and validate the recording + supervisions + recordings, segments = fix_manifests( + recordings=RecordingSet.from_recordings([recording]), + supervisions=SupervisionSet.from_segments([segment]), + ) + validate_recordings_and_supervisions( + recordings=recordings, supervisions=segments + ) + # Create the cut since most users will need it anyway. + # There will be exactly one cut since there's exactly one recording. + cuts = CutSet.from_manifests( + recordings=recordings, supervisions=segments + ) + # Write the manifests + rec_writer.write(recordings[0]) + sup_writer.write(segments[0]) + cut_writer.write(cuts[0]) + + manifests[part] = { + "recordings": RecordingSet.from_jsonl_lazy(rec_writer.path), + "supervisions": SupervisionSet.from_jsonl_lazy(sup_writer.path), + "cuts": CutSet.from_jsonl_lazy(cut_writer.path), + } + + return dict(manifests) + + +def parse_utterance(item: Any) -> Optional[Tuple[Recording, SupervisionSegment]]: + """ + Process a single utterance from the ReazonSpeech dataset. + :param item: The utterance to process. + :return: A tuple containing the Recording and SupervisionSegment. + """ + recording = Recording.from_file(item["audio_filepath"], recording_id=item["id"]) + segments = SupervisionSegment( + id=item["id"], + recording_id=item["id"], + start=0.0, + duration=item["duration"], + channel=0, + language="Japanese", + text=item["text"], + ) + return recording, segments From c1d443288373ceaf50aee665a9bc9831a82356a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 31 May 2024 11:34:41 -0400 Subject: [PATCH 28/69] Fix one-off edge case in split_lazy (#1347) --- lhotse/utils.py | 8 ++++++-- test/test_manipulation.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/lhotse/utils.py b/lhotse/utils.py index 23ed1fd45..fb171537f 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -310,9 +310,13 @@ def split_manifest_lazy( if prefix == "": prefix = "split" - items = iter(it) split_idx = start_idx splits = [] + items = iter(it) + try: + item = next(items) + except StopIteration: + return splits while True: try: written = 0 @@ -321,9 +325,9 @@ def split_manifest_lazy( (output_dir / prefix).with_suffix(f".{idx}.jsonl.gz") ) as writer: while written < chunk_size: - item = next(items) writer.write(item) written += 1 + item = next(items) split_idx += 1 except StopIteration: break diff --git a/test/test_manipulation.py b/test/test_manipulation.py index 6f692ed9c..9afd04dfe 100644 --- a/test/test_manipulation.py +++ b/test/test_manipulation.py @@ -113,6 +113,17 @@ def test_split_lazy_even(manifest_type): ) +def test_split_lazy_edge_case_extra_shard(tmp_path): + N = 512 + chsz = 32 + nshrd = 16 + manifest = DummyManifest(CutSet, begin_id=0, end_id=N - 1) + manifest_subsets = manifest.split_lazy(output_dir=tmp_path, chunk_size=chsz) + assert len(manifest_subsets) == nshrd + for item in sorted(tmp_path.glob("*")): + print(item) + + @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) def test_combine(manifest_type): expected = DummyManifest(manifest_type, begin_id=0, end_id=200) From f29e30be7a7566bb722d25fac77648610876f4f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 4 Jun 2024 12:04:19 -0400 Subject: [PATCH 29/69] Increase the start diff tolerance for feature loading (#1349) --- lhotse/features/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/features/base.py b/lhotse/features/base.py index 60bb909d5..ef4b161a1 100644 --- a/lhotse/features/base.py +++ b/lhotse/features/base.py @@ -472,7 +472,7 @@ def load( start = self.start # In case the caller requested only a sub-span of the features, trim them. # Left trim - if start < self.start - 1e-5: + if start < self.start - 1e-3: raise ValueError( f"Cannot load features for recording {self.recording_id} starting from {start}s. " f"The available range is ({self.start}, {self.end}) seconds." From c542e29ee039862544c2b5de2a40e78df8c2bd02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 5 Jun 2024 14:48:07 -0400 Subject: [PATCH 30/69] More test coverage for lhotse subset (#1345) --- lhotse/cut/set.py | 4 ++-- lhotse/features/base.py | 5 +++-- lhotse/utils.py | 2 +- test/test_manipulation.py | 23 +++++++++++++++-------- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 58900306b..e0a2f0e16 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -916,9 +916,9 @@ def subset( if last is not None: assert last > 0 - if last > len(self): - return self N = len(self) + if last > N: + return self return CutSet.from_cuts(islice(self, N - last, N)) if supervision_ids is not None: diff --git a/lhotse/features/base.py b/lhotse/features/base.py index ef4b161a1..c8cba931a 100644 --- a/lhotse/features/base.py +++ b/lhotse/features/base.py @@ -701,9 +701,10 @@ def subset( if last is not None: assert last > 0 - if last > len(self): + N = len(self) + if last > N: return self - return FeatureSet.from_features(self.features[-last:]) + return FeatureSet.from_items(islice(self, N - last, N)) def find( self, diff --git a/lhotse/utils.py b/lhotse/utils.py index fb171537f..be891ede5 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -616,7 +616,7 @@ class nullcontext(AbstractContextManager): Note(pzelasko): This is copied from Python 3.7 stdlib so that we can use it in 3.6. """ - def __init__(self, enter_result=None): + def __init__(self, enter_result=None, *args, **kwargs): self.enter_result = enter_result def __enter__(self): diff --git a/test/test_manipulation.py b/test/test_manipulation.py index 9afd04dfe..ceff26e5d 100644 --- a/test/test_manipulation.py +++ b/test/test_manipulation.py @@ -10,6 +10,7 @@ from lhotse.manipulation import combine from lhotse.supervision import SupervisionSet from lhotse.testing.dummies import DummyManifest, as_lazy +from lhotse.utils import nullcontext @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) @@ -144,19 +145,25 @@ def test_combine(manifest_type): @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) -def test_subset_first(manifest_type): +@mark.parametrize("lazy", [False, True]) +def test_subset_first(manifest_type, lazy): + ctx = as_lazy if lazy else nullcontext any_set = DummyManifest(manifest_type, begin_id=0, end_id=200) - expected = DummyManifest(manifest_type, begin_id=0, end_id=10) - subset = any_set.subset(first=10) - assert subset == expected + with ctx(any_set, ".jsonl") as any_set: + expected = DummyManifest(manifest_type, begin_id=0, end_id=10) + subset = any_set.subset(first=10) + assert subset == expected @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) -def test_subset_last(manifest_type): +@mark.parametrize("lazy", [False, True]) +def test_subset_last(manifest_type, lazy): + ctx = as_lazy if lazy else nullcontext any_set = DummyManifest(manifest_type, begin_id=0, end_id=200) - expected = DummyManifest(manifest_type, begin_id=190, end_id=200) - subset = any_set.subset(last=10) - assert subset == expected + with ctx(any_set, ".jsonl") as any_set: + expected = DummyManifest(manifest_type, begin_id=190, end_id=200) + subset = any_set.subset(last=10) + assert subset == expected @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) From cf6cde80100ea8b1f951cb70899ecc5d1ed200f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 5 Jun 2024 14:57:59 -0400 Subject: [PATCH 31/69] Dynamic bucket selection rng sync (#1341) * Support syncing dynamic bucket selection across DDP ranks * fix * Fix dataset tail iteration with sync_buckets=True * Tests. Fix the sync_buckets support for map-style dataset usage and early bucket depletion cases. --- lhotse/dataset/sampling/dynamic.py | 2 - lhotse/dataset/sampling/dynamic_bucketing.py | 184 +++++++++++++++--- .../sampling/test_dynamic_bucketing.py | 162 ++++++++++++--- 3 files changed, 297 insertions(+), 51 deletions(-) diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index a42726d4a..2d36b4130 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -134,7 +134,6 @@ def __init__( self.consistent_ids = consistent_ids self.shuffle_buffer_size = shuffle_buffer_size self.quadratic_duration = quadratic_duration - self.rng = None if strict is not None: warnings.warn( @@ -195,7 +194,6 @@ def __iter__(self) -> "DynamicCutSampler": # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() seed = resolve_seed(self.seed) - self.rng = random.Random(seed + self.epoch) # Initiate iteration self.cuts_iter = [iter(cs) for cs in self.cuts] # Optionally shuffle diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index 7dcbfd8d8..cf4da23f2 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -2,9 +2,11 @@ import warnings from bisect import bisect_right from collections import deque +from dataclasses import dataclass from itertools import islice from typing import ( Any, + Callable, Deque, Dict, Generator, @@ -18,6 +20,7 @@ ) import numpy as np +import torch from lhotse import CutSet, Seconds from lhotse.cut import Cut @@ -91,6 +94,7 @@ def __init__( world_size: Optional[int] = None, rank: Optional[int] = None, seed: Union[int, Literal["randomized", "trng"]] = 0, + sync_buckets: bool = True, strict=None, shuffle_buffer_size=None, ) -> None: @@ -125,6 +129,8 @@ def __init__( :param quadratic_duration: When set, it adds an extra penalty that's quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers. + :param sync_buckets: When set, we'll try to make each DDP rank sample from as close + duration buckets as possible to minimize the tail worker effect. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. @@ -147,6 +153,7 @@ def __init__( self.num_cuts_for_bins_estimate = num_cuts_for_bins_estimate self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration + self.sync_buckets = sync_buckets self.rng = None check_constraint(constraint, max_duration, max_cuts) @@ -238,6 +245,20 @@ def __iter__(self) -> "DynamicBucketingSampler": return self seed = resolve_seed(self.seed) self.rng = random.Random(seed + self.epoch) + if self.sync_buckets: + # Bucket sync requested. To achieve that we will fix the RNG seed for a special bucket RNG + # in a deterministic way. We also consider whether the sampler object lives in the training loop + # process (map-style dataset or num_workers=0) or the dataloading subprocess (iterable-style dataset). + # In the latter case, we want each worker to choose different buckets but still be in sync + # with workers with the same IDs on other ranks. + # Note: PyTorch dataloader always iterates workers sequentially, so they won't get out-of-order. + bucket_rng_seed = 1234 + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + bucket_rng_seed += worker_info.id + bucket_rng = random.Random(bucket_rng_seed) + else: + bucket_rng = None # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps @@ -255,6 +276,7 @@ def __iter__(self) -> "DynamicBucketingSampler": cuts_iter = DynamicBucketer( cuts_iter, duration_bins=self.duration_bins, + world_size=self.world_size, max_duration=self.max_duration, max_cuts=self.max_cuts, constraint=self.constraint, @@ -263,6 +285,7 @@ def __iter__(self) -> "DynamicBucketingSampler": quadratic_duration=self.quadratic_duration, shuffle=self.shuffle, rng=self.rng, + bucket_rng=bucket_rng, diagnostics=self.diagnostics, ) self.cuts_iter = iter(cuts_iter) @@ -337,11 +360,50 @@ def estimate_duration_buckets( return bins +class BucketSelectionState: + """ + Helper class used in the context of bucket selection synchronization across DDP ranks. + It's only necessary when using a map-style dataset (i.e., the sampler lives in the training loop process) + and world_size is greater than 1. In these cases we have to use the same bucket idx ``world_size`` times + to ensure each rank uses the same bucket. This is due to how CutSampler distributes mini-batches + across ranks, ensuring the number of steps is always equal for each rank. + """ + + def __init__( + self, bucket_rng: random.Random, num_buckets: int, world_size: int + ) -> None: + self._bucket_rng = bucket_rng + self._num_buckets = num_buckets + self._world_size = world_size + self._usage_count = 0 + self._bucket_idx = None + + def select_bucket_idx(self) -> int: + if self._bucket_idx is None or self._usage_count == self._world_size: + self._bucket_idx = self._bucket_rng.randrange(self._num_buckets) + self._usage_count = 0 + self._usage_count += 1 + return self._bucket_idx + + def save(self) -> Dict[str, Any]: + return { + "_bucket_rng": self._bucket_rng.getstate(), + "_bucket_idx": self._bucket_idx, + "_usage_count": self._usage_count, + } + + def restore(self, ckpt: Dict[str, Any]) -> None: + self._bucket_rng.setstate(ckpt["_bucket_rng"]) + self._bucket_idx = ckpt["_bucket_idx"] + self._usage_count = ckpt["_usage_count"] + + class DynamicBucketer: def __init__( self, cuts: Iterable[Union[Cut, Tuple[Cut]]], duration_bins: List[Seconds], + world_size: int, max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, constraint: Optional[SamplingConstraint] = None, @@ -350,10 +412,12 @@ def __init__( quadratic_duration: Optional[Seconds] = None, shuffle: bool = False, rng: random.Random = None, + bucket_rng: random.Random = None, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.cuts = cuts self.duration_bins = duration_bins + self.world_size = world_size self.max_duration = max_duration self.max_cuts = max_cuts self.constraint = constraint @@ -364,6 +428,7 @@ def __init__( if rng is None: rng = random.Random() self.rng = rng + self.bucket_rng = bucket_rng self.shuffle = shuffle assert duration_bins == sorted(duration_bins), ( @@ -402,32 +467,17 @@ def __iter__(self) -> Generator[CutSet, None, None]: self.cuts_iter = iter(self.cuts) self._collect_cuts_in_buckets(self.buffer_size) - # Init: determine which buckets are "ready" - def is_ready(bucket: Deque[Cut]): - tot = self.constraint.copy() - for c in bucket: - tot.add(c[0] if isinstance(c, tuple) else c) - if tot.close_to_exceeding(): - return True - return False + state = BucketSelectionState( + bucket_rng=self.bucket_rng, + num_buckets=len(self.buckets), + world_size=self.world_size, + ) # The iteration code starts here. # On each step we're sampling a new batch. try: while True: - ready_buckets = [b for b in self.buckets if is_ready(b)] - if not ready_buckets: - # No bucket has enough data to yield for the last full batch. - non_empty_buckets = [b for b in self.buckets if b] - if self.drop_last or len(non_empty_buckets) == 0: - # Either the user requested only full batches, or we have nothing left. - raise StopIteration() - else: - # Sample from partial batches that are left. - ready_buckets = non_empty_buckets - # Choose a bucket to sample from. - # We'll only select from the buckets that have a full batch available. - sampling_bucket = self.rng.choice(ready_buckets) + sampling_bucket = self._select_bucket(state) # Apply random shuffling if requested: we'll shuffle the items present within the bucket. maybe_shuffled = sampling_bucket indexes_used = [] @@ -465,7 +515,93 @@ def is_ready(bucket: Deque[Cut]): # Cleanup. self.cuts_iter = None - def _collect_cuts_in_buckets(self, n_cuts: int): + def _select_bucket(self, state: BucketSelectionState) -> Deque[Cut]: + if self.bucket_rng is None: + # Bucket selection algo 1: + # * there is just one RNG for choosing buckets and choosing samples randomly from the buckets + # * check which buckets are ready, and then use the RNG to select one of them. + # * no guarantees about bucket selection sync across GPUs. + ready_buckets = [b for b in self.buckets if self._is_ready(b)] + if not ready_buckets: + # No bucket has enough data to yield for the last full batch. + non_empty_buckets = [b for b in self.buckets if b] + if self.drop_last or len(non_empty_buckets) == 0: + # Either the user requested only full batches, or we have nothing left. + raise StopIteration() + else: + # Sample from partial batches that are left. + ready_buckets = non_empty_buckets + # Choose a bucket to sample from. + # We'll only select from the buckets that have a full batch available. + return self.rng.choice(ready_buckets) + else: + # Bucket selection algo 2: + # * bucket selection has its own independent RNG. + # * when bucket selection RNG is initialized identically on all ranks/workers, + # then each rank will initially select the same bucket for batch sampling + # * if one of the ranks selects a bucket that is not filled enough, + # it will scan the neighbouring buckets until it finds one that's ready + # * if no bucket is ready, we end iteration + + def scan_buckets(predicate: Callable[[Deque[Cut]], bool]) -> int: + bucket_idx = state.select_bucket_idx() + + def valid_idx() -> bool: + return 0 <= bucket_idx < len(self.buckets) + + num_attempts = 0 + seen_min, seen_max = bucket_idx, bucket_idx + while not (valid_idx() and predicate(self.buckets[bucket_idx])): + if seen_min < 0 and seen_max >= len(self.buckets): + raise BucketsDontHaveEnoughData() + num_attempts += 1 + bucket_idx = ( + bucket_idx + (1 if num_attempts % 2 == 0 else -1) * num_attempts + ) + seen_min = min(seen_min, bucket_idx) + seen_max = max(seen_max, bucket_idx) + + return bucket_idx + + # This try/except is first trying to choose a bucket to sample a full mini-batch from, + # and if that fails and drop_last=False, it tries again, this time accepting partial mini-batch. + # Because we have no guarantee that samplers in different ranks will start exhausting the buckets + # at the same time, it takes only a single occurrence of all buckets not being ready to permanently + # run out-of-sync. + # For this reason, we create a checkpoint of the bucket sampling state, and if we go into except + # fallback, we restore this state first to ensure we use the bucket_rng exactly the same number + # of times on each rank, no matter the circumstance. + ckpt = state.save() + try: + # Typical case: at least one bucket has enough data to sample from. + selected_bucket_idx = scan_buckets(self._is_ready) + except BucketsDontHaveEnoughData: + # We didn't hit the typical case either because we are finishing + # the epoch, or because the buffers are too small. + if self.drop_last: + # The user doesn't want partial mini-batches: early exit. + raise StopIteration() + # The user wants to iterate the full dataset. + # We'll try again, this time accepting buckets that have any amount of data available, + # which may yield partial batches. + try: + state.restore(ckpt) + selected_bucket_idx = scan_buckets(lambda b: len(b) > 0) + except BucketsDontHaveEnoughData: + # We exhausted the full dataset. + raise StopIteration() + + return self.buckets[selected_bucket_idx] + + def _is_ready(self, bucket: Deque[Cut]) -> bool: + tot = self.constraint.copy() + for c in bucket: + tot.add(c[0] if isinstance(c, tuple) else c) + if tot.close_to_exceeding(): + return True + return False + + def _collect_cuts_in_buckets(self, n_cuts: int) -> None: try: for _ in range(n_cuts): cuts = next(self.cuts_iter) @@ -494,6 +630,10 @@ def pick_at_random( yield bucket[idx] +class BucketsDontHaveEnoughData(Exception): + pass + + def _emit_shuffle_buffer_size_warning(): warnings.warn( "Since Lhotse v1.20 'shuffle_buffer_size' is deprecated, because DynamicBucketingSampler " diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index 9596a4c5b..bb3253b30 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -1,4 +1,7 @@ import random +from itertools import islice + +import pytest from lhotse import CutSet from lhotse.dataset.sampling.dynamic_bucketing import ( @@ -6,7 +9,7 @@ DynamicBucketingSampler, estimate_duration_buckets, ) -from lhotse.testing.dummies import DummyManifest +from lhotse.testing.dummies import DummyManifest, dummy_cut def test_estimate_duration_buckets_2b(): @@ -48,7 +51,9 @@ def test_dynamic_bucketing_drop_last_false(): c.duration = 2 rng = random.Random(0) - sampler = DynamicBucketer(cuts, duration_bins=[2], max_duration=5, rng=rng) + sampler = DynamicBucketer( + cuts, duration_bins=[2], max_duration=5, rng=rng, world_size=1 + ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -84,7 +89,7 @@ def test_dynamic_bucketing_drop_last_true(): rng = random.Random(0) sampler = DynamicBucketer( - cuts, duration_bins=[2], max_duration=5, rng=rng, drop_last=True + cuts, duration_bins=[2], max_duration=5, rng=rng, drop_last=True, world_size=1 ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -133,11 +138,11 @@ def test_dynamic_bucketing_sampler(): assert len(batches[0]) == 2 assert sum(c.duration for c in batches[0]) == 4 - assert len(batches[1]) == 2 - assert sum(c.duration for c in batches[1]) == 4 + assert len(batches[1]) == 5 + assert sum(c.duration for c in batches[1]) == 5 - assert len(batches[2]) == 5 - assert sum(c.duration for c in batches[2]) == 5 + assert len(batches[2]) == 2 + assert sum(c.duration for c in batches[2]) == 4 assert len(batches[3]) == 1 assert sum(c.duration for c in batches[3]) == 2 @@ -177,14 +182,14 @@ def test_dynamic_bucketing_sampler_precomputed_duration_bins(): assert len(batches[0]) == 2 assert sum(c.duration for c in batches[0]) == 4 - assert len(batches[1]) == 2 - assert sum(c.duration for c in batches[1]) == 3 + assert len(batches[1]) == 4 + assert sum(c.duration for c in batches[1]) == 5 assert len(batches[2]) == 2 assert sum(c.duration for c in batches[2]) == 3 - assert len(batches[3]) == 4 - assert sum(c.duration for c in batches[3]) == 5 + assert len(batches[3]) == 2 + assert sum(c.duration for c in batches[3]) == 3 def test_dynamic_bucketing_sampler_max_duration_and_max_cuts(): @@ -353,18 +358,18 @@ def test_dynamic_bucketing_sampler_cut_pairs(): bidx = 1 sc, tc = batches[bidx][0], batches[bidx][1] - assert len(sc) == 2 - assert len(tc) == 2 - assert sum(c.duration for c in sc) == 4 - assert sum(c.duration for c in tc) == 4 - - bidx = 2 - sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 5 assert len(tc) == 5 assert sum(c.duration for c in sc) == 5 assert sum(c.duration for c in tc) == 5 + bidx = 2 + sc, tc = batches[bidx][0], batches[bidx][1] + assert len(sc) == 2 + assert len(tc) == 2 + assert sum(c.duration for c in sc) == 4 + assert sum(c.duration for c in tc) == 4 + bidx = 3 sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 1 @@ -494,15 +499,6 @@ def test_dynamic_bucketing_sampler_cut_triplets(): bidx = 1 c1, c2, c3 = batches[bidx][0], batches[bidx][1], batches[bidx][2] - assert len(c1) == 2 - assert len(c2) == 2 - assert len(c3) == 2 - assert sum(c.duration for c in c1) == 4 - assert sum(c.duration for c in c2) == 4 - assert sum(c.duration for c in c3) == 4 - - bidx = 2 - c1, c2, c3 = batches[bidx][0], batches[bidx][1], batches[bidx][2] assert len(c1) == 5 assert len(c2) == 5 assert len(c3) == 5 @@ -510,6 +506,15 @@ def test_dynamic_bucketing_sampler_cut_triplets(): assert sum(c.duration for c in c2) == 5 assert sum(c.duration for c in c3) == 5 + bidx = 2 + c1, c2, c3 = batches[bidx][0], batches[bidx][1], batches[bidx][2] + assert len(c1) == 2 + assert len(c2) == 2 + assert len(c3) == 2 + assert sum(c.duration for c in c1) == 4 + assert sum(c.duration for c in c2) == 4 + assert sum(c.duration for c in c3) == 4 + bidx = 3 c1, c2, c3 = batches[bidx][0], batches[bidx][1], batches[bidx][2] assert len(c1) == 1 @@ -562,3 +567,106 @@ def test_dynamic_bucketing_quadratic_duration(): b = batches[3] assert len(b) == 1 # single cut assert sum(c.duration for c in b) == 30 # 30s long + + +@pytest.mark.parametrize("sync_buckets", [True, False]) +def test_dynamic_bucketing_sampler_sync_buckets_iterable_dataset_usage(sync_buckets): + # With iterable datasets a sampler replica will be placed in each dataloading worker, + # given world_size=1, and have its data shuffled differently than other replicas. + # To simulate that in this test, we provide a different seed and rank=0 world_size=1. + dur_rng = random.Random(0) + cuts = CutSet( + [ + dummy_cut(i, duration=dur_rng.choices([1, 10], weights=[0.9, 0.1])[0]) + for i in range(10000) + ] + ) + + common = dict( + max_duration=5, + num_buckets=2, + rank=0, + sync_buckets=sync_buckets, + world_size=1, + drop_last=True, + shuffle=True, + duration_bins=[5.0], + ) + s0 = DynamicBucketingSampler(cuts, seed=0, **common) + s1 = DynamicBucketingSampler(cuts, seed=1, **common) + + # check the first 30 mini-batches + batches0 = [b for b in islice(s0, 30)] + batches1 = [b for b in islice(s1, 30)] + cuts0 = CutSet([c for b in batches0 for c in b]) + cuts1 = CutSet([c for b in batches1 for c in b]) + + # Invariant: no duplicated cut IDs across ranks + assert set(cuts0.ids) & set(cuts1.ids) == set() + + if sync_buckets: + matching_ids = [] + # Ensure identical batch sizes and example durations + for bidx, (b0, b1) in enumerate(zip(batches0, batches1)): + assert len(b0) == len(b1), bidx + for c0, c1 in zip(b0, b1): + assert c0.duration == c1.duration + matching_ids.append(c0.id == c1.id) + # At least some IDs are mismatching because despite identical shapes, the actual sampled data is different. + assert not all(matching_ids) + if not sync_buckets: + # some shapes will be mismatched because different buckets were selected. + matching_shapes = [len(b0) == len(b1) for b0, b1 in zip(batches0, batches1)] + assert not all(matching_shapes) + + +@pytest.mark.parametrize("sync_buckets", [True, False]) +def test_dynamic_bucketing_sampler_sync_buckets_map_dataset_usage(sync_buckets): + # With map datasets the sampler lives in the training loop process and must have synced random seed + # with other ranks in DDP. + # The data is de-duplicated by sampling world_size batches and keeping the batch at rank index. + # To simulate that in this test, we provide the same seed, world_size=2 and set rank appropriately. + dur_rng = random.Random(0) + cuts = CutSet( + [ + dummy_cut(i, duration=dur_rng.choices([1, 10], weights=[0.9, 0.1])[0]) + for i in range(10000) + ] + ) + + common = dict( + max_duration=5, + num_buckets=2, + seed=0, + sync_buckets=sync_buckets, + world_size=2, + drop_last=True, + shuffle=True, + duration_bins=[5.0], + ) + s0 = DynamicBucketingSampler(cuts, rank=0, **common) + s1 = DynamicBucketingSampler(cuts, rank=1, **common) + + # check the first 30 mini-batches + batches0 = [b for b in islice(s0, 30)] + batches1 = [b for b in islice(s1, 30)] + cuts0 = CutSet([c for b in batches0 for c in b]) + cuts1 = CutSet([c for b in batches1 for c in b]) + + # Invariant: no duplicated cut IDs across ranks + assert set(cuts0.ids) & set(cuts1.ids) == set() + + if sync_buckets: + matching_ids = [] + # Ensure identical batch sizes and example durations + for bidx, (b0, b1) in enumerate(zip(batches0, batches1)): + assert len(b0) == len(b1), bidx + for c0, c1 in zip(b0, b1): + assert c0.duration == c1.duration + matching_ids.append(c0.id == c1.id) + # At least some IDs are mismatching because despite identical shapes, the actual sampled data is different. + assert not all(matching_ids) + if not sync_buckets: + # some shapes will be mismatched because different buckets were selected. + matching_shapes = [len(b0) == len(b1) for b0, b1 in zip(batches0, batches1)] + assert not all(matching_shapes) From 4d57d53dae0f76279590e6819aa7a30499fa3d57 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 6 Jun 2024 03:17:39 +0800 Subject: [PATCH 32/69] Add new sampler: weighted sampler (#1344) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add file * add a weighted data source to enable sampling based on per-sample weight; do not allow duplicated sample within the same epoch * add a weighted sampler; do not allow lazy mode; do not allow duplicated cut in the same batch * modify init file accordingly * add more documentations * use numpy for sampling; pre-compute the indexes in __iter__ to save time * add more documentation * minor changes to the arguments * remove unused file * add test * add more docs * fix isort * inherit from SimpleCutSampler; remove duplicated code * minor fix * Add changes requested in code review --------- Co-authored-by: Piotr Żelasko --- lhotse/dataset/sampling/__init__.py | 2 + lhotse/dataset/sampling/data_source.py | 77 ++++++++++- lhotse/dataset/sampling/weighted_simple.py | 147 +++++++++++++++++++++ test/dataset/sampling/test_sampling.py | 49 +++++++ 4 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 lhotse/dataset/sampling/weighted_simple.py diff --git a/lhotse/dataset/sampling/__init__.py b/lhotse/dataset/sampling/__init__.py index 2644d7fe7..c7c16293b 100644 --- a/lhotse/dataset/sampling/__init__.py +++ b/lhotse/dataset/sampling/__init__.py @@ -12,6 +12,7 @@ from .simple import SimpleCutSampler from .stateless import StatelessSampler from .utils import find_pessimistic_batches, report_padding_ratio_estimate +from .weighted_simple import WeightedSimpleCutSampler from .zip import ZipSampler __all__ = [ @@ -25,6 +26,7 @@ "DynamicBucketingSampler", "RoundRobinSampler", "SimpleCutSampler", + "WeightedSimpleCutSampler", "StatelessSampler", "ZipSampler", "find_pessimistic_batches", diff --git a/lhotse/dataset/sampling/data_source.py b/lhotse/dataset/sampling/data_source.py index 154ae1c1f..c6ba73926 100644 --- a/lhotse/dataset/sampling/data_source.py +++ b/lhotse/dataset/sampling/data_source.py @@ -1,6 +1,8 @@ import random from collections import deque -from typing import Optional +from typing import List, Optional + +import numpy as np from lhotse import CutSet from lhotse.cut import Cut @@ -98,3 +100,76 @@ def __next__(self) -> Cut: def __len__(self) -> int: return len(self._shuffled_items) + + +class WeightedDataSource(DataSource): + """ + An iterator wrapper over CutSet that helps with the sampling process: + it allows for deterministic re-shuffling of elements and "returning" + sampled elements to be yielded again. + + Every cut has a sampling weight. At the beginning of each epoch, we + pre-compute the indexes by sampling from multi-nomial distribution without + replacement. The data source will be exhausted if the number of drawn cuts + exceed num_samples + """ + + def __init__(self, items: CutSet, weights: List, num_samples: int): + """The constructor of the weighted data source + + Args: + items (CutSet): The cutset itself + weights (List): A list of values representing the weight of each cut. All values must be positive + num_samples (int): The number of samples to be drawn. Must smaller than the total number of cuts + """ + super().__init__(items=items) + assert len(items) == len(weights), "The length should match" + assert num_samples < len( + weights + ), "The number of samples to be drawn should not exceed the dataset size" + + # normalize the weight + weights = np.array(weights) + weights = weights / weights.sum() + + self.weights = weights + self.num_samples = num_samples + self.sampled_indexes = None + + def reset(self) -> None: + """Reset the iterable state of DataSource.""" + self._iter = None + self.sampled_indexes = None + self._reusable.clear() + self._remaining_duration = self._total_duration + self.remaining_cuts = self._total_cuts + + def fast_forward(self, steps: int) -> None: + """Advance the data source by ``steps`` amount of steps.""" + assert steps >= 0 + iter(self) + for i in range(steps): + next(self.sampled_indexes) + + def __iter__(self) -> "WeightedDataSource": + self.reset() + self._iter = iter(self._shuffled_items) + self.sampled_indexes = np.random.choice( + len(self.weights), + self.num_samples, + p=self.weights, + replace=False, + ) + self.sampled_indexes = iter(self.sampled_indexes) + return self + + def __next__(self) -> Cut: + if self._reusable: + next_cut = self._reusable.popleft() + else: + next_cut = self._orig_items[next(self.sampled_indexes)] + + if not self.is_lazy: + self._remaining_duration -= next_cut.duration + self.remaining_cuts -= 1 + return next_cut diff --git a/lhotse/dataset/sampling/weighted_simple.py b/lhotse/dataset/sampling/weighted_simple.py new file mode 100644 index 000000000..7c3f76034 --- /dev/null +++ b/lhotse/dataset/sampling/weighted_simple.py @@ -0,0 +1,147 @@ +import warnings +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Seconds +from lhotse.dataset.sampling.base import TimeConstraint +from lhotse.dataset.sampling.data_source import WeightedDataSource +from lhotse.dataset.sampling.simple import SimpleCutSampler + + +class WeightedSimpleCutSampler(SimpleCutSampler): + """ + Samples cuts from a CutSet, where the sampling prob is given by a list. + To enable global sampling, cuts must be in eager mode. + + When performing sampling, it avoids having duplicated cuts in the same batch. + The sampler terminates if the number of sampled cuts reach :attr:`num_samples` + + When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + the batch size is dynamic. + + Example usage: + + >>> dataset = K2SpeechRecognitionDataset(cuts) + >>> weights = get_weights(cuts) + >>> sampler = WeightedSimpleCutSampler(cuts, weights, num_samples=100, max_duration=200.0) + >>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) + >>> for epoch in range(start_epoch, n_epochs): + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + cuts: CutSet, + cuts_weight: List, + num_samples: int, + max_duration: Seconds = None, + max_cuts: Optional[int] = None, + shuffle: bool = False, + drop_last: bool = False, + world_size: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + ): + """ + WeightedSimpleCutSampler's constructor + + :param cuts: the ``CutSet`` to sample data from. + :param cuts_weight: the weight of each cut for sampling. + :param num_samples: the number of samples to be drawn. + :param max_duration: The maximum total recording duration from ``cuts``. + :param max_cuts: The maximum number of cuts sampled to form a mini-batch. + By default, this constraint is off. + :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. + Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: + `for epoch in range(10): for batch in dataset: ...` as every epoch will see a + different cuts order. + :param drop_last: When ``True``, the last batch is dropped if it's incomplete. + :param world_size: Total number of distributed nodes. We will try to infer it by default. + :param rank: Index of distributed node. We will try to infer it by default. + :param seed: Random seed used to consistently shuffle the dataset across different processes. + """ + super().__init__( + cuts=cuts, + drop_last=drop_last, + shuffle=shuffle, + world_size=world_size, + rank=rank, + max_duration=max_duration, + max_cuts=max_cuts, + seed=seed, + ) + assert not cuts.is_lazy, "This sampler does not support lazy mode!" + self.data_source = WeightedDataSource( + cuts, weights=cuts_weight, num_samples=num_samples + ) + + self.weights = cuts_weight + self.num_samples = num_samples + + def state_dict(self) -> Dict[str, Any]: + """ + Return the current state of the sampler in a state_dict. + Together with ``load_state_dict()``, this can be used to restore the + training loop's state to the one stored in the state_dict. + """ + state_dict = super().state_dict() + state_dict.update( + { + "time_constraint": self.time_constraint.state_dict(), + "weights": self.weights, + "num_samples": self.num_samples, + } + ) + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Restore the state of the sampler that is described in a state_dict. + This will result in the sampler yielding batches from where the previous training left it off. + + .. caution:: + The samplers are expected to be initialized with the same CutSets, + but this is not explicitly checked anywhere. + + .. caution:: + The input ``state_dict`` is being mutated: we remove each consumed key, and expect + it to be empty at the end of loading. If you don't want this behavior, pass a copy + inside of this function (e.g., using ``import deepcopy``). + + .. note:: + For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be + handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). + """ + time_constraint = TimeConstraint(**state_dict.pop("time_constraint")) + if self.time_constraint != time_constraint: + warnings.warn( + "SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n" + f"expected {self.time_constraint}\n" + f"received {time_constraint}\n" + f"We will overwrite the settings with the received state_dict." + ) + self.time_constraint = time_constraint + + super().load_state_dict(state_dict) + + # Restore the data source's state + self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts) + + self.weights = state_dict.pop("weights") + self.num_samples = state_dict.pop("num_samples") + + def __iter__(self) -> "WeightedSimpleCutSampler": + """ + Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. + """ + # Restored state with load_state_dict()? Skip resetting only this once. + if self._just_restored_state: + return self + # Why reset the current epoch? + # Either we are iterating the epoch for the first time and it's a no-op, + # or we are iterating the same epoch again, in which case setting more steps + # than are actually available per epoch would have broken the checkpoint restoration. + self.diagnostics.reset_current_epoch() + # Reset the state to the beginning of the epoch. + iter(self.data_source) + return self diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 820cd2ad1..74736794e 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -25,6 +25,7 @@ BucketingSampler, CutPairsSampler, SimpleCutSampler, + WeightedSimpleCutSampler, ZipSampler, ) from lhotse.dataset.sampling.base import SamplingDiagnostics, TimeConstraint @@ -1024,6 +1025,54 @@ def test_cut_pairs_sampler_lazy_shuffle(sampler_cls): assert [c.id for c in sampled_src_cuts] != [c.id for c in lazy_cuts] +def test_weighted_sampler_num_samples(): + cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) + weight = [random.random() for i in range(100)] + num_samples = 32 + + sampler = WeightedSimpleCutSampler( + cut_set, + weight, + num_samples=num_samples, + max_duration=10.0, + drop_last=True, + ) + + sampled_cuts = [] + num_cuts = 0 + for batch in sampler: + sampled_cuts.extend(batch) + num_cuts += len(batch) + + assert num_cuts <= num_samples + + +def test_weighted_sampler_across_epochs(): + cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) + weight = [random.random() for i in range(100)] + num_samples = 32 + + sampler = WeightedSimpleCutSampler( + cut_set, + weight, + num_samples=num_samples, + max_duration=10.0, + drop_last=True, + ) + + # 1st epoch + sampler.set_epoch(1) + batch = next(iter(sampler)) + cut_ids1 = [c.id for c in batch] + + # 2st epoch + sampler.set_epoch(2) + batch = next(iter(sampler)) + cut_ids2 = [c.id for c in batch] + + assert set(cut_ids1) != set(cut_ids2) + + @pytest.mark.parametrize("datasize", [10, 1000, 20000]) @pytest.mark.parametrize("bufsize", [100, 1000, 10000]) def test_streaming_shuffle(datasize, bufsize): From 866e4a80b0a4a2ea1f44b796f5ffa64a603431d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 10 Jun 2024 14:29:53 -0400 Subject: [PATCH 33/69] Support for reading data from AIStore using Python SDK (#1354) * Support for reading data from AIStore using Python SDK * More AIStore related docs --- README.md | 2 ++ docs/getting-started.rst | 5 +++ lhotse/serialization.py | 70 +++++++++++++++++++++++++++++++++------- lhotse/utils.py | 9 ++++++ 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index b5dfbc5ee..3d4bb17f6 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,7 @@ Lhotse uses several environment variables to customize it's behavior. They are a - `LHOTSE_LEGACY_OPUS_LOADING` - (`=1`) reverts to a legacy OPUS loading mechanism that triggered a new ffmpeg subprocess for each OPUS file. - `LHOTSE_PREPARING_RELEASE` - used internally by developers when releasing a new version of Lhotse. - `TORCHAUDIO_USE_BACKEND_DISPATCHER` - when set to `1` and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio. +- `AIS_ENDPOINT` is read by AIStore client to determine AIStore endpoint URL. Required for AIStore dataloading. - `RANK`, `WORLD_SIZE`, `WORKER`, and `NUM_WORKERS` are internally used to inform Lhotse Shar dataloading subprocesses. - `READTHEDOCS` is internally used for documentation builds. @@ -121,6 +122,7 @@ Lhotse uses several environment variables to customize it's behavior. They are a - `pip install lhotse[webdataset]`. We support "compiling" your data into WebDataset tarball format for more effective IO. You can still interact with the data as if it was a regular lazy CutSet. To learn more, check out the following tutorial: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/02-webdataset-integration.ipynb) - `pip install h5py` if you want to extract speech features and store them as HDF5 arrays. - `pip install dill`. When `dill` is installed, we'll use it to pickle CutSet that uses a lambda function in calls such as `.map` or `.filter`. This is helpful in PyTorch DataLoader with `num_jobs>0`. Without `dill`, depending on your environment, you'll see an exception or a hanging script. +- `pip install aistore` to read manifests, tar fles, and other data from AIStore using AIStore-supported URLs (set `AIS_ENDPOINT` environment variable to activate it). See [AIStore documentation](https://aiatscale.org) for more details. - `pip install smart_open` to read and write manifests and data in any location supported by `smart_open` (e.g. cloud, http). - `pip install opensmile` for feature extraction using the OpenSmile toolkit's Python wrapper. diff --git a/docs/getting-started.rst b/docs/getting-started.rst index c6e17085c..9a299c973 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -133,6 +133,8 @@ Lhotse uses several environment variables to customize it's behavior. They are a * ``TORCHAUDIO_USE_BACKEND_DISPATCHER`` - when set to 1 and torchaudio version is below 2.1, we'll enable the experimental ffmpeg backend of torchaudio. +* ``AIS_ENDPOINT`` is read by AIStore client to determine AIStore endpoint URL. Required for AIStore dataloading. + * ``RANK``, ``WORLD_SIZE``, ``WORKER``, and ``NUM_WORKERS`` are internally used to inform Lhotse Shar dataloading subprocesses. * ``READTHEDOCS`` is internally used for documentation builds. @@ -153,6 +155,8 @@ Optional dependencies * ``pip install dill``. When ``dill`` is installed, we'll use it to pickle CutSet that uses a lambda function in calls such as ``.map`` or ``.filter``. This is helpful in PyTorch DataLoader with ``num_jobs>0``. Without ``dill``, depending on your environment, you'll see an exception or a hanging script. +* ``pip install aistore`` to read manifests, tar fles, and other data from AIStore using AIStore-supported URLs (set ``AIS_ENDPOINT`` environment variable to activate it). See |AIStore| for more details. + * ``pip install smart_open`` to read and write manifests and data in any location supported by ``smart_open`` (e.g. cloud, http). * ``pip install opensmile`` for feature extraction using the OpenSmile toolkit's Python wrapper. @@ -225,3 +229,4 @@ the speech starts roughly at the first second (100 frames): .. _Kaldi: https://github.com/kaldi-asr/kaldi .. _Icefall recipes: https://github.com/k2-fsa/icefall .. _orjson: https://pypi.org/project/orjson/ +.. _AIStore: https://aiatscale.org diff --git a/lhotse/serialization.py b/lhotse/serialization.py index 0a90d98ff..7f18b4dd4 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -1,15 +1,17 @@ import itertools import json +import os import sys import warnings from codecs import StreamReader, StreamWriter +from functools import lru_cache from io import BytesIO, StringIO from pathlib import Path from typing import Any, Dict, Generator, Iterable, Optional, Type, Union import yaml -from lhotse.utils import Pathlike, is_module_available +from lhotse.utils import Pathlike, SmartOpen, is_module_available, is_valid_url from lhotse.workarounds import gzip_open_robust # TODO: figure out how to use some sort of typing stubs @@ -28,7 +30,8 @@ def open_best(path: Pathlike, mode: str = "r"): either stdin or stdout depending on the mode. The concept is similar to Kaldi's "generalized pipes", but uses WebDataset syntax. """ - if str(path) == "-": + strpath = str(path) + if strpath == "-": if mode == "r": return StdStreamWrapper(sys.stdin) elif mode == "w": @@ -41,22 +44,32 @@ def open_best(path: Pathlike, mode: str = "r"): if isinstance(path, (BytesIO, StringIO, StreamWriter, StreamReader)): return path - if str(path).startswith("pipe:"): + if strpath.startswith("pipe:"): return open_pipe(path[5:], mode) - if is_module_available("smart_open"): - from smart_open import smart_open + if strpath.startswith("ais://"): + return open_aistore(path, mode) - # This will work with JSONL anywhere that smart_open supports, e.g. cloud storage. - open_fn = smart_open - else: - compressed = str(path).endswith(".gz") - if compressed and "t" not in mode and "b" not in mode: + if is_valid_url(strpath): + if is_aistore_available(): + return open_aistore(path, mode) + elif is_module_available("smart_open"): + return SmartOpen.open(path, mode) + else: + raise ValueError( + f"In order to open URLs/URIs please run 'pip install smart_open' " + f"(if you're trying to use AIStore, either the Python SDK is not installed " + f"or {AIS_ENDPOINT_ENVVAR} is not defined." + ) + + compressed = strpath.endswith(".gz") + if compressed: + if "t" not in mode and "b" not in mode: # Opening as bytes not requested explicitly, use "t" to tell gzip to handle unicode. mode = mode + "t" - open_fn = gzip_open_robust if compressed else open + return gzip_open_robust(path, mode) - return open_fn(path, mode) + return open(path, mode) def open_pipe(cmd: str, mode: str): @@ -69,6 +82,39 @@ def open_pipe(cmd: str, mode: str): return Pipe(cmd, mode=mode, shell=True, bufsize=8092) +AIS_ENDPOINT_ENVVAR = "AIS_ENDPOINT" + + +@lru_cache +def is_aistore_available() -> bool: + return AIS_ENDPOINT_ENVVAR in os.environ and is_valid_url( + os.environ[AIS_ENDPOINT_ENVVAR] + ) + + +@lru_cache +def get_aistore_client(): + if not is_module_available("aistore"): + raise ImportError( + "Please run 'pip install aistore' in order to read data from AIStore." + ) + if not is_aistore_available(): + raise ValueError( + "Set a valid URL as AIS_ENDPOINT environment variable's value to read data from AIStore." + ) + from aistore import Client + + endpoint_url = os.environ["AIS_ENDPOINT"] + return Client(endpoint_url) + + +def open_aistore(uri: str, mode: str): + assert "r" in mode, "We only support reading from AIStore at this time." + client = get_aistore_client() + object = client.fetch_object_by_url(uri) + return object.get().raw() + + def save_to_yaml(data: Any, path: Pathlike) -> None: with open_best(path, "w") as f: try: diff --git a/lhotse/utils.py b/lhotse/utils.py index be891ede5..2a44a5a04 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -32,6 +32,7 @@ TypeVar, Union, ) +from urllib.parse import urlparse import click import numpy as np @@ -128,6 +129,14 @@ def open(cls, uri, mode="rb", transport_params=None, **kwargs): ) +def is_valid_url(value: str) -> bool: + try: + result = urlparse(value) + return bool(result.scheme) and bool(result.netloc) + except AttributeError: + return False + + def fix_random_seed(random_seed: int): """ Set the same random seed for the libraries and modules that Lhotse interacts with. From 012532f2f0668fa7c9c7d6eea83374ba9ac29329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 11 Jun 2024 08:34:44 -0400 Subject: [PATCH 34/69] Bump dev version to 1.25.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 53cc1a6f9..ad2191947 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.24.0 +1.25.0 From f9fb181cbb733a95ebc2eda130d4db5b5a802677 Mon Sep 17 00:00:00 2001 From: Seung Hyun Lee Date: Wed, 12 Jun 2024 22:57:52 +0900 Subject: [PATCH 35/69] Add KsponSpeech recipe (#1353) * Add KsponSpeech recipe * Fix an error occured during prepare ksponspeech. * Make normalization optional in prepare_ksponspeech * Modify pcm_to_wav -> pcm_to_flac --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/ksponspeech.py | 51 ++++++ lhotse/recipes/__init__.py | 2 + lhotse/recipes/ksponspeech.py | 227 ++++++++++++++++++++++++ 5 files changed, 283 insertions(+) create mode 100644 lhotse/bin/modes/recipes/ksponspeech.py create mode 100644 lhotse/recipes/ksponspeech.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 9f037380b..79299eac1 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -131,6 +131,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_iwslt22_ta` * - KeSpeech - :func:`lhotse.recipes.prepare_kespeech` + * - KsponSpeech + - :func:`lhotse.recipes.prepare_ksponspeech` * - L2 Arctic - :func:`lhotse.recipes.prepare_l2_arctic` * - LibriCSS diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index 20f9b5c9a..aafc871e3 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -45,6 +45,7 @@ from .icsi import * from .iwslt22_ta import * from .kespeech import * +from .ksponspeech import * from .l2_arctic import * from .libricss import * from .librilight import * diff --git a/lhotse/bin/modes/recipes/ksponspeech.py b/lhotse/bin/modes/recipes/ksponspeech.py new file mode 100644 index 000000000..4f4a9d5bd --- /dev/null +++ b/lhotse/bin/modes/recipes/ksponspeech.py @@ -0,0 +1,51 @@ +from typing import Sequence + +import click + +from lhotse.bin.modes import prepare +from lhotse.recipes.ksponspeech import prepare_ksponspeech +from lhotse.utils import Pathlike + +__all__ = ["ksponspeech"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-p", + "--dataset-parts", + type=str, + default=["all"], + multiple=True, + help="List of dataset parts to prepare. To prepare multiple parts, pass each with `-p` " + "Example: `-p train -p test`", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +@click.option( + "--normalize-text", + type=click.Choice(["none", "default"], case_sensitive=False), + default="default", + help="Type of text normalization to apply.", +) +def ksponspeech( + corpus_dir: Pathlike, + output_dir: Pathlike, + dataset_parts: Sequence[str], + num_jobs: int, +): + """KsponSpeech ASR data preparation.""" + if len(dataset_parts) == 1: + dataset_parts = dataset_parts[0] + prepare_ksponspeech( + corpus_dir, + output_dir=output_dir, + num_jobs=num_jobs, + dataset_parts=dataset_parts, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 5e59b613c..2b5ec8338 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -42,6 +42,7 @@ from .icsi import download_icsi, prepare_icsi from .iwslt22_ta import prepare_iwslt22_ta from .kespeech import prepare_kespeech +from .ksponspeech import prepare_ksponspeech from .l2_arctic import prepare_l2_arctic from .libricss import download_libricss, prepare_libricss from .librilight import prepare_librilight @@ -153,6 +154,7 @@ "prepare_icsi", "prepare_iwslt22_ta", "prepare_kespeech", + "prepare_ksponspeech", "prepare_l2_arctic", "download_libricss", "prepare_libricss", diff --git a/lhotse/recipes/ksponspeech.py b/lhotse/recipes/ksponspeech.py new file mode 100644 index 000000000..6dde6ed9b --- /dev/null +++ b/lhotse/recipes/ksponspeech.py @@ -0,0 +1,227 @@ +""" +KsponSpeech is a large-scale spontaneous speech corpus of Korean. +This corpus contains 969 hours of open-domain dialog utterances, +spoken by about 2,000 native Korean speakers in a clean environment. + +All data were constructed by recording the dialogue of two people +freely conversing on a variety of topics and manually transcribing the utterances. + +The transcription provides a dual transcription consisting of orthography and pronunciation, +and disfluency tags for spontaneity of speech, such as filler words, repeated words, and word fragments. + +The original audio data has a pcm extension. +During preprocessing, it is converted into a file in the flac extension and saved anew. + +KsponSpeech is publicly available on an open data hub site of the Korea government. +The dataset must be downloaded manually. + +For more details, please visit: + +Dataset: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123 +Paper: https://www.mdpi.com/2076-3417/10/19/6936 +""" + +import logging +import re +from concurrent.futures.thread import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import soundfile as sf +from tqdm.auto import tqdm + +from lhotse import fix_manifests, validate_recordings_and_supervisions +from lhotse.audio import Recording, RecordingSet +from lhotse.recipes.utils import manifests_exist, read_manifests_if_cached +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike + +KSPONSPEECH = ( + "train", + "dev", + "eval_clean", + "eval_other", +) + + +def normalize( + raw_content: str, + normalize_text: str = "default", +) -> Tuple[str, str]: + """ + Normalizing KsponSpeech text datasets with '.trn' extension. + Perform the following processing. + + 1. Separate file name and text labeling from raw content using separator '::'. + 2. Remove noise labeling characters. (e.g. `o/`, `b/`...) + 3. Remove the actual pronunciation from the text labeling, Use the spelling content. + 4. Remove other special characters and double spaces from text labeling. + + :param raw_content: A raw text labeling content containing file name and text labeling. + :param normalize_text: str, the text normalization type. Available options: "default", "none". + :return: A tuple with file name and normalized text labeling. + """ + if len(raw_content) == 0: + return "" + + original_content_id, content = raw_content.split(" :: ") + + if normalize_text == "none": + return original_content_id, content + + elif normalize_text == "default": + content = re.sub(r"[a-z]/", "", content) + content = re.sub(r"\((.*?)\)/\((.*?)\)", r"\1", content) + content = content.replace("*", "") + content = content.replace("+", "") + content = content.replace("/", "") + while " " in content: + content = content.replace(" ", " ") + + return original_content_id, content.strip() + + +def prepare_ksponspeech( + corpus_dir: Pathlike, + dataset_parts: Union[str, Sequence[str]] = "all", + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, + normalize_text: str = "default", +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + + :param corpus_dir: Pathlike, the path of the data dir. + :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train', 'test'. + By default we will infer which parts are available in ``corpus_dir``. + :param output_dir: Pathlike, the path where to write the manifests. + :param num_jobs: int, number of parallel threads used for 'parse_utterance' calls. + :param normalize_text: str, the text normalization type. Available options: "default", "none". + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + if dataset_parts == "all": + dataset_parts = set(KSPONSPEECH) + + elif isinstance(dataset_parts, str): + dataset_parts = [dataset_parts] + + manifests = {} + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Maybe the manifests already exist: we can read them and save a bit of preparation time. + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=output_dir + ) + + with ThreadPoolExecutor(num_jobs) as ex: + for part in tqdm(dataset_parts, desc="Dataset parts"): + logging.info(f"Processing KsponSpeech subset: {part}") + if manifests_exist(part=part, output_dir=output_dir): + logging.info(f"KsponSpeech subset: {part} already prepared - skipping.") + continue + recordings = [] + supervisions = [] + futures = [] + + trans_path = corpus_dir / f"{part}.trn" + with open(trans_path) as f: + for line in f: + futures.append( + ex.submit( + parse_utterance, corpus_dir, part, line, normalize_text + ) + ) + + for future in tqdm(futures, desc="Processing", leave=False): + result = future.result() + if result is None: + continue + recording, segment = result + recordings.append(recording) + supervisions.append(segment) + + recording_set = RecordingSet.from_recordings(recordings) + supervision_set = SupervisionSet.from_segments(supervisions) + + recording_set, supervision_set = fix_manifests( + recording_set, supervision_set + ) + validate_recordings_and_supervisions(recording_set, supervision_set) + + if output_dir is not None: + supervision_set.to_file( + output_dir / f"ksponspeech_supervisions_{part}.jsonl.gz" + ) + recording_set.to_file( + output_dir / f"ksponspeech_recordings_{part}.jsonl.gz" + ) + + manifests[part] = { + "recordings": recording_set, + "supervisions": supervision_set, + } + + return manifests + + +def pcm_to_flac( + pcm_path: Union[str, Path], + flac_path: Union[str, Path], + sample_rate: Optional[int] = 16000, + channels: Optional[int] = 1, + bit_depth: Optional[int] = 16, +) -> Path: + # typecasting + pcm_path = Path(pcm_path) + flac_path = Path(flac_path) + + data, _ = sf.read( + pcm_path, + channels=channels, + samplerate=sample_rate, + format="RAW", + subtype="PCM_16", + ) + + sf.write(flac_path, data, sample_rate, format="FLAC") + return flac_path + + +def parse_utterance( + corpus_dir: Pathlike, + part: str, + line: str, + normalize_text: str = "default", +) -> Optional[Tuple[Recording, SupervisionSegment]]: + corpus_dir = Path(corpus_dir) + audio_path, normalized_line = normalize(line, normalize_text) + if "eval" in part: + audio_path = audio_path.split("/", maxsplit=1)[1] + + audio_path = corpus_dir / audio_path + recording_id = audio_path.stem + + # Create the Recording first + if not audio_path.is_file(): + logging.warning(f"No such file: {audio_path}") + return None + flac_path = audio_path.with_suffix(".flac") + flac_path = pcm_to_flac(audio_path, flac_path) + recording = Recording.from_file(flac_path, recording_id=recording_id) + # Then, create the corresponding supervisions + segment = SupervisionSegment( + id=recording_id, + recording_id=recording_id, + start=0.0, + duration=recording.duration, + channel=0, + language="Korean", + text=normalized_line, + ) + return recording, segment From 9930ae407469590beb0f9b1005ea338c207c7d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 24 Jun 2024 11:04:39 -0400 Subject: [PATCH 36/69] Restoring smart open for local files if available (#1360) * Restoring smart open for local files if available * lilcom workaround * Use numpy<2 in CI * fix --- .github/workflows/missing_torchaudio.yml | 2 +- .github/workflows/unit_tests.yml | 4 ++-- lhotse/serialization.py | 5 ++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/missing_torchaudio.yml b/.github/workflows/missing_torchaudio.yml index 0ffa8e909..1f51cfc5d 100644 --- a/.github/workflows/missing_torchaudio.yml +++ b/.github/workflows/missing_torchaudio.yml @@ -38,7 +38,7 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install wheel numpy scipy + pip install wheel 'numpy<2' scipy # Force the installation of a CPU-only PyTorch ${{ matrix.torch-install-cmd }} # the torchaudio env var does nothing when torchaudio is installed, but doesn't require it's presence when it's not diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 1a121aee4..157ea9635 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -51,11 +51,11 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install wheel numpy + pip install wheel 'numpy<2' # Force the installation of a CPU-only PyTorch ${{ matrix.torch-install-cmd }} # the torchaudio env var does nothing when torchaudio is installed, but doesn't require it's presence when it's not - pip install '.[tests]' + pip install lilcom '.[tests]' # Enable some optional tests pip install h5py dill smart_open[http] kaldi_native_io webdataset==0.2.5 s3prl scipy nara_wpe pyloudnorm ${{ matrix.extra_deps }} - name: Install sph2pipe diff --git a/lhotse/serialization.py b/lhotse/serialization.py index 7f18b4dd4..c11390fb4 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -58,10 +58,13 @@ def open_best(path: Pathlike, mode: str = "r"): else: raise ValueError( f"In order to open URLs/URIs please run 'pip install smart_open' " - f"(if you're trying to use AIStore, either the Python SDK is not installed " + f"(if you're trying to use AIStore, either the Python SDK is not installed (pip install aistore) " f"or {AIS_ENDPOINT_ENVVAR} is not defined." ) + if is_module_available("smart_open"): + return SmartOpen.open(path, mode) + compressed = strpath.endswith(".gz") if compressed: if "t" not in mode and "b" not in mode: From d1f94c078aaa2ab7f4f7ec8aa4bc4b86695dabeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 24 Jun 2024 11:58:00 -0400 Subject: [PATCH 37/69] Fix Recording.to_dict() when transforms are dicts and transform pickling issues (#1355) * Fix Recording.to_dict() when transforms are dicts * fix * fix pickling issues with transforms * fix * fix * fix --- lhotse/audio/recording.py | 13 +++++++++++-- lhotse/augmentation/rir.py | 6 ++++-- lhotse/augmentation/utils.py | 5 ++++- lhotse/dataset/cut_transforms/perturb_speed.py | 2 +- lhotse/dataset/cut_transforms/reverberate.py | 2 +- test/audio/test_recording_set.py | 10 ++++++++++ 6 files changed, 31 insertions(+), 7 deletions(-) diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index 971166995..aaf55e15f 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -336,7 +336,9 @@ def _aslist(x): def to_dict(self) -> dict: d = asdict_nonull(self) if self.transforms is not None: - d["transforms"] = [t.to_dict() for t in self.transforms] + d["transforms"] = [ + t if isinstance(t, dict) else t.to_dict() for t in self.transforms + ] return d def to_cut(self): @@ -866,8 +868,15 @@ def resample(self, sampling_rate: int) -> "Recording": @staticmethod def from_dict(data: dict) -> "Recording": raw_sources = data.pop("sources") + try: + transforms = data.pop("transforms") + transforms = [AudioTransform.from_dict(t) for t in transforms] + except KeyError: + transforms = None return Recording( - sources=[AudioSource.from_dict(s) for s in raw_sources], **data + sources=[AudioSource.from_dict(s) for s in raw_sources], + transforms=transforms, + **data, ) diff --git a/lhotse/augmentation/rir.py b/lhotse/augmentation/rir.py index bf2e8a1a5..482c7d0e5 100644 --- a/lhotse/augmentation/rir.py +++ b/lhotse/augmentation/rir.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -69,7 +69,9 @@ def to_dict(self) -> dict: "normalize_output": self.normalize_output, "early_only": self.early_only, "rir_channels": list(self.rir_channels), - "rir_generator": self.rir_generator, + "rir_generator": self.rir_generator + if self.rir_generator is None or isinstance(self.rir_generator, dict) + else self.rir_generator.to_dict(), }, } diff --git a/lhotse/augmentation/utils.py b/lhotse/augmentation/utils.py index 2d87eb8e2..75e9a9efa 100644 --- a/lhotse/augmentation/utils.py +++ b/lhotse/augmentation/utils.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from typing import List, Optional import numpy as np @@ -100,6 +100,9 @@ def __post_init__(self): else np.random.default_rng() ) + def to_dict(self): + return asdict(self) + def __call__(self, nsource: int = 1) -> np.ndarray: """ :param nsource: number of sources (RIR filters) to simulate. Default: 1. diff --git a/lhotse/dataset/cut_transforms/perturb_speed.py b/lhotse/dataset/cut_transforms/perturb_speed.py index 903162b71..7e46933bb 100644 --- a/lhotse/dataset/cut_transforms/perturb_speed.py +++ b/lhotse/dataset/cut_transforms/perturb_speed.py @@ -27,7 +27,7 @@ def __init__( def __call__(self, cuts: CutSet) -> CutSet: if self.random is None: - self.random = random + self.random = random.Random() return CutSet.from_cuts( cut.perturb_speed( factor=self.random.choice(self.factors), affix_id=not self.preserve_id diff --git a/lhotse/dataset/cut_transforms/reverberate.py b/lhotse/dataset/cut_transforms/reverberate.py index 6fab7e4cd..bce7b2e2b 100644 --- a/lhotse/dataset/cut_transforms/reverberate.py +++ b/lhotse/dataset/cut_transforms/reverberate.py @@ -33,7 +33,7 @@ def __init__( def __call__(self, cuts: CutSet) -> CutSet: if self.random is None: - self.random = random + self.random = random.Random() return CutSet.from_cuts( cut.reverb_rir( rir_recording=self.random.choice(self.rir_recordings) diff --git a/test/audio/test_recording_set.py b/test/audio/test_recording_set.py index 2e3249a6d..b45c8393a 100644 --- a/test/audio/test_recording_set.py +++ b/test/audio/test_recording_set.py @@ -14,6 +14,7 @@ ) from lhotse.audio import DurationMismatchError from lhotse.audio.mixer import AudioMixer +from lhotse.augmentation import ReverbWithImpulseResponse from lhotse.testing.dummies import DummyManifest from lhotse.utils import INT16MAX, fastcopy, is_module_available from lhotse.utils import nullcontext as does_not_raise @@ -632,3 +633,12 @@ def test_memory_recording_dict_serialization(): rec_reconstructed = Recording.from_dict(data) assert rec == rec_reconstructed np.testing.assert_equal(rec_reconstructed.load_audio(), rec.load_audio()) + + +def test_recording_to_dict_with_transform_dict(): + path = "test/fixtures/mono_c0.wav" + recording = Recording.from_file(path) + recording = recording.reverb_rir() + serialized = recording.to_dict() + recording_restored = Recording.from_dict(serialized) + assert recording == recording_restored From e3bed730e5b3430c3620bb19368e03ec650f55e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 24 Jun 2024 16:40:09 -0400 Subject: [PATCH 38/69] Utils for discovering attached data and dropping in-memory data (#1361) --- lhotse/array.py | 20 +++++- lhotse/audio/recording.py | 8 +++ lhotse/cut/base.py | 3 + lhotse/cut/data.py | 68 ++++++++++++++++++- lhotse/cut/mixed.py | 55 ++++++++++++++- lhotse/cut/padding.py | 18 ++++- lhotse/cut/set.py | 12 ++++ lhotse/features/base.py | 10 ++- lhotse/features/io.py | 4 ++ test/cut/test_cut_with_in_memory_data.py | 85 ++++++++++++++++++++++++ 10 files changed, 274 insertions(+), 9 deletions(-) diff --git a/lhotse/array.py b/lhotse/array.py index 52f633f4d..9849f0696 100644 --- a/lhotse/array.py +++ b/lhotse/array.py @@ -7,7 +7,7 @@ import numpy as np -from lhotse.utils import Pathlike, Seconds, fastcopy, ifnone +from lhotse.utils import Pathlike, Seconds, fastcopy @dataclass @@ -51,6 +51,16 @@ class Array: def ndim(self) -> int: return len(self.shape) + @property + def is_in_memory(self) -> bool: + from lhotse.features.io import is_in_memory + + return is_in_memory(self.storage_type) + + @property + def is_placeholder(self) -> bool: + return self.storage_type == "shar" + def to_dict(self) -> dict: return asdict(self) @@ -157,6 +167,14 @@ class TemporalArray: # the shape, temporal_dim, and frame_shift. start: Seconds + @property + def is_in_memory(self) -> bool: + return self.array.is_in_memory + + @property + def is_placeholder(self) -> bool: + return self.array.is_placeholder + @property def shape(self) -> List[int]: return self.array.shape diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index aaf55e15f..e555acb1e 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -155,6 +155,14 @@ def _video_source(self) -> Optional[AudioSource]: return s return None + @property + def is_in_memory(self) -> bool: + return any(s.type == "memory" for s in self.sources) + + @property + def is_placeholder(self) -> bool: + return any(s.type == "shar" for s in self.sources) + @property def num_channels(self) -> int: return len(self.channel_ids) diff --git a/lhotse/cut/base.py b/lhotse/cut/base.py index fa1400a4a..1e473cafc 100644 --- a/lhotse/cut/base.py +++ b/lhotse/cut/base.py @@ -177,6 +177,9 @@ class Cut: drop_features: Callable drop_recording: Callable drop_supervisions: Callable + drop_alignments: Callable + drop_in_memory_data: Callable + iter_data: Callable truncate: Callable pad: Callable extend_by: Callable diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index 7229353af..36637e895 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -3,12 +3,23 @@ from dataclasses import dataclass, field from decimal import ROUND_DOWN from math import isclose -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import numpy as np import torch from intervaltree import IntervalTree +from lhotse.array import Array, TemporalArray from lhotse.audio import Recording, VideoInfo from lhotse.augmentation import AugmentFn from lhotse.custom import CustomFieldMixin @@ -81,6 +92,32 @@ def to_dict(self) -> dict: d["custom"][k] = v.to_dict() return {**d, "type": type(self).__name__} + def iter_data( + self, + ) -> Generator[ + Tuple[str, Union[Recording, Features, Array, TemporalArray]], None, None + ]: + """ + Iterate over each data piece attached to this cut. + Returns a generator yielding tuples of ``(key, manifest)``, where + ``key`` is the name of the attribute under which ``manifest`` is found. + ``manifest`` is of type :class:`~lhotse.Recording`, :class:`~lhotse.Features`, + :class:`~lhotse.TemporalArray`, or :class:`~lhotse.Array`. + + For example, if ``key`` is ``recording``, then ``manifest`` is ``self.recording``. + """ + if self.has_recording: + yield "recording", self.recording + if self.has_features: + yield "features", self.features + for k, v in (self.custom or {}).items(): + if isinstance(v, (Recording, Features, Array, TemporalArray)): + yield k, v + + @property + def is_in_memory(self) -> bool: + return any(v.is_in_memory for k, v in self.iter_data()) + @property def recording_id(self) -> str: return self.recording.id if self.has_recording else self.features.recording_id @@ -327,6 +364,35 @@ def drop_alignments(self) -> "DataCut": self, supervisions=[fastcopy(s, alignment={}) for s in self.supervisions] ) + def drop_in_memory_data(self) -> "DataCut": + """ + Return a copy of the current :class:`.DataCut`, detached from any in-memory data. + The manifests for in-memory data are converted into placeholders that can still be looked up for + metadata, but will fail on attempts to load the data. + """ + from lhotse.shar.utils import to_shar_placeholder + + custom = None + if self.custom is not None: + custom = self.custom.copy() + for k in custom: + v = custom[k] + if ( + isinstance(v, (Recording, Features, Array, TemporalArray)) + and v.is_in_memory + ): + custom[k] = to_shar_placeholder(v) + return fastcopy( + self, + recording=to_shar_placeholder(self.recording) + if self.has_recording and self.recording.is_in_memory + else self.recording, + features=to_shar_placeholder(self.features) + if self.has_features and self.features.is_in_memory + else self.features, + custom=custom, + ) + def fill_supervision( self, add_empty: bool = True, shrink_ok: bool = False ) -> "DataCut": diff --git a/lhotse/cut/mixed.py b/lhotse/cut/mixed.py index 23ee32c1e..e74888eee 100644 --- a/lhotse/cut/mixed.py +++ b/lhotse/cut/mixed.py @@ -4,12 +4,23 @@ from functools import partial, reduce from io import BytesIO from operator import add -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import numpy as np import torch from intervaltree import IntervalTree +from lhotse.array import Array, TemporalArray from lhotse.audio import Recording, VideoInfo, get_audio_duration_mismatch_tolerance from lhotse.audio.backend import save_audio from lhotse.audio.mixer import AudioMixer, VideoMixer, audio_energy @@ -27,6 +38,7 @@ FeatureMixer, create_default_feature_extractor, ) +from lhotse.features.base import Features from lhotse.features.io import FeaturesWriter from lhotse.supervision import SupervisionSegment from lhotse.utils import ( @@ -153,6 +165,10 @@ def has_recording(self) -> bool: def has_video(self) -> bool: return self._first_non_padding_cut.has_video + @property + def is_in_memory(self) -> bool: + return any(track.cut.is_in_memory for track in self.tracks) + def has(self, field: str) -> bool: return self._first_non_padding_cut.has(field) @@ -191,6 +207,22 @@ def num_channels(self) -> Optional[int]: def features_type(self) -> Optional[str]: return self._first_non_padding_cut.features.type if self.has_features else None + def iter_data( + self, + ) -> Generator[ + Tuple[str, Union[Recording, Features, Array, TemporalArray]], None, None + ]: + """ + Iterate over each data piece attached to this cut. + Returns a generator yielding tuples of ``(key, manifest)``, where + ``key`` is the name of the attribute under which ``manifest`` is found. + ``manifest`` is of type :class:`~lhotse.Recording`, :class:`~lhotse.Features`, + :class:`~lhotse.TemporalArray`, or :class:`~lhotse.Array`. + + For example, if ``key`` is ``recording``, then ``manifest`` is ``self.recording``. + """ + return self._first_non_padding_cut.iter_data() + def __getattr__(self, name: str) -> Any: """ This magic function is called when the user tries to access an attribute @@ -1212,6 +1244,13 @@ def drop_alignments(self) -> "MixedCut": tracks=[fastcopy(t, cut=t.cut.drop_alignments()) for t in self.tracks], ) + def drop_in_memory_data(self) -> "MixedCut": + """Return a copy of the current :class:`MixedCut`, which doesn't contain any in-memory data.""" + return fastcopy( + self, + tracks=[fastcopy(t, cut=t.cut.drop_in_memory_data()) for t in self.tracks], + ) + def compute_and_store_features( self, extractor: FeatureExtractor, @@ -1540,9 +1579,19 @@ def with_recording_path_prefix(self, path: Pathlike) -> "MixedCut": ) @property - def _first_non_padding_cut(self) -> DataCut: + def first_non_padding_cut(self) -> DataCut: return self._first_non_padding_track.cut @property - def _first_non_padding_track(self) -> MixTrack: + def first_non_padding_track(self) -> MixTrack: return [t for t in self.tracks if not isinstance(t.cut, PaddingCut)][0] + + # Note: the private properties below are kept for backward compatibility. + + @property + def _first_non_padding_cut(self) -> DataCut: + return self.first_non_padding_cut + + @property + def _first_non_padding_track(self) -> MixTrack: + return self.first_non_padding_track diff --git a/lhotse/cut/padding.py b/lhotse/cut/padding.py index 7e569bcea..c535bde2b 100644 --- a/lhotse/cut/padding.py +++ b/lhotse/cut/padding.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -85,6 +85,10 @@ def has_video(self) -> bool: def num_channels(self) -> int: return 1 + @property + def is_in_memory(self) -> bool: + return False + def has(self, field: str) -> bool: if field == "recording": return self.has_recording @@ -99,6 +103,10 @@ def has(self, field: str) -> bool: def recording_id(self) -> str: return "PAD" + def iter_data(self) -> Iterable: + """Empty iterable.""" + return () + # noinspection PyUnusedLocal def load_features(self, *args, **kwargs) -> Optional[np.ndarray]: if self.has_features: @@ -421,11 +429,15 @@ def drop_recording(self) -> "PaddingCut": return fastcopy(self, num_samples=None) def drop_supervisions(self) -> "PaddingCut": - """Return a copy of the current :class:`.PaddingCut`, detached from ``supervisions``.""" + """No-op""" return self def drop_alignments(self) -> "PaddingCut": - """Return a copy of the current :class:`.PaddingCut`, detached from ``alignments``.""" + """No-op""" + return self + + def drop_in_memory_data(self) -> "PaddingCut": + """No-op.""" return self def compute_and_store_features( diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index e0a2f0e16..60836db72 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -1743,6 +1743,14 @@ def drop_alignments(self) -> "CutSet": """ return self.map(_drop_alignments) + def drop_in_memory_data(self) -> "CutSet": + """ + Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from any in-memory data it held. + The manifests for in-memory data are converted into placeholders that can still be looked up for + metadata, but will fail on attempts to load the data. + """ + return self.map(_drop_in_memory_data) + def compute_and_store_features( self, extractor: FeatureExtractor, @@ -3355,6 +3363,10 @@ def _drop_supervisions(cut, *args, **kwargs): return cut.drop_supervisions(*args, **kwargs) +def _drop_in_memory_data(cut, *args, **kwargs): + return cut.drop_in_memory_data(*args, **kwargs) + + def _truncate_single( cut: Cut, max_duration: Seconds, diff --git a/lhotse/features/base.py b/lhotse/features/base.py index c8cba931a..698f93b3d 100644 --- a/lhotse/features/base.py +++ b/lhotse/features/base.py @@ -16,7 +16,7 @@ from lhotse.audio import Recording from lhotse.augmentation import AugmentFn -from lhotse.features.io import FeaturesWriter, get_reader +from lhotse.features.io import FeaturesWriter, get_reader, is_in_memory from lhotse.lazy import AlgorithmMixin from lhotse.serialization import LazyMixin, Serializable, load_yaml, save_to_yaml from lhotse.utils import ( @@ -458,6 +458,14 @@ class Features: def end(self) -> Seconds: return self.start + self.duration + @property + def is_in_memory(self) -> bool: + return is_in_memory(self.storage_type) + + @property + def is_placeholder(self) -> bool: + return self.storage_type == "shar" + def load( self, start: Optional[Seconds] = None, diff --git a/lhotse/features/io.py b/lhotse/features/io.py index 4a7b812ab..a76fe6d60 100644 --- a/lhotse/features/io.py +++ b/lhotse/features/io.py @@ -1108,6 +1108,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): """ +def is_in_memory(storage_type: str) -> bool: + return "memory" in storage_type + + def get_memory_writer(name: str): assert "memory" in name return get_writer(name) diff --git a/test/cut/test_cut_with_in_memory_data.py b/test/cut/test_cut_with_in_memory_data.py index 5ae1db6b8..b56fabde4 100644 --- a/test/cut/test_cut_with_in_memory_data.py +++ b/test/cut/test_cut_with_in_memory_data.py @@ -2,12 +2,14 @@ from tempfile import TemporaryDirectory import numpy as np +import pytest from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording from lhotse.array import Array from lhotse.cut import MixedCut, PaddingCut from lhotse.testing.dummies import dummy_cut from lhotse.utils import compute_num_frames +from lhotse.utils import nullcontext as does_not_raise def test_features_move_to_memory(): @@ -233,3 +235,86 @@ def test_mixed_cut_to_mono_with_custom(): audio = cut.load_audio() audio_mem = cut_mem.load_audio() np.testing.assert_almost_equal(audio, audio_mem, decimal=1) + + +def test_drop_in_memory_data(): + cut = dummy_cut(0, with_data=True) + + # Assertions about test data (not the actual test) + assert cut.is_in_memory + expected_keys = { + "recording", + "features", + "custom_recording", + "custom_features", + "custom_indexes", + "custom_embedding", + } + observed_keys = set() + for k, v in cut.iter_data(): + observed_keys.add(k) + if k == "features": + assert not v.is_in_memory + else: + assert v.is_in_memory + assert expected_keys == observed_keys + + # The actual test + cut_nomem = cut.drop_in_memory_data() + assert not cut_nomem.is_in_memory + observed_keys = set() + for k, v in cut_nomem.iter_data(): + observed_keys.add(k) + assert not v.is_in_memory + if k == "recording": + with pytest.raises(Exception): + cut_nomem.load_audio() + elif k == "features": + with does_not_raise(): + cut_nomem.load_features() + else: + with pytest.raises(Exception): + cut_nomem.load_custom(k) + assert expected_keys == observed_keys + + +def test_drop_in_memory_data_mixed(): + cut = dummy_cut(0, with_data=True) + cut = cut.pad(duration=cut.duration + 2.0) + + # Assertions about test data (not the actual test) + assert cut.is_in_memory + expected_keys = { + "recording", + "features", + "custom_recording", + "custom_features", + "custom_indexes", + "custom_embedding", + } + observed_keys = set() + for k, v in cut.iter_data(): + observed_keys.add(k) + if k == "features": + assert not v.is_in_memory + else: + assert v.is_in_memory + assert expected_keys == observed_keys + + # The actual test + cut_nomem = cut.drop_in_memory_data() + assert not cut_nomem.is_in_memory + observed_keys = set() + for k, v in cut_nomem.iter_data(): + observed_keys.add(k) + assert not v.is_in_memory + if k == "recording": + with pytest.raises(Exception): + cut_nomem.load_audio() + elif k == "features": + with does_not_raise(): + cut_nomem.load_features() + else: + with pytest.raises(Exception): + cut_nomem.load_custom(k) + assert expected_keys == observed_keys From da4d70d7affc477eb8dc3a51f9b13d387817059a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 25 Jun 2024 11:54:29 -0400 Subject: [PATCH 39/69] Numpy 2.0 compatibility (#1362) * Remove numpy version limit * Bump CI torch version in some tests * Upgrade torch version in most CI tests, fix one failing test with numpy 2.0 * remove kaldifeat from python3.9 CI due to compilation issues --- .github/workflows/missing_torchaudio.yml | 6 +++--- .github/workflows/unit_tests.yml | 12 ++++++------ test/cut/test_padding_cut.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/missing_torchaudio.yml b/.github/workflows/missing_torchaudio.yml index 1f51cfc5d..d665c3b66 100644 --- a/.github/workflows/missing_torchaudio.yml +++ b/.github/workflows/missing_torchaudio.yml @@ -16,8 +16,8 @@ jobs: strategy: matrix: include: - - python-version: "3.11" - torch-install-cmd: "pip install torch==2.0 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.12" + torch-install-cmd: "pip install torch==2.3 --index-url https://download.pytorch.org/whl/cpu" fail-fast: false @@ -38,7 +38,7 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install wheel 'numpy<2' scipy + pip install wheel numpy scipy # Force the installation of a CPU-only PyTorch ${{ matrix.torch-install-cmd }} # the torchaudio env var does nothing when torchaudio is installed, but doesn't require it's presence when it's not diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 157ea9635..33310c313 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,16 +20,16 @@ jobs: torch-install-cmd: "pip install torch==1.12.1 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: kaldifeat - python-version: "3.9" - torch-install-cmd: "pip install torch==2.0 torchaudio==2.0 --extra-index-url https://download.pytorch.org/whl/cpu" - extra_deps: kaldifeat + torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + extra_deps: "" - python-version: "3.10" - torch-install-cmd: "pip install torch==2.1 torchaudio==2.1 --extra-index-url https://download.pytorch.org/whl/cpu" + torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - python-version: "3.11" - torch-install-cmd: "pip install torch==2.2 torchaudio==2.2 --extra-index-url https://download.pytorch.org/whl/cpu" + torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - python-version: "3.12" - torch-install-cmd: "pip install torch==2.2 torchaudio==2.2 --extra-index-url https://download.pytorch.org/whl/cpu" + torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" fail-fast: false @@ -51,7 +51,7 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - pip install wheel 'numpy<2' + pip install wheel numpy # Force the installation of a CPU-only PyTorch ${{ matrix.torch-install-cmd }} # the torchaudio env var does nothing when torchaudio is installed, but doesn't require it's presence when it's not diff --git a/test/cut/test_padding_cut.py b/test/cut/test_padding_cut.py index 55623d92e..a20ed5eb0 100644 --- a/test/cut/test_padding_cut.py +++ b/test/cut/test_padding_cut.py @@ -33,7 +33,7 @@ def test_load_features_log(padding_cut, expected_value): feats = padding_cut.load_features() assert feats.shape[0] == 1000 assert feats.shape[1] == 40 - np.testing.assert_almost_equal(feats, expected_value) + np.testing.assert_almost_equal(feats, expected_value, decimal=6) def test_frame_shift(padding_cut): From 0a4aed49754d61b781c14de85f7772dda71c6226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 15 Jul 2024 12:43:54 -0400 Subject: [PATCH 40/69] Fix MixedCut transforms serialization (#1370) * Fix MixedCut transforms serialization * fix --- lhotse/cut/mixed.py | 35 +++++++++++++++++++++++--- test/cut/test_cut_augmentation.py | 41 ++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/lhotse/cut/mixed.py b/lhotse/cut/mixed.py index e74888eee..19439d84a 100644 --- a/lhotse/cut/mixed.py +++ b/lhotse/cut/mixed.py @@ -84,6 +84,16 @@ def from_dict(data: dict): cut_dict["type"] = data.pop("type") return MixTrack(deserialize_cut(cut_dict), **data) + def to_dict(self) -> Dict: + ans = { + "cut": self.cut.to_dict(), + "type": self.type, + "offset": self.offset, + } + if self.snr is not None: + ans["snr"] = self.snr + return ans + @dataclass class MixedCut(Cut): @@ -125,7 +135,7 @@ class MixedCut(Cut): id: str tracks: List[MixTrack] - transforms: Optional[List[Dict]] = None + transforms: Optional[List[AudioTransform]] = None @property def supervisions(self) -> List[SupervisionSegment]: @@ -207,6 +217,16 @@ def num_channels(self) -> Optional[int]: def features_type(self) -> Optional[str]: return self._first_non_padding_cut.features.type if self.has_features else None + def to_dict(self) -> dict: + ans = { + "id": self.id, + "tracks": [t.to_dict() for t in self.tracks], + "type": type(self).__name__, + } + if self.transforms: + ans["transforms"] = [t.to_dict() for t in self.transforms] + return ans + def iter_data( self, ) -> Generator[ @@ -793,7 +813,7 @@ def normalize_loudness( if mix_first: transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(LoudnessNormalization(target=target).to_dict()) + transforms.append(LoudnessNormalization(target=target)) return fastcopy( self, id=f"{self.id}_ln{target}" if affix_id else self.id, @@ -908,7 +928,7 @@ def reverb_rir( early_only=early_only, rir_channels=rir_channels if rir_channels is not None else [0], rir_generator=rir_generator, - ).to_dict() + ) ) return fastcopy( self, @@ -1133,7 +1153,10 @@ def load_audio( # We'll apply the transforms now (if any). transforms = [ - AudioTransform.from_dict(params) for params in self.transforms or [] + tnfm + if isinstance(tnfm, AudioTransform) + else AudioTransform.from_dict(tnfm) + for tnfm in self.transforms or [] ] for tfn in transforms: audio = tfn(audio, self.sampling_rate) @@ -1551,9 +1574,13 @@ def filter_supervisions( def from_dict(data: dict) -> "MixedCut": if "type" in data: data.pop("type") + transforms = None + if "transforms" in data: + transforms = [AudioTransform.from_dict(t) for t in data["transforms"]] return MixedCut( id=data["id"], tracks=[MixTrack.from_dict(track) for track in data["tracks"]], + transforms=transforms, ) def with_features_path_prefix(self, path: Pathlike) -> "MixedCut": diff --git a/test/cut/test_cut_augmentation.py b/test/cut/test_cut_augmentation.py index 9219cebbc..e4be8247c 100644 --- a/test/cut/test_cut_augmentation.py +++ b/test/cut/test_cut_augmentation.py @@ -7,7 +7,7 @@ from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment from lhotse.audio import RecordingSet -from lhotse.cut import PaddingCut +from lhotse.cut import Cut, MixedCut, PaddingCut from lhotse.testing.dummies import dummy_cut, dummy_multi_cut from lhotse.utils import fastcopy, is_module_available, nullcontext @@ -424,6 +424,27 @@ def test_mixed_cut_start01_reverb_rir_mix_first(cut_with_supervision_start01, ri ) +def test_mixed_cut_start01_reverb_rir_mix_first_deserialized( + cut_with_supervision_start01, rir +): + mixed_rvb_orig = cut_with_supervision_start01.pad(duration=0.5).reverb_rir( + rir_recording=rir, mix_first=True + ) + mixed_rvb = MixedCut.from_dict(mixed_rvb_orig.to_dict()) + assert mixed_rvb.start == 0 # MixedCut always starts at 0 + assert mixed_rvb.duration == 0.5 + assert mixed_rvb.end == 0.5 + assert mixed_rvb.num_samples == 4000 + + # Check that the padding part should not be all zeros afte + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_almost_equal, + mixed_rvb.load_audio()[:, 3200:], + np.zeros((1, 800)), + ) + + def test_mixed_cut_start01_reverb_rir_with_fast_random( cut_with_supervision_start01, rir ): @@ -498,6 +519,24 @@ def test_mixed_cut_normalize_loudness(cut_with_supervision_start01, target, mix_ assert loudness == pytest.approx(target, abs=0.5) +@pytest.mark.skipif( + not is_module_available("pyloudnorm"), + reason="This test requires pyloudnorm to be installed.", +) +def test_mixed_cut_normalize_loudness_deserialized(cut_with_supervision_start01): + target = -15.0 + mixed_cut = cut_with_supervision_start01.append(cut_with_supervision_start01) + mixed_cut_ln_orig = mixed_cut.normalize_loudness(target, mix_first=True) + mixed_cut_ln = MixedCut.from_dict(mixed_cut_ln_orig.to_dict()) + + import pyloudnorm as pyln + + # check if loudness is correct + meter = pyln.Meter(mixed_cut_ln.sampling_rate) # create BS.1770 meter + loudness = meter.integrated_loudness(mixed_cut_ln.load_audio().T) + assert loudness == pytest.approx(target, abs=0.5) + + @pytest.mark.skipif( not is_module_available("nara_wpe"), reason="This test requires nara_wpe to be installed.", From c286f28fe5a883ccc64fffb087b46ce124a8efe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 18 Jul 2024 18:35:00 -0400 Subject: [PATCH 41/69] Support for pre-determined batch sizes in DynamicBucketingSampler (#1372) --- lhotse/dataset/sampling/dynamic_bucketing.py | 117 ++++++++++++++++-- .../sampling/test_dynamic_bucketing.py | 58 +++++++++ 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index cf4da23f2..9b9a41f1c 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -2,7 +2,7 @@ import warnings from bisect import bisect_right from collections import deque -from dataclasses import dataclass +from dataclasses import asdict, dataclass from itertools import islice from typing import ( Any, @@ -104,7 +104,7 @@ def __init__( Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. - :param num_buckets: how many buckets to create. + :param num_buckets: how many buckets to create. Ignored if duration_bins are provided. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: @@ -169,15 +169,11 @@ def __init__( self.buffer_size += shuffle_buffer_size if duration_bins is not None: - if num_buckets is not None: - assert len(duration_bins) == num_buckets - 1, ( - f"num_buckets=={num_buckets} but len(duration_bins)=={len(duration_bins)} " - f"(bins are the boundaries, it should be one less than the number of buckets)." - ) assert list(duration_bins) == sorted( duration_bins ), "Duration bins must be sorted ascendingly." self.duration_bins = duration_bins + self.num_buckets = len(duration_bins) + 1 else: if constraint is None: constraint = TimeConstraint( @@ -316,6 +312,113 @@ def num_cuts(self) -> Optional[int]: return None +@dataclass +class FixedBucketBatchSizeConstraint(SamplingConstraint): + """ + Sampling constraint that accepts a pre-defined batch size for each bucket. + It uses the example's sequence length to determine which bucket we're sampling for, + and otherwise the batch size is locally static for each bucket. + + This constraint doesn't support samples longer than the upper bound of the last bucket; + if such sample is provided, we will raise an exception. + """ + + max_seq_len_buckets: List[float] + batch_sizes: List[int] + current_bucket: Union[int, None] = None + num_cuts: int = 0 + + def __post_init__(self): + assert sorted(self.max_seq_len_buckets) == list(self.max_seq_len_buckets) + + def is_active(self) -> bool: + return True + + def add(self, example: Cut) -> None: + """ + Increment the internal counter for the time constraint, + selecting the right property from the input ``cut`` object. + """ + seqlen = self.measure_length(example) + bucket_idx = bisect_right(self.max_seq_len_buckets, seqlen) + assert bucket_idx < len(self.max_seq_len_buckets), ( + f"Received example with sequence length {seqlen} that exceeds " + f"the highest allowed length {self.max_seq_len_buckets[-1]}." + ) + if self.current_bucket is None: + self.current_bucket = bucket_idx + else: + assert self.current_bucket == bucket_idx, ( + f"User error: FixedBucketBatchSizeConstraint is supposed to be used only on one bucket. " + f"The example we received has sequence length {seqlen} which is outside of the allowed bounds " + f"for bucket index {bucket_idx} in buckets {self.max_seq_len_buckets}." + ) + self.num_cuts += 1 + + def exceeded(self) -> bool: + """Is the constraint exceeded or not.""" + return self.num_cuts > self.batch_sizes[self.current_bucket] + + def close_to_exceeding(self) -> bool: + """ + Check if the batch is close to satisfying the constraints. + We define "closeness" as: if we added one more cut that has + duration/num_frames/num_samples equal to the longest seen cut + in the current batch, then the batch would have exceeded the constraints. + """ + return self.num_cuts >= self.batch_sizes[self.current_bucket] + + def reset(self) -> None: + """ + Reset the internal counter (to be used after a batch was created, + to start collecting a new one). + """ + self.current_bucket = None + self.num_cuts = 0 + + def measure_length(self, example: Cut) -> float: + return example.duration + + def state_dict(self) -> Dict[str, Any]: + return asdict(self) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.max_seq_len_buckets = state_dict.pop("max_seq_len_buckets") + self.batch_sizes = state_dict.pop("batch_sizes") + self.current_bucket = state_dict.pop("current_bucket") + self.num_cuts = state_dict.pop("num_cuts") + assert len(state_dict) == 0, ( + "Error in FixedBucketBatchSizeConstraint.load_state_dict(): Unexpected keys:\n- " + + "\n- ".join(state_dict.keys()) + ) + + def __add__( + self, other: "FixedBucketBatchSizeConstraint" + ) -> "FixedBucketBatchSizeConstraint": + for key in ("max_seq_len_buckets", "batch_sizes", "current_bucket"): + self_attr = getattr(self, key) + other_attr = getattr(other, key) + is_none = self_attr is None and other_attr is None + assert is_none or self_attr == other_attr, ( + f"To add two TimeConstraint objects, they need to represent the same constraint " + f"(got self.{key}={self_attr} != other.{key}={other_attr})." + ) + return FixedBucketBatchSizeConstraint( + max_seq_len_buckets=self.max_seq_len_buckets, + batch_sizes=self.batch_sizes, + current_bucket=self.current_bucket, + num_cuts=self.num_cuts + other.num_cuts, + ) + + def __eq__(self, other: "TimeConstraint") -> bool: + return ( + isinstance(other, FixedBucketBatchSizeConstraint) + and self.max_seq_len_buckets == other.max_seq_len_buckets + and self.batch_sizes == other.batch_sizes + and self.current_bucket == other.current_bucket + ) + + def estimate_duration_buckets( cuts: Iterable[Cut], num_buckets: int, diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index bb3253b30..454bff58f 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -7,6 +7,7 @@ from lhotse.dataset.sampling.dynamic_bucketing import ( DynamicBucketer, DynamicBucketingSampler, + FixedBucketBatchSizeConstraint, estimate_duration_buckets, ) from lhotse.testing.dummies import DummyManifest, dummy_cut @@ -670,3 +671,60 @@ def test_dynamic_bucketing_sampler_sync_buckets_map_dataset_usage(sync_buckets): # some shapes will be mismatched because different buckets were selected. matching_shapes = [len(b0) == len(b1) for b0, b1 in zip(batches0, batches1)] assert not all(matching_shapes) + + +def test_dynamic_bucketing_sampler_fixed_batch_constraint(): + cuts = DummyManifest(CutSet, begin_id=0, end_id=10) + for i, c in enumerate(cuts): + if i < 5: + c.duration = 1 + else: + c.duration = 2 + + duration_bins = [1.5, 2.5] + sampler = DynamicBucketingSampler( + cuts, + duration_bins=duration_bins, + constraint=FixedBucketBatchSizeConstraint( + max_seq_len_buckets=duration_bins, batch_sizes=[2, 1] + ), + seed=0, + shuffle=True, + ) + + batches = [b for b in sampler] + sampled_cuts = [c for b in batches for c in b] + + # Invariant: no duplicated cut IDs + assert len(set(c.id for b in batches for c in b)) == len(sampled_cuts) + + # Same number of sampled and source cuts. + assert len(sampled_cuts) == len(cuts) + + # We sampled the follwoing batches with this RNG: + assert len(batches) == 8 + print([len(b) for b in batches]) + + assert len(batches[0]) == 1 + assert sum(c.duration for c in batches[0]) == 2 + + assert len(batches[1]) == 2 + assert sum(c.duration for c in batches[1]) == 2 + + assert len(batches[2]) == 2 + assert sum(c.duration for c in batches[2]) == 2 + + assert len(batches[3]) == 1 + assert sum(c.duration for c in batches[3]) == 2 + + assert len(batches[4]) == 1 + assert sum(c.duration for c in batches[4]) == 2 + + assert len(batches[5]) == 1 + assert sum(c.duration for c in batches[5]) == 2 + + assert len(batches[6]) == 1 + assert sum(c.duration for c in batches[6]) == 2 + + assert len(batches[7]) == 1 + assert sum(c.duration for c in batches[7]) == 1 From 18436e9a3461d611a0b738a659fa9dbe81d0ffa5 Mon Sep 17 00:00:00 2001 From: Peter Ross Date: Fri, 19 Jul 2024 09:38:31 +1000 Subject: [PATCH 42/69] augmentation/torchaudio: add Phone effect (mulaw, lpc10 codecs) (#1348) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * augmentation/torchaudio: add Phone effect (mulaw, lpc10 codecs) * restore_orig_sr option --------- Co-authored-by: Piotr Żelasko --- lhotse/audio/recording.py | 31 +++++ lhotse/augmentation/torchaudio.py | 168 +++++++++++++++++++++++++++ lhotse/cut/base.py | 1 + lhotse/cut/data.py | 39 +++++++ lhotse/cut/set.py | 21 ++++ lhotse/supervision.py | 18 +++ test/augmentation/test_torchaudio.py | 9 ++ 7 files changed, 287 insertions(+) diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index e555acb1e..063084af7 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -20,6 +20,7 @@ AudioTransform, DereverbWPE, LoudnessNormalization, + Narrowband, Resample, ReverbWithImpulseResponse, Speed, @@ -732,6 +733,36 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording": transforms=transforms, ) + def narrowband( + self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True + ) -> "Recording": + """ + Return a new ``Recording`` that will lazily apply narrowband effect while loading audio. + by affixing it with "_nb_{codec}". + + :return: a modified copy of the current ``Recording``. + """ + transforms = self.transforms.copy() if self.transforms is not None else [] + transforms.append( + Narrowband( + codec=codec, + source_sampling_rate=self.sampling_rate, + restore_orig_sr=restore_orig_sr, + ).to_dict() + ) + new_num_samples = compute_num_samples( + self.duration, + self.sampling_rate if restore_orig_sr else 8000, + rounding=ROUND_HALF_UP, + ) + return fastcopy( + self, + id=f"{self.id}_nb_{codec}" if affix_id else self.id, + num_samples=new_num_samples, + sampling_rate=self.sampling_rate if restore_orig_sr else 8000, + transforms=transforms, + ) + def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recording": """ Return a new ``Recording`` that will lazily apply WPE dereverberation. diff --git a/lhotse/augmentation/torchaudio.py b/lhotse/augmentation/torchaudio.py index 910fb0a3f..c6306aefd 100644 --- a/lhotse/augmentation/torchaudio.py +++ b/lhotse/augmentation/torchaudio.py @@ -212,6 +212,174 @@ def reverse_timestamps( ) +class Codec: + def __call__(self, samples: np.ndarray) -> np.ndarray: + """ + Apply encoder then decoder. + + To be implemented in derived classes. + """ + raise NotImplementedError + + +class MuLawCodec(Codec): + def __init__(self): + import torchaudio + + self.encoder = torchaudio.transforms.MuLawEncoding() + self.decoder = torchaudio.transforms.MuLawDecoding() + + def __call__(self, samples): + return self.decoder(self.encoder(samples)) + + +from ctypes import CDLL, POINTER, c_int, c_short, c_uint8, c_void_p + +LPC10_FRAME_SAMPLES = 180 +LPC10_FRAME_BYTES = 7 + + +def libspandsp_api(): + try: + api = CDLL("libspandsp.so") + except OSError as e: + raise RuntimeError( + "We cannot apply the narrowband transformation using the LPC10 codec as the SpanDSP library cannot be found. " + "To install use `apt-get install libspandsp-dev` or visit ." + ) + + api.lpc10_encode_init.restype = c_void_p + api.lpc10_encode_init.argtypes = [c_void_p, c_int] + + api.lpc10_encode.restype = c_int + api.lpc10_encode.argtypes = [c_void_p, POINTER(c_uint8), POINTER(c_short), c_int] + + api.lpc10_encode_free.argtypes = [c_void_p] + + api.lpc10_decode_init.restype = c_void_p + api.lpc10_decode_init.argtypes = [c_void_p, c_int] + + api.lpc10_decode.restype = c_int + api.lpc10_decode.argtypes = [c_void_p, POINTER(c_short), POINTER(c_uint8), c_int] + + api.lpc10_decode_free.argtypes = [c_void_p] + + return api + + +class LPC10Codec(Codec): + def __init__(self): + self.api = libspandsp_api() + self.c_data = (c_uint8 * LPC10_FRAME_BYTES)() + self.c_samples = (c_short * LPC10_FRAME_SAMPLES)() + + def __call__(self, samples): + encoder = self.api.lpc10_encode_init(None, 0) + decoder = self.api.lpc10_decode_init(None, 0) + + frames = samples[0].split(LPC10_FRAME_SAMPLES) + + idx = 0 + out = torch.zeros([1, len(frames) * LPC10_FRAME_SAMPLES]) + + for frame in frames: + + samples_int = (frame * 32768).to(torch.int16) + + for i in range(0, samples_int.shape[0]): + self.c_samples[i] = samples_int[i] + + for i in range(samples_int.shape[0], LPC10_FRAME_SAMPLES): + self.c_samples[i] = 0 + + assert ( + self.api.lpc10_encode( + encoder, self.c_data, self.c_samples, len(self.c_samples) + ) + == LPC10_FRAME_BYTES + ) + assert ( + self.api.lpc10_decode( + decoder, self.c_samples, self.c_data, LPC10_FRAME_BYTES + ) + == LPC10_FRAME_SAMPLES + ) + + for i in range(0, LPC10_FRAME_SAMPLES): + out[0][idx] = self.c_samples[i] + idx = idx + 1 + + self.api.lpc10_encode_free(encoder) + self.api.lpc10_decode_free(decoder) + + return out / 32768 + + +CODECS = { + "lpc10": LPC10Codec, + "mulaw": MuLawCodec, +} + + +@dataclass +class Narrowband(AudioTransform): + """ + Narrowband effect. + + Resample input audio to 8000 Hz, apply codec (encode then immediately decode), then (optionally) resample back to the original sampling rate. + """ + + codec: str + source_sampling_rate: int + restore_orig_sr: bool + + def __post_init__(self): + check_torchaudio_version() + import torchaudio + + if self.codec in CODECS: + self.codec_instance = CODECS[self.codec]() + else: + raise ValueError(f"unsupported codec: {self.codec}") + + def __call__(self, samples: np.ndarray, sampling_rate: int) -> np.ndarray: + import torchaudio + + orig_size = samples.size + + samples = torch.from_numpy(samples) + + if self.source_sampling_rate != 8000: + resampler_down = get_or_create_resampler(self.source_sampling_rate, 8000) + samples = resampler_down(samples) + + samples = self.codec_instance(samples) + + if self.restore_orig_sr and self.source_sampling_rate != 8000: + resampler_up = get_or_create_resampler(8000, self.source_sampling_rate) + samples = resampler_up(samples) + + samples = samples.numpy() + + if self.restore_orig_sr and orig_size != samples.size: + samples = np.resize(samples, (1, orig_size)) + + return samples + + def reverse_timestamps( + self, + offset: Seconds, + duration: Optional[Seconds], + sampling_rate: Optional[int], + ) -> Tuple[Seconds, Optional[Seconds]]: + """ + This method just returnes the original offset and duration as the narrowband effect + doesn't change any these audio properies. + """ + + return offset, duration + + @dataclass class Volume(AudioTransform): """ diff --git a/lhotse/cut/base.py b/lhotse/cut/base.py index 1e473cafc..268db2b7a 100644 --- a/lhotse/cut/base.py +++ b/lhotse/cut/base.py @@ -187,6 +187,7 @@ class Cut: perturb_speed: Callable perturb_tempo: Callable perturb_volume: Callable + phone: Callable reverb_rir: Callable map_supervisions: Callable merge_supervisions: Callable diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index 36637e895..ad47ca381 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -918,6 +918,45 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "DataCut": supervisions=supervisions_vp, ) + def narrowband( + self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True + ) -> "DataCut": + """ + Return a new ``DataCut`` that will lazily apply narrowband effect. + + :param codec: Codec name. + :param restore_orig_sr: Restore original sampling rate. + :param affix_id: When true, we will modify the ``DataCut.id`` field + by affixing it with "_nb_{codec}". + :return: a modified copy of the current ``DataCut``. + """ + # Pre-conditions + assert ( + self.has_recording + ), "Cannot apply narrowband effect on a DataCut without Recording." + if self.has_features: + logging.warning( + "Attempting to apply narrowband effect on a DataCut that references pre-computed features. " + "The feature manifest will be detached, as we do not support feature-domain " + "volume perturbation." + ) + self.features = None + # Actual audio perturbation. + recording_nb = self.recording.narrowband( + codec=codec, restore_orig_sr=restore_orig_sr, affix_id=affix_id + ) + # Match the supervision's id (and it's underlying recording id). + supervisions_nb = [ + s.narrowband(codec=codec, affix_id=affix_id) for s in self.supervisions + ] + + return fastcopy( + self, + id=f"{self.id}_nb_{codec}" if affix_id else self.id, + recording=recording_nb, + supervisions=supervisions_nb, + ) + def normalize_loudness( self, target: float, affix_id: bool = False, **kwargs ) -> "DataCut": diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 60836db72..08fc830b8 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -1592,6 +1592,27 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet": """ return self.map(partial(_perturb_volume, factor=factor, affix_id=affix_id)) + def narrowband( + self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True + ) -> "CutSet": + """ + Return a new :class:`~lhotse.cut.CutSet` that contains narrowband effect cuts. + It requires the recording manifests to be present. + If the feature manifests are attached, they are dropped. + The supervision manifests are remaining the same. + + :param codec: Codec name. + :param restore_orig_sr: Restore original sampling rate. + :param affix_id: Should we modify the ID (useful if both versions of the same + cut are going to be present in a single manifest). + :return: a modified copy of the ``CutSet``. + """ + return self.map( + lambda cut: cut.narrowband( + codec=codec, restore_orig_sr=restore_orig_sr, affix_id=affix_id + ) + ) + def normalize_loudness( self, target: float, mix_first: bool = True, affix_id: bool = True ) -> "CutSet": diff --git a/lhotse/supervision.py b/lhotse/supervision.py index 5208593ab..631348cbf 100644 --- a/lhotse/supervision.py +++ b/lhotse/supervision.py @@ -331,6 +331,24 @@ def perturb_volume( else self.recording_id, ) + def narrowband(self, codec: str, affix_id: bool = True) -> "SupervisionSegment": + """ + Return a ``SupervisionSegment`` with modified ids. + + :param codec: Codec name. + :param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields + by affixing it with "_nb_{codec}". + :return: a modified copy of the current ``SupervisionSegment``. + """ + + return fastcopy( + self, + id=f"{self.id}_nb_{codec}" if affix_id else self.id, + recording_id=f"{self.recording_id}_nb_{codec}" + if affix_id + else self.recording_id, + ) + def reverb_rir( self, affix_id: bool = True, channel: Optional[Union[int, List[int]]] = None ) -> "SupervisionSegment": diff --git a/test/augmentation/test_torchaudio.py b/test/augmentation/test_torchaudio.py index 89d109f7b..6b6a49a24 100644 --- a/test/augmentation/test_torchaudio.py +++ b/test/augmentation/test_torchaudio.py @@ -12,6 +12,7 @@ from lhotse import MonoCut, Recording, Seconds from lhotse.augmentation import ( AudioTransform, + Narrowband, Resample, ReverbWithImpulseResponse, Speed, @@ -266,3 +267,11 @@ def test_augmentation_chain_randomized( recording=recording_aug, ) assert cut_aug.load_audio().shape[1] == cut_aug.num_samples + + +def test_narrowband(mono_audio): + narrowband = Narrowband( + codec="mulaw", source_sampling_rate=SAMPLING_RATE, restore_orig_sr=True + ) + nb = narrowband(mono_audio, SAMPLING_RATE) + assert nb.shape == mono_audio.shape From 6a177213b1ad7e1d6b8233aa3db0d86cba3cdff4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 18 Jul 2024 19:41:03 -0400 Subject: [PATCH 43/69] bump dev version to 1.26.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index ad2191947..5ff8c4f5d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.25.0 +1.26.0 From fa8cbfe6ede9633fb065433122026c7568708d49 Mon Sep 17 00:00:00 2001 From: Sofian Mejjoute <77058942+Ryu1845@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:39:50 +0200 Subject: [PATCH 44/69] Add EARS recipe (#1375) * Add EARS recipe * Add download and fix cli for the EARS dataset * Fix formatting for EARS recipe --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/ears.py | 41 +++++ lhotse/recipes/__init__.py | 3 + lhotse/recipes/ears.py | 226 +++++++++++++++++++++++++++ 5 files changed, 273 insertions(+) create mode 100644 lhotse/bin/modes/recipes/ears.py create mode 100644 lhotse/recipes/ears.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 79299eac1..89deeb7ab 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -99,6 +99,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_earnings21` * - Earnings'22 - :func:`lhotse.recipes.prepare_earnings22` + * - EARS + - :func:`lhotse.recipes.prepare_ears` * - The Edinburgh International Accents of English Corpus - :func:`lhotse.recipes.prepare_edacc` * - English Broadcast News 1997 diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index aafc871e3..174b1036a 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -29,6 +29,7 @@ from .dipco import * from .earnings21 import * from .earnings22 import * +from .ears import * from .edacc import * from .eval2000 import * from .fisher_english import * diff --git a/lhotse/bin/modes/recipes/ears.py b/lhotse/bin/modes/recipes/ears.py new file mode 100644 index 000000000..6cdefaa68 --- /dev/null +++ b/lhotse/bin/modes/recipes/ears.py @@ -0,0 +1,41 @@ +from typing import Optional, Sequence, Union + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.ears import download_ears, prepare_ears +from lhotse.utils import Pathlike + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +def ears( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, +): + """EARS data preparation.""" + prepare_ears( + corpus_dir=corpus_dir, + output_dir=output_dir, + num_jobs=num_jobs, + ) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +def ears( + target_dir: Pathlike, +): + """EARS data download.""" + download_ears( + target_dir=target_dir, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 2b5ec8338..9e06dfbb9 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -26,6 +26,7 @@ from .dipco import download_dipco, prepare_dipco from .earnings21 import download_earnings21, prepare_earnings21 from .earnings22 import download_earnings22, prepare_earnings22 +from .ears import download_ears, prepare_ears from .edacc import download_edacc, prepare_edacc from .eval2000 import prepare_eval2000 from .fisher_english import prepare_fisher_english @@ -131,6 +132,8 @@ "prepare_earnings21", "download_earnings22", "prepare_earnings22", + "download_ears", + "prepare_ears", "download_edacc", "prepare_edacc", "prepare_eval2000", diff --git a/lhotse/recipes/ears.py b/lhotse/recipes/ears.py new file mode 100644 index 000000000..dd03fcbc5 --- /dev/null +++ b/lhotse/recipes/ears.py @@ -0,0 +1,226 @@ +""" +Description taken from the abstract of the paper: +"EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation" +https://arxiv.org/abs/2406.06185 + +We release the EARS (Expressive Anechoic Recordings of Speech) dataset, a high-quality speech dataset comprising +107 speakers from diverse backgrounds, totaling in 100 hours of clean, anechoic speech data. The dataset covers +a large range of different speaking styles, including emotional speech, different reading styles, non-verbal sounds, +and conversational freeform speech. We benchmark various methods for speech enhancement and dereverberation on the +dataset and evaluate their performance through a set of instrumental metrics. In addition, we conduct a listening +test with 20 participants for the speech enhancement task, where a generative method is preferred. We introduce +a blind test set that allows for automatic online evaluation of uploaded data. Dataset download links and automatic +evaluation server can be found online. +""" + + +import json +import logging +import re +import shutil +import zipfile +from collections import defaultdict +from pathlib import Path +from typing import Dict, Iterable, Optional, Sequence, Union + +from tqdm import tqdm + +from lhotse import ( + RecordingSet, + SupervisionSegment, + SupervisionSet, + fix_manifests, + validate_recordings_and_supervisions, +) +from lhotse.recipes.utils import ( + DEFAULT_DETECTED_MANIFEST_TYPES, + TYPES_TO_CLASSES, + load_manifest, + manifests_exist, +) +from lhotse.utils import Pathlike, resumable_download + + +def _read_manifests_if_cached_no_parts( + output_dir: Optional[Pathlike], + prefix: str = "", + suffix: str = "jsonl.gz", + types: Iterable[str] = DEFAULT_DETECTED_MANIFEST_TYPES, + lazy: bool = False, +) -> Optional[Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: + """ + Loads manifests from the disk, or a subset of them if only some exist. + The manifests are searched for using the pattern ``output_dir / f'{prefix}_{manifest}_{part}.json'``, + where `manifest` is one of ``["recordings", "supervisions"]`` and ``part`` is specified in ``dataset_parts``. + This function is intended to speedup data preparation if it has already been done before. + + :param output_dir: Where to look for the files. + :param prefix: Optional common prefix for the manifest files (underscore is automatically added). + :param suffix: Optional common suffix for the manifest files ("json" by default). + :param types: Which types of manifests are searched for (default: 'recordings' and 'supervisions'). + :return: A dict with manifest (``d[dataset_part]['recording'|'manifest']``) or ``None``. + """ + if output_dir is None: + return None + if prefix and not prefix.endswith("_"): + prefix = f"{prefix}_" + if suffix.startswith("."): + suffix = suffix[1:] + if lazy and not suffix.startswith("jsonl"): + raise ValueError( + f"Only JSONL manifests can be opened lazily (got suffix: '{suffix}')" + ) + manifests = defaultdict(dict) + output_dir = Path(output_dir) + for manifest in types: + path = output_dir / f"{prefix}{manifest}.{suffix}" + if not path.is_file(): + continue + if lazy: + manifests[manifest] = TYPES_TO_CLASSES[manifest].from_jsonl_lazy(path) + else: + manifests[manifest] = load_manifest(path) + return dict(manifests) + + +def download_ears( + target_dir: Pathlike = ".", + force_download: bool = False, +) -> Path: + """ + Download and unzip the EARS dataset. + + :param target_dir: Pathlike, the path of the dir to storage the dataset. + :param force_download: Bool, if True, download the tars no matter if the tars exist. + :return: the path to downloaded and extracted directory with data. + """ + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + resumable_download( + "https://raw.githubusercontent.com/facebookresearch/ears_dataset/main/speaker_statistics.json", + filename=target_dir / "speaker_statistics.json", + force_download=force_download, + ) + resumable_download( + "https://raw.githubusercontent.com/facebookresearch/ears_dataset/main/transcripts.json", + filename=target_dir / "transcripts.json", + force_download=force_download, + ) + for part in tqdm( + range(1, 108), desc="Downloading the 107 speakers of the EARS dataset" + ): + part = f"p{part:03d}" + url = f"https://github.com/facebookresearch/ears_dataset/releases/download/dataset" + zip_name = f"{part}.zip" + zip_path = target_dir / zip_name + part_dir = target_dir / part + completed_detector = part_dir / ".completed" + if completed_detector.is_file(): + logging.info(f"Skipping {part} because {completed_detector} exists.") + continue + full_url = f"{url}/{zip_name}" + resumable_download(full_url, filename=zip_path, force_download=force_download) + shutil.rmtree(part_dir, ignore_errors=True) + with zipfile.ZipFile(zip_path) as zf: + zf.extractall(path=target_dir) + completed_detector.touch() + + return target_dir + + +def prepare_ears( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, +) -> Dict[str, Union[RecordingSet, SupervisionSet]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + + :param corpus_dir: Pathlike, the path of the data dir. + :param output_dir: Pathlike, the path where to write the manifests. + :param num_jobs: the number of parallel workers parsing the data. + :return: a Dict whose keys are 'recordings' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + dataset_parts = [f"p{spk:03d}" for spk in range(1, 108)] + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Maybe the manifests already exist: we can read them and save a bit of preparation time. + manifests = _read_manifests_if_cached_no_parts( + output_dir=output_dir, prefix="ears" + ) + + # Contents of the file + # { + # "p001": { + # "age": "36-45", + # "ethnicity": "white or caucasian", + # "gender": "male", + # "weight": "160 - 180 lbs", + # "native language": "german", + # "height": "6' - 6'3" + # }, + # ... + # } + spk2meta = json.loads((corpus_dir / "speaker_statistics.json").read_text()) + + # Contents of the file + # { + # "emo_adoration_sentences": "You're just the sweetest person I know and I am so happy to call you my friend. I had the best time with you, I just adore you. I love this gift, thank you!", + # "emo_amazement_sentences": "I just love how you can play guitar. You're so impressive. I admire your abilities so much.", + # ... + # } + utt2transcript = json.loads((corpus_dir / "transcripts.json").read_text()) + supervisions = [] + recordings_list = [] + for part in tqdm(dataset_parts, desc="Preparing EARS speakers"): + if manifests_exist(part=part, output_dir=output_dir, prefix="ears"): + logging.info(f"EARS subset: {part} already prepared - skipping.") + continue + spk_id = part + part_path = corpus_dir / part + recordings = RecordingSet.from_dir( + part_path, + "*.wav", + num_jobs=num_jobs, + recording_id=lambda path: f"{spk_id}_{path.stem}", + ) + recordings_list.append(recordings) + for rec in recordings: + utt = rec.id.split("_")[1] + meta = spk2meta[spk_id].copy() + supervisions.append( + SupervisionSegment( + id=rec.id, + recording_id=rec.id, + start=0.0, + duration=rec.duration, + channel=0, + text=utt2transcript.get(utt), + language="English", + speaker=spk_id, + gender=meta.pop("gender", None), + custom=meta, + ) + ) + + recordings = [] + for recs in recordings_list: + recordings += list(recs) + recordings = RecordingSet.from_recordings(recordings) + supervisions = SupervisionSet.from_segments(supervisions) + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions(recordings, supervisions) + + if output_dir is not None: + supervisions.to_file(output_dir / f"ears_supervisions.jsonl.gz") + recordings.to_file(output_dir / f"ears_recordings.jsonl.gz") + + manifests = {"recordings": recordings, "supervisions": supervisions} + + return manifests From bd12d5d6e2ceb70898c53b1297487abd162b0231 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 22 Jul 2024 19:15:00 -0400 Subject: [PATCH 45/69] Concurrent dynamic bucketing (#1373) * Concurrent reads in dynamic bucketing for faster start time. * Don't exceed the buffer_size; eliminate some race conditions * Missing flag * use a proper queue for concurrency * disable concurrency by default * Add a test for the concurrent implementation --- lhotse/dataset/sampling/dynamic_bucketing.py | 99 ++++++++++++++++--- .../sampling/test_dynamic_bucketing.py | 7 +- 2 files changed, 89 insertions(+), 17 deletions(-) diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index 9b9a41f1c..4f7d9d32b 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -1,9 +1,12 @@ import random +import threading +import time import warnings from bisect import bisect_right from collections import deque from dataclasses import asdict, dataclass from itertools import islice +from queue import Queue from typing import ( Any, Callable, @@ -95,6 +98,7 @@ def __init__( rank: Optional[int] = None, seed: Union[int, Literal["randomized", "trng"]] = 0, sync_buckets: bool = True, + concurrent: bool = False, strict=None, shuffle_buffer_size=None, ) -> None: @@ -131,6 +135,10 @@ def __init__( when models have quadratic input complexity. Set between 15 and 40 for transformers. :param sync_buckets: When set, we'll try to make each DDP rank sample from as close duration buckets as possible to minimize the tail worker effect. + :param concurrent: Enabling concurrency eliminates most of the waiting to pre-populate the + bucketing buffers before the sampler starts yielding examples. For tarred/Lhotse Shar data + this can speed up the start of the training. Note that enabling concurrency will cause the + sampling results to be non-deterministic. This feature is experimental. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. @@ -154,6 +162,7 @@ def __init__( self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration self.sync_buckets = sync_buckets + self.concurrent = concurrent self.rng = None check_constraint(constraint, max_duration, max_cuts) @@ -282,6 +291,7 @@ def __iter__(self) -> "DynamicBucketingSampler": shuffle=self.shuffle, rng=self.rng, bucket_rng=bucket_rng, + concurrent=self.concurrent, diagnostics=self.diagnostics, ) self.cuts_iter = iter(cuts_iter) @@ -516,6 +526,7 @@ def __init__( shuffle: bool = False, rng: random.Random = None, bucket_rng: random.Random = None, + concurrent: bool = False, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.cuts = cuts @@ -533,6 +544,7 @@ def __init__( self.rng = rng self.bucket_rng = bucket_rng self.shuffle = shuffle + self.concurrent = concurrent assert duration_bins == sorted(duration_bins), ( f"Argument list for 'duration_bins' is expected to be in " @@ -561,14 +573,19 @@ def __init__( ) # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). - self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [ - deque() for _ in range(len(duration_bins) + 1) - ] + self.buckets: List[Queue] = [Queue() for _ in range(len(duration_bins) + 1)] + + self._producer_thread = None def __iter__(self) -> Generator[CutSet, None, None]: # Init: sample `buffer_size` cuts and assign them to the right buckets. self.cuts_iter = iter(self.cuts) - self._collect_cuts_in_buckets(self.buffer_size) + + if self.concurrent: + self._start_data_producer_thread() + self._maybe_wait_for_producer() + else: + self._collect_cuts_in_buckets(self.buffer_size) state = BucketSelectionState( bucket_rng=self.bucket_rng, @@ -588,6 +605,9 @@ def __iter__(self) -> Generator[CutSet, None, None]: maybe_shuffled = pick_at_random( maybe_shuffled, rng=self.rng, out_indexes_used=indexes_used ) + else: + with sampling_bucket.mutex: + maybe_shuffled = list(sampling_bucket.queue) # Sample one batch from that bucket and yield it to the caller. batcher = DurationBatcher( maybe_shuffled, @@ -604,21 +624,26 @@ def __iter__(self) -> Generator[CutSet, None, None]: if indexes_used: # Shuffling, sort indexes of yielded elements largest -> smallest and remove them indexes_used.sort(reverse=True) - for idx in indexes_used: - del sampling_bucket[idx] + with sampling_bucket.mutex: + _q = sampling_bucket.queue + for idx in indexes_used: + del _q[idx] else: # No shuffling, remove first N for _ in range(batch_size): - sampling_bucket.popleft() + sampling_bucket.get() # Fetch new cuts and add them to appropriate buckets. - self._collect_cuts_in_buckets(batch_size) + if self.concurrent: + self._maybe_wait_for_producer() + else: + self._collect_cuts_in_buckets(batch_size) except StopIteration: pass # Cleanup. self.cuts_iter = None - def _select_bucket(self, state: BucketSelectionState) -> Deque[Cut]: + def _select_bucket(self, state: BucketSelectionState) -> Queue: if self.bucket_rng is None: # Bucket selection algo 1: # * there is just one RNG for choosing buckets and choosing samples randomly from the buckets @@ -646,7 +671,7 @@ def _select_bucket(self, state: BucketSelectionState) -> Deque[Cut]: # it will scan the neighbouring buckets until it finds one that's ready # * if no bucket is ready, we end iteration - def scan_buckets(predicate: Callable[[Deque[Cut]], bool]) -> int: + def scan_buckets(predicate: Callable[[Queue], bool]) -> int: bucket_idx = state.select_bucket_idx() def valid_idx() -> bool: @@ -689,22 +714,55 @@ def valid_idx() -> bool: # which may yield partial batches. try: state.restore(ckpt) - selected_bucket_idx = scan_buckets(lambda b: len(b) > 0) + selected_bucket_idx = scan_buckets(lambda b: b.qsize() > 0) except BucketsDontHaveEnoughData: # We exhausted the full dataset. raise StopIteration() return self.buckets[selected_bucket_idx] - def _is_ready(self, bucket: Deque[Cut]) -> bool: + def _is_ready(self, bucket: Queue) -> bool: tot = self.constraint.copy() - for c in bucket: + with bucket.mutex: + contents = list(bucket.queue) + for c in contents: tot.add(c[0] if isinstance(c, tuple) else c) if tot.close_to_exceeding(): return True return False + def _start_data_producer_thread(self): + """Start concurrent filling of the bucket buffer in a background thread.""" + + def producer(): + try: + self._source_exhausted = False + while not self._source_exhausted: + if sum(b.qsize() for b in self.buckets) == self.buffer_size: + time.sleep(0.1) + continue + cuts = next(self.cuts_iter) + duration = self.constraint.measure_length( + cuts[0] if isinstance(cuts, tuple) else cuts + ) + bucket_idx = bisect_right(self.duration_bins, duration) + self.buckets[bucket_idx].put(cuts) + except StopIteration: + self._source_exhausted = True + + self._producer_thread = threading.Thread(target=producer) + self._producer_thread.start() + + def _maybe_wait_for_producer(self): + """Triggers wait for producer if the bucket buffers are less than 10% utilized.""" + while ( + sum(b.qsize() for b in self.buckets) < self.buffer_size / 10 + and not self._source_exhausted + ): + time.sleep(1.0) + def _collect_cuts_in_buckets(self, n_cuts: int) -> None: + """Fetches ``n_cuts`` from the input data iterable. Doesn't use concurrency.""" try: for _ in range(n_cuts): cuts = next(self.cuts_iter) @@ -712,13 +770,22 @@ def _collect_cuts_in_buckets(self, n_cuts: int) -> None: cuts[0] if isinstance(cuts, tuple) else cuts ) bucket_idx = bisect_right(self.duration_bins, duration) - self.buckets[bucket_idx].append(cuts) + self.buckets[bucket_idx].put(cuts) except StopIteration: pass + def __del__(self): + if ( + self.concurrent + and self._producer_thread is not None + and self._producer_thread.is_alive() + ): + self._source_exhausted = True + self._producer_thread.join() + def pick_at_random( - bucket: Sequence[Union[Cut, Tuple[Cut, ...]]], + bucket: Queue, rng: random.Random, out_indexes_used: list, ) -> Generator[Union[Cut, Tuple[Cut, ...]], None, None]: @@ -726,6 +793,8 @@ def pick_at_random( Generator which will yield items in a sequence in a random order. It will append the indexes of items yielded during iteration via ``out_used_indexes``. """ + with bucket.mutex: + bucket = list(bucket.queue) indexes = list(range(len(bucket))) rng.shuffle(indexes) for idx in indexes: diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index 454bff58f..e7d2db019 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -115,7 +115,8 @@ def test_dynamic_bucketing_drop_last_true(): assert sum(c.duration for c in batches[2]) == 5 -def test_dynamic_bucketing_sampler(): +@pytest.mark.parametrize("concurrent", [False, True]) +def test_dynamic_bucketing_sampler(concurrent): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): if i < 5: @@ -123,7 +124,9 @@ def test_dynamic_bucketing_sampler(): else: c.duration = 2 - sampler = DynamicBucketingSampler(cuts, max_duration=5, num_buckets=2, seed=0) + sampler = DynamicBucketingSampler( + cuts, max_duration=5, num_buckets=2, seed=0, concurrent=concurrent + ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] From 21b102ca02f3b747b4d1daca2f286d095e1ac4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 24 Jul 2024 12:25:27 -0400 Subject: [PATCH 46/69] Refactor bucket selection for customization (#1377) * Refactor bucket selection to allow customization * Extend the API further * Prune imports --- lhotse/dataset/sampling/base.py | 21 +++++++++++++++++++- lhotse/dataset/sampling/dynamic_bucketing.py | 18 ++++++++--------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index ca43b36fc..6545b1671 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -2,6 +2,7 @@ import os import warnings from abc import ABCMeta, abstractmethod +from bisect import bisect_right from copy import deepcopy from dataclasses import asdict, dataclass from math import isclose @@ -15,7 +16,7 @@ from lhotse.cut.text import TextExample from lhotse.lazy import Dillable from lhotse.manipulation import combine -from lhotse.utils import Seconds, ifnone, is_none_or_gt +from lhotse.utils import Seconds, exactly_one_not_null, ifnone, is_none_or_gt class CutSampler(Sampler, Dillable): @@ -407,6 +408,24 @@ def measure_length(self, example: Any) -> float: """ pass + def select_bucket( + self, buckets: Any, example: Any = None, example_len: Any = None + ) -> int: + """ + Given a list of buckets and an example, assign the example to the correct bucket. + This is leveraged by bucketing samplers. + + Default implementation assumes that buckets are expressed in the same units as + the output of :meth:`SamplingConstraint.measure_length` and returns the index + of the first bucket that has a larger length than the example. + """ + assert exactly_one_not_null( + example, example_len + ), f"select_bucket requires either example= or example_len= as the input (we received {example=} and {example_len=})." + if example_len is None: + example_len = self.measure_length(example) + return bisect_right(buckets, example_len) + def copy(self) -> "SamplingConstraint": """Return a shallow copy of this constraint.""" return copy.copy(self) diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index 4f7d9d32b..1cf8061a5 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -2,8 +2,6 @@ import threading import time import warnings -from bisect import bisect_right -from collections import deque from dataclasses import asdict, dataclass from itertools import islice from queue import Queue @@ -350,7 +348,9 @@ def add(self, example: Cut) -> None: selecting the right property from the input ``cut`` object. """ seqlen = self.measure_length(example) - bucket_idx = bisect_right(self.max_seq_len_buckets, seqlen) + bucket_idx = self.select_bucket( + buckets=self.max_seq_len_buckets, example_len=seqlen + ) assert bucket_idx < len(self.max_seq_len_buckets), ( f"Received example with sequence length {seqlen} that exceeds " f"the highest allowed length {self.max_seq_len_buckets[-1]}." @@ -742,10 +742,10 @@ def producer(): time.sleep(0.1) continue cuts = next(self.cuts_iter) - duration = self.constraint.measure_length( - cuts[0] if isinstance(cuts, tuple) else cuts + bucket_idx = self.constraint.select_bucket( + buckets=self.duration_bins, + example=cuts[0] if isinstance(cuts, tuple) else cuts, ) - bucket_idx = bisect_right(self.duration_bins, duration) self.buckets[bucket_idx].put(cuts) except StopIteration: self._source_exhausted = True @@ -766,10 +766,10 @@ def _collect_cuts_in_buckets(self, n_cuts: int) -> None: try: for _ in range(n_cuts): cuts = next(self.cuts_iter) - duration = self.constraint.measure_length( - cuts[0] if isinstance(cuts, tuple) else cuts + bucket_idx = self.constraint.select_bucket( + buckets=self.duration_bins, + example=cuts[0] if isinstance(cuts, tuple) else cuts, ) - bucket_idx = bisect_right(self.duration_bins, duration) self.buckets[bucket_idx].put(cuts) except StopIteration: pass From 2b756225a3a543ec48f0bb8dd637433918049275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 26 Jul 2024 11:59:30 -0400 Subject: [PATCH 47/69] Bump dev version to 1.27.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 5ff8c4f5d..5db08bf2d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.26.0 +1.27.0 From bcd1e22332ecd3b67628f10204953c45d5cb329b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 2 Aug 2024 18:18:46 -0400 Subject: [PATCH 48/69] Cap the 'trng' random seeds to 2**31 avoiding numpy error (#1379) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- lhotse/dataset/dataloading.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/lhotse/dataset/dataloading.py b/lhotse/dataset/dataloading.py index 88e15d386..2afb66d2d 100644 --- a/lhotse/dataset/dataloading.py +++ b/lhotse/dataset/dataloading.py @@ -1,6 +1,7 @@ import os import random import secrets +import sys from functools import partial from typing import Callable, Literal, Optional, Union @@ -68,7 +69,7 @@ def worker_init_fn( os.environ["WORLD_SIZE"] = str(world_size) -def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int: +def resolve_seed(seed: Union[int, Literal["trng", "randomized"], None]) -> int: """ Resolves the special values of random seed supported in Lhotse. @@ -83,16 +84,25 @@ def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int: If we are not in a dataloading worker (or ``num_workers`` was set to ``0``), we'll return Python's ``random`` module global seed. """ + + # Specific number provided: use it. if isinstance(seed, int): return seed + # No request for a specific type of random seed resolution: return Python's global random seed. + if seed is None: + return random.getstate()[1][0] + + # Deterministic randomized random seed resolution: + # Each dataloading worker and DDP rank gets a separate random seed. + # If we're not in a dataloading worker, use global RNG's current seed. if seed == "randomized": worker_info = torch.utils.data.get_worker_info() if worker_info is None: - # not in a dataloader sub-process: get python global random seed + # Not in a dataloader sub-process: get Python's global random seed. return random.getstate()[1][0] else: - # in a dataloader sub-process: read out the seed we assigned to it + # In a dataloader sub-process: read out the seed we assigned to it. assert LHOTSE_PROCESS_SEED in os.environ, ( "Requested seed='randomized' for shuffling shards differently " "on each DataLoader node and worker, " @@ -100,12 +110,16 @@ def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int: ) return int(os.environ[LHOTSE_PROCESS_SEED]) + # True-random number generator requested for seed generation ("complete randomness"). if seed == "trng": - return secrets.randbelow(2**32) + # 2**32 may trigger the following exception if you add anything: + # File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding + # ValueError: Seed must be between 0 and 2**32 - 1 + return secrets.randbelow(2**31) raise ValueError( f"Unexpected type or value of seed: {type(seed)=} {seed=}. " - f"Supported values are: int, 'trng', and 'randomized'." + f"Supported values are: None, int, 'trng', and 'randomized'." ) From 748cd50fe51786dbb4fe22b2ac5cfe949f722238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 2 Aug 2024 18:30:35 -0400 Subject: [PATCH 49/69] `CutSet`.prefetch() for background cuts loading during iteration (#1380) --- lhotse/cut/set.py | 39 ++++++++++++++++++++++++++++++++++++--- test/cut/test_cut_set.py | 6 ++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 08fc830b8..927558bce 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -45,7 +45,6 @@ from lhotse.lazy import ( AlgorithmMixin, Dillable, - LazyFilter, LazyFlattener, LazyIteratorChain, LazyManifestIterator, @@ -63,7 +62,6 @@ Seconds, compute_num_frames, compute_num_samples, - deprecated, exactly_one_not_null, fastcopy, ifnone, @@ -2033,7 +2031,6 @@ def compute_and_store_features_batch( """ from concurrent.futures import ThreadPoolExecutor - import torch from torch.utils.data import DataLoader from lhotse.dataset import SimpleCutSampler, UnsupervisedWaveformDataset @@ -2523,6 +2520,36 @@ def transform_text(self, transform_fn: Callable[[str], str]) -> "CutSet": partial(_transform_text, transform_fn=transform_fn) ) + def prefetch(self, buffer_size: int = 10) -> "CutSet": + """ + Pre-fetches the CutSet elements in a background process. + Useful for enabling concurrent reading/processing/writing in ETL-style tasks. + + .. caution:: This method internally uses a PyTorch DataLoader with a single worker. + It is not suitable for use in typical PyTorch training scripts. + + .. caution:: If you run into pickling issues when using this method, you're also likely + using .filter/.map methods with a lambda function. + Please set ``lhotse.set_dill_enabled(True)`` to resolve these issues, or convert lambdas + to regular functions + ``functools.partial`` + + """ + from torch.utils.data import DataLoader + + from lhotse.dataset import DynamicCutSampler, IterableDatasetWrapper + + return CutSet( + DataLoader( + dataset=IterableDatasetWrapper( + _BackgroundCutFetcher(), + DynamicCutSampler(self, max_cuts=1, rank=0, world_size=1), + ), + batch_size=None, + num_workers=1, + prefetch_factor=buffer_size, + ) + ) + def __repr__(self) -> str: try: len_val = len(self) @@ -2554,6 +2581,12 @@ def __iter__(self) -> Iterable[Cut]: yield from self.cuts +class _BackgroundCutFetcher(torch.utils.data.Dataset): + def __getitem__(self, cuts: CutSet): + assert len(cuts) == 1 + return cuts[0] + + def mix( reference_cut: Cut, mixed_in_cut: Cut, diff --git a/test/cut/test_cut_set.py b/test/cut/test_cut_set.py index 209978d6e..a8a5b866f 100644 --- a/test/cut/test_cut_set.py +++ b/test/cut/test_cut_set.py @@ -81,6 +81,12 @@ def test_cut_set_iteration(cut_set_with_mixed_cut): assert len(cuts) == 3 +def test_cut_set_prefetch_iteration(cut_set_with_mixed_cut): + cuts = list(cut_set_with_mixed_cut.prefetch()) + assert len(cut_set_with_mixed_cut) == 3 + assert len(cuts) == 3 + + def test_cut_set_holds_both_simple_and_mixed_cuts(cut_set_with_mixed_cut): simple_cuts = cut_set_with_mixed_cut.simple_cuts assert all(isinstance(c, MonoCut) for c in simple_cuts) From bf37599aea5b880404df2444800beef64f347c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 7 Aug 2024 21:17:59 -0400 Subject: [PATCH 50/69] Include a copyright NOTICE listing major copyright holders (#1381) * Include a copyright NOTICE * Include a copyright NOTICE --- NOTICE | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 NOTICE diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..46664aefc --- /dev/null +++ b/NOTICE @@ -0,0 +1,20 @@ +Lhotse +Copyright 2020-2024 Piotr Żelasko +Copyright 2020-2024 Johns Hopkins University +Copyright 2020-2024 Xiaomi Corporation +Copyright 2022-2023 Meaning.Team Inc. +Copyright 2023-2024 NVIDIA Corporation + +This repository includes software developed by: +- Johns Hopkins University +- Xiaomi Corporation +- Meaning.Team Inc. +- NVIDIA Corporation +- other organizations and individuals. + +This project includes contributions from various organizations and individuals. +Only major copyright holders are listed here. +For a complete list of contributors, please refer to the project's version control history. + +Licensed under the Apache License, Version 2.0 (the "License"). +See the LICENSE file for the full contents of the license. From 9bea2dbb88c3a4adec60c031d50f54fa124aecc6 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:41:32 -0700 Subject: [PATCH 51/69] Added has_custom to MixedCut (#1383) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- lhotse/cut/mixed.py | 8 ++++++++ test/cut/test_custom_attrs.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/lhotse/cut/mixed.py b/lhotse/cut/mixed.py index 19439d84a..01acf248d 100644 --- a/lhotse/cut/mixed.py +++ b/lhotse/cut/mixed.py @@ -290,6 +290,14 @@ def __getattr__(self, name: str) -> Any: f"when a MixedCut consists of more than one MonoCut with that attribute)." ) + def has_custom(self, name: str) -> bool: + ( + non_padding_idx, + mono_cut, + ) = self._assert_one_data_cut_with_attr_and_return_it_with_track_index(name) + + return hasattr(mono_cut, name) + def load_custom(self, name: str) -> np.ndarray: """ Load custom data as numpy array. The custom data is expected to have diff --git a/test/cut/test_custom_attrs.py b/test/cut/test_custom_attrs.py index 693dfa4eb..794b01b9d 100644 --- a/test/cut/test_custom_attrs.py +++ b/test/cut/test_custom_attrs.py @@ -441,3 +441,32 @@ def test_multi_cut_custom_multi_recording_channel_selector(): audio = two_channel_out.load_target_recording() assert audio.shape == (2, 16000) np.testing.assert_allclose(ref_tgt_audio[::3], audio) + + +def test_padded_cut_custom_recording(): + original_duration = 1.0 # seconds + padded_duration = 2.0 # seconds + + # prepare cut + cut = dummy_cut(unique_id=0, with_data=True, duration=original_duration) + cut.target_recording = dummy_recording( + unique_id=1, duration=cut.duration, with_data=True + ) + target_recording = cut.load_target_recording() + + # prepare padded cut (MixedCut) + padded_cut = cut.pad(duration=padded_duration) + + # check the padded cut (MixedCut) has the custom attribute + assert padded_cut.has_custom("target_recording") + + # load the audio from the padded cut + padded_target_recording = padded_cut.load_target_recording() + + # check the non-padded component is matching + np.testing.assert_allclose( + padded_target_recording[:, : cut.num_samples], target_recording + ) + + # check the padded component is zero + assert np.all(padded_target_recording[:, cut.num_samples :] == 0) From e78add505f7182f24076f80f7f3eb12a1651dae9 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 14 Aug 2024 19:10:09 +0800 Subject: [PATCH 52/69] [Recipe] Wenetspeech4tts (#1384) * add wenetspeech4tts recipe * fix wenetspeech4tts recipe * fix wenetspeech4tts recipe float * fix wenetspeech4tts recipe typo * fix wenetspeech4tts recipe typo * add wenetspeech4tts doc --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/wenetspeech4tts.py | 43 +++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/wenetspeech4tts.py | 182 ++++++++++++++++++++ 5 files changed, 229 insertions(+) create mode 100644 lhotse/bin/modes/recipes/wenetspeech4tts.py create mode 100644 lhotse/recipes/wenetspeech4tts.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 89deeb7ab..5f4f31bea 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -201,6 +201,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_voxpopuli` * - WenetSpeech - :func:`lhotse.recipes.prepare_wenet_speech` + * - WenetSpeech4TTS + - :func:`lhotse.recipes.prepare_wenetspeech4tts` * - YesNo - :func:`lhotse.recipes.prepare_yesno` * - Eval2000 diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index 174b1036a..e9af9185b 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -86,5 +86,6 @@ from .voxconverse import * from .voxpopuli import * from .wenet_speech import * +from .wenetspeech4tts import * from .xbmu_amdo31 import * from .yesno import * diff --git a/lhotse/bin/modes/recipes/wenetspeech4tts.py b/lhotse/bin/modes/recipes/wenetspeech4tts.py new file mode 100644 index 000000000..42b412f18 --- /dev/null +++ b/lhotse/bin/modes/recipes/wenetspeech4tts.py @@ -0,0 +1,43 @@ +from typing import Sequence + +import click + +from lhotse.bin.modes import prepare +from lhotse.recipes import prepare_wenetspeech4tts +from lhotse.utils import Pathlike + +__all__ = ["wenetspeech4tts"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many jobs to use (can give good speed-ups with slow disks).", +) +@click.option( + "-p", + "--dataset-parts", + type=str, + default=["all"], + multiple=True, + help="List of dataset parts to prepare. To prepare multiple parts, pass each with `-p` " + "Example: `-p Basic -p Premium`", +) +def wenetspeech4tts( + corpus_dir: Pathlike, + output_dir: Pathlike, + dataset_parts: Sequence[str], + num_jobs: int, +): + """WenetSpeech4TTS data preparation.""" + prepare_wenetspeech4tts( + corpus_dir, + output_dir=output_dir, + num_jobs=num_jobs, + dataset_parts=dataset_parts, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 9e06dfbb9..307d508bd 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -85,6 +85,7 @@ from .voxconverse import download_voxconverse, prepare_voxconverse from .voxpopuli import download_voxpopuli, prepare_voxpopuli from .wenet_speech import prepare_wenet_speech +from .wenetspeech4tts import prepare_wenetspeech4tts from .xbmu_amdo31 import download_xbmu_amdo31, prepare_xbmu_amdo31 from .yesno import download_yesno, prepare_yesno diff --git a/lhotse/recipes/wenetspeech4tts.py b/lhotse/recipes/wenetspeech4tts.py new file mode 100644 index 000000000..2b7ad0ddb --- /dev/null +++ b/lhotse/recipes/wenetspeech4tts.py @@ -0,0 +1,182 @@ +""" +This recipe supports Chinese TTS corpora: WenetSpeech4TTS. + +Paper: https://arxiv.org/abs/2406.05763v3 +HuggingFace Dataset: https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS + +Download using huggingface-cli: +huggingface-cli login +huggingface-cli download --repo-type dataset --local-dir $DATA_DIR Wenetspeech4TTS/WenetSpeech4TTS + +Extract the downloaded data: +for folder in Standard Premium Basic; do + for file in "$folder"/*.tar.gz; do + tar -xzvf "$file" -C "$folder" + done +done +""" +import logging +import re +import shutil +import tarfile +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +from tqdm import tqdm + +from lhotse import ( + SupervisionSegment, + SupervisionSet, + fix_manifests, + validate_recordings_and_supervisions, +) +from lhotse.audio import Recording, RecordingSet +from lhotse.recipes.utils import manifests_exist, read_manifests_if_cached +from lhotse.utils import Pathlike, resumable_download, safe_extract + +WENETSPEECH4TTS = ( + "Basic", + "Premium", + "Standard", +) + + +def prepare_wenetspeech4tts( + corpus_dir: Pathlike, + dataset_parts: Union[str, Sequence[str]] = "Basic", + output_dir: Optional[Pathlike] = None, + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + + :param corpus_dir: Pathlike, the path of the data dir. + :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'Basic', 'Premium'. + By default we will prepare all parts. + :param output_dir: Pathlike, the path where to write the manifests. + :param num_jobs: the number of parallel workers parsing the data. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + if dataset_parts == "all" or dataset_parts[0] == "all": + dataset_parts = WENETSPEECH4TTS + elif isinstance(dataset_parts, str): + assert ( + dataset_parts in WENETSPEECH4TTS + ), f"Unsupported dataset part: {dataset_parts}" + dataset_parts = [dataset_parts] + + manifests = {} + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + # Maybe the manifests already exist: we can read them and save a bit of preparation time. + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=output_dir, prefix="wenetspeech4tts" + ) + + basic_wav_scp_dict = {} + premium_wav_scp_dict = {} + standard_wav_scp_dict = {} + with open(corpus_dir / "filelists" / "Basic_filelist.lst") as f: + for line in f: + line = line.strip().split() + basic_wav_scp_dict[line[0]] = line[1] + if "Basic" not in line[1]: + standard_wav_scp_dict[line[0]] = line[1] + if "Premium" in line[1]: + premium_wav_scp_dict[line[0]] = line[1] + + basic_dns_mos_dict = {} + premium_dns_mos_dict = {} + standard_dns_mos_dict = {} + with open(corpus_dir / "DNSMOS_P808Scores" / "Basic_DNSMOS.lst") as f: + for line in f: + line = line.strip().split() + basic_dns_mos_dict[line[0]] = float(line[1]) + with open(corpus_dir / "DNSMOS_P808Scores" / "Premium_DNSMOS.lst") as f: + for line in f: + line = line.strip().split() + premium_dns_mos_dict[line[0]] = float(line[1]) + with open(corpus_dir / "DNSMOS_P808Scores" / "Standard_DNSMOS.lst") as f: + for line in f: + line = line.strip().split() + standard_dns_mos_dict[line[0]] = float(line[1]) + + for part in dataset_parts: + if manifests_exist(part=part, output_dir=output_dir, prefix="wenetspeech4tts"): + logging.info(f"WenetSpeech4TTS subset: {part} already prepared - skipping.") + continue + recordings = [] + supervisions = [] + if part == "Premium": + wav_scp_dict = premium_wav_scp_dict + dns_mos_dict = premium_dns_mos_dict + elif part == "Standard": + wav_scp_dict = standard_wav_scp_dict + dns_mos_dict = standard_dns_mos_dict + else: + wav_scp_dict = basic_wav_scp_dict + dns_mos_dict = basic_dns_mos_dict + for wav_name, wav_path in tqdm( + wav_scp_dict.items(), desc=f"Preparing WenetSpeech4TTS {part}" + ): + # get the actual wav path, remove the prefix '../' + # e.g. ../Premium/WenetSpeech4TTS_Premium_9/wavs/X0000015306_83500032_S00110-S00112.wav -> Premium/WenetSpeech4TTS_Premium_9/wavs/X0000015306_83500032_S00110-S00112.wav + assert wav_path.startswith("../") + wav_path = corpus_dir / wav_path[3:] + if not wav_path.is_file(): + logging.warning(f"No such file: {wav_path}") + continue + recording = Recording.from_file(wav_path) + recordings.append(recording) + + # get the text path + # e.g. ../Premium/WenetSpeech4TTS_Premium_9/txts/X0000015306_83500032_S00110-S00112.txt + txt_path = ( + wav_path.parent.parent + / "txts" + / wav_path.name.replace("wavs", "txts").replace(".wav", ".txt") + ) + if not txt_path.is_file(): + logging.warning(f"No such file: {txt_path}") + continue + with open(txt_path, "r") as f: + lines = f.readlines() + text = lines[0].strip().split("\t")[1] + timestamp = lines[1].strip() + supervisions.append( + SupervisionSegment( + id=wav_name, + recording_id=wav_name, + start=0.0, + duration=recording.duration, + channel=0, + language="Chinese", + text=text, + custom={ + "timestamp": timestamp, + "dns_mos": dns_mos_dict.get(wav_name, None), + }, + ) + ) + recordings = RecordingSet.from_recordings(recordings) + supervisions = SupervisionSet.from_segments(supervisions) + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions(recordings, supervisions) + + if output_dir is not None: + supervisions.to_file( + output_dir / f"wenetspeech4tts_supervisions_{part}.jsonl.gz" + ) + recordings.to_file( + output_dir / f"wenetspeech4tts_recordings_{part}.jsonl.gz" + ) + + manifests[part] = {"recordings": recordings, "supervisions": supervisions} + + return manifests From 66b95bae6eef1be7394b88c790c9b6ee43b5ff79 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 14 Aug 2024 19:12:16 +0800 Subject: [PATCH 53/69] [Recipe] Spatial LibriSpeech (#1386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init commit * added dependencies for unit_tests * fixed compatibility for python 3.8 * fixed base_url * fixed metadata_url * Update spatial_librispeech.py * Update spatial_librispeech.py * minor fixes * multi-threaded 🪢 * Update spatial_librispeech.py * finalize the recipe * minor updates * fixed missing import cmd --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + .../bin/modes/recipes/spatial_librispeech.py | 88 ++++++ lhotse/recipes/__init__.py | 4 + lhotse/recipes/spatial_librispeech.py | 269 ++++++++++++++++++ 5 files changed, 364 insertions(+) create mode 100644 lhotse/bin/modes/recipes/spatial_librispeech.py create mode 100644 lhotse/recipes/spatial_librispeech.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 5f4f31bea..6a5be4f97 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -173,6 +173,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_reazonspeech` * - RIRs and Noises Corpus (OpenSLR 28) - :func:`lhotse.recipes.prepare_rir_noise` + * - Spatial-LibriSpeech + - :func:`lhotse.recipes.prepare_spatial_librispeech` * - Speech Commands - :func:`lhotse.recipes.prepare_speechcommands` * - SpeechIO diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index e9af9185b..d0ddd4c84 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -68,6 +68,7 @@ from .reazonspeech import * from .rir_noise import * from .slu import * +from .spatial_librispeech import * from .speechcommands import * from .speechio import * from .spgispeech import * diff --git a/lhotse/bin/modes/recipes/spatial_librispeech.py b/lhotse/bin/modes/recipes/spatial_librispeech.py new file mode 100644 index 000000000..838dbf740 --- /dev/null +++ b/lhotse/bin/modes/recipes/spatial_librispeech.py @@ -0,0 +1,88 @@ +from typing import Sequence + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.spatial_librispeech import ( + download_spatial_librispeech, + prepare_spatial_librispeech, +) +from lhotse.utils import Pathlike + +__all__ = ["spatial_librispeech"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-p", + "--dataset-parts", + type=str, + default=["all"], + multiple=True, + help="List of dataset parts to prepare. To prepare multiple parts, pass each with `-p` " + "Example: `-p train -p test`", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +@click.option( + "--normalize-text", + type=click.Choice(["none", "lower"], case_sensitive=False), + default="none", + help="Conversion of transcripts to lower-case (originally in upper-case).", + show_default=True, +) +def spatial_librispeech( + corpus_dir: Pathlike, + output_dir: Pathlike, + dataset_parts: Sequence[str], + normalize_text: str, + num_jobs: int, +): + """Spatial-LibriSpeech ASR data preparation.""" + if len(dataset_parts) == 1: + dataset_parts = dataset_parts[0] + prepare_spatial_librispeech( + corpus_dir, + output_dir=output_dir, + dataset_parts=dataset_parts, + normalize_text=normalize_text, + num_jobs=num_jobs, + ) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +@click.option( + "-p", + "--dataset-parts", + type=str, + default=["all"], + multiple=True, + help="List of dataset parts to download. To prepare multiple parts, pass each with `-p` " + "Example: `-p train -p test`", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +def spatial_librispeech( + target_dir: Pathlike, + dataset_parts: Sequence[str], + num_jobs: int, +): + """Spatial-LibriSpeech download.""" + if len(dataset_parts) == 1: + dataset_parts = dataset_parts[0] + download_spatial_librispeech( + target_dir, dataset_parts=dataset_parts, num_jobs=num_jobs + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 307d508bd..b0c909d99 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -69,6 +69,10 @@ from .reazonspeech import download_reazonspeech, prepare_reazonspeech from .rir_noise import download_rir_noise, prepare_rir_noise from .slu import prepare_slu +from .spatial_librispeech import ( + download_spatial_librispeech, + prepare_spatial_librispeech, +) from .speechcommands import download_speechcommands, prepare_speechcommands from .speechio import prepare_speechio from .spgispeech import download_spgispeech, prepare_spgispeech diff --git a/lhotse/recipes/spatial_librispeech.py b/lhotse/recipes/spatial_librispeech.py new file mode 100644 index 000000000..d33ec3478 --- /dev/null +++ b/lhotse/recipes/spatial_librispeech.py @@ -0,0 +1,269 @@ +import logging +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +from tqdm.auto import tqdm + +from lhotse import fix_manifests, validate_recordings_and_supervisions +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike, resumable_download + +SPATIAL_LIBRISPEECH = ("train", "test") +BASE_URL = "https://docs-assets.developer.apple.com/ml-research/datasets/spatial-librispeech/v1" +META_DATA_URL = "https://docs-assets.developer.apple.com/ml-research/datasets/spatial-librispeech/v1/metadata.parquet" + + +def _download_and_save_audio(target_file: Pathlike, url: str): + # Implementation from https://github.com/apple/ml-spatial-librispeech/pull/1/ + # Use the requests module to avoid the 403 forbidden error + def _download_file(url: str) -> bytes: + """This function downloads and returns the content of the given url + Args: + url (str): the url of the file to be downloaded + Raises: + e: The exception that is raised by the request module + Returns: + file_content (bytes): The file content downloaded from the url + """ + + try: + import requests + except ImportError: + raise ImportError( + "The Spatial LibriSpeech recipe requires requests dependency to download the dataset. You can install the dependency using: pip install requests" + ) + + try: + file_content = requests.get(url, allow_redirects=True).content + return file_content + except requests.exceptions.RequestException as e: + raise e + + # Implementation from https://github.com/apple/ml-spatial-librispeech/pull/1/ + def _save_audio_content(target_file: str, file_content: bytes): + """This function saves the downloaded content passed via `file_content' in the `target_file' + Args: + target_file (str): the target path for the file content to be saved to + file_content (bytes): the content to be saved + + Raises: + e: the IOError raised by the writing operation + """ + try: + with open(target_file, "wb") as file: + file.write(file_content) + except IOError as e: + raise e + + file_content = _download_file(url) + _save_audio_content(target_file, file_content) + + +def download_spatial_librispeech( + target_dir: Pathlike = ".", + dataset_parts: Union[str, Sequence[str]] = SPATIAL_LIBRISPEECH, + force_download: bool = False, + base_url: str = BASE_URL, + num_jobs: int = 1, +) -> Path: + """ + Download the Spatial-LibriSpeech dataset. + + :param target_dir: Pathlike, the path of the dir to storage the dataset. + :param dataset_parts: "all" or a list of splits (e.g. ["train", "test"]) to download. + :param force_download: Bool, if True, download the tars no matter if the tars exist. + :param base_url: str, the url of the resource. + :return: the path to downloaded and extracted directory with data. + """ + + try: + import pandas as pd + except ImportError: + raise ImportError( + "The Spatial LibriSpeech recipe requires pandas, pyarrow and fastparquet dependency to parse parquet formatted metadata. You can install the dependencies using: pip install pandas pyarrow fastparquet" + ) + + def _download_spatial_librispeech_audio_files( + target_dir: Pathlike, + dataset_parts: Sequence[str], + metadata: pd.DataFrame, + base_url: str, + force_download: bool = False, + num_jobs: int = 1, + ): + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + audio_url = f"{base_url}/ambisonics" + from concurrent.futures.thread import ThreadPoolExecutor + + for part in dataset_parts: + part_dir = target_dir / part + part_dir.mkdir(parents=True, exist_ok=True) + + with ThreadPoolExecutor(num_jobs) as ex: + for sample_id, split in tqdm( + zip(metadata["sample_id"], metadata["split"]), + total=len(metadata["sample_id"]), + ): + if split not in dataset_parts: + continue + recording_path = target_dir / split / f"{sample_id:06}.flac" + recording_url = f"{audio_url}/{sample_id:06}.flac" + if not recording_path.exists() or force_download: + ex.submit(_download_and_save_audio, recording_path, recording_url) + + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + if dataset_parts == "all": + dataset_parts = SPATIAL_LIBRISPEECH + else: + dataset_parts = ( + [dataset_parts] if isinstance(dataset_parts, str) else dataset_parts + ) + for part in dataset_parts: + assert part in SPATIAL_LIBRISPEECH, f"Unknown dataset part: {part}" + + corpus_dir = target_dir / "Spatial-LibriSpeech" + corpus_dir.mkdir(parents=True, exist_ok=True) + + completed_detector = corpus_dir / ".completed" + if completed_detector.is_file(): + logging.info(f"Skipping download, found {completed_detector}.") + return corpus_dir + + metadata_path = corpus_dir / "metadata.parquet" + if not metadata_path.is_file() or force_download: + resumable_download(META_DATA_URL, metadata_path, force_download=force_download) + elif metadata_path.is_file(): + logging.info(f"Skipping download, found {metadata_path}.") + + metadata = pd.read_parquet(metadata_path) + try: + _download_spatial_librispeech_audio_files( + target_dir=corpus_dir / "audio_files", + dataset_parts=dataset_parts, + metadata=metadata, + base_url=base_url, + force_download=force_download, + num_jobs=num_jobs, + ) + except Exception as e: + logging.error(f"Failed to download audio files: {e}") + raise e + + completed_detector.touch() + return corpus_dir + + +def prepare_spatial_librispeech( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + dataset_parts: Union[str, Sequence[str]] = SPATIAL_LIBRISPEECH, + normalize_text: str = "none", + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions. + When all the manifests are available in the ``output_dir``, it will simply read and return them. + + :param corpus_dir: Pathlike, the path of the data dir. + :param output_dir: Pathlike, the path where to write the manifests. + :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train', 'test'. + By default we will infer which parts are available in ``corpus_dir``. + :param normalize_text: str, "none" or "lower", + for "lower" the transcripts are converted to lower-case. + :param num_jobs: int, number of parallel threads used for 'parse_utterance' calls. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. + """ + + try: + import pandas as pd + except ImportError: + raise ImportError( + "The Spatial LibriSpeech recipe requires pandas, pyarrow and fastparquet dependency to parse parquet formatted metadata. You can install the dependencies using: pip install pandas pyarrow fastparquet" + ) + + corpus_dir = Path(corpus_dir) + output_dir = Path(output_dir) if output_dir is not None else corpus_dir + output_dir.mkdir(parents=True, exist_ok=True) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + if dataset_parts == "all": + dataset_parts = SPATIAL_LIBRISPEECH + else: + dataset_parts = ( + [dataset_parts] if isinstance(dataset_parts, str) else dataset_parts + ) + for part in dataset_parts: + assert part in SPATIAL_LIBRISPEECH, f"Unknown dataset part: {part}" + + metadata_path = corpus_dir / "metadata.parquet" + assert metadata_path.is_file(), f"{metadata_path} not found" + metadata = pd.read_parquet(metadata_path) + + manifests = {} + + for part in dataset_parts: + assert part in SPATIAL_LIBRISPEECH, f"Unknown dataset part: {part}" + logging.info(f"Processing {part} split...") + part_dir = corpus_dir / "audio_files" / part + recording_set = RecordingSet.from_dir( + part_dir, + pattern="*.flac", + num_jobs=num_jobs, + recording_id=lambda x: x.stem, + ) + + supervision_segments = [] + part_metadata = metadata[metadata["split"] == part] + for _, row in tqdm( + part_metadata.iterrows(), + total=len(part_metadata["sample_id"]), + desc=f"Processing supervision segments for split: {part}", + ): + recording_id = f"{row['sample_id']:06}" + start = 0 + duration = recording_set[recording_id].duration + channel = recording_set[recording_id].channel_ids + text = row["speech/librispeech_metadata/transcription"] + speaker = row["speech/librispeech_metadata/reader_id"] + gender = row["speech/librispeech_metadata/reader_sex"] + segment = SupervisionSegment( + id=recording_id, + recording_id=recording_id, + start=start, + duration=duration, + channel=channel, + text=text, + gender=gender, + speaker=speaker, + ) + supervision_segments.append(segment) + supervision_set = SupervisionSet.from_segments(supervision_segments) + + # Normalize text to lowercase + if normalize_text == "lower": + to_lower = lambda text: text.lower() + supervision_set = SupervisionSet.from_segments( + [s.transform_text(to_lower) for s in supervision_set] + ) + + recording_set, supervision_set = fix_manifests(recording_set, supervision_set) + validate_recordings_and_supervisions(recording_set, supervision_set) + + if output_dir is not None: + recording_set.to_file( + output_dir / f"spatial-librispeech_recordings_{part}.jsonl.gz" + ) + supervision_set.to_file( + output_dir / f"spatial-librispeech_supervisions_{part}.jsonl.gz" + ) + + manifests[part] = { + "recordings": recording_set, + "supervisions": supervision_set, + } + + return manifests From 170046fd54853eb03d277b550f9b83ab229d608a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 11:21:26 -0400 Subject: [PATCH 54/69] =?UTF-8?q?Fix=20to=20fixed=20batch=20size=20bucketi?= =?UTF-8?q?ng=20and=20audio=20loading=20network=20connectio=E2=80=A6=20(#1?= =?UTF-8?q?387)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix to fixed batch size bucketing and audio loading network connection resets * Fix tests and add more 'paranoia' tests --- lhotse/audio/utils.py | 2 + lhotse/dataset/sampling/base.py | 4 +- lhotse/dataset/sampling/stateless.py | 19 ++++- test/audio/test_audio_reads.py | 26 +++++++ .../sampling/test_dynamic_bucketing.py | 77 +++++++++++++++++-- test/dataset/sampling/test_sampling.py | 24 ++++++ .../sampling/test_stateless_sampler.py | 6 +- 7 files changed, 145 insertions(+), 13 deletions(-) diff --git a/lhotse/audio/utils.py b/lhotse/audio/utils.py index c4a604234..b5f7debd5 100644 --- a/lhotse/audio/utils.py +++ b/lhotse/audio/utils.py @@ -125,6 +125,7 @@ def suppress_audio_loading_errors(enabled: bool = True): AudioLoadingError, DurationMismatchError, NonPositiveEnergyError, + ConnectionResetError, # when reading from object stores / network sources enabled=enabled, ): yield @@ -141,6 +142,7 @@ def suppress_video_loading_errors(enabled: bool = True): AudioLoadingError, DurationMismatchError, NonPositiveEnergyError, + ConnectionResetError, # when reading from object stores / network sources enabled=enabled, ): yield diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index 6545b1671..26adda779 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -2,7 +2,7 @@ import os import warnings from abc import ABCMeta, abstractmethod -from bisect import bisect_right +from bisect import bisect_left from copy import deepcopy from dataclasses import asdict, dataclass from math import isclose @@ -424,7 +424,7 @@ def select_bucket( ), f"select_bucket requires either example= or example_len= as the input (we received {example=} and {example_len=})." if example_len is None: example_len = self.measure_length(example) - return bisect_right(buckets, example_len) + return bisect_left(buckets, example_len) def copy(self) -> "SamplingConstraint": """Return a shallow copy of this constraint.""" diff --git a/lhotse/dataset/sampling/stateless.py b/lhotse/dataset/sampling/stateless.py index 6667242b6..91f9395f1 100644 --- a/lhotse/dataset/sampling/stateless.py +++ b/lhotse/dataset/sampling/stateless.py @@ -1,7 +1,17 @@ import logging import random from pathlib import Path -from typing import Callable, Dict, Generator, Iterable, Optional, Sequence, Tuple, Union +from typing import ( + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) import torch from cytoolz import compose_left @@ -89,6 +99,8 @@ class StatelessSampler(torch.utils.data.Sampler, Dillable): :param max_duration: Maximum total number of audio seconds in a mini-batch (dynamic batch size). :param max_cuts: Maximum number of examples in a mini-batch (static batch size). :param num_buckets: If set, enables bucketing (each mini-batch has examples of a similar duration). + :param duration_bins: A list of floats (seconds); when provided, we'll skip the initial + estimation of bucket duration bins (useful to speed-up the launching of experiments). :param quadratic_duration: If set, adds a penalty term for longer duration cuts. Works well with models that have quadratic time complexity to keep GPU utilization similar when using bucketing. Suggested values are between 30 and 45. @@ -102,6 +114,7 @@ def __init__( max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, num_buckets: Optional[int] = None, + duration_bins: List[Seconds] = None, quadratic_duration: Optional[Seconds] = None, ) -> None: super().__init__(data_source=None) @@ -146,6 +159,7 @@ def __init__( self.max_duration = max_duration self.max_cuts = max_cuts self.num_buckets = num_buckets + self.duration_bins = duration_bins self.quadratic_duration = quadratic_duration self.base_seed = base_seed assert any( @@ -216,12 +230,13 @@ def _inner(): yield cut n += 1 - if self.num_buckets is not None and self.num_buckets > 1: + if self.num_buckets is not None or self.duration_bins is not None: inner_sampler = DynamicBucketingSampler( _inner(), max_duration=self.max_duration, max_cuts=self.max_cuts, num_buckets=self.num_buckets, + duration_bins=self.duration_bins, shuffle=False, drop_last=False, quadratic_duration=self.quadratic_duration, diff --git a/test/audio/test_audio_reads.py b/test/audio/test_audio_reads.py index 2a4b9c7ed..d841ca9f9 100644 --- a/test/audio/test_audio_reads.py +++ b/test/audio/test_audio_reads.py @@ -2,6 +2,7 @@ from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory +from unittest.mock import Mock import numpy as np import pytest @@ -10,6 +11,7 @@ import lhotse from lhotse import AudioSource, Recording +from lhotse.audio import suppress_audio_loading_errors from lhotse.audio.backend import ( info, read_opus_ffmpeg, @@ -260,3 +262,27 @@ def test_set_audio_backend(): ) audio2 = recording.load_audio() np.testing.assert_array_almost_equal(audio1, audio2) + + +def test_fault_tolerant_audio_network_exception(): + def _mock_load_audio(*args, **kwargs): + raise ConnectionResetError() + + source = Mock() + source.load_audio = _mock_load_audio + source.has_video = False + + recording = Recording( + id="irrelevant", + sources=[source], + sampling_rate=16000, + num_samples=16000, + duration=1.0, + channel_ids=[0], + ) + + with pytest.raises(ConnectionResetError): + recording.load_audio() # does raise + + with suppress_audio_loading_errors(True): + recording.load_audio() # is silently caught diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index e7d2db019..94bb40cc7 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -53,7 +53,7 @@ def test_dynamic_bucketing_drop_last_false(): rng = random.Random(0) sampler = DynamicBucketer( - cuts, duration_bins=[2], max_duration=5, rng=rng, world_size=1 + cuts, duration_bins=[1.5], max_duration=5, rng=rng, world_size=1 ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -90,7 +90,7 @@ def test_dynamic_bucketing_drop_last_true(): rng = random.Random(0) sampler = DynamicBucketer( - cuts, duration_bins=[2], max_duration=5, rng=rng, drop_last=True, world_size=1 + cuts, duration_bins=[1.5], max_duration=5, rng=rng, drop_last=True, world_size=1 ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -125,7 +125,7 @@ def test_dynamic_bucketing_sampler(concurrent): c.duration = 2 sampler = DynamicBucketingSampler( - cuts, max_duration=5, num_buckets=2, seed=0, concurrent=concurrent + cuts, max_duration=5, duration_bins=[1.5], seed=0, concurrent=concurrent ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -231,7 +231,9 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled(): c.duration = 2 # 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets - sampler = DynamicBucketingSampler(cuts, max_duration=100, num_buckets=2, seed=0) + sampler = DynamicBucketingSampler( + cuts, max_duration=100, duration_bins=[1.5], seed=0 + ) batches = [b for b in sampler] sampled_cuts = [c for b in batches for c in b] @@ -249,6 +251,35 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled(): assert len(b) == 5 +def test_dynamic_bucketing_sampler_much_less_data_than_ddp_ranks(): + world_size = 128 + orig_cut = dummy_cut(0) + cuts = CutSet([orig_cut]) + samplers = [ + DynamicBucketingSampler( + cuts, + max_duration=2000.0, + duration_bins=[1.5, 3.7, 15.2, 27.9, 40.0], + drop_last=False, + concurrent=False, + world_size=world_size, + rank=i, + ) + for i in range(world_size) + ] + # None of the ranks drops anything, all of them return the one cut we have. + for sampler in samplers: + (batch,) = [b for b in sampler] + assert len(batch) == 1 + (sampled_cut,) = batch + assert ( + sampled_cut.id[: len(orig_cut.id)] == orig_cut.id + ) # same stem, possibly added '_dupX' suffix + # otherwise the cuts are identical + sampled_cut.id = orig_cut.id + assert sampled_cut == orig_cut + + def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches(): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): @@ -337,7 +368,9 @@ def test_dynamic_bucketing_sampler_cut_pairs(): else: c.duration = 2 - sampler = DynamicBucketingSampler(cuts, cuts, max_duration=5, num_buckets=2, seed=0) + sampler = DynamicBucketingSampler( + cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0 + ) batches = [b for b in sampler] sampled_cut_pairs = [cut_pair for b in batches for cut_pair in zip(*b)] source_cuts = [sc for sc, tc in sampled_cut_pairs] @@ -473,7 +506,7 @@ def test_dynamic_bucketing_sampler_cut_triplets(): c.duration = 2 sampler = DynamicBucketingSampler( - cuts, cuts, cuts, max_duration=5, num_buckets=2, seed=0 + cuts, cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0 ) batches = [b for b in sampler] sampled_cut_triplets = [cut_triplet for b in batches for cut_triplet in zip(*b)] @@ -542,7 +575,7 @@ def test_dynamic_bucketing_quadratic_duration(): # quadratic_duration=30 sampler = DynamicBucketingSampler( - cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=30 + cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=30 ) batches = [b for b in sampler] assert len(batches) == 6 @@ -556,7 +589,7 @@ def test_dynamic_bucketing_quadratic_duration(): # quadratic_duration=None (disabled) sampler = DynamicBucketingSampler( - cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=None + cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=None ) batches = [b for b in sampler] assert len(batches) == 4 @@ -731,3 +764,31 @@ def test_dynamic_bucketing_sampler_fixed_batch_constraint(): assert len(batches[7]) == 1 assert sum(c.duration for c in batches[7]) == 1 + + +def test_select_bucket_includes_upper_bound_in_bin(): + constraint = FixedBucketBatchSizeConstraint( + max_seq_len_buckets=[2.0, 4.0], batch_sizes=[2, 1] + ) + + # within bounds + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=1.0) == 0 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=2.0) == 0 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=3.0) == 1 + ) + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=4.0) == 1 + ) + constraint.add(dummy_cut(0, duration=4.0)) # can add max duration without exception + + # out of bounds + assert ( + constraint.select_bucket(constraint.max_seq_len_buckets, example_len=5.0) == 2 + ) + with pytest.raises(AssertionError): + constraint.add(dummy_cut(0, duration=5.0)) diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 74736794e..686b472b9 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -1232,3 +1232,27 @@ def test_sampler_map(): b = batches[1] assert len(b) == 1 assert b[0].duration == 5.0 + + +def test_sampler_much_less_data_than_ddp_ranks(): + world_size = 128 + orig_cut = dummy_cut(0) + cuts = CutSet([orig_cut]) + + samplers = [ + DynamicCutSampler( + cuts, max_cuts=256, drop_last=False, world_size=world_size, rank=i + ) + for i in range(world_size) + ] + # None of the ranks drops anything, all of them return the one cut we have. + for sampler in samplers: + (batch,) = [b for b in sampler] + assert len(batch) == 1 + (sampled_cut,) = batch + assert ( + sampled_cut.id[: len(orig_cut.id)] == orig_cut.id + ) # same stem, possibly added '_dupX' suffix + # otherwise the cuts are identical + sampled_cut.id = orig_cut.id + assert sampled_cut == orig_cut diff --git a/test/dataset/sampling/test_stateless_sampler.py b/test/dataset/sampling/test_stateless_sampler.py index 416e7e2b2..f0e32cbc3 100644 --- a/test/dataset/sampling/test_stateless_sampler.py +++ b/test/dataset/sampling/test_stateless_sampler.py @@ -189,7 +189,11 @@ def test_stateless_sampler_in_dataloader_with_iterable_dataset( def test_stateless_sampler_bucketing(cuts_files: Tuple[Path]): index_path = cuts_files[0].parent / "cuts.idx" sampler = StatelessSampler( - cuts_files, index_path=index_path, num_buckets=2, max_duration=4, base_seed=0 + cuts_files, + index_path=index_path, + duration_bins=[1.5], + max_duration=4, + base_seed=0, ) for idx, batch in enumerate(sampler): From 4ca97dc3e5da0a8ad96f3811c387b62b90d62422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 11:26:54 -0400 Subject: [PATCH 55/69] Bump dev version to 1.28.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 5db08bf2d..cfc730712 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.27.0 +1.28.0 From bc2c0a294b1437b90d1581d4f214348d2f8bfc12 Mon Sep 17 00:00:00 2001 From: jianyou Date: Tue, 17 Sep 2024 22:25:12 +0800 Subject: [PATCH 56/69] [spgispeech] Fix durations object is null issue (#1390) [spgispeech] Fix durations are null issue --- lhotse/recipes/spgispeech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/recipes/spgispeech.py b/lhotse/recipes/spgispeech.py index 7c679a172..a3002c0a7 100644 --- a/lhotse/recipes/spgispeech.py +++ b/lhotse/recipes/spgispeech.py @@ -121,7 +121,6 @@ def prepare_spgispeech( def audio_read_worker(p: Path) -> Recording: r = Recording.from_file(p, recording_id=f"{p.parent.stem}_{p.stem}") - durations[r.id] = r.duration return r with RecordingSet.open_writer( @@ -135,6 +134,7 @@ def audio_read_worker(p: Path) -> Recording: ), desc="Processing SPGISpeech recordings", ): + durations[recording.id] = recording.duration rec_writer.write(recording) # Read supervisions and write them to manifest From a31a5322a0d60ac6cc70060e5023865f6debeac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=9C=87=E4=B8=9C?= Date: Wed, 2 Oct 2024 06:36:49 +0800 Subject: [PATCH 57/69] Fix backend to None while ffmpeg is unavailable. (#1392) --- lhotse/audio/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lhotse/audio/backend.py b/lhotse/audio/backend.py index 97efacb29..278f2b00b 100644 --- a/lhotse/audio/backend.py +++ b/lhotse/audio/backend.py @@ -808,7 +808,8 @@ def torchaudio_info( if torchaudio_ffmpeg_backend_available(): # Torchaudio 2.1 with official "ffmpeg" backend should solve all the special cases below. - info = torchaudio.info(path_or_fileobj, backend="ffmpeg") + backend = "ffmpeg" if "ffmpeg" in torchaudio.list_audio_backends() else None + info = torchaudio.info(path_or_fileobj, backend=backend) return LibsndfileCompatibleAudioInfo( channels=info.num_channels, frames=info.num_frames, From 82b313f5cbce3ac05198e74f4ce375032211f70a Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 2 Oct 2024 06:37:34 +0800 Subject: [PATCH 58/69] Fix ksponspeech recipe (#1394) * fix ksponspeech.py * fix black --- lhotse/recipes/ksponspeech.py | 45 +++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/lhotse/recipes/ksponspeech.py b/lhotse/recipes/ksponspeech.py index 6dde6ed9b..381dc8c47 100644 --- a/lhotse/recipes/ksponspeech.py +++ b/lhotse/recipes/ksponspeech.py @@ -1,16 +1,16 @@ """ KsponSpeech is a large-scale spontaneous speech corpus of Korean. -This corpus contains 969 hours of open-domain dialog utterances, +This corpus contains 969 hours of open-domain dialogue utterances, spoken by about 2,000 native Korean speakers in a clean environment. -All data were constructed by recording the dialogue of two people +All data were constructed by recording the dialogue between two people freely conversing on a variety of topics and manually transcribing the utterances. The transcription provides a dual transcription consisting of orthography and pronunciation, -and disfluency tags for spontaneity of speech, such as filler words, repeated words, and word fragments. +and disfluency tags for the spontaneity of speech, such as filler words, repeated words, and word fragments. -The original audio data has a pcm extension. -During preprocessing, it is converted into a file in the flac extension and saved anew. +The original audio data has a PCM extension. +During preprocessing, it is converted into a file in the FLAC extension and saved anew. KsponSpeech is publicly available on an open data hub site of the Korea government. The dataset must be downloaded manually. @@ -52,14 +52,14 @@ def normalize( Normalizing KsponSpeech text datasets with '.trn' extension. Perform the following processing. - 1. Separate file name and text labeling from raw content using separator '::'. - 2. Remove noise labeling characters. (e.g. `o/`, `b/`...) - 3. Remove the actual pronunciation from the text labeling, Use the spelling content. + 1. Separate file name and text labeling from raw content using separator ' :: '; + 2. Remove noise labeling characters (e.g. `o/`, `b/`...); + 3. Remove the actual pronunciation from the text labeling; use the spelling content; 4. Remove other special characters and double spaces from text labeling. - :param raw_content: A raw text labeling content containing file name and text labeling. - :param normalize_text: str, the text normalization type. Available options: "default", "none". - :return: A tuple with file name and normalized text labeling. + :param raw_content: a raw text labeling content containing file name and text labeling. + :param normalize_text: str, the text normalization type, "default" or "none". + :return: a tuple with file name and normalized text labeling. """ if len(raw_content) == 0: return "" @@ -75,8 +75,7 @@ def normalize( content = content.replace("*", "") content = content.replace("+", "") content = content.replace("/", "") - while " " in content: - content = content.replace(" ", " ") + content = re.sub(r"\s+", " ", content) return original_content_id, content.strip() @@ -93,11 +92,11 @@ def prepare_ksponspeech( When all the manifests are available in the ``output_dir``, it will simply read and return them. :param corpus_dir: Pathlike, the path of the data dir. - :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train', 'test'. - By default we will infer which parts are available in ``corpus_dir``. + :param dataset_parts: string or sequence of strings representing dataset part names, e.g. 'train', 'dev'. + By default, we will infer all parts. :param output_dir: Pathlike, the path where to write the manifests. :param num_jobs: int, number of parallel threads used for 'parse_utterance' calls. - :param normalize_text: str, the text normalization type. Available options: "default", "none". + :param normalize_text: str, the text normalization type, "default" or "none". :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'audio' and 'supervisions'. """ corpus_dir = Path(corpus_dir) @@ -116,15 +115,25 @@ def prepare_ksponspeech( output_dir.mkdir(parents=True, exist_ok=True) # Maybe the manifests already exist: we can read them and save a bit of preparation time. manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, output_dir=output_dir + dataset_parts=dataset_parts, + output_dir=output_dir, + prefix="ksponspeech", + suffix="jsonl.gz", + lazy=True, ) with ThreadPoolExecutor(num_jobs) as ex: for part in tqdm(dataset_parts, desc="Dataset parts"): logging.info(f"Processing KsponSpeech subset: {part}") - if manifests_exist(part=part, output_dir=output_dir): + if manifests_exist( + part=part, + output_dir=output_dir, + prefix="ksponspeech", + suffix="jsonl.gz", + ): logging.info(f"KsponSpeech subset: {part} already prepared - skipping.") continue + recordings = [] supervisions = [] futures = [] From c8ba6d019ddb1c6c79c463685f34020b4a122bad Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 2 Oct 2024 06:38:07 +0800 Subject: [PATCH 59/69] Fix cli for ksponspeech (#1393) fix ksponspeech.py --- lhotse/bin/modes/recipes/ksponspeech.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lhotse/bin/modes/recipes/ksponspeech.py b/lhotse/bin/modes/recipes/ksponspeech.py index 4f4a9d5bd..c90272f70 100644 --- a/lhotse/bin/modes/recipes/ksponspeech.py +++ b/lhotse/bin/modes/recipes/ksponspeech.py @@ -39,6 +39,7 @@ def ksponspeech( output_dir: Pathlike, dataset_parts: Sequence[str], num_jobs: int, + normalize_text: str, ): """KsponSpeech ASR data preparation.""" if len(dataset_parts) == 1: @@ -48,4 +49,5 @@ def ksponspeech( output_dir=output_dir, num_jobs=num_jobs, dataset_parts=dataset_parts, + normalize_text=normalize_text, ) From d1b078b99ca14a4aff94d0d6ea3100486bed2ea1 Mon Sep 17 00:00:00 2001 From: Matthew Maciejewski Date: Fri, 4 Oct 2024 20:51:08 +0300 Subject: [PATCH 60/69] Add recipe for the Santa Barbara Corpus of Spoken American English (SBCSAE) (#1395) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial commit * transcript fixes * added SBCSAE download * Updates sbcsae to properly process mono_channel audio and adds speaker origin as geolocations for speakers * Fixes a few 0-width segments by adding 0.02 s of padding * small fix * Add alignment export option Exports aligned supervisions along with the original supervisions with or without changing the text after manual inspections and corrections. * update to cli flags and docs * added sbcsae to docs and fixed python compatibility * more python3.8 fixes --------- Co-authored-by: Matthew Wiesner Co-authored-by: Dominik Klement Co-authored-by: Piotr Żelasko --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/sbcsae.py | 58 ++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/sbcsae.py | 1155 ++++++++++++++++++++++++++ 5 files changed, 1217 insertions(+) create mode 100644 lhotse/bin/modes/recipes/sbcsae.py create mode 100644 lhotse/recipes/sbcsae.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 6a5be4f97..bc8d71bfb 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -173,6 +173,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_reazonspeech` * - RIRs and Noises Corpus (OpenSLR 28) - :func:`lhotse.recipes.prepare_rir_noise` + * - SBCSAE + - :func:`lhotse.recipes.prepare_sbcsae` * - Spatial-LibriSpeech - :func:`lhotse.recipes.prepare_spatial_librispeech` * - Speech Commands diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index d0ddd4c84..caa904cb8 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -67,6 +67,7 @@ from .primewords import * from .reazonspeech import * from .rir_noise import * +from .sbcsae import * from .slu import * from .spatial_librispeech import * from .speechcommands import * diff --git a/lhotse/bin/modes/recipes/sbcsae.py b/lhotse/bin/modes/recipes/sbcsae.py new file mode 100644 index 000000000..d6eece516 --- /dev/null +++ b/lhotse/bin/modes/recipes/sbcsae.py @@ -0,0 +1,58 @@ +from typing import Optional, Sequence + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.sbcsae import download_sbcsae, prepare_sbcsae +from lhotse.utils import Pathlike + +__all__ = ["sbcsae"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "--geolocation", + type=bool, + is_flag=True, + default=False, + help="Include geographic coordinates of speakers' hometowns in the manifests.", +) +@click.option( + "--omit-realignments", + type=bool, + is_flag=True, + default=False, + help="Only output the original corpus segmentation without boundary improvements.", +) +def sbcsae( + corpus_dir: Pathlike, + output_dir: Pathlike, + geolocation: bool, + omit_realignments: bool, +): + """SBCSAE data preparation.""" + prepare_sbcsae( + corpus_dir, + output_dir=output_dir, + geolocation=geolocation, + omit_realignments=omit_realignments, + ) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +@click.option( + "--force-download", + type=bool, + is_flag=True, + default=False, + help="Force download.", +) +def sbcsae( + target_dir: Pathlike, + force_download: bool, +): + """SBCSAE download.""" + download_sbcsae(target_dir, force_download=force_download) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index b0c909d99..1eef6623b 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -68,6 +68,7 @@ from .peoples_speech import prepare_peoples_speech from .reazonspeech import download_reazonspeech, prepare_reazonspeech from .rir_noise import download_rir_noise, prepare_rir_noise +from .sbcsae import download_sbcsae, prepare_sbcsae from .slu import prepare_slu from .spatial_librispeech import ( download_spatial_librispeech, diff --git a/lhotse/recipes/sbcsae.py b/lhotse/recipes/sbcsae.py new file mode 100644 index 000000000..4927548ba --- /dev/null +++ b/lhotse/recipes/sbcsae.py @@ -0,0 +1,1155 @@ +""" +This script downloads and prepares the data directory for the Santa Barbara +Corpus of Spoken American English. + +The Santa Barbara Corpus of Spoken American English is based on a large body of +recordings of naturally occurring spoken interaction from all over the United +States. The Santa Barbara Corpus represents a wide variety of people of +different regional origins, ages, occupations, genders, and ethnic and social +backgrounds. The predominant form of language use represented is face-to-face +conversation, but the corpus also documents many other ways that that people use +language in their everyday lives: telephone conversations, card games, food +preparation, on-the-job talk, classroom lectures, sermons, story-telling, town +hall meetings, tour-guide spiels, and more. + +The Santa Barbara Corpus was compiled by researchers in the Linguistics +Department of the University of California, Santa Barbara. The Director of the +Santa Barbara Corpus is John W. Du Bois, working with Associate Editors Wallace +L. Chafe and Sandra A. Thompson (all of UC Santa Barbara), and Charles Meyer +(UMass, Boston). For the publication of Parts 3 and 4, the authors are John W. +Du Bois and Robert Englebretson. + +If you use the corpus or our data preparation scripts, please cite the following: +@misc{dubois_2005, + author={Du Bois, John W. and Chafe, Wallace L. and Meyer, Charles and Thompson, Sandra A. and Englebretson, Robert and Martey, Nii}, + year={2000--2005}, + title={{S}anta {B}arbara corpus of spoken {A}merican {E}nglish, {P}arts 1--4}, + address={Philadelphia}, + organization={Linguistic Data Consortium}, +} +@inproceedings{maciejewski24_interspeech, + author={Matthew Maciejewski and Dominik Klement and Ruizhe Huang and Matthew Wiesner and Sanjeev Khudanpur}, + title={Evaluating the {Santa Barbara} Corpus: Challenges of the Breadth of Conversational Spoken Language}, + year=2024, + booktitle={Proc. Interspeech 2024} +} +""" +import logging +import re +import tarfile +from copy import deepcopy +from dataclasses import dataclass +from math import inf +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +from tqdm import tqdm + +from lhotse import ( + Recording, + RecordingSet, + SupervisionSegment, + SupervisionSet, + fix_manifests, +) +from lhotse.utils import ( + Pathlike, + fastcopy, + is_module_available, + resumable_download, + safe_extract, +) + +SBCSAE_TAR_URL = "https://www.openslr.org/resources/155/SBCSAE.tar.gz" + + +lang_iterators = { + "SBC004": iter(["Spanish"] * 17), + "SBC006": iter(["French"] * 2), + "SBC010": iter(["Spanish"]), + "SBC012": iter(["Greek"] * 2), + "SBC015": iter(["Spanish"] * 10), + "SBC025": iter(["German"] * 2 + ["Latin"]), + "SBC027": iter(["Spanish"] * 6 + ["French"] * 2), + "SBC031": iter(["French"] * 2), + "SBC033": iter(["French"]), + "SBC034": iter(["French"] * 3), + "SBC036": iter(["Spanish"] * 36), + "SBC037": iter(["Spanish"] * 60), + "SBC047": iter(["Spanish"]), + "SBC057": iter(["Japanese"] * 62), + "SBC058": iter(["Spanish"] + ["Italian"] * 2), +} + + +# These corrections to the participant metadata were needed to get geolocations +# from the geopy package. +annotation_corrections = { + "metro St.L. IL": "Saint Louis MO", # Use the MO side of the city + "middle Wes MO": "Missouri", # Just use the state location + "S.E.Texas TX": "South East Texas", # The geo package seems to parse this + "South Alabama mostly AL": "Andalusia Alabama", # Arbitrarily chosen nearby town + "South FL": "South Bay Florida", # Arbitrarily chosen nearby town + "Walnut Cre CA": "Walnut Creek CA", # Spelling error + "San Leandr CA": "San Leandro CA", + "Boston/Santa Fe MA/NM": "Boston/Santa Fe\tMA/NM", # Handle this specially + "Boston/New Mexico MA/NM": "Boston/Santa Fe\tMA/NM", + "Millstad IL": "Millstadt IL", # Spelling error + "Cleveland/San Francisco OH/CA": "Cleveland/San Fransisco\tOH/CA", # Handle specially + "Jamesville WI": "Janesville WI", # Spelling error + "Falls Church/Albuquerque VA/NM": "Falls Church/Albuquerque\tVA/NM", # Handle specially + "Southern Florida": "South Bay Florida", # Arbitarily chosen nearby town + "Massachusetts MA": "Massachusetts", + "New Zealand n/a": "New Zealand", + "French n/a": "France", +} + + +bad_stereo = ["SBC020", "SBC021", "SBC027", "SBC028"] + + +class Dummy_Spk_Iterator: + def __init__(self): + self.ind = 213 + + def next(self, spk="SBCXXX_X"): + self.ind = self.ind + 1 + name = "_".join(spk.split("_")[1:]) + if name.startswith("X") or name.startswith("AUD"): + name = "UNK" + return f"{self.ind:04d}_{name}" + + +dummy_spk_iterator = Dummy_Spk_Iterator() + + +def download_sbcsae( + target_dir: Pathlike = ".", + force_download: Optional[bool] = False, +) -> Path: + """ + Download and untar the dataset. + + :param: target_dir: Pathlike, the path of the directory where the SBCSAE + dataset will be downloaded. + :param force_download: bool, if True, download the archive even if it already exists. + :return: The path to the directory with the data. + """ + target_dir = Path(target_dir) + corpus_dir = target_dir / "SBCSAE" + corpus_dir.mkdir(parents=True, exist_ok=True) + tar_path = target_dir / "SBCSAE.tar.gz" + + completed_detector = target_dir / ".sbcsae_completed" + if completed_detector.is_file(): + logging.info(f"Skipping download because {completed_detector} exists.") + return corpus_dir + + resumable_download(SBCSAE_TAR_URL, filename=tar_path, force_download=force_download) + with tarfile.open(tar_path) as tar: + safe_extract(tar, path=corpus_dir) + completed_detector.touch() + + return corpus_dir + + +def prepare_sbcsae( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + geolocation: Optional[bool] = False, + omit_realignments: Optional[bool] = False, +) -> Dict[str, Union[RecordingSet, SupervisionSet]]: + """ + Prepares manifest for SBCSAE dataset. + + :param: corpus_dir: Path to the root where SBCSAE data was downloaded. It + should be called SBCSAE. There is no consistent formatting between + releases of the data. Check script comments for details if using an + existing corpus download rather than Lhotse's download script. + :param: output_dir: Root directory where .json manifests are stored. + :param: geolocation: Include geographic coordinates of speakers' hometowns + in the manifests. + :param: omit_realignments: Only output original corpus segmentation. + :return: The manifests. + """ + # Resolve corpus_dir type + if isinstance(corpus_dir, str): + corpus_dir = Path(corpus_dir) + + # Resolve output_dir type + if isinstance(output_dir, str): + output_dir = Path(output_dir) + + audio_dir = corpus_dir / "WAV" + recordings = RecordingSet.from_recordings( + Recording.from_file(p) for p in audio_dir.glob("*.wav") + ) + if len(recordings) == 0: + logging.warning(f"No .wav files found in {audio_dir}") + + doc_dir = corpus_dir / "docs" + spk2gen_dict, spk2glob_dict = generate_speaker_map_dicts(doc_dir) + + spk_coords = {} + if geolocation: + spk_coords = generate_geolocations(corpus_dir, spk2glob_dict) + + supervisions = [] + trn_dir = corpus_dir / "TRN" + for p in tqdm( + list(trn_dir.glob("*.trn")), "Collecting and normalizing transcripts ..." + ): + for supervision in _filename_to_supervisions(p, spk2gen_dict, spk2glob_dict): + supervisions.append(supervision) + + if len(supervisions) == 0: + logging.warning(f"No supervisions found in {trn_dir}") + + supervisions_ = [] + for s in supervisions: + if s.duration < 0.02: + # Just pad with a minimum 0.02 duration + s_reco = recordings[s.recording_id] + new_start = max(0, s.start - 0.01) + s_ = fastcopy( + s, + start=new_start, + duration=min(new_start + 0.02, s_reco.duration), + ) + else: + s_ = s + + if s_.speaker in spk_coords: + s_.custom = { + "lat": spk_coords[s.speaker][0][0], + "lon": spk_coords[s.speaker][0][1], + } + + if ( + not isinstance(recordings[s.recording_id].channel_ids, list) + or len(recordings[s.recording_id].channel_ids) < 2 + or s.recording_id in bad_stereo + ): + s_.channel = recordings[s.recording_id].channel_ids[0] + supervisions_.append(s_) + + supervisions = SupervisionSet.from_segments(supervisions_) + recordings, supervisions = fix_manifests(recordings, supervisions) + + if output_dir is not None: + if isinstance(output_dir, str): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + recordings.to_file(output_dir / "sbcsae_recordings.jsonl.gz") + supervisions.to_file(output_dir / "sbcsae_supervisions.jsonl.gz") + + manifests = {"recordings": recordings, "supervisions": supervisions} + + if not omit_realignments: + asr_supervisions, diar_supervisions = apply_aligned_stms( + list(recordings.ids), supervisions + ) + _, asr_supervisions = fix_manifests(recordings, asr_supervisions) + _, diar_supervisions = fix_manifests(recordings, diar_supervisions) + + asr_supervisions.to_file( + output_dir / "sbcsae_supervisions_asr_aligned.jsonl.gz" + ) + diar_supervisions.to_file( + output_dir / "sbcsae_supervisions_diar_aligned.jsonl.gz" + ) + + manifests = { + "asr_supervisions": asr_supervisions, + "diar_supervisions": diar_supervisions, + **manifests, + } + + return manifests + + +def generate_geolocations(corpus: Path, spk2glob_dict: dict): + if not is_module_available("geopy"): + raise ImportError( + "geopy package not found. Please install..." " (pip install geopy)" + ) + else: + from geopy import geocoders + from geopy.geocoders import Nominatim + + speakers = corpus.rglob("docs/Part_*/speaker.tbl") + # This geolocator object is repsonsible for generating a + # latitiude and longitude from a textual description of a location, i.e., + # CHICAGO IL --> (41,-87) + geolocator = Nominatim(user_agent="myapplication") + spk_coords = {} + for spk in tqdm(list(speakers), "Generating speaker geolocations..."): + with open(spk) as f: + for l in f: + vals = l.strip().split(",") + if len(vals) < 5: + continue + # Check non-empty + empty_hometown = vals[4] in ("", "?") + empty_state = vals[5] in ("", "?") + if empty_hometown and not empty_state: + loc = vals[5] + ", United States" + elif not empty_hometown: + orig_loc = vals[4] + " " + vals[5] + loc = annotation_corrections.get(orig_loc, orig_loc) + else: + continue + if "/" in loc: + try: + hometowns, states = loc.split("\t", 1) + hometowns = hometowns.split("/") + states = states.split("/") + coords = [] + for h, s in zip(hometowns, states): + coords.append( + geolocator.geocode(f"{h} {s}", timeout=None)[1] + ) + except ValueError: + states, country = loc.split(",", 1) + coords = [] + for s in states.split("/"): + coords.append( + geolocator.geocode(f"{s}, {country}", timeout=None)[1] + ) + else: + coords = [geolocator.geocode(loc, timeout=None)[1]] + spk_coords[vals[0]] = coords + spknum2spk_name = {n.split("_")[0]: n for s, n in spk2glob_dict.items()} + spk_coords_ = {} + for s in spk_coords: + if s in spknum2spk_name: + spk_coords_[spknum2spk_name[s]] = spk_coords[s] + return spk_coords_ + + +def generate_speaker_map_dicts(doc_dir: Path): + spk2gen_dict = dict() + spk2glob_dict = dict() + + spk_num_to_reco_ids = dict() + for part in ["Part_1", "Part_2", "Part_4"]: + filename = doc_dir / part / "segment.tbl" + for line in filename.read_text().split("\n"): + if "speaker:" in line: + line = line.replace(" 0", "\t0") + reco_id = re.sub(r"sbc0?([0-9]{3})\s.*", r"SBC\1", line) + spk_num = line.split("\t")[-1][:4] + if spk_num not in spk_num_to_reco_ids: + spk_num_to_reco_ids[spk_num] = [] + if reco_id not in spk_num_to_reco_ids[spk_num]: + spk_num_to_reco_ids[spk_num].append(reco_id) + + for part in ["Part_1", "Part_2", "Part_4"]: + filename = doc_dir / part / "speaker.tbl" + for line in filename.read_text().split("\n"): + if "," not in line: + continue + line = line.replace("0163,Dan,m", "0166,Dan,M") + spk_num, name, gen = line.split(",")[:3] + name = ( + name.replace(" (extra-corpus)", "").upper().split(" ")[-1].split("/")[0] + ) + gen = gen.upper() + if not gen: + gen = None + + if spk_num in ["0069", "0091", "0092", "0097"]: + continue + for reco in spk_num_to_reco_ids[spk_num]: + spk2gen_dict[reco + "_" + name] = gen + spk2glob_dict[reco + "_" + name] = spk_num + "_" + name + + for part in ["Part_3"]: + seg_list = [] + filename = doc_dir / part / "segment.tbl" + for line in filename.read_text().split("\n"): + if "speaker:" in line: + reco_id = re.sub(r"sbc0?([0-9]{3})\s.*", r"SBC\1", line) + name = line.split(" ")[-1].upper().split("/")[0] + seg_list.append([name, reco_id]) + + spk_list = [] + filename = doc_dir / part / "speaker.tbl" + for line in filename.read_text().split("\n"): + if "," not in line: + continue + spk_num, name, gen = line.split(",")[:3] + name = name.upper().split("/")[0] + spk_list.append([name, spk_num, gen]) + + for seg_info, spk_info in zip(seg_list, spk_list): + assert seg_info[0] == spk_info[0], f"{seg_info[0]} != {spk_info[0]}" + spk2gen_dict[seg_info[1] + "_" + seg_info[0]] = spk_info[2] + spk2glob_dict[seg_info[1] + "_" + seg_info[0]] = ( + spk_info[1] + "_" + spk_info[0] + ) + + for spk_key in [ + "SBC006_ALL", + "SBC008_ALL", + "SBC012_MANY", + "SBC020_AUD", + "SBC021_MANY", + "SBC023_MANY", + "SBC025_AUD", + "SBC026_AUD", + "SBC027_MANY", + "SBC027_AUD", + "SBC028_BOTH", + "SBC030_AUD", + "SBC038_AUD", + "SBC053_RADIO", + "SBC054_AUD", + "SBC054_MANY", + "SBC055_AUD", + ]: + spk2gen_dict[spk_key] = None + spk2glob_dict[spk_key] = spk_key + + return spk2gen_dict, spk2glob_dict + + +def _filename_to_supervisions(filename: Path, spk2gen_dict: dict, spk2glob_dict: dict): + reco_id = filename.stem.split(".")[0] + lines = filename.read_text(encoding="latin1") + supervisions = [] + + #### Transcript fix + lines = lines.replace("\x92", "'") + lines = lines.replace("\u007f", "") + lines = lines.replace("\u0000", "c") + + if reco_id == "SBC002": + lines = lines.replace("(TSK ", "(TSK) ") + elif reco_id == "SBC004": + lines = lines.replace("KATE", "KATHY") + lines = lines.replace("sen~orita", "se\xf1orita") + elif reco_id == "SBC005": + lines = lines.replace("good_/god/", "good") + lines = lines.replace("(H)@>", "(H) @>") + lines = lines.replace("[@@ <@Mm@>]", "[@@ <@ Mm @>]") + elif reco_id == "SBC006": + lines = lines.replace("/pub/", "pub") + lines = lines.replace("", "") + lines = lines.replace("[2(H)2]1", "[2(H)2]") + elif reco_id == "SBC007": + lines = lines.replace( + "\\000000000 000000000 MARY: 1182.90 1186.92\t ", + "\n1182.90 1186.92\tMARY: ", + ) + lines = lines.replace("(YAWN0", "(YAWN)") + elif reco_id == "SBC008": + lines = lines.replace("[", "[") + elif reco_id == "SBC010": + lines = lines.replace("366.87 366.87", "366.16 366.87") + elif reco_id == "SBC012": + lines = lines.replace( + "\n".join(["807.02 807.92\tFRANK: \t.. Mhm."] * 2), + "807.02 807.92\tFRANK: \t.. Mhm.", + ) + lines = lines.replace("MONTOYA", "MONTOYO") + elif reco_id == "SBC013": + lines = lines.replace("[8<@She8]", "[8<@ She8]") + lines = lines.replace("[2(H) cou_ couch@>2]", "[2(H) cou_ couch @>2]") + lines = lines.replace("[4<@No=4]", "[4<@ No=4]") + lines = lines.replace("VOX2]", "VOX>2]") + elif reco_id == "SBC014": + lines = lines.replace("\\000000000 000000000 ", "\n") + lines = lines.replace("<@he thought", "<@ he thought") + elif reco_id == "SBC015": + lines = lines.replace( + "243.055\t244.080\tKEN:\t(H)] the little,", + "243.465\t244.670\tKEN:\t(H)] the little,", + ) + lines = lines.replace("\u0000urch things.", "church things.") + lines = lines.replace("2(H]=2", "2(H)=2") + lines = lines.replace(" 0.000000e+00", "e") + lines = lines.replace("0m=,", "um=,") + lines = lines.replace("0eople", "people") + lines = lines.replace("0id", "did") + lines = lines.replace("X 0ne %tho", "X uh line %tho") + lines = lines.replace("and 0t [was]", "and it [was]") + lines = lines.replace("0t was like", "it was like") + elif reco_id == "SBC016": + lines = lines.replace("/sed ai/", "sed ai") + elif reco_id == "SBC017": + lines = lines.replace("a\tand names the] na=me,", "and names the] na=me,") + lines = lines.replace(" 0.000000e+00", "e") + lines = lines.replace("[2I mean2", "[2I mean2]") + lines = lines.replace("no2.", "no.") + lines = lines.replace("0rganisms", "organisms") + lines = lines.replace("0ttle", "little") + elif reco_id == "SBC018": + lines = lines.replace("0f", "if") + lines = lines.replace( + "129.916\t130.324\tLINDSEY:\tYeah.\n129.915\t130.325\t\t[Mhm.]\n", + "129.915\t130.325\tLINDSEY:\t[Mhm.] Yeah.\n", + ) + elif reco_id == "SBC019": + lines = lines.replace("cello_(/cheller/)", "cheller") + lines = lines.replace("(sigh)", "(SIGH)") + lines = lines.replace(" Mo=m", "]", "[]") + lines = lines.replace("5]", "X>5]") + lines = lines.replace("0nly", "uh only") + lines = lines.replace("[50r5]", "[5Or5]") + elif reco_id == "SBC024": + lines = lines.replace(" >ENV: ", ">ENV:\t") + lines = lines.replace(" 0.000000irst", "First") + lines = lines.replace("2[cause", "[2cause") + lines = lines.replace(" 0oes", "does") + lines = lines.replace("0id]", "did]") + elif reco_id == "SBC025": + lines = lines.replace("", "<@ Oh[2= @>") + lines = lines.replace(" 0.000000", " ") + lines = lines.replace("i 0f", "i- if") + lines = lines.replace("0f we", "if we") + lines = lines.replace("th- 0t's", "th- that's") + lines = lines.replace("0t's", "it's") + lines = lines.replace("0f", "if") + elif reco_id == "SBC029": + lines = lines.replace("96.230\t98.240\t>ENV: ", "96.230\t98.240\t>ENV:\t") + lines = lines.replace("(H )", "(H)") + lines = lines.replace("<0h=,", "<% Oh=,") + lines = lines.replace("knowX>]", "know X>]") + lines = lines.replace("0verheating", "overheating") + elif reco_id == "SBC030": + lines = lines.replace("DANNY", "BRADLEY") + lines = lines.replace("AUD:\tYes", "X:\tYes") + elif reco_id == "SBC034": + lines = lines.replace("13548.02 ", "1354.802") + elif reco_id == "SBC036": + lines = lines.replace( + "1558.463\t1558.906\t\t[thought he was,", + "1558.906\t1558.923\t\t[thought he was,", + ) + elif reco_id == "SBC038": + lines = lines.replace("AUD:\t... What's", "X_2:\t... What's") + lines = lines.replace("AUD:\t... U", "X_3:\t... U") + lines = lines.replace("AUD:\t... How far", "X_2:\t... How far") + lines = lines.replace("AUD:\t", "") + lines = lines.replace("ANNETTE", "ANETTE") + elif reco_id == "SBC048": + lines = lines.replace("<@in San[2ta", "<@ in San[2ta") + elif reco_id == "SBC052": + lines = lines.replace("~Janine\t said", "~Janine said") + elif reco_id == "SBC054": + lines = lines.replace("", "") + lines = lines.replace("AUD:\tX", "X:\tX") + lines = lines.replace("AUD:\t") + lines = lines.replace("sensei", "") + lines = lines.replace("ippon", "Ippon") + lines = lines.replace("Ippon", "") + lines = re.sub(r"gi([^a-z])", r"\1", lines) + lines = re.sub(r"Makikomi([^-])", r"\1", lines) + lines = lines.replace("Hane-goshi", "") + lines = lines.replace("Sode-makikomi", "") + lines = lines.replace("shiai", "") + lines = lines.replace("randori", "") + lines = re.sub(r"Sode([^-])", r"\1", lines) + lines = lines.replace("Ukemi", "") + lines = lines.replace("Ha-jime", "") + lines = lines.replace("Ude-garami", "") + lines = lines.replace("Hane-uchi-mata", "") + lines = lines.replace("Uchi-", "Uchi-mata") + lines = lines.replace("Uchi-mata", "") + lines = lines.replace("Hande-maki- \1", lines) + lines = lines.replace("%Sode-maki[komi]", "") + lines = lines.replace("Tsuri-komi", "") + lines = lines.replace("Uchi-komi", "") + lines = lines.replace("O-uchi", "") + lines = lines.replace("Goshi", "") + lines = lines.replace("Uchi]-mata", "") + lines = lines.replace("Komi", "") + lines = lines.replace("Tani-otoshi", "") + lines = lines.replace("Hane-maki][2komi=", "") + lines = lines.replace("Makikomi-waza", "") + lines = lines.replace("Seoi", "") + lines = lines.replace("uke", "") + elif reco_id == "SBC059": + lines = lines.replace("[]", "hour[6=6] F>") + + spk_buffer = "" + lang_buffer = "English" + for line in lines.split("\n"): + #### Transcript fixes + if line == "77.200\t77.540 :\t(H)": + continue + if line.startswith("000000000 000000000 ") or line.startswith("0.00 0.00"): + continue + if line.startswith("\t"): + line.lstrip("\t") + if "and in his pamphlet the Liber Arbetrio" in line: + continue + + line = line.strip() + line = re.sub(r" +", " ", line) + line = re.sub(r"\t+", "\t", line) + fields = line.strip().split("\t") + if len(fields) == 4: + spk_field, raw_trans = fields[2:] + start, end = [float(time.rstrip()) for time in fields[:2]] + elif len(fields) == 3: + if len(fields[0].rstrip().split(" ")) > 1: + spk_field, raw_trans = fields[1:] + start, end = [float(time) for time in fields[0].split(" ")[:2]] + raw_trans = fields[-1] + else: + start, end = [float(time.rstrip()) for time in fields[:2]] + spk_field_candidate = fields[2].split(" ")[0] + if re.fullmatch(r"[A-Z]+:", spk_field_candidate): + spk_field = spk_field_candidate + raw_trans = " ".join(fields[2].split(" ")[1:]) + else: + spk_field = "" + raw_trans = fields[2] + elif len(fields) == 2: + timesish = fields[0].rstrip().split(" ") + if len(timesish) == 1: + continue + start, end = [float(time) for time in timesish[:2]] + if len(timesish) > 2: + spk_field = timesish[2] + raw_trans = fields[1] + else: + spk_field_candidate = fields[1].split(" ")[0] + if re.fullmatch(r"[A-Z]+:", spk_field_candidate): + spk_field = spk_field_candidate + raw_trans = " ".join(fields[1].split(" ")[1:]) + else: + spk_field = "" + raw_trans = fields[1] + else: + split = line.split(" ") + if re.fullmatch(r"[0-9]+\.[0-9]+", split[0]) and re.fullmatch( + r"[0-9]+\.[0-9]+", split[1] + ): + start, end = [float(time.rstrip()) for time in split[:2]] + if re.fullmatch(r"[A-Z]+:", split[2]): + spk_field = split[2] + raw_trans = " ".join(split[3:]) + else: + spk_field = "" + raw_trans = " ".join(split[2:]) + else: + continue + + #### Transcript fixes + if raw_trans == "[2ENV", "ENV", ">MAC", ">DOG", ">HORSE", ">CAT", ">BABY"]: + continue + elif spk_field == "#READ": + spk_field = "WALT" + + if spk_field: + spk_field = re.sub(r"^[^A-Z]", "", spk_field) + spk_buffer = spk_field + + utt_id = f"{reco_id}_{int(start*1000):07}_{int(end*1000):07}_{spk_buffer}" + + text, lang_tag = _parse_raw_transcript(raw_trans) + + if "l" in lang_tag: + for _ in range(lang_tag.count("l")): + new_lang = next(lang_iterators[reco_id]) + if "c" in lang_tag: + lang_buffer = f"English-{new_lang}" + else: + lang_buffer = new_lang + elif "c" in lang_tag: + lang_buffer = f"English-{lang_buffer.split('-')[-1]}" + + spk_key = reco_id + "_" + spk_buffer + if spk_key not in spk2glob_dict and reco_id != "SBC021": + spk2gen_dict[spk_key] = None + spk2glob_dict[spk_key] = dummy_spk_iterator.next(spk_key) + + if spk_key in spk2glob_dict: + speaker = spk2glob_dict[spk_key] + gender = spk2gen_dict[spk_key] + else: + speaker = dummy_spk_iterator.next(spk_key) + gender = None + + if re.search(r"[A-Za-z]", text): + supervisions.append( + SupervisionSegment( + id=utt_id, + recording_id=reco_id, + start=start, + duration=end - start, + channel=[0, 1], + text=text, + language=lang_buffer, + speaker=speaker, + gender=gender, + ) + ) + + if lang_tag: + if lang_tag[-1] == "r": + lang_buffer = "English" + if lang_tag[-1] == "l": + lang_buffer = lang_buffer.split("-")[-1] + + return supervisions + + +def _parse_raw_transcript(transcript: str): + + transcript = transcript.replace("0h", "oh") + transcript = transcript.replace("s@so", "s- so") + transcript = transcript.replace("la@ter", "later") + transcript = transcript.replace("you@.", "you @.") + transcript = transcript.replace("[N=]", "N") + transcript = transcript.replace("[2C2]=", "C") + transcript = transcript.replace("[MM=]", "MM") + transcript = transcript.replace("[I=]", "I") + + transcript = transcript.replace("(YELL)", "") + + transcript = transcript.replace("_", "-") + + transcript = transcript.replace("=", "") + transcript = transcript.replace("%", "") + + # Process overlapped UNKs before they get removed by the following step + transcript = re.sub(r"\[([2-9]?)([A-Z])+\1\]", r"\2", transcript) + + # Paired parenthetical/bracket annotation remover + paren_matches = re.findall(r"\([^a-z@ ]*\)", transcript) + for paren_match in paren_matches: + transcript = transcript.replace( + paren_match, re.sub(r"[^\[\]]", "", paren_match) + ) + brack_matches = re.findall(r"\[[^a-z@ ]+\]", transcript) + for brack_match in brack_matches: + transcript = transcript.replace( + brack_match, re.sub(r"[^\(\)]", "", brack_match) + ) + + transcript = re.sub(r"<<[^a-z@ ]+>>", "", transcript) + transcript = re.sub(r"<<[^a-z@ ]+", "", transcript) + transcript = re.sub(r"[^a-z@ ]+>>", "", transcript) + + transcript = re.sub(r"<[^a-z@ ]+>", "", transcript) + transcript = re.sub(r"<[^a-z2 ]*[^2 ]([ <])", r"\1", transcript) + transcript = re.sub(r"([ >])[^a-z2 ]*[^a-z 2]>", r"\1", transcript) + + transcript = re.sub(r"\[[2-9]?", "", transcript) + transcript = re.sub(r"[2-9]?\]", "", transcript) + + transcript = transcript.replace("(Hx)", " ") + transcript = transcript.replace("(hx)", " ") + transcript = transcript.replace("(@Hx)", "@") + + transcript = transcript.replace("(COUGH COUGH)", " ") + transcript = transcript.replace("(SNIFF", "") + + transcript = transcript.replace("(", "") + transcript = transcript.replace(")", "") + + transcript = transcript.replace("< ", " ") + transcript = transcript.replace(" >", " ") + + transcript = re.sub(r"[^A-Za-z-]-+", "", transcript) + transcript = re.sub(r"\.\.+", "", transcript) + + transcript = transcript.replace("+", "") + transcript = transcript.replace("&", "") + transcript = transcript.replace("#", "") + transcript = transcript.replace("*", "") + + transcript = re.sub(r"!([A-Za-z])", r"\1", transcript) + + # Deal with extra white space + transcript = re.sub(r" +", " ", transcript) + + # Merge X's + transcript = re.sub(r"X+", "X", transcript) + + # Parse laughter + transcript = transcript.replace("on@,", "on @,") + transcript = re.sub(r"([a-z-])@([a-z])", r"\1\2", transcript) + transcript = re.sub(r"@+", "@", transcript) + transcript = re.sub(r"(^| )@([^ ])", r" @ \2", transcript) + transcript = re.sub(r"([^ ])@( |$)", r"\1 @ ", transcript) + transcript = transcript.replace("@ @", "@").replace("@ @", "@") + + transcript = re.sub(r"(^| )X([ ,.?']|$)", r"\1\2", transcript) + transcript = re.sub(r"(^| )X([ ,.?']|$)", r"\1\2", transcript) + transcript = re.sub(r"X-($| )", r"\1", transcript) + + transcript = re.sub(r"^ ", "", transcript) + transcript = re.sub(r" $", "", transcript) + + transcript = transcript.replace(" .", ".") + transcript = transcript.replace(" ,", ",") + transcript = transcript.replace(" ?", "?") + + transcript = re.sub(r"^\. ", "", transcript) + transcript = re.sub(r"^\.$", "", transcript) + + if ( + len(transcript.split(" 1 + and re.search(r"[A-Za-z]", transcript.split("")) > 1 + and re.search(r"[A-Za-z]", transcript.split("L2>")[-1]) + ): + lang_tag = "c" + else: + lang_tag = "" + + transcript = transcript.replace("@", "") + transcript = transcript.replace("", "") + + if "L2" in transcript: + lang_tag = lang_tag + re.sub( + r"()(?!.*()).*$", + r"\1", + re.sub(r".*?()", r"\1", transcript), + ) + lang_tag = lang_tag.replace("", "r") + + # We choose to leave the language tags in, but uncommenting this would remove them. + # transcript = transcript.replace("", "") + + return transcript, lang_tag + + +@dataclass +class StmSegment: + recording_id: str + speaker: str + start: float + end: float + text: str + channel: str = "1" + + +def parse_stm_file(data: str) -> List[StmSegment]: + lines = data.split("\n") + stm_segments = [] + + for line in lines: + if not line: + continue + + fields = line.strip().split() + reco_id, channel, speaker = fields[:3] + start, end = [float(time) for time in fields[3:5]] + text = " ".join(fields[5:]) + + stm_segments.append( + StmSegment( + recording_id=reco_id, + speaker=speaker, + start=start, + end=end, + text=text, + channel=channel, + ) + ) + + return stm_segments + + +def retrieve_stm_file(url) -> List[StmSegment]: + import urllib.request + + response = urllib.request.urlopen(url) + data = response.read().decode("utf-8") + + return parse_stm_file(data) + + +def norm_txt(text: str): + text = text.strip() + text = text.lower() + return text + + +def compute_iou(seg1: SupervisionSegment, seg2: StmSegment) -> float: + start = max(seg1.start, seg2.start) + end = min(seg1.end, seg2.end) + + intersection = max(0.0, end - start) + union = (seg1.end - seg1.start) + (seg2.end - seg2.start) - intersection + + return intersection / union + + +def apply_stm( + recording_ids: List[str], + supervisions: SupervisionSet, + aligned_stm_segs: List[StmSegment], +) -> SupervisionSet: + + if not is_module_available("intervaltree"): + raise ImportError( + "intervaltree package not found. Please install..." + " (pip install intervaltree)" + ) + else: + from intervaltree import IntervalTree + + if not is_module_available("jiwer"): + raise ImportError( + "jiwer package not found. Please install..." " (pip install jiwer==3.0.4)" + ) + else: + from jiwer import cer + + sset = deepcopy(supervisions) + + per_rec_its = {} + for rid in recording_ids: + per_rec_its[rid] = IntervalTree() + for stm_seg in tqdm(aligned_stm_segs, desc="Building interval tree..."): + per_rec_its[stm_seg.recording_id][stm_seg.start : stm_seg.end] = stm_seg + + for s in tqdm(sset, desc="Applying STM..."): + # We need to find the closest and best-matching segment. + # Some labeled segments were misplaced a lot and fixed by manual post-processing. + # Hence, in order to find a good match, we tuned collar value to find all matches. + # Example: 451 seconds, SBC027 recording. + collar = 2.0 + matching_segments = list( + filter( + lambda x: x.data.speaker == s.speaker, + per_rec_its[s.recording_id][s.start - collar : s.end + collar], + ) + ) + # Alignments used slightly different speaker IDs for UNK speakers, so we relax the speaker ID matching. + if not matching_segments: + matching_segments = per_rec_its[s.recording_id][ + s.start - collar : s.end + collar + ] + + best_cer = inf + best_cer_res = None + best_matching_seg = None + best_iou = 0.0 + + for matching_seg in matching_segments: + cer_res = cer( + norm_txt(s.text), norm_txt(matching_seg.data.text), return_dict=True + ) + cer_val = cer_res["cer"] + + if cer_val < best_cer: + best_cer = cer_val + best_cer_res = cer_res + best_matching_seg = matching_seg + best_iou = compute_iou(s, matching_seg.data) + + # There's been an update between the alignments and the lhotse recipe, so some UNK speakers have shifted IDs. + # It's enough to match the speaker names (or UNK). + if ( + cer_val == best_cer + and matching_seg.data.speaker.split("_")[1] == s.speaker.split("_")[1] + ): + current_iou = compute_iou(s, matching_seg.data) + if current_iou >= best_iou: + best_matching_seg = matching_seg + best_cer_res = cer_res + best_iou = current_iou + + if ( + s.speaker.split("_")[1] == best_matching_seg.data.speaker.split("_")[1] + and best_cer_res["substitutions"] == best_cer_res["deletions"] == 0 + and (best_cer < 0.5 or len(s.text) < 3) + ): + s.start = best_matching_seg.data.start + s.duration = best_matching_seg.data.end - best_matching_seg.data.start + s.text = best_matching_seg.data.text + + per_rec_its[s.recording_id].remove(best_matching_seg) + + return sset + + +def apply_aligned_stms( + recording_ids: List[str], processed_supervisions: SupervisionSet +) -> Tuple[SupervisionSet, SupervisionSet]: + aligned_for_asr_stm = retrieve_stm_file( + "https://raw.githubusercontent.com/domklement/SBCSAE_alignments/main/alignments/stm/aligned_for_asr.stm" + ) + aligned_for_diar_stm = retrieve_stm_file( + "https://raw.githubusercontent.com/domklement/SBCSAE_alignments/main/alignments/stm/aligned_for_diar.stm" + ) + + asr_sup = apply_stm(recording_ids, processed_supervisions, aligned_for_asr_stm) + diar_sup = apply_stm(recording_ids, processed_supervisions, aligned_for_diar_stm) + + return asr_sup, diar_sup From e2b149dc70b74532329e04dc1e6e6ff8ecc1cce9 Mon Sep 17 00:00:00 2001 From: Dominik Klement Date: Mon, 7 Oct 2024 15:17:21 -0400 Subject: [PATCH 61/69] Implement conversion from CutSet to HuggingFace dataset (#1398) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement conversion from CutSet to HuggingFace dataset So far, conversion from CutSet containing MonoCut and single-source audio to HuggingFace dataset. * Refactor * Add docs to set.py --------- Co-authored-by: Piotr Żelasko --- lhotse/cut/set.py | 82 +++++++++++++ lhotse/hf.py | 303 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+) create mode 100644 lhotse/hf.py diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 927558bce..2a7afd16c 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -2550,6 +2550,88 @@ def prefetch(self, buffer_size: int = 10) -> "CutSet": ) ) + def to_huggingface_dataset(self): + """ + Converts a CutSet to a HuggingFace Dataset. Currently, only MonoCut with one recording source is supported. + Other cut types will be supported in the future. + + Currently, two formats are supported: + 1. If each cut has one supervision (e.g. LibriSpeech), each cut is represented as a single row (entry) + in the HuggingFace dataset with all the supervision information stored along the cut information. + The final HuggingFace dataset format is: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ id ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ audio ║ Audio() ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ duration ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ num_channels ║ Value(dtype='uint16') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ text ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ speaker ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ language ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ {x}_alignment ║ Sequence(Alignment) ║ + ╚═══════════════════╩═══════════════════════════════╝ + where x stands for the alignment type (commonly used: "word", "phoneme"). + + Alignment is represented as: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ symbol ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ start ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ end ║ Value(dtype='float32') ║ + ╚═══════════════════╩═══════════════════════════════╝ + + + 2. If each cut has multiple supervisions (e.g. AMI), each cut is represented as a single row (entry) + while all the supervisions are stored in a separate list of dictionaries under the 'segments' key. + The final HuggingFace dataset format is: + ╔══════════════╦════════════════════════════════════╗ + ║ Feature ║ Type ║ + ╠══════════════╬════════════════════════════════════╣ + ║ id ║ Value(dtype='string') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ audio ║ Audio() ║ + ╠══════════════╬════════════════════════════════════╣ + ║ duration ║ Value(dtype='float32') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ num_channels ║ Value(dtype='uint16') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ segments ║ Sequence(Segment) ║ + ╚══════════════╩════════════════════════════════════╝ + where one Segment is represented as: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ text ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ start ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ end ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ channel ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ speaker ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ language ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ {x}_alignment ║ Sequence(Alignment) ║ + ╚═══════════════════╩═══════════════════════════════╝ + :return: A HuggingFace Dataset. + """ + from lhotse.hf import export_cuts_to_hf + + return export_cuts_to_hf(self) + def __repr__(self) -> str: try: len_val = len(self) diff --git a/lhotse/hf.py b/lhotse/hf.py new file mode 100644 index 000000000..493087ad9 --- /dev/null +++ b/lhotse/hf.py @@ -0,0 +1,303 @@ +""" +╔══════════════════════════════════════╗ +║ Export CutSet to HuggingFace Dataset ║ +╚══════════════════════════════════════╝ +""" +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from lhotse.cut import CutSet, MonoCut +from lhotse.utils import is_module_available + + +def contains_only_mono_cuts(cutset: CutSet) -> bool: + return all(isinstance(cut, MonoCut) for cut in cutset) + + +def has_one_supervision_per_cut(cutset: CutSet) -> bool: + return all(len(cut.supervisions) == 1 for cut in cutset) + + +def has_one_audio_source(cutset: CutSet) -> bool: + return all(len(cut.recording.sources) == 1 for cut in cutset) + + +def convert_cuts_info_to_hf(cutset: CutSet) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Converts the cut information into a dictionary compatible with HuggingFace datasets format. + + :param cutset: A CutSet object. + :return: A tuple where the first element is a dictionary + representing the cut attributes and the second element is a dictionary describing the + format of the HuggingFace dataset. + """ + from datasets import Audio, Value + + cut_info = { + "id": [cut.id for cut in cutset], + "audio": [cut.recording.sources[0].source for cut in cutset], + "duration": [cut.duration for cut in cutset], + "num_channels": [len(cut.recording.channel_ids) for cut in cutset], + } + cut_info_description = { + "id": Value("string"), + "audio": Audio(mono=False), + "duration": Value("float"), + "num_channels": Value("uint16"), + } + return cut_info, cut_info_description + + +def convert_supervisions_info_to_hf( + cutset: CutSet, + exclude_attributes: Optional[Union[List[str], Set[str]]] = None, +) -> Tuple[List[List[Dict[str, Any]]], Dict[str, Any]]: + """ + Converts cut supervisions into a dictionary compatible with HuggingFace datasets format. + + :param cutset: A CutSet object. + :param exclude_attributes: A list|set of attributes to exclude from the supervisions dicts. + :return: A tuple where the first element is a dictionary + representing the cut attributes and the second element is a dictionary describing the + format of the HuggingFace dataset. + """ + + from datasets import Features, Sequence, Value + + has_speaker = any( + ( + hasattr(cut.supervisions[0], "speaker") + and cut.supervisions[0].speaker is not None + ) + for cut in cutset + ) + has_language = any( + ( + hasattr(cut.supervisions[0], "language") + and cut.supervisions[0].language is not None + ) + for cut in cutset + ) + alignment_types = [ + s.alignment.keys() + for c in cutset + for s in c.supervisions + if s.alignment is not None + ] + alignment_types = set([item for sublist in alignment_types for item in sublist]) + + sup_dicts = [] + for c in cutset: + cut_sup_dicts = [] + for s in c.supervisions: + sup_dict = { + "text": s.text, + } + + if exclude_attributes is None or "start" not in exclude_attributes: + sup_dict["start"] = s.start + + if exclude_attributes is None or "end" not in exclude_attributes: + sup_dict["end"] = s.end + + if exclude_attributes is None or "channel" not in exclude_attributes: + if isinstance(s.channel, list): + sup_dict["channel"] = ",".join(map(str, s.channel)) + else: + sup_dict["channel"] = str(s.channel) + + if has_speaker and ( + exclude_attributes is None or "speaker" not in exclude_attributes + ): + sup_dict["speaker"] = str(s.speaker) + + if has_language and ( + exclude_attributes is None or "language" not in exclude_attributes + ): + sup_dict["language"] = str(s.language) + + if alignment_types and ( + exclude_attributes is None or "alignments" not in exclude_attributes + ): + alignments = {} + for alignment_type in alignment_types: + alignments[alignment_type + "_alignment"] = list( + map( + lambda item: { + "symbol": item.symbol, + "start": item.start, + "end": item.end, + }, + s.alignment[alignment_type], + ) + ) + + sup_dict = {**sup_dict, **alignments} + + cut_sup_dicts.append(sup_dict) + sup_dicts.append(cut_sup_dicts) + + sup_dicts_info = {"text": Value("string")} + + if exclude_attributes is None or "start" not in exclude_attributes: + sup_dicts_info["start"] = Value("float") + + if exclude_attributes is None or "end" not in exclude_attributes: + sup_dicts_info["end"] = Value("float") + + if exclude_attributes is None or "channel" not in exclude_attributes: + sup_dicts_info["channel"] = Value("string") + + if has_speaker and ( + exclude_attributes is None or "speaker" not in exclude_attributes + ): + sup_dicts_info["speaker"] = Value("string") + + if has_language and ( + exclude_attributes is None or "language" not in exclude_attributes + ): + sup_dicts_info["language"] = Value("string") + + if alignment_types and ( + exclude_attributes is None or "alignments" not in exclude_attributes + ): + alignment_info = { + "symbol": Value("string"), + "start": Value("float"), + "end": Value("float"), + } + for alignment_type in alignment_types: + sup_dicts_info[alignment_type + "_alignment"] = Sequence( + Features(**alignment_info) + ) + + return sup_dicts, sup_dicts_info + + +def lod_to_dol(lod: List[Dict[str, Any]]) -> Dict[str, List]: + """ + Converts List of Dicts to Dict of Lists. + """ + return {k: [d[k] for d in lod] for k in lod[0].keys()} + + +def export_cuts_to_hf(cutset: CutSet): + """ + Converts a CutSet to a HuggingFace Dataset. Currently, only MonoCut with one recording source is supported. + Other cut types will be supported in the future. + + Currently, two formats are supported: + 1. If each cut has one supervision (e.g. LibriSpeech), each cut is represented as a single row (entry) + in the HuggingFace dataset with all the supervision information stored along the cut information. + The final HuggingFace dataset format is: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ id ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ audio ║ Audio() ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ duration ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ num_channels ║ Value(dtype='uint16') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ text ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ speaker ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ language ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ {x}_alignment ║ Sequence(Alignment) ║ + ╚═══════════════════╩═══════════════════════════════╝ + where x stands for the alignment type (commonly used: "word", "phoneme"). + + Alignment is represented as: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ symbol ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ start ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ end ║ Value(dtype='float32') ║ + ╚═══════════════════╩═══════════════════════════════╝ + + + 2. If each cut has multiple supervisions (e.g. AMI), each cut is represented as a single row (entry) + while all the supervisions are stored in a separate list of dictionaries under the 'segments' key. + The final HuggingFace dataset format is: + ╔══════════════╦════════════════════════════════════╗ + ║ Feature ║ Type ║ + ╠══════════════╬════════════════════════════════════╣ + ║ id ║ Value(dtype='string') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ audio ║ Audio() ║ + ╠══════════════╬════════════════════════════════════╣ + ║ duration ║ Value(dtype='float32') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ num_channels ║ Value(dtype='uint16') ║ + ╠══════════════╬════════════════════════════════════╣ + ║ segments ║ Sequence(Segment) ║ + ╚══════════════╩════════════════════════════════════╝ + where one Segment is represented as: + ╔═══════════════════╦═══════════════════════════════╗ + ║ Feature ║ Type ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ text ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ start ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ end ║ Value(dtype='float32') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ channel ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ speaker ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ language ║ Value(dtype='string') ║ + ╠═══════════════════╬═══════════════════════════════╣ + ║ {x}_alignment ║ Sequence(Alignment) ║ + ╚═══════════════════╩═══════════════════════════════╝ + + :param cutset: A CutSet object. + :return: A HuggingFace Dataset. + """ + + assert has_one_audio_source( + cutset + ), "Only CutSets with one audio source per cut are supported. MultiSource cuts coming soon." + + if not is_module_available("datasets"): + raise ImportError( + "Please install the 'datasets' package (pip install datasets)." + ) + from datasets import Dataset, Features, Sequence + + # We don't need start and end attribute if we have only one supervision/segment per cut, + # as start=0 and end=duration. + cut_info, cut_info_description = convert_cuts_info_to_hf(cutset) + sup_dicts, sup_dicts_info = convert_supervisions_info_to_hf( + cutset, + exclude_attributes={"start", "end", "channel"} + if has_one_supervision_per_cut(cutset) + else None, + ) + + if has_one_supervision_per_cut(cutset): + dataset_dict = { + **cut_info, + **lod_to_dol([x[0] for x in sup_dicts]), + } + dataset_info = Features( + **cut_info_description, + **sup_dicts_info, + ) + else: + dataset_dict = { + **cut_info, + "segments": sup_dicts, + } + dataset_info = Features( + segments=Sequence(Features(**sup_dicts_info)), + **cut_info_description, + ) + + return Dataset.from_dict(dataset_dict, features=dataset_info) From 25475d41870f6a8ebbdeb9eba124c72eb1ab51ac Mon Sep 17 00:00:00 2001 From: Matthew Wiesner Date: Mon, 21 Oct 2024 15:14:07 -0400 Subject: [PATCH 62/69] Adds radio data recipe (#1400) * Adds radio data recipe * Makes some small formatting changes * Fixing black and isort formatting * Fixes disable_ffmpeg_torchaudio_info to use contextmanager * Removes what appears to be an unnecessary set_ffmpeg_torchaudio_info_enabled call. The recipe runs fine without it. --- lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/radio.py | 41 ++++++++ lhotse/recipes/__init__.py | 2 + lhotse/recipes/radio.py | 139 +++++++++++++++++++++++++++ 4 files changed, 183 insertions(+) create mode 100644 lhotse/bin/modes/recipes/radio.py create mode 100755 lhotse/recipes/radio.py diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index caa904cb8..facfd2eaa 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -65,6 +65,7 @@ from .nsc import * from .peoples_speech import * from .primewords import * +from .radio import * from .reazonspeech import * from .rir_noise import * from .sbcsae import * diff --git a/lhotse/bin/modes/recipes/radio.py b/lhotse/bin/modes/recipes/radio.py new file mode 100644 index 000000000..e0fe3be3a --- /dev/null +++ b/lhotse/bin/modes/recipes/radio.py @@ -0,0 +1,41 @@ +from typing import List, Optional, Sequence, Tuple, Union + +import click + +from lhotse.bin.modes import prepare +from lhotse.recipes.radio import prepare_radio +from lhotse.utils import Pathlike + +__all__ = ["radio"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(dir_okay=True)) +@click.argument("output_dir", type=click.Path(dir_okay=True)) +@click.option( + "-d", + "--min-seg-dur", + type=float, + default=0.5, + help="The minimum segment duration", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=4, + help="The number of parallel threads to use for data preparation", +) +def radio( + corpus_dir: Pathlike, + output_dir: Pathlike, + min_seg_dur: float = 0.5, + num_jobs: int = 4, +): + """Data preparation""" + prepare_radio( + corpus_dir, + output_dir=output_dir, + num_jobs=num_jobs, + min_segment_duration=min_seg_dur, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 1eef6623b..85092263e 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -66,6 +66,7 @@ from .musan import download_musan, prepare_musan from .nsc import prepare_nsc from .peoples_speech import prepare_peoples_speech +from .radio import prepare_radio from .reazonspeech import download_reazonspeech, prepare_reazonspeech from .rir_noise import download_rir_noise, prepare_rir_noise from .sbcsae import download_sbcsae, prepare_sbcsae @@ -194,6 +195,7 @@ "prepare_peoples_speech", "download_reazonspeech", "prepare_reazonspeech", + "prepare_radio", "download_rir_noise", "prepare_rir_noise", "prepare_slu", diff --git a/lhotse/recipes/radio.py b/lhotse/recipes/radio.py new file mode 100755 index 000000000..e667ff845 --- /dev/null +++ b/lhotse/recipes/radio.py @@ -0,0 +1,139 @@ +""" +This recipe prepares data collected from radio streamed on the web. The data +have some metadata attached to them, including the geographic location of +broadcast, date and time of the recorded clip, as well as a unique station +identifier. + +Obtaining the data +----------------------------------------------------------- +If you want to use this corpus please email: wiesner@jhu.edu + +As the data are collected from radio stream, they cannot be broadly +disseminated or used for commercial purposes. In the email, include your +affiliated academic institution and the intended use for the data and we will +the data to you if it is indeed for non-commercial, academic purporses. + +Description +------------------------------------------------------------ +The data consist of ∼4000 hours of speech collected between +September 27, 2023 to October 1, 2023, in 9449 locations all over the world, +from 17171 stations. + +These data were used for Geolocation of speech in order to answer the question, +Where are you from? in the paper + +Where are you from? Geolocating Speech and Applications to Language +Identification, presented at NAACL 2024. Please read for a full descrption +and please cite as + +@inproceedings{foley2024you, + title={Where are you from? Geolocating Speech and Applications to Language Identification}, + author={Foley, Patrick and Wiesner, Matthew and Odoom, Bismarck and Perera, Leibny Paola Garcia and Murray, Kenton and Koehn, Philipp}, + booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)}, + pages={5114--5126}, + year={2024} +} +""" +import json +import re +from functools import partial +from pathlib import Path +from typing import Dict, Optional, Union + +from tqdm import tqdm + +from lhotse.audio import Recording, RecordingSet +from lhotse.parallel import parallel_map +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike + + +def _make_reco_and_sups_from_file(sf: str, msd: float = 0.5): + corpus_dir = sf.parents[2] + audio_dir = corpus_dir / "recos" + fname = sf.with_suffix(".flac").stem + + # E.g. 2023_10_01_09h_02m_54s_dur30_ZnpbY9Zx_lat3.17_long113.04 + chunk_idx = int(sf.parent.suffix.strip(".")) + reco_file = audio_dir / f"recos.{chunk_idx}" / f"{fname}.flac" + reco = Recording.from_file(reco_file, recording_id=fname) + reco.channel_ids = [0] + sups = [] + total = 0 + with open(sf) as f: + segments = json.load(f) + + # Parse the file format, shown in the comment above, to get: + # date, station, latitude, longitude, and the estimated gender + lat, lon = re.search(r"lat[^_]+_long[^_]+", Path(sf).stem).group(0).split("_") + lat = float(lat.replace("lat", "")) + lon = float(lon.replace("long", "")) + station = re.search(r"s_dur[0-9]+_(.*)_lat[^_]+_long[^_]+", fname).groups()[0] + fname_vals = fname.split("_") + date = [int(i.strip("hms")) for i in fname_vals[0:6]] # YY MM DD hh mm ss + for seg in segments: + start, end = float(seg[1]), float(seg[2]) + dur = end - start + if seg[0] in ("male", "female") and dur > msd: + sups.append( + SupervisionSegment( + id=f"{fname}_{int(100*start):04}", + recording_id=fname, + start=start, + duration=round(dur, 4), + channel=0, + custom={ + "date": date, + "lat": lat, + "lon": lon, + "station": station, + "est_gender": seg[0], + }, + ) + ) + return sups, reco + + +def prepare_radio( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + min_segment_duration: float = 0.5, + num_jobs: int = 4, +) -> Dict[str, Union[RecordingSet, SupervisionSet]]: + """ + Return the manifests which consist of recordings and supervisions + :param corpus_dir: Path to the collected radio samples + :param output_dir: Pathlike, the path where manifests are written + :return: A Dict whose key is the dataset part and the value is a Dict with + keys 'recordings' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + segment_files = corpus_dir.rglob("segs/*/*.json") + supervisions, recordings = [], [] + fun = partial(_make_reco_and_sups_from_file, msd=min_segment_duration) + output_dir = Path(output_dir) if output_dir is not None else None + output_dir.mkdir(mode=511, parents=True, exist_ok=True) + with RecordingSet.open_writer( + output_dir / "radio_recordings.jsonl.gz" + ) as rec_writer: + with SupervisionSet.open_writer( + output_dir / "radio_supervisions.jsonl.gz" + ) as sup_writer: + for sups, reco in tqdm( + parallel_map( + fun, + segment_files, + num_jobs=num_jobs, + ), + desc=f"Making recordings and supervisions", + ): + rec_writer.write(reco) + for sup in sups: + sup_writer.write(sup) + + manifests = { + "recordings": RecordingSet.from_jsonl_lazy(rec_writer.path), + "supervisions": SupervisionSet.from_jsonl_lazy(sup_writer.path), + } + + return manifests From a30720b8329676a92ced850d941d45a352df5bb7 Mon Sep 17 00:00:00 2001 From: Matthew Wiesner Date: Mon, 21 Oct 2024 15:14:37 -0400 Subject: [PATCH 63/69] Fleurs (#1402) * Adds fleurs recipe * Black formatting * Removes useless num_jobs argument in the download cli, and ran isort and black again on *recipes/fleurs.py * Removes what appears to be an unnecessary set_ffmpeg_torchaudio_info call * isort and black fix * Fixes remaining black issues due to trailing space in recipes/__init__.py * Adds FLEURS entry in docs/corpus.rst --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/fleurs.py | 68 +++++ lhotse/recipes/__init__.py | 3 + lhotse/recipes/fleurs.py | 410 +++++++++++++++++++++++++++ 5 files changed, 484 insertions(+) create mode 100644 lhotse/bin/modes/recipes/fleurs.py create mode 100755 lhotse/recipes/fleurs.py diff --git a/docs/corpus.rst b/docs/corpus.rst index bc8d71bfb..4a50be2d8 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -109,6 +109,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_fisher_english` * - Fisher Spanish - :func:`lhotse.recipes.prepare_fisher_spanish` + * - FLEURS + - :func:`lhotse.recipes.prepare_fleurs` * - Fluent Speech Commands - :func:`lhotse.recipes.slu` * - GALE Arabic Broadcast Speech diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index facfd2eaa..bb331620d 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -34,6 +34,7 @@ from .eval2000 import * from .fisher_english import * from .fisher_spanish import * +from .fleurs import * from .gale_arabic import * from .gale_mandarin import * from .gigaspeech import * diff --git a/lhotse/bin/modes/recipes/fleurs.py b/lhotse/bin/modes/recipes/fleurs.py new file mode 100644 index 000000000..cf6bca3e1 --- /dev/null +++ b/lhotse/bin/modes/recipes/fleurs.py @@ -0,0 +1,68 @@ +from typing import Optional, Sequence, Union + +import click + +from lhotse.bin.modes import download, prepare +from lhotse.recipes.fleurs import download_fleurs, prepare_fleurs +from lhotse.utils import Pathlike + +__all__ = ["fleurs"] + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +@click.option( + "-l", + "--lang", + multiple=True, + default=["all"], + help="Specify which languages to prepare, e.g., " + " lhoste prepare librispeech mtedx_corpus data -l de -l fr -l es ", +) +def fleurs( + corpus_dir: Pathlike, + output_dir: Pathlike, + num_jobs: int, + lang: Optional[Union[str, Sequence[str]]], +): + """Fleurs ASR data preparation.""" + prepare_fleurs(corpus_dir, output_dir=output_dir, num_jobs=num_jobs, languages=lang) + + +@download.command(context_settings=dict(show_default=True)) +@click.argument("target_dir", type=click.Path()) +@click.option( + "-l", + "--lang", + multiple=True, + default=["all"], + help="Specify which languages to download, e.g., " + " lhotse download fleurs . -l hi_in -l en_us " + " lhotse download fleurs", +) +@click.option( + "--force-download", + type=bool, + is_flag=True, + default=False, + help="Specify whether to overwrite an existing archive", +) +def fleurs( + target_dir: Pathlike, + lang: Optional[Union[str, Sequence[str]]], + force_download: bool = False, +): + """FLEURS download.""" + download_fleurs( + target_dir, + languages=lang, + force_download=force_download, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 85092263e..fc7d3670a 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -31,6 +31,7 @@ from .eval2000 import prepare_eval2000 from .fisher_english import prepare_fisher_english from .fisher_spanish import prepare_fisher_spanish +from .fleurs import download_fleurs, prepare_fleurs from .gale_arabic import prepare_gale_arabic from .gale_mandarin import prepare_gale_mandarin from .gigaspeech import prepare_gigaspeech @@ -146,6 +147,8 @@ "prepare_eval2000", "prepare_fisher_english", "prepare_fisher_spanish", + "download_fleurs", + "prepare_fleurs", "prepare_gale_arabic", "prepare_gale_mandarin", "prepare_gigaspeech", diff --git a/lhotse/recipes/fleurs.py b/lhotse/recipes/fleurs.py new file mode 100755 index 000000000..d87513880 --- /dev/null +++ b/lhotse/recipes/fleurs.py @@ -0,0 +1,410 @@ +""" +This recipe provides functionality for downloading and preparing the fleurs +corpus. The data is hosted on huggingface and to enable more control of the +download format, we use the streaming download interface and save each audio +file as it is streamed. The download can take quite some time. + +The fleurs corpus consist of data in 102 languages spoken by multiple speakers. +There is about 10 hrs of trainign data in each language with smaller +accompanying dev and test sets. Full details can be found in + +@inproceedings{conneau2023fleurs, + title={Fleurs: Few-shot learning evaluation of universal representations of speech}, + author={Conneau, Alexis and Ma, Min and Khanuja, Simran and Zhang, Yu and Axelrod, Vera and Dalmia, Siddharth and Riesa, Jason and Rivera, Clara and Bapna, Ankur}, + booktitle={2022 IEEE Spoken Language Technology Workshop (SLT)}, + pages={798--805}, + year={2023}, + organization={IEEE} +} +""" +import logging +from pathlib import Path +from typing import Dict, Optional, Sequence, Union + +from tqdm import tqdm + +from lhotse import ( + Recording, + RecordingSet, + SupervisionSegment, + SupervisionSet, + audio, + fix_manifests, + get_ffmpeg_torchaudio_info_enabled, + set_ffmpeg_torchaudio_info_enabled, +) +from lhotse.parallel import parallel_map +from lhotse.utils import Pathlike, is_module_available + +# The FLEURS languages are indicated by 2-letter ISO-codes followed by a +# country code, i.e., +# +# en_us, fr_fr, ml_in +# +# for American English, French French and Indian Malayalam respectively. + +DEFAULT_LANGUAGES = [ + "af_za", + "am_et", + "ar_eg", + "as_in", + "ast_es", + "az_az", + "be_by", + "bg_bg", + "bn_in", + "bs_ba", + "ca_es", + "ceb_ph", + "ckb_iq", + "cmn_hans_cn", + "cs_cz", + "cy_gb", + "da_dk", + "de_de", + "el_gr", + "en_us", + "es_419", + "et_ee", + "fa_ir", + "ff_sn", + "fi_fi", + "fil_ph", + "fr_fr", + "ga_ie", + "gl_es", + "gu_in", + "ha_ng", + "he_il", + "hi_in", + "hr_hr", + "hu_hu", + "hy_am", + "id_id", + "ig_ng", + "is_is", + "it_it", + "ja_jp", + "jv_id", + "ka_ge", + "kam_ke", + "kea_cv", + "kk_kz", + "km_kh", + "kn_in", + "ko_kr", + "ky_kg", + "lb_lu", + "lg_ug", + "ln_cd", + "lo_la", + "lt_lt", + "luo_ke", + "lv_lv", + "mi_nz", + "mk_mk", + "ml_in", + "mn_mn", + "mr_in", + "ms_my", + "mt_mt", + "my_mm", + "nb_no", + "ne_np", + "nl_nl", + "nso_za", + "ny_mw", + "oc_fr", + "om_et", + "or_in", + "pa_in", + "pl_pl", + "ps_af", + "pt_br", + "ro_ro", + "ru_ru", + "sd_in", + "sk_sk", + "sl_si", + "sn_zw", + "so_so", + "sr_rs", + "sv_se", + "sw_ke", + "ta_in", + "te_in", + "tg_tj", + "th_th", + "tr_tr", + "uk_ua", + "umb_ao", + "ur_pk", + "uz_uz", + "vi_vn", + "wo_sn", + "xh_za", + "yo_ng", + "yue_hant_hk", + "zu_za", +] + + +def download_fleurs( + target_dir: Pathlike = ".", + languages: Optional[Union[str, Sequence[str]]] = "all", + force_download: Optional[bool] = False, +) -> Path: + """ + Download the specified fleurs datasets. + + :param target_dir: The path to which the corpus will be downloaded. + :type target_dir: Pathlike + :param languages: Optional list of str or str specifying which + languages to download. The str specifier for a language has the + ISOCODE_COUNTRYCODE format, and is all lower case. By default + this is set to "all", which will download the entire set of + languages. + :type languages: Optional[Union[str, Sequence[str]]] + :param force_download: Specifies whether to overwrite an existing + archive. + :type force_download: bool + :return: The root path of the downloaded data + :rtype: Path + """ + target_dir = Path(target_dir) + corpus_dir = target_dir / "fleurs" + metadata_dir = corpus_dir / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + + if isinstance(languages, str) and languages == "all" or languages[0] == "all": + languages = DEFAULT_LANGUAGES + + if isinstance(languages, str): + languages = [languages] + + for lang in tqdm(languages): + # Download one language at a time + lang_dir = corpus_dir / lang + download_single_fleurs_language( + lang_dir, + lang, + force_download, + ) + return corpus_dir + + +def download_single_fleurs_language( + target_dir: Pathlike, + language: str, + force_download: bool = False, +) -> Path: + """ + Download a single fleurs language + + :param target_dir: The path to which one langauge will be downloaded + :type target_dir: Pathlike + :param language: The code for the specified language + :type language: str + :param force_download: Specifies whether to overwrite an existing + archive. + :type force_download: bool + :return: The path to the downloaded data for the specified language + :rtype: Path + """ + if not is_module_available("datasets"): + raise ImportError( + "The huggingface datasets package is not installed. Please install" + " ...(pip install datasets)" + ) + else: + from datasets import load_dataset + + def _identity(x): + return x + + target_dir = Path(target_dir) + metadata_dir = target_dir.parents[0] / "metadata" / language + target_dir.mkdir(parents=True, exist_ok=True) + metadata_dir.mkdir(parents=True, exist_ok=True) + + completed_detector = target_dir / f".{language}_completed" + if completed_detector.is_file() and not force_download: + logging.info("Skipping dowload because {completed_detector} exists.") + return target_dir + + for split in tqdm(["train", "validation", "test"]): + fleurs = load_dataset( + "google/fleurs", + language, + cache_dir="/expscratch/mwiesner/geolocation/test", + streaming=True, + split=split, + ) + metadata = [] + osplit = "dev" if split == "validation" else split + split_dir = target_dir / osplit + split_dir.mkdir(parents=True, exist_ok=True) + for data in tqdm(fleurs, desc=f"Downloading data from {language}-{osplit}"): + audio.save_audio( + f"{split_dir}/{Path(data['audio']['path']).name}", + data["audio"]["array"], + data["audio"]["sampling_rate"], + ) + metadata_ = [ + str(data["id"]), # ID + Path(data["audio"]["path"]).name, # filename + data["raw_transcription"], # raw transcript + data["transcription"], # transcript + " ".join("|".join(data["transcription"].split())) + " |", # chars + str(data["num_samples"]), # number of audio samples + "FEMALE" if data["gender"] == 1 else "MALE", # gender + ] + metadata.append(metadata_) + with open(metadata_dir / f"{osplit}.tsv", "w") as f: + for md in metadata: + print("\t".join(md), file=f) + + completed_detector.touch() + return target_dir + + +def prepare_fleurs( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + languages: Optional[Union[str, Sequence[str]]] = "all", + num_jobs: int = 1, +) -> Dict[str, Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: + """ + Prepares the manifest for all of the FLEURS languages requested. + + :param corpus_dir: Path to the root where the FLEURS data are stored. + :type corpus_dir: Pathlike, + :param output_dir: The directory where the .jsonl.gz manifests will be written. + :type output_dir: Pathlike, + :param langauges: str or str sequence specifying the languages to prepare. + The str 'all' prepares all 102 languages. + :return: The manifest + :rtype: Dict[str, Dict[str, Union[RecordingSet, Supervisions]]]] + """ + + if isinstance(corpus_dir, str): + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + + if isinstance(output_dir, str): + output_dir = Path(output_dir) + + output_dir.mkdir(mode=511, parents=True, exist_ok=True) + + langs_list = DEFAULT_LANGUAGES + if isinstance(languages, str) and languages != "all": + langs_list = [languages] + elif isinstance(languages, list) or isinstance(languages, tuple): + if languages[0] != "all": + langs_list = languages + + # Start buildings the recordings and supervisions + manifests = {} + for lang in langs_list: + corpus_dir_lang = corpus_dir / f"{lang}" + if not corpus_dir_lang.is_dir(): + logging.info(f"Skipping {lang}. No directory {corpus_dir_lang} found.") + continue + output_dir_lang = output_dir / f"{lang}" + output_dir_lang.mkdir(mode=511, parents=True, exist_ok=True) + manifests[lang] = prepare_single_fleurs_language( + corpus_dir_lang, + output_dir_lang, + language=lang, + num_jobs=num_jobs, + ) + + if output_dir is not None: + for l in manifests: + for dset in ("train", "dev", "test"): + manifests[l][dset]["supervisions"].to_file( + output_dir / f"{l}" / f"fleurs-{l}_supervisions_{dset}.jsonl.gz" + ) + manifests[l][dset]["recordings"].to_file( + output_dir / f"{l}" / f"fleurs-{l}_recordings_{dset}.jsonl.gz" + ) + return manifests + + +def _make_recording(path): + return Recording.from_file(path, recording_id=Path(path).stem) + + +def prepare_single_fleurs_language( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + language: str = "language", + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Prepares manifests using a single FLEURS language. + + :param corpus_dir: Path to the root where the FLEURS data are stored. + :type corpus_dir: Pathlike, + :param output_dir: The directory where the .jsonl.gz manifests will be written. + :type output_dir: Pathlike, + :param langauge: str specifying the language to prepare. + + :return: The manifest + :rtype: Dict[str, Dict[str, Union[RecordingSet, Supervisions]]]] + """ + + if isinstance(corpus_dir, str): + corpus_dir = Path(corpus_dir) + + recordings = {"train": [], "dev": [], "test": []} + supervisions = {"train": [], "dev": [], "test": []} + + # First prepare the supervisions + for dset in ("train", "dev", "test"): + print(f"Preparing {dset} ...") + prompt_ids = {} + with open( + corpus_dir.parents[0] / "metadata" / corpus_dir.stem / f"{dset}.tsv" + ) as f: + for l in f: + vals = l.strip().split("\t") + prompt_id, fname, raw_text, text, _, nsamples, gender = vals + if prompt_id not in prompt_ids: + prompt_ids[prompt_id] = 0 + prompt_ids[prompt_id] += 1 + fname = Path(fname).stem + supervisions[dset].append( + SupervisionSegment( + id=f"{prompt_id}_{prompt_ids[prompt_id]}_{fname}", + recording_id=fname, + start=0.0, + duration=round(int(nsamples) / 16000, 4), + channel=0, + text=text, + language=language, + speaker=f"{prompt_id}_{prompt_ids[prompt_id]}", + gender=gender, + custom={"raw_text": raw_text}, + ) + ) + for dset in ("train", "dev", "test"): + for reco in tqdm( + parallel_map( + _make_recording, + ( + corpus_dir / f"{dset}/{s.recording_id}.wav" + for s in supervisions[dset] + ), + num_jobs=num_jobs, + ), + desc=f"Making recordings from {language} {dset}", + ): + recordings[dset].append(reco) + manifests = {} + for dset in ("train", "dev", "test"): + sups = SupervisionSet.from_segments(supervisions[dset]) + recos = RecordingSet.from_recordings(recordings[dset]) + recos, sups = fix_manifests(recos, sups) + manifests[dset] = {"supervisions": sups, "recordings": recos} + return manifests From 41269ff1f86e2fab6831d9b638ea922409b6b166 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Oct 2024 03:27:41 +0800 Subject: [PATCH 64/69] Add the Emilia corpus (#1404) * Add the Emilia corpus. * Return cutset instead * fix style issues --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/emilia.py | 36 ++++++ lhotse/recipes/emilia.py | 165 +++++++++++++++++++++++++++ 4 files changed, 204 insertions(+) create mode 100644 lhotse/bin/modes/recipes/emilia.py create mode 100644 lhotse/recipes/emilia.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 4a50be2d8..7182f5224 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -211,6 +211,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_wenetspeech4tts` * - YesNo - :func:`lhotse.recipes.prepare_yesno` + * - Emilia + - :func:`lhotse.recipes.prepare_emilia` * - Eval2000 - :func:`lhotse.recipes.prepare_eval2000` * - MGB2 diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index bb331620d..913ed56a0 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -31,6 +31,7 @@ from .earnings22 import * from .ears import * from .edacc import * +from .emilia import * from .eval2000 import * from .fisher_english import * from .fisher_spanish import * diff --git a/lhotse/bin/modes/recipes/emilia.py b/lhotse/bin/modes/recipes/emilia.py new file mode 100644 index 000000000..6382a164c --- /dev/null +++ b/lhotse/bin/modes/recipes/emilia.py @@ -0,0 +1,36 @@ +import click + +from lhotse.bin.modes import prepare +from lhotse.recipes.emilia import prepare_emilia +from lhotse.utils import Pathlike + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-l", + "--lang", + type=str, + help="The language to process. Valid values: zh, en, ja, ko, de, fr", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +def emilia( + corpus_dir: Pathlike, + output_dir: Pathlike, + lang: str, + num_jobs: int = 1, +): + """Prepare the Emilia corpus manifests.""" + prepare_emilia( + corpus_dir=corpus_dir, + output_dir=output_dir, + lang=lang, + num_jobs=num_jobs, + ) diff --git a/lhotse/recipes/emilia.py b/lhotse/recipes/emilia.py new file mode 100644 index 000000000..e43ce44d4 --- /dev/null +++ b/lhotse/recipes/emilia.py @@ -0,0 +1,165 @@ +""" +The Emilia dataset is constructed from a vast collection of speech data sourced +from diverse video platforms and podcasts on the Internet, covering various +content genres such as talk shows, interviews, debates, sports commentary, and +audiobooks. This variety ensures the dataset captures a wide array of real +human speaking styles. The initial version of the Emilia dataset includes a +total of 101,654 hours of multilingual speech data in six different languages: +English, French, German, Chinese, Japanese, and Korean. + +See also +https://emilia-dataset.github.io/Emilia-Demo-Page/ + +Please note that Emilia does not own the copyright to the audio files; the +copyright remains with the original owners of the videos or audio. Users are +permitted to use this dataset only for non-commercial purposes under the +CC BY-NC-4.0 license. + +Please refer to +https://huggingface.co/datasets/amphion/Emilia-Dataset +or +https://openxlab.org.cn/datasets/Amphion/Emilia +to download the dataset. + +Note that you need to apply for downloading. + +""" + +from concurrent.futures.thread import ThreadPoolExecutor +from pathlib import Path +from typing import Optional, Tuple + +from tqdm.auto import tqdm + +from lhotse import CutSet, MonoCut +from lhotse.audio import Recording +from lhotse.serialization import load_jsonl +from lhotse.supervision import SupervisionSegment +from lhotse.utils import Pathlike + + +def _parse_utterance( + data_dir: Path, + line: dict, +) -> Optional[Tuple[Recording, SupervisionSegment]]: + """ + :param data_dir: Path to the data directory + :param line: dict, it looks like below:: + + { + "id": "DE_B00000_S00000_W000029", + "wav": "DE_B00000/DE_B00000_S00000/mp3/DE_B00000_S00000_W000029.mp3", + "text": " Und es gibt auch einen Stadtplan von Tegun zu sehen.", + "duration": 3.228, + "speaker": "DE_B00000_S00000", + "language": "de", + "dnsmos": 3.3697 + } + + :return: a tuple of "recording" and "supervision" + """ + full_path = data_dir / line["wav"] + + if not full_path.is_file(): + return None + + recording = Recording.from_file( + path=full_path, + recording_id=full_path.stem, + ) + segment = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0.0, + duration=recording.duration, + channel=0, + text=line["text"], + language=line["language"], + speaker=line["speaker"], + custom={"dnsmos": line["dnsmos"]}, + ) + + return recording, segment + + +def prepare_emilia( + corpus_dir: Pathlike, + lang: str, + num_jobs: int, + output_dir: Optional[Pathlike] = None, +) -> CutSet: + """ + Returns the manifests which consist of the Recordings and Supervisions + + :param corpus_dir: Pathlike, the path of the data dir. + We assume the directory has the following structure: + corpus_dir/raw/openemilia_all.tar.gz, + corpus_dir/raw/DE, + corpus_dir/raw/DE/DE_B00000.jsonl, + corpus_dir/raw/DE/DE_B00000/DE_B00000_S00000/mp3/DE_B00000_S00000_W000000.mp3, + corpus_dir/raw/EN, etc. + :param lang: str, one of en, zh, de, ko, ja, fr + :param num_jobs: int, number of threads for processing jsonl files + :param output_dir: Pathlike, the path where to write the manifests. + :return: The CutSet containing the data for the given language. + """ + if lang is None: + raise ValueError("Please provide --lang") + + lang_uppercase = lang.upper() + if lang_uppercase not in ("DE", "EN", "FR", "JA", "KO", "ZH"): + raise ValueError( + "Please provide a valid language. " + f"Choose from de, en, fr, ja, ko, zh. Given: {lang}" + ) + + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + data_dir = corpus_dir / "raw" / lang_uppercase + assert data_dir.is_dir(), f"No such directory: {data_dir}" + + jsonl_files = data_dir.glob("*.jsonl") + + cuts = [] + futures = [] + + with ThreadPoolExecutor(num_jobs) as ex: + for jsonl_file in jsonl_files: + for item in tqdm( + # Note: People's Speech manifest.json is really a JSONL. + load_jsonl(jsonl_file), + desc=f"Processing {jsonl_file} with {num_jobs} jobs", + ): + futures.append( + ex.submit( + _parse_utterance, + data_dir, + item, + ) + ) + + for future in tqdm(futures, desc="Collecting futures"): + result = future.result() + if result is None: + continue + + recording, segment = result + + cuts.append( + MonoCut( + id=recording.id, + recording=recording, + start=0, + duration=recording.duration, + supervisions=[segment], + channel=0, + ) + ) + + cut_set = CutSet.from_cuts(cuts) + + if output_dir is not None: + output_dir = Path(output_dir) + cut_set.to_file(output_dir / f"emilia_cuts_{lang_uppercase}.jsonl.gz") + + return cut_set From 8b6d6f58ce76b2763da9807216d3d01d16406509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=9C=87=E4=B8=9C?= Date: Wed, 23 Oct 2024 19:39:14 +0800 Subject: [PATCH 65/69] [fix] fisher_english recipe (#1410) --- lhotse/recipes/fisher_english.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/lhotse/recipes/fisher_english.py b/lhotse/recipes/fisher_english.py index f128a6ac3..5d57b8295 100644 --- a/lhotse/recipes/fisher_english.py +++ b/lhotse/recipes/fisher_english.py @@ -148,19 +148,13 @@ def prepare_fisher_english( for audio_dir in audio_dirs: audio_dir_path = corpus_dir / audio_dir for audio_partition_dir in audio_dir_path.iterdir(): - audio_partition_dir_path = audio_dir_path / audio_partition_dir / "audio" - audio_subdir_paths += [ - audio_partition_dir_path / audio_subdir - for audio_subdir in audio_partition_dir_path.iterdir() - ] + audio_partition_dir_path = audio_partition_dir / "audio" + audio_subdir_paths += audio_partition_dir_path.iterdir() transcript_subdir_paths = [] for transcript_dir in transcript_dirs: transcript_dir_path = corpus_dir / transcript_dir / "data" / "trans" - transcript_subdir_paths += [ - transcript_dir_path / transcript_subdir - for transcript_subdir in transcript_dir_path.iterdir() - ] + transcript_subdir_paths += transcript_dir_path.iterdir() audio_paths = walk_dirs_parallel( audio_subdir_paths, "*.sph", "Parsing audio sub-dirs" From aff1188f9a8ae1bc47304e9cdd46a74a5320ee86 Mon Sep 17 00:00:00 2001 From: annapovey <109050627+annapovey@users.noreply.github.com> Date: Wed, 23 Oct 2024 04:41:53 -0700 Subject: [PATCH 66/69] downgrading sphinx version from 7.2.6 to 7.1.2 (#1409) Co-authored-by: npovey --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 18864ae23..7bf056ad6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.18.1 sphinx_rtd_theme==2.0.0 -sphinx==7.2.6 +sphinx==7.1.2 sphinx-click==5.1.0 sphinx-autodoc-typehints==2.0.0 From 96485163a3ac581b485765cfa79faaaf82883f7f Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 29 Oct 2024 02:49:50 +0800 Subject: [PATCH 67/69] Add workflow: annotate DNSMOS P.835 (#1406) * add workflow: dnsmos * add cli for dnsmos workflow * fix and test * fix --------- Co-authored-by: Your Name --- lhotse/bin/modes/workflows.py | 79 +++++++++++++ lhotse/workflows/__init__.py | 1 + lhotse/workflows/dnsmos.py | 213 ++++++++++++++++++++++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 lhotse/workflows/dnsmos.py diff --git a/lhotse/bin/modes/workflows.py b/lhotse/bin/modes/workflows.py index dec178c7c..ef629e4fe 100644 --- a/lhotse/bin/modes/workflows.py +++ b/lhotse/bin/modes/workflows.py @@ -569,3 +569,82 @@ def activity_detection( supervisions.to_file(str(sups_path)) print("Results saved to:", str(sups_path), sep="\n") + + +@workflows.command() +@click.argument("out_cuts", type=click.Path(allow_dash=True)) +@click.option( + "-m", + "--recordings-manifest", + type=click.Path(exists=True, dir_okay=False, allow_dash=True), + help="Path to an existing recording manifest.", +) +@click.option( + "-r", + "--recordings-dir", + type=click.Path(exists=True, file_okay=False), + help="Directory with recordings. We will create a RecordingSet for it automatically.", +) +@click.option( + "-c", + "--cuts-manifest", + type=click.Path(exists=True, dir_okay=False, allow_dash=True), + help="Path to an existing cuts manifest.", +) +@click.option( + "-e", + "--extension", + default="wav", + help="Audio file extension to search for. Used with RECORDINGS_DIR.", +) +@click.option( + "-p", + "--is-personalized-mos", + default=False, + help="Flag to indicate if personalized MOS score is needed or regular.", +) +@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.") +def annotate_dnsmos( + out_cuts: str, + recordings_manifest: Optional[str], + recordings_dir: Optional[str], + cuts_manifest: Optional[str], + extension: str, + is_personalized_mos: str, + jobs: int, +): + """ + Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. + It will predict DNSMOS P.835 score including SIG, NAK, and OVRL. + + See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS + + RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive. If CUTS_MANIFEST + is provided, its supervisions will be overwritten with the results of the inference. + """ + from lhotse import annotate_dnsmos as annotate_dnsmos_ + + assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), ( + "Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive " + "and at least one is required." + ) + + if recordings_manifest is not None: + manifest = RecordingSet.from_file(recordings_manifest) + elif recordings_dir is not None: + manifest = RecordingSet.from_dir( + recordings_dir, pattern=f"*.{extension}", num_jobs=jobs + ) + else: + manifest = CutSet.from_file(cuts_manifest).to_eager() + + with CutSet.open_writer(out_cuts) as writer: + for cut in tqdm( + annotate_dnsmos_( + manifest, + is_personalized_mos=is_personalized_mos, + ), + total=len(manifest), + desc="Annotating with DNSMOS P.835 prediction model", + ): + writer.write(cut, flush=True) diff --git a/lhotse/workflows/__init__.py b/lhotse/workflows/__init__.py index ccce27bf6..953aa803e 100644 --- a/lhotse/workflows/__init__.py +++ b/lhotse/workflows/__init__.py @@ -1,4 +1,5 @@ from .activity_detection import * +from .dnsmos import annotate_dnsmos from .forced_alignment import align_with_torchaudio from .meeting_simulation import * from .whisper import annotate_with_whisper diff --git a/lhotse/workflows/dnsmos.py b/lhotse/workflows/dnsmos.py new file mode 100644 index 000000000..1270212c3 --- /dev/null +++ b/lhotse/workflows/dnsmos.py @@ -0,0 +1,213 @@ +import logging +import os +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Generator, List, Optional, Union + +import numpy as np +from tqdm import tqdm + +from lhotse import CutSet, MonoCut, RecordingSet, SupervisionSegment +from lhotse.utils import fastcopy, is_module_available, resumable_download + + +class ComputeScore: + def __init__(self, primary_model_path) -> None: + import onnxruntime as ort + + self.onnx_sess = ort.InferenceSession(primary_model_path) + self.SAMPLING_RATE = 16000 + self.INPUT_LENGTH = 9.01 + + def audio_melspec( + self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True + ): + import librosa + + mel_spec = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels + ) + if to_db: + mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 + return mel_spec.T + + def get_polyfit_val(self, sig, bak, ovr, is_personalized_mos): + if is_personalized_mos: + p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) + p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) + else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + sig_poly = p_sig(sig) + bak_poly = p_bak(bak) + ovr_poly = p_ovr(ovr) + + return sig_poly, bak_poly, ovr_poly + + def __call__(self, manifest, is_personalized_mos): + fs = self.SAMPLING_RATE + audio = manifest.resample(fs).load_audio() + len_samples = int(self.INPUT_LENGTH * fs) + while len(audio) < len_samples: + audio = np.append(audio, audio) + + num_hops = int(np.floor(len(audio) / fs) - self.INPUT_LENGTH) + 1 + hop_len_samples = fs + predicted_mos_sig_seg = [] + predicted_mos_bak_seg = [] + predicted_mos_ovr_seg = [] + + for idx in range(num_hops): + audio_seg = audio[ + int(idx * hop_len_samples) : int( + (idx + self.INPUT_LENGTH) * hop_len_samples + ) + ] + if len(audio_seg) < len_samples: + continue + + input_features = np.array(audio_seg).astype("float32")[np.newaxis, :] + oi = {"input_1": input_features} + mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] + mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( + mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_mos + ) + predicted_mos_sig_seg.append(mos_sig) + predicted_mos_bak_seg.append(mos_bak) + predicted_mos_ovr_seg.append(mos_ovr) + + return manifest, { + "OVRL": np.mean(predicted_mos_ovr_seg), + "SIG": np.mean(predicted_mos_sig_seg), + "BAK": np.mean(predicted_mos_bak_seg), + } + + +def download_model( + is_personalized_mos: bool = False, + download_root: Optional[str] = None, +) -> str: + download_root = download_root if download_root is not None else "/tmp" + url = ( + "https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx" + if is_personalized_mos + else "https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx" + ) + filename = os.path.join(download_root, "sig_bak_ovr.onnx") + resumable_download(url, filename=filename) + return filename + + +def annotate_dnsmos( + manifest: Union[RecordingSet, CutSet], + is_personalized_mos: bool = False, + download_root: Optional[str] = None, +) -> Generator[MonoCut, None, None]: + """ + Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. + It will predict DNSMOS P.835 score including SIG, NAK, and OVRL. + + See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS + + :param manifest: a ``RecordingSet`` or ``CutSet`` object. + :param is_personalized_mos: flag to indicate if personalized MOS score is needed or regular. + :param download_root: if specified, the model will be downloaded to this directory. Otherwise, + it will be downloaded to /tmp. + :return: a generator of cuts (use ``CutSet.open_writer()`` to write them). + """ + assert is_module_available("librosa"), ( + "This function expects librosa to be installed. " + "You can install it via 'pip install librosa'" + ) + + assert is_module_available("onnxruntime"), ( + "This function expects onnxruntime to be installed. " + "You can install it via 'pip install onnxruntime'" + ) + + if isinstance(manifest, RecordingSet): + yield from _annotate_recordings( + manifest, + is_personalized_mos, + download_root, + ) + elif isinstance(manifest, CutSet): + yield from _annotate_cuts( + manifest, + is_personalized_mos, + download_root, + ) + else: + raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.") + + +def _annotate_recordings( + recordings: RecordingSet, + is_personalized_mos: bool = False, + download_root: Optional[str] = None, +): + """ + Helper function that annotates a RecordingSet with DNSMOS P.835 prediction model. + """ + primary_model_path = download_model(is_personalized_mos, download_root) + compute_score = ComputeScore(primary_model_path) + + with ThreadPoolExecutor() as ex: + futures = [] + for recording in tqdm(recordings, desc="Distributing tasks"): + if recording.num_channels > 1: + logging.warning( + f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, " + f"but we currently only support mono input." + ) + continue + futures.append(ex.submit(compute_score, recording, is_personalized_mos)) + + for future in tqdm(futures, desc="Processing"): + recording, result = future.result() + supervision = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0, + duration=recording.duration, + ) + cut = MonoCut( + id=recording.id, + start=0, + duration=recording.duration, + channel=0, + recording=recording, + supervisions=[supervision], + custom=result, + ) + yield cut + + +def _annotate_cuts( + cuts: CutSet, + is_personalized_mos: bool = False, + download_root: Optional[str] = None, +): + """ + Helper function that annotates a CutSet with DNSMOS P.835 prediction model. + """ + primary_model_path = download_model(is_personalized_mos, download_root) + compute_score = ComputeScore(primary_model_path) + + with ThreadPoolExecutor() as ex: + futures = [] + for cut in tqdm(cuts, desc="Distributing tasks"): + if cut.num_channels > 1: + logging.warning( + f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, " + f"but we currently only support mono input." + ) + continue + futures.append(ex.submit(compute_score, cut, is_personalized_mos)) + + for future in tqdm(futures, desc="Processing"): + cut, result = future.result() + new_cut = fastcopy(cut, custom=result) + yield new_cut From 3ab39176b3e87f8bd40632b7b3874e5335ec6407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=9C=87=E4=B8=9C?= Date: Wed, 6 Nov 2024 18:33:53 +0800 Subject: [PATCH 68/69] Update lhotse.py (#1414) Remove the deprecated usage. --- lhotse/bin/lhotse.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/lhotse/bin/lhotse.py b/lhotse/bin/lhotse.py index b241a643a..1944bde31 100755 --- a/lhotse/bin/lhotse.py +++ b/lhotse/bin/lhotse.py @@ -1,22 +1,6 @@ #!/usr/bin/env python3 """ -Use this script like: - -$ lhotse --help -$ lhotse make-feats --help -$ lhotse make-feats --compressed recording_manifest.yml mfcc_dir/ -$ lhotse write-default-feature-config feat-conf.yml -$ lhotse kaldi import data/train 16000 train_manifests/ -$ lhotse split 3 audio.yml split_manifests/ -$ lhotse combine feature.1.yml feature.2.yml combined_feature.yml -$ lhotse recipe --help -$ lhotse recipe librimix-dataprep path/to/librimix.csv output_manifests_dir/ -$ lhotse recipe librimix-obtain target_dir/ -$ lhotse recipe mini-librispeech-dataprep corpus_dir/ output_manifests_dir/ -$ lhotse recipe mini-librispeech-obtain target_dir/ -$ lhotse cut --help -$ lhotse cut simple supervisions.yml features.yml simple_cuts.yml -$ lhotse cut stereo-mixed supervisions.yml features.yml mixed_cuts.yml +Use this script like: https://lhotse.readthedocs.io/en/latest/cli.html """ # Note: we import all the CLI modes here so they get auto-registered From 54bb42fc7be95e2a6ebf705bdc98c95fe1917df5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 6 Nov 2024 05:50:47 -0500 Subject: [PATCH 69/69] Make torchaudio an optional dependency (#1382) * Make torchaudio an optional dependency * Remove torchaudio from some CI tests --- .github/workflows/unit_tests.yml | 12 ++++++------ README.md | 3 ++- docs/conf.py | 2 +- docs/getting-started.rst | 5 ++++- lhotse/audio/recording.py | 4 ++-- setup.py | 32 +++----------------------------- 6 files changed, 18 insertions(+), 40 deletions(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 33310c313..da5d652af 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -22,14 +22,14 @@ jobs: - python-version: "3.9" torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.10" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.10" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.11" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.11" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.12" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.12" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" fail-fast: false diff --git a/README.md b/README.md index 3d4bb17f6..61febb6c4 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,8 @@ Lhotse uses several environment variables to customize it's behavior. They are a ### Optional dependencies -**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package like this: `pip install lhotse[package_name]`. The supported optional packages include: +**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package: +- `torchaudio` used to be a core dependency in Lhotse, but is now optional. Refer to [official PyTorch documentation for installation](https://pytorch.org/get-started/locally/). - `pip install lhotse[kaldi]` for a maximal feature set related to Kaldi compatibility. It includes libraries such as `kaldi_native_io` (a more efficient variant of `kaldi_io`) and `kaldifeat` that port some of Kaldi functionality into Python. - `pip install lhotse[orjson]` for up to 50% faster reading of JSONL manifests. - `pip install lhotse[webdataset]`. We support "compiling" your data into WebDataset tarball format for more effective IO. You can still interact with the data as if it was a regular lazy CutSet. To learn more, check out the following tutorial: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/02-webdataset-integration.ipynb) diff --git a/docs/conf.py b/docs/conf.py index d674bcacc..8a50287b9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -78,4 +78,4 @@ "exclude-members": "__weakref__", } -autodoc_mock_imports = ["torchaudio", "SoundFile", "soundfile"] +autodoc_mock_imports = ["SoundFile", "soundfile"] diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 9a299c973..89072397f 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -143,7 +143,9 @@ Lhotse uses several environment variables to customize it's behavior. They are a Optional dependencies ********************* -**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package like this: ``pip install lhotse[package_name]``. The supported optional packages include: +**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package: + +* ``torchaudio`` used to be a core dependency in Lhotse, but is now optional. Refer to official PyTorch documentation for installation at `official Pytorch documentation for installation`_. * ``pip install lhotse[kaldi]`` for a maximal feature set related to Kaldi compatibility. It includes libraries such as ``kaldi_native_io`` (a more efficient variant of ``kaldi_io``) and ``kaldifeat`` that port some of Kaldi functionality into Python. @@ -230,3 +232,4 @@ the speech starts roughly at the first second (100 frames): .. _Icefall recipes: https://github.com/k2-fsa/icefall .. _orjson: https://pypi.org/project/orjson/ .. _AIStore: https://aiatscale.org +.. _official Pytorch documentation for installation: https://pytorch.org/get-started/locally/ diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index 063084af7..ec0f605f2 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -8,7 +8,7 @@ import torch from _decimal import ROUND_HALF_UP -from lhotse.audio.backend import info, save_audio, torchaudio_info +from lhotse.audio.backend import get_current_audio_backend, info, save_audio from lhotse.audio.source import AudioSource from lhotse.audio.utils import ( AudioLoadingError, @@ -260,7 +260,7 @@ def from_bytes( :return: a new ``Recording`` instance that owns the byte string data. """ stream = BytesIO(data) - audio_info = torchaudio_info(stream) + audio_info = get_current_audio_backend().info(stream) return Recording( id=recording_id, sampling_rate=audio_info.samplerate, diff --git a/setup.py b/setup.py index b96a3e36e..831786fa4 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ ) # False = public release, True = otherwise -LHOTSE_REQUIRE_TORCHAUDIO = os.environ.get("LHOTSE_REQUIRE_TORCHAUDIO", "1") in ( +LHOTSE_REQUIRE_TORCHAUDIO = os.environ.get("LHOTSE_REQUIRE_TORCHAUDIO", "0") in ( "1", "True", "true", @@ -157,6 +157,7 @@ def mark_lhotse_version(version: str) -> None: "packaging", "pyyaml>=5.3.1", "tabulate>=0.8.1", + "torch", "tqdm", ] @@ -167,30 +168,6 @@ def mark_lhotse_version(version: str) -> None: else: install_requires.append("lilcom>=1.1.0") -try: - # If the user already installed PyTorch, make sure he has torchaudio too. - # Otherwise, we'll just install the latest versions from PyPI for the user. - import torch - - if LHOTSE_REQUIRE_TORCHAUDIO: - try: - import torchaudio - except ImportError: - raise ValueError( - "We detected that you have already installed PyTorch, but haven't installed torchaudio. " - "Unfortunately we can't detect the compatible torchaudio version for you; " - "you will have to install it manually. " - "For instructions, please refer either to https://pytorch.org/get-started/locally/ " - "or https://github.com/pytorch/audio#dependencies " - "You can also disable torchaudio dependency by setting the following environment variable: " - "LHOTSE_USE_TORCHAUDIO=0" - ) -except ImportError: - extras = ["torch"] - if LHOTSE_REQUIRE_TORCHAUDIO: - extras.append("torchaudio") - install_requires.extend(extras) - docs_require = (project_root / "docs" / "requirements.txt").read_text().splitlines() tests_require = [ "pytest==7.1.3", @@ -222,13 +199,10 @@ def mark_lhotse_version(version: str) -> None: all_requires = sorted(dev_requires) if os.environ.get("READTHEDOCS", False): - # When building documentation, omit torchaudio installation and mock it instead. - # This works around the inability to install libsoundfile1 in read-the-docs env, - # which caused the documentation builds to silently crash. install_requires = [ req for req in install_requires - if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"]) + if not any(req.startswith(dep) for dep in ["SoundFile"]) ] setup(