diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 33310c313..da5d652af 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -22,14 +22,14 @@ jobs: - python-version: "3.9" torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.10" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.10" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.11" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.11" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" - - python-version: "3.12" - torch-install-cmd: "pip install torch==2.3 torchaudio==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" + - python-version: "3.12" # note: no torchaudio + torch-install-cmd: "pip install torch==2.3 --extra-index-url https://download.pytorch.org/whl/cpu" extra_deps: "" fail-fast: false diff --git a/README.md b/README.md index 3d4bb17f6..61febb6c4 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,8 @@ Lhotse uses several environment variables to customize it's behavior. They are a ### Optional dependencies -**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package like this: `pip install lhotse[package_name]`. The supported optional packages include: +**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package: +- `torchaudio` used to be a core dependency in Lhotse, but is now optional. Refer to [official PyTorch documentation for installation](https://pytorch.org/get-started/locally/). - `pip install lhotse[kaldi]` for a maximal feature set related to Kaldi compatibility. It includes libraries such as `kaldi_native_io` (a more efficient variant of `kaldi_io`) and `kaldifeat` that port some of Kaldi functionality into Python. - `pip install lhotse[orjson]` for up to 50% faster reading of JSONL manifests. - `pip install lhotse[webdataset]`. We support "compiling" your data into WebDataset tarball format for more effective IO. You can still interact with the data as if it was a regular lazy CutSet. To learn more, check out the following tutorial: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/02-webdataset-integration.ipynb) diff --git a/docs/conf.py b/docs/conf.py index d674bcacc..8a50287b9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -78,4 +78,4 @@ "exclude-members": "__weakref__", } -autodoc_mock_imports = ["torchaudio", "SoundFile", "soundfile"] +autodoc_mock_imports = ["SoundFile", "soundfile"] diff --git a/docs/datasets.rst b/docs/datasets.rst index 03df15ea1..d3609e955 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -28,7 +28,7 @@ It allows for interesting collation methods - e.g. **padding the speech with noi The items for mini-batch creation are selected by the ``Sampler``. Lhotse defines ``Sampler`` classes that are initialized with :class:`~lhotse.cut.CutSet`'s, so that they can look up specific properties of an utterance to stratify the sampling. -For example, :class:`~lhotse.dataset.sampling.SimpleCutSampler` has a defined ``max_frames`` attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of frames. +For example, :class:`~lhotse.dataset.sampling.SimpleCutSampler` has a defined ``max_duration`` attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of seconds. Another strategy — used in :class:`~lhotse.dataset.sampling.BucketingSampler` — will first group the cuts of similar durations into buckets, and then randomly select a bucket to draw the whole batch from. For tasks where both input and output of the model are speech utterances, we can use the :class:`~lhotse.dataset.sampling.CutPairsSampler`, which accepts two :class:`~lhotse.cut.CutSet`'s and will match the cuts in them by their IDs. @@ -38,11 +38,11 @@ A typical Lhotse's dataset API usage might look like this: .. code-block:: from torch.utils.data import DataLoader - from lhotse.dataset import SpeechRecognitionDataset, SimpleCutSampler + from lhotse.dataset import K2SpeechRecognitionDataset, SimpleCutSampler cuts = CutSet(...) - dset = SpeechRecognitionDataset(cuts) - sampler = SimpleCutSampler(cuts, max_frames=50000) + dset = K2SpeechRecognitionDataset(cuts) + sampler = SimpleCutSampler(cuts, max_duration=500) # Dataset performs batching by itself, so we have to indicate that # to the DataLoader with batch_size=None dloader = DataLoader(dset, sampler=sampler, batch_size=None, num_workers=1) diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 9a299c973..89072397f 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -143,7 +143,9 @@ Lhotse uses several environment variables to customize it's behavior. They are a Optional dependencies ********************* -**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package like this: ``pip install lhotse[package_name]``. The supported optional packages include: +**Other pip packages.** You can leverage optional features of Lhotse by installing the relevant supporting package: + +* ``torchaudio`` used to be a core dependency in Lhotse, but is now optional. Refer to official PyTorch documentation for installation at `official Pytorch documentation for installation`_. * ``pip install lhotse[kaldi]`` for a maximal feature set related to Kaldi compatibility. It includes libraries such as ``kaldi_native_io`` (a more efficient variant of ``kaldi_io``) and ``kaldifeat`` that port some of Kaldi functionality into Python. @@ -230,3 +232,4 @@ the speech starts roughly at the first second (100 frames): .. _Icefall recipes: https://github.com/k2-fsa/icefall .. _orjson: https://pypi.org/project/orjson/ .. _AIStore: https://aiatscale.org +.. _official Pytorch documentation for installation: https://pytorch.org/get-started/locally/ diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index 063084af7..cd8a41fb2 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -8,7 +8,7 @@ import torch from _decimal import ROUND_HALF_UP -from lhotse.audio.backend import info, save_audio, torchaudio_info +from lhotse.audio.backend import get_current_audio_backend, info, save_audio from lhotse.audio.source import AudioSource from lhotse.audio.utils import ( AudioLoadingError, @@ -168,6 +168,23 @@ def is_placeholder(self) -> bool: def num_channels(self) -> int: return len(self.channel_ids) + @property + def source_format(self) -> str: + """Infer format of the audio sources. + If all sources have the same format, return it. + If sources have different formats, raise an error. + """ + source_formats = list(set([s.format for s in self.sources])) + + if len(source_formats) == 1: + # if all sources have the same format, return it + return source_formats[0] + else: + # at the moment, we don't resolve different formats + raise NotImplementedError( + "Sources have different formats. Resolving to a single format not implemented." + ) + @staticmethod def from_file( path: Pathlike, @@ -260,7 +277,7 @@ def from_bytes( :return: a new ``Recording`` instance that owns the byte string data. """ stream = BytesIO(data) - audio_info = torchaudio_info(stream) + audio_info = get_current_audio_backend().info(stream) return Recording( id=recording_id, sampling_rate=audio_info.samplerate, diff --git a/lhotse/audio/source.py b/lhotse/audio/source.py index 459881ea2..88bb9743d 100644 --- a/lhotse/audio/source.py +++ b/lhotse/audio/source.py @@ -1,3 +1,5 @@ +import io +import os import warnings from dataclasses import dataclass from io import BytesIO, FileIO @@ -6,6 +8,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +import soundfile as sf import torch from lhotse.audio.backend import read_audio @@ -64,6 +67,10 @@ class AudioSource: def has_video(self) -> bool: return self.video is not None + @property + def format(self) -> str: + return self._get_format() + def load_audio( self, offset: Seconds = 0.0, @@ -316,3 +323,24 @@ def _prepare_for_reading( ) return source + + def _get_format(self) -> str: + """Get format for the audio source. + If using 'file' or 'url' types, the format is inferred from the file extension, as in soundfile. + If using 'memory' type, the format is inferred from the binary data. + """ + if self.type in ("file", "url"): + # Resolve audio format based on the filename + format = os.path.splitext(self.source)[-1][1:] + return format.lower() + elif self.type == "memory": + sf_info = sf.info(io.BytesIO(self.source)) + if sf_info.format == "OGG" and sf_info.subtype == "OPUS": + # soundfile describes opus as ogg container with opus coding + return "opus" + else: + return sf_info.format.lower() + else: + raise NotImplementedError( + f"Getting format not implemented for source type {self.type}" + ) diff --git a/lhotse/bin/lhotse.py b/lhotse/bin/lhotse.py index b241a643a..1944bde31 100755 --- a/lhotse/bin/lhotse.py +++ b/lhotse/bin/lhotse.py @@ -1,22 +1,6 @@ #!/usr/bin/env python3 """ -Use this script like: - -$ lhotse --help -$ lhotse make-feats --help -$ lhotse make-feats --compressed recording_manifest.yml mfcc_dir/ -$ lhotse write-default-feature-config feat-conf.yml -$ lhotse kaldi import data/train 16000 train_manifests/ -$ lhotse split 3 audio.yml split_manifests/ -$ lhotse combine feature.1.yml feature.2.yml combined_feature.yml -$ lhotse recipe --help -$ lhotse recipe librimix-dataprep path/to/librimix.csv output_manifests_dir/ -$ lhotse recipe librimix-obtain target_dir/ -$ lhotse recipe mini-librispeech-dataprep corpus_dir/ output_manifests_dir/ -$ lhotse recipe mini-librispeech-obtain target_dir/ -$ lhotse cut --help -$ lhotse cut simple supervisions.yml features.yml simple_cuts.yml -$ lhotse cut stereo-mixed supervisions.yml features.yml mixed_cuts.yml +Use this script like: https://lhotse.readthedocs.io/en/latest/cli.html """ # Note: we import all the CLI modes here so they get auto-registered diff --git a/lhotse/bin/modes/shar.py b/lhotse/bin/modes/shar.py index 95b670529..cffdf596b 100644 --- a/lhotse/bin/modes/shar.py +++ b/lhotse/bin/modes/shar.py @@ -27,8 +27,8 @@ def shar(): "-a", "--audio", default="none", - type=click.Choice(["none", "wav", "flac", "mp3", "opus"]), - help="Format in which to export audio (disabled by default, enabling will make a copy of the data)", + type=click.Choice(["none", "wav", "flac", "mp3", "opus", "original"]), + help="Format in which to export audio. Original will save in the same format as the original audio (disabled by default, enabling will make a copy of the data)", ) @click.option( "-f", diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index ad47ca381..a939db5a2 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -723,7 +723,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/mixed.py b/lhotse/cut/mixed.py index 01acf248d..cd83d29e0 100644 --- a/lhotse/cut/mixed.py +++ b/lhotse/cut/mixed.py @@ -622,7 +622,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/padding.py b/lhotse/cut/padding.py index c535bde2b..a95be6062 100644 --- a/lhotse/cut/padding.py +++ b/lhotse/cut/padding.py @@ -236,7 +236,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 2a7afd16c..5a62ba21c 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -2821,7 +2821,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param cut: DataCut to be padded. diff --git a/lhotse/dataset/audio_tagging.py b/lhotse/dataset/audio_tagging.py index 0ca44a687..fbf370fd6 100644 --- a/lhotse/dataset/audio_tagging.py +++ b/lhotse/dataset/audio_tagging.py @@ -78,7 +78,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ self.hdf5_fix.update() diff --git a/lhotse/dataset/sampling/bucketing.py b/lhotse/dataset/sampling/bucketing.py index dd53551cc..b869185b6 100644 --- a/lhotse/dataset/sampling/bucketing.py +++ b/lhotse/dataset/sampling/bucketing.py @@ -30,7 +30,7 @@ class BucketingSampler(CutSampler): ... # BucketingSampler specific args ... sampler_type=SimpleCutSampler, num_buckets=20, ... # Args passed into SimpleCutSampler - ... max_frames=20000 + ... max_duration=200 ... ) Bucketing sampler with 20 buckets, sampling pairs of source-target cuts:: @@ -40,7 +40,7 @@ class BucketingSampler(CutSampler): ... # BucketingSampler specific args ... sampler_type=CutPairsSampler, num_buckets=20, ... # Args passed into CutPairsSampler - ... max_source_frames=20000, max_target_frames=15000 + ... max_source_duration=200, max_target_duration=150 ... ) """ diff --git a/lhotse/dataset/sampling/cut_pairs.py b/lhotse/dataset/sampling/cut_pairs.py index 1582158d2..cd13353d8 100644 --- a/lhotse/dataset/sampling/cut_pairs.py +++ b/lhotse/dataset/sampling/cut_pairs.py @@ -12,10 +12,10 @@ class CutPairsSampler(CutSampler): It expects that both CutSet's strictly consist of Cuts with corresponding IDs. It behaves like an iterable that yields lists of strings (cut IDs). - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_source_duration`, :attr:`max_target_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. - Padding required to collate the batch does not contribute to max frames/samples/duration. + Padding required to collate the batch does not contribute to max source_duration/target_duration. """ def __init__( @@ -229,7 +229,7 @@ def _next_batch(self) -> Tuple[CutSet, CutSet]: self.source_constraints.add(next_source_cut) self.target_constraints.add(next_target_cut) - # Did we exceed the max_source_frames and max_cuts constraints? + # Did we exceed the max_source_duration and max_cuts constraints? if ( not self.source_constraints.exceeded() and not self.target_constraints.exceeded() @@ -249,7 +249,7 @@ def _next_batch(self) -> Tuple[CutSet, CutSet]: # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates one of the max_... constraints" - "we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc." + "we'll return it anyway. Consider increasing max_source_duration/max_cuts/etc." ) source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index 2d36b4130..dc5858010 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -335,7 +335,7 @@ def detuplify( else next_cut_or_tpl ) - # Did we exceed the max_frames and max_cuts constraints? + # Did we exceed the max_duration and max_cuts constraints? if self.constraint.close_to_exceeding(): # Yes. Finish sampling this batch. if self.constraint.exceeded() and len(cuts) == 1: diff --git a/lhotse/dataset/sampling/simple.py b/lhotse/dataset/sampling/simple.py index 66b56dae2..a8ca079c4 100644 --- a/lhotse/dataset/sampling/simple.py +++ b/lhotse/dataset/sampling/simple.py @@ -11,10 +11,10 @@ class SimpleCutSampler(CutSampler): Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs). - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. - Padding required to collate the batch does not contribute to max frames/samples/duration. + Padding required to collate the batch does not contribute to max duration. Example usage:: @@ -197,10 +197,10 @@ def _next_batch(self) -> CutSet: self.diagnostics.discard_single(next_cut) continue - # Track the duration/frames/etc. constraints. + # Track the duration/etc. constraints. self.time_constraint.add(next_cut) - # Did we exceed the max_frames and max_cuts constraints? + # Did we exceed the max_duration and max_cuts constraints? if not self.time_constraint.exceeded(): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut) @@ -215,9 +215,9 @@ def _next_batch(self) -> CutSet: # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " - "the max_frames, max_cuts, or max_duration constraints - " + "the max_duration, or max_cuts constraints - " "we'll return it anyway. " - "Consider increasing max_frames/max_cuts/max_duration." + "Consider increasing max_duration/max_cuts." ) cuts.append(next_cut) diff --git a/lhotse/dataset/sampling/weighted_simple.py b/lhotse/dataset/sampling/weighted_simple.py index 7c3f76034..4a3191b02 100644 --- a/lhotse/dataset/sampling/weighted_simple.py +++ b/lhotse/dataset/sampling/weighted_simple.py @@ -15,7 +15,7 @@ class WeightedSimpleCutSampler(SimpleCutSampler): When performing sampling, it avoids having duplicated cuts in the same batch. The sampler terminates if the number of sampled cuts reach :attr:`num_samples` - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Example usage: diff --git a/lhotse/dataset/speech_recognition.py b/lhotse/dataset/speech_recognition.py index 4c9919f99..4a3520b37 100644 --- a/lhotse/dataset/speech_recognition.py +++ b/lhotse/dataset/speech_recognition.py @@ -94,7 +94,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) diff --git a/lhotse/dataset/speech_translation.py b/lhotse/dataset/speech_translation.py index 672d27069..1def4475b 100644 --- a/lhotse/dataset/speech_translation.py +++ b/lhotse/dataset/speech_translation.py @@ -97,7 +97,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) self.hdf5_fix.update() diff --git a/lhotse/dataset/surt.py b/lhotse/dataset/surt.py index 8eda83b5f..5e424353c 100644 --- a/lhotse/dataset/surt.py +++ b/lhotse/dataset/surt.py @@ -170,7 +170,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) diff --git a/lhotse/parallel.py b/lhotse/parallel.py index f9ab8e55f..dd882ab4e 100644 --- a/lhotse/parallel.py +++ b/lhotse/parallel.py @@ -88,7 +88,7 @@ class ParallelExecutor: >>> class MyRunner: ... def __init__(self): - ... self.name = name + ... pass ... def __call__(self, x): ... return f'processed: {x}' ... diff --git a/lhotse/serialization.py b/lhotse/serialization.py index c11390fb4..76822ae08 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, Iterable, Optional, Type, Union import yaml +from packaging.version import parse as parse_version from lhotse.utils import Pathlike, SmartOpen, is_module_available, is_valid_url from lhotse.workarounds import gzip_open_robust @@ -105,17 +106,23 @@ def get_aistore_client(): raise ValueError( "Set a valid URL as AIS_ENDPOINT environment variable's value to read data from AIStore." ) - from aistore import Client + import aistore endpoint_url = os.environ["AIS_ENDPOINT"] - return Client(endpoint_url) + version = parse_version(aistore.__version__) + return aistore.Client(endpoint_url), version def open_aistore(uri: str, mode: str): assert "r" in mode, "We only support reading from AIStore at this time." - client = get_aistore_client() + client, version = get_aistore_client() object = client.fetch_object_by_url(uri) - return object.get().raw() + request = object.get() + if version >= parse_version("1.9.1"): + # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency + return request.as_file() + else: + return request.raw() def save_to_yaml(data: Any, path: Pathlike) -> None: diff --git a/lhotse/shar/writers/audio.py b/lhotse/shar/writers/audio.py index b3a855f59..22bdf4cf1 100644 --- a/lhotse/shar/writers/audio.py +++ b/lhotse/shar/writers/audio.py @@ -66,6 +66,14 @@ def close(self): def output_paths(self) -> List[str]: return self.tar_writer.output_paths + def resolve_format(self, original_format: str): + if self.format == "original": + # save using the original format of the input audio + return original_format + else: + # save using the format specified at initialization + return self.format + def write_placeholder(self, key: str) -> None: self.tar_writer.write(f"{key}.nodata", BytesIO()) self.tar_writer.write(f"{key}.nometa", BytesIO(), count=False) @@ -76,15 +84,18 @@ def write( value: np.ndarray, sampling_rate: int, manifest: Recording, + original_format: Optional[str] = None, ) -> None: + save_format = self.resolve_format(original_format) + value, manifest, sampling_rate = self._maybe_resample( - value, manifest, sampling_rate + value, manifest, sampling_rate, format=save_format ) # Write binary data stream = BytesIO() save_audio( - dest=stream, src=value, sampling_rate=sampling_rate, format=self.format + dest=stream, src=value, sampling_rate=sampling_rate, format=save_format ) self.tar_writer.write(f"{key}.{self.format}", stream) @@ -103,13 +114,14 @@ def _maybe_resample( audio: Union[torch.Tensor, np.ndarray], manifest: Recording, sampling_rate: int, + format: str, ) -> Tuple[Union[np.ndarray, torch.Tensor], Recording, int]: # Resampling is required for some versions of OPUS encoders. # First resample the manifest which only adjusts the metadata; # then resample the audio array to 48kHz. OPUS_DEFAULT_SAMPLING_RATE = 48000 if ( - self.format == "opus" + format == "opus" and is_torchaudio_available() and not isinstance(get_current_audio_backend(), LibsndfileBackend) and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE diff --git a/lhotse/shar/writers/shar.py b/lhotse/shar/writers/shar.py index 229a073b6..2c1a88442 100644 --- a/lhotse/shar/writers/shar.py +++ b/lhotse/shar/writers/shar.py @@ -135,7 +135,11 @@ def write(self, cut: Cut) -> None: recording.sources[0].channels = cut_channels recording.channel_ids = cut_channels self.writers["recording"].write( - cut.id, data, cut.sampling_rate, manifest=recording + cut.id, + data, + cut.sampling_rate, + manifest=recording, + original_format=cut.recording.source_format, ) cut = fastcopy(cut, recording=recording) else: @@ -224,6 +228,7 @@ def resolve_writer(name: str) -> Tuple[FieldWriter, str]: "flac": (partial(AudioTarWriter, format="flac"), ".tar"), "mp3": (partial(AudioTarWriter, format="mp3"), ".tar"), "opus": (partial(AudioTarWriter, format="opus"), ".tar"), + "original": (partial(AudioTarWriter, format="original"), ".tar"), "lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"), "numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"), "jsonl": (JsonlShardWriter, ".jsonl.gz"), diff --git a/lhotse/testing/dummies.py b/lhotse/testing/dummies.py index aec6a7581..0999906aa 100644 --- a/lhotse/testing/dummies.py +++ b/lhotse/testing/dummies.py @@ -63,6 +63,7 @@ def dummy_recording( duration: float = 1.0, sampling_rate: int = 16000, with_data: bool = False, + source_format: str = "wav", ) -> Recording: num_samples = compute_num_samples(duration, sampling_rate) return Recording( @@ -72,6 +73,7 @@ def dummy_recording( sampling_rate=sampling_rate, num_samples=num_samples, with_data=with_data, + format=source_format, ) ], sampling_rate=sampling_rate, @@ -85,6 +87,7 @@ def dummy_audio_source( sampling_rate: int = 16000, channels: Optional[List[int]] = None, with_data: bool = False, + format: str = "wav", ) -> AudioSource: if channels is None: channels = [0] @@ -95,21 +98,40 @@ def dummy_audio_source( else: import soundfile - # 1kHz sine wave - data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples)) + # generate 1kHz sine wave + f_sine = 1000 + assert ( + f_sine < sampling_rate / 2 + ), f"Sine wave frequency {f_sine} exceeds Nyquist frequency {sampling_rate/2} for sampling rate {sampling_rate}" + data = torch.sin(2 * np.pi * f_sine / sampling_rate * torch.arange(num_samples)) + + # prepare multichannel data if len(channels) > 1: data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1) # ensure each channel has different data for channel selection testing mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)]) data = data * mults + + # prepare source with the selected format binary_data = BytesIO() - soundfile.write( - binary_data, - data.numpy(), - sampling_rate, - format="wav", - closefd=False, - ) + if format == "opus": + # workaround for OPUS: soundfile supports OPUS as a subtype of OGG format + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format="OGG", + subtype="OPUS", + closefd=False, + ) + else: + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format=format, + closefd=False, + ) binary_data.seek(0) return AudioSource( type="memory", channels=channels, source=binary_data.getvalue() diff --git a/setup.py b/setup.py index b96a3e36e..831786fa4 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ ) # False = public release, True = otherwise -LHOTSE_REQUIRE_TORCHAUDIO = os.environ.get("LHOTSE_REQUIRE_TORCHAUDIO", "1") in ( +LHOTSE_REQUIRE_TORCHAUDIO = os.environ.get("LHOTSE_REQUIRE_TORCHAUDIO", "0") in ( "1", "True", "true", @@ -157,6 +157,7 @@ def mark_lhotse_version(version: str) -> None: "packaging", "pyyaml>=5.3.1", "tabulate>=0.8.1", + "torch", "tqdm", ] @@ -167,30 +168,6 @@ def mark_lhotse_version(version: str) -> None: else: install_requires.append("lilcom>=1.1.0") -try: - # If the user already installed PyTorch, make sure he has torchaudio too. - # Otherwise, we'll just install the latest versions from PyPI for the user. - import torch - - if LHOTSE_REQUIRE_TORCHAUDIO: - try: - import torchaudio - except ImportError: - raise ValueError( - "We detected that you have already installed PyTorch, but haven't installed torchaudio. " - "Unfortunately we can't detect the compatible torchaudio version for you; " - "you will have to install it manually. " - "For instructions, please refer either to https://pytorch.org/get-started/locally/ " - "or https://github.com/pytorch/audio#dependencies " - "You can also disable torchaudio dependency by setting the following environment variable: " - "LHOTSE_USE_TORCHAUDIO=0" - ) -except ImportError: - extras = ["torch"] - if LHOTSE_REQUIRE_TORCHAUDIO: - extras.append("torchaudio") - install_requires.extend(extras) - docs_require = (project_root / "docs" / "requirements.txt").read_text().splitlines() tests_require = [ "pytest==7.1.3", @@ -222,13 +199,10 @@ def mark_lhotse_version(version: str) -> None: all_requires = sorted(dev_requires) if os.environ.get("READTHEDOCS", False): - # When building documentation, omit torchaudio installation and mock it instead. - # This works around the inability to install libsoundfile1 in read-the-docs env, - # which caused the documentation builds to silently crash. install_requires = [ req for req in install_requires - if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"]) + if not any(req.startswith(dep) for dep in ["SoundFile"]) ] setup( diff --git a/test/shar/test_write.py b/test/shar/test_write.py index cee35de1c..7d88bd3d0 100644 --- a/test/shar/test_write.py +++ b/test/shar/test_write.py @@ -66,55 +66,6 @@ def test_tar_writer_pipe(tmp_path: Path): assert f2.read() == b"test" -@pytest.mark.parametrize( - "format", - [ - "wav", - pytest.param( - "flac", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("0.12.1"), - reason="Torchaudio v0.12.1 or greater is required.", - ), - ), - # "mp3", # apparently doesn't work in CI, mp3 encoder is missing - pytest.param( - "opus", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("2.1.0"), - reason="Torchaudio v2.1.0 or greater is required.", - ), - ), - ], -) -def test_audio_tar_writer(tmp_path: Path, format: str): - from lhotse.testing.dummies import dummy_recording - - recording = dummy_recording(0, with_data=True) - audio = recording.load_audio() - - with AudioTarWriter( - str(tmp_path / "test.tar"), shard_size=None, format=format - ) as writer: - writer.write( - key="my-recording", - value=audio, - sampling_rate=recording.sampling_rate, - manifest=recording, - ) - - (path,) = writer.output_paths - - ((deserialized_recording, inner_path),) = list(TarIterator(path)) - - deserialized_audio = deserialized_recording.resample( - recording.sampling_rate - ).load_audio() - - rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) - assert rmse < 0.5 - - @pytest.mark.parametrize( ["format", "backend"], [ @@ -175,6 +126,59 @@ def test_audio_tar_writer(tmp_path: Path, format: str, backend: str): assert rmse < 0.5 +@pytest.mark.parametrize( + ["original_format", "rmse_threshold"], + [("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)], +) +def test_audio_tar_writer_original_format( + tmp_path: Path, original_format: str, rmse_threshold: float +): + """Test using AudioTarWritter to write the audio signal in the exact same format + as it was loaded from the source. + """ + from lhotse.testing.dummies import dummy_recording + + backend = "default" # use the default backend for reading the audio + writer_format = "original" # write the audio in the same format as it was loaded + + recording = dummy_recording(0, with_data=True, source_format=original_format) + audio = recording.load_audio() + + assert ( + recording.source_format == original_format + ), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})" + + with audio_backend(backend): + with AudioTarWriter( + str(tmp_path / "test.tar"), shard_size=None, format=writer_format + ) as writer: + writer.write( + key="my-recording", + value=audio, + sampling_rate=recording.sampling_rate, + manifest=recording, + original_format=recording.source_format, + ) + (path,) = writer.output_paths + ((deserialized_recording, inner_path),) = list(TarIterator(path)) + + # make sure the deserialized audio is in the same format as the original + assert ( + deserialized_recording.source_format == original_format + ), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})" + + # load audio + deserialized_audio = deserialized_recording.resample( + recording.sampling_rate + ).load_audio() + + # check difference between original and deserialized audio + rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) + assert ( + rmse <= rmse_threshold + ), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}" + + def test_shar_writer(tmp_path: Path): # Prepare data cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)