From be42a3398d15bb9776ba2e9eb63597e5b54e2bf3 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Fri, 28 Jun 2024 01:54:51 -0700 Subject: [PATCH 1/8] add recipe for gigaspeech2 --- docs/corpus.rst | 2 + lhotse/bin/modes/recipes/__init__.py | 1 + lhotse/bin/modes/recipes/gigaspeech2.py | 38 +++++ lhotse/recipes/__init__.py | 1 + lhotse/recipes/gigaspeech2.py | 199 ++++++++++++++++++++++++ 5 files changed, 241 insertions(+) create mode 100644 lhotse/bin/modes/recipes/gigaspeech2.py create mode 100644 lhotse/recipes/gigaspeech2.py diff --git a/docs/corpus.rst b/docs/corpus.rst index 79299eac1..f76d58ff6 100644 --- a/docs/corpus.rst +++ b/docs/corpus.rst @@ -115,6 +115,8 @@ a CLI tool that create the manifests given a corpus directory. - :func:`lhotse.recipes.prepare_gale_mandarin` * - GigaSpeech - :func:`lhotse.recipes.prepare_gigaspeech` + * - GigaSpeech 2 + - :func:`lhotse.recipes.prepare_gigaspeech2` * - GigaST - :func:`lhotse.recipes.prepare_gigast` * - Heroico diff --git a/lhotse/bin/modes/recipes/__init__.py b/lhotse/bin/modes/recipes/__init__.py index aafc871e3..d6ec5adbe 100644 --- a/lhotse/bin/modes/recipes/__init__.py +++ b/lhotse/bin/modes/recipes/__init__.py @@ -36,6 +36,7 @@ from .gale_arabic import * from .gale_mandarin import * from .gigaspeech import * +from .gigaspeech2 import * from .gigast import * from .grid import * from .heroico import * diff --git a/lhotse/bin/modes/recipes/gigaspeech2.py b/lhotse/bin/modes/recipes/gigaspeech2.py new file mode 100644 index 000000000..1c1011b05 --- /dev/null +++ b/lhotse/bin/modes/recipes/gigaspeech2.py @@ -0,0 +1,38 @@ +from typing import Optional, Sequence, Union + +import click + +from lhotse.bin.modes import prepare +from lhotse.recipes.gigaspeech2 import prepare_gigaspeech2 +from lhotse.utils import Pathlike + + +@prepare.command(context_settings=dict(show_default=True)) +@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) +@click.argument("output_dir", type=click.Path()) +@click.option( + "-l", + "--languages", + default="auto", + help="Languages to prepare (scans CORPUS_DIR for language codes by default).", +) +@click.option( + "-j", + "--num-jobs", + type=int, + default=1, + help="How many threads to use (can give good speed-ups with slow disks).", +) +def gigaspeech2( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + languages: Union[str, Sequence[str]] = "auto", + num_jobs: int = 1, +): + """GigaSpeech 2 data preparation.""" + prepare_gigaspeech2( + corpus_dir=corpus_dir, + output_dir=output_dir, + languages=languages, + num_jobs=num_jobs, + ) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 2b5ec8338..24fa5dc52 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -33,6 +33,7 @@ from .gale_arabic import prepare_gale_arabic from .gale_mandarin import prepare_gale_mandarin from .gigaspeech import prepare_gigaspeech +from .gigaspeech import prepare_gigaspeech2 from .gigast import download_gigast, prepare_gigast from .grid import download_grid, prepare_grid from .heroico import download_heroico, prepare_heroico diff --git a/lhotse/recipes/gigaspeech2.py b/lhotse/recipes/gigaspeech2.py new file mode 100644 index 000000000..514e59883 --- /dev/null +++ b/lhotse/recipes/gigaspeech2.py @@ -0,0 +1,199 @@ +""" +Description taken from the abstract of paper: +"GigaSpeech 2: An Evolving, Large-Scale and Multi-domain ASR Corpus for Low-Resource Languages with Automated Crawling, Transcription and Refinement" +https://arxiv.org/abs/2406.11546 + +The evolution of speech technology has been spurred by the rapid increase in dataset sizes. Traditional speech models generally depend on a large amount of labeled training data, which is scarce for low-resource languages. This paper presents GigaSpeech 2, a large-scale, multi-domain, multilingual speech recognition corpus. It is designed for low-resource languages and does not rely on paired speech and text data. GigaSpeech 2 comprises about 30,000 hours of automatically transcribed speech, including Thai, Indonesian, and Vietnamese, gathered from unlabeled YouTube videos. We also introduce an automated pipeline for data crawling, transcription, and label refinement. Specifically, this pipeline uses Whisper for initial transcription and TorchAudio for forced alignment, combined with multi-dimensional filtering for data quality assurance. A modified Noisy Student Training is developed to further refine flawed pseudo labels iteratively, thus enhancing model performance. Experimental results on our manually transcribed evaluation set and two public test sets from Common Voice and FLEURS confirm our corpus’s high quality and broad applicability. Notably, ASR models trained on GigaSpeech 2 can reduce the word error rate for Thai, Indonesian, and Vietnamese on our challenging and realistic YouTube test set by 25% to 40% compared to the Whisper large-v3 model, with merely 10% model parameters. Furthermore, our ASR models trained on Gigaspeech 2 yield superior performance compared to commercial services. We believe that our newly introduced corpus and pipeline will open a new avenue for low-resource speech recognition and significantly facilitate research in this area. +""" + +import logging +from collections import defaultdict +from concurrent.futures.process import ProcessPoolExecutor +from pathlib import Path +from typing import Dict, Optional, Sequence, Tuple, Union + +from tqdm.auto import tqdm + +from lhotse.audio import Recording, RecordingSet +from lhotse.qa import fix_manifests, validate_recordings_and_supervisions +from lhotse.supervision import SupervisionSegment, SupervisionSet +from lhotse.utils import Pathlike + +GIGASPEECH2_URL = "https://huggingface.co/datasets/speechcolab/gigaspeech2" + +GIGASPEECH2_LANGS = ("th", "id", "vi") +# GIGASPEECH2_SPLITS = ("train_raw", "train_refined", "dev", "test") +GIGASPEECH2_SPLITS = ("dev", "test") + + +def _read_manifests_if_cached( + output_dir: Optional[Pathlike], + language: str, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns: + { + "train_raw": {"recordings": ..., "supervisions": ...}, + "train_refined": ..., + "dev": ..., + "test": ..., + } + """ + if output_dir is None: + return {} + manifests = defaultdict(dict) + for part in ["train_raw", "train_refined", "dev", "test"]: + for manifest in ["recordings", "supervisions"]: + path = output_dir / f"gigaspeech2-{language}_{manifest}_{part}.jsonl.gz" + if not path.is_file(): + continue + manifests[part][manifest] = load_manifest(path) + return manifests + + +def _parse_utterance( + lang: str, + part_dir: Pathlike, + audio_info: Pathlike, +) -> Optional[Tuple[Recording, SupervisionSegment]]: + segment_id, text = audio_info.split("\t") + audio_path = part_dir.joinpath(*segment_id.split("-")[:-1]) / f"{segment_id}.wav" + audio_path = audio_path.resolve() + + if not audio_path.is_file(): + logging.warning(f"No such file: {audio_path}") + return None + + recording = Recording.from_file( + path=audio_path, + recording_id=segment_id, + ) + + segment = SupervisionSegment( + id=segment_id, + recording_id=segment_id, + start=0.0, + duration=recording.duration, + channel=0, + language=lang, + text=text.strip(), + ) + + return recording, segment + + +def _prepare_subset( + lang: str, + part: str, + lang_dir: Pathlike, + num_jobs: int = 1, +) -> Tuple[RecordingSet, SupervisionSet]: + """ + Returns the RecodingSet and SupervisionSet given a dataset part. + + :param lang: string language code (e.g., "th"). + :param part: str, the name of the subset. + :param lang_dir: Pathlike, the path of the data dir for a specific language. + :return: the RecodingSet and SupervisionSet for train and valid. + """ + lang_dir = Path(lang_dir) + part_dir = lang_dir / part + tsv_path = lang_dir / f"{part}.tsv" + + audio_infos = [] + with open(tsv_path) as f: + audio_infos = f.read().splitlines() + + with ProcessPoolExecutor(num_jobs) as ex: + futures = [] + recordings = [] + supervisions = [] + for audio_info in tqdm(audio_infos, desc="Distributing tasks"): + futures.append(ex.submit(_parse_utterance, lang, part_dir, audio_info)) + + for future in tqdm(futures, desc="Processing"): + result = future.result() + if result is None: + continue + recording, segment = result + recordings.append(recording) + supervisions.append(segment) + + recording_set = RecordingSet.from_recordings(recordings) + supervision_set = SupervisionSet.from_segments(supervisions) + + # Fix manifests + recording_set, supervision_set = fix_manifests(recording_set, supervision_set) + validate_recordings_and_supervisions(recording_set, supervision_set) + + return recording_set, supervision_set + + +def prepare_gigaspeech2( + corpus_dir: Pathlike, + output_dir: Optional[Pathlike] = None, + languages: Union[str, Sequence[str]] = "auto", + num_jobs: int = 1, +) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: + """ + Returns the manifests which consist of the Recordings and Supervisions + :param corpus_dir: Path to the GigaSpeech 2 dataset. + :param output_dir: Pathlike, the path where to write the manifests. + :param languages: 'auto' (prepare all discovered data) or a list of language codes. + :return: a Dict whose key is the dataset part, and the value is Dicts with the keys 'recordings' and 'supervisions'. + """ + corpus_dir = Path(corpus_dir) + assert corpus_dir.is_dir(), f"No such directory: {corpus_dir}" + corpus_dir = Path(corpus_dir) / "data" + + if languages == "auto": + languages = set(GIGASPEECH2_LANGS).intersection( + path.name for path in corpus_dir.glob("*") + ) + if not languages: + raise ValueError( + f"Could not find any of GigaSpeech 2 languages in: {corpus_dir}" + ) + elif isinstance(languages, str): + languages = [languages] + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + manifests = defaultdict(dict) + + for lang in tqdm(languages, desc="Processing GigaSpeech 2 languages"): + logging.info(f"Language: {lang}") + lang_dir = corpus_dir / lang + + # Maybe the manifests already exist: we can read them and save a bit of preparation time. + lang_manifests = _read_manifests_if_cached(output_dir=output_dir, language=lang) + + for part in tqdm(GIGASPEECH2_SPLITS, desc="Processing GigaSpeech 2 subset"): + logging.info(f"Processing GigaSpeech 2 subset: {part}") + if part in lang_manifests: + logging.info(f"GigaSpeech 2 {lang} {part} already prepared - skipping.") + continue + + recording_set, supervision_set = _prepare_subset( + lang=lang, + part=part, + lang_dir=lang_dir, + num_jobs=num_jobs, + ) + + if output_dir is not None: + supervision_set.to_file( + output_dir / f"gigaspeech2-{lang}_supervisions_{part}.jsonl.gz" + ) + recording_set.to_file( + output_dir / f"gigaspeech2-{lang}_recordings_{part}.jsonl.gz" + ) + + lang_manifests[part] = { + "supervisions": supervision_set, + "recordings": recording_set, + } + + manifests[lang] = lang_manifests From 2cc34cafe073d26f7ac664d56cd70faef5637a7c Mon Sep 17 00:00:00 2001 From: yfyeung Date: Fri, 28 Jun 2024 02:03:13 -0700 Subject: [PATCH 2/8] fix flake8 and isort --- lhotse/recipes/__init__.py | 3 ++- lhotse/recipes/gigaspeech2.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lhotse/recipes/__init__.py b/lhotse/recipes/__init__.py index 24fa5dc52..1116a197f 100644 --- a/lhotse/recipes/__init__.py +++ b/lhotse/recipes/__init__.py @@ -33,7 +33,7 @@ from .gale_arabic import prepare_gale_arabic from .gale_mandarin import prepare_gale_mandarin from .gigaspeech import prepare_gigaspeech -from .gigaspeech import prepare_gigaspeech2 +from .gigaspeech2 import prepare_gigaspeech2 from .gigast import download_gigast, prepare_gigast from .grid import download_grid, prepare_grid from .heroico import download_heroico, prepare_heroico @@ -140,6 +140,7 @@ "prepare_gale_arabic", "prepare_gale_mandarin", "prepare_gigaspeech", + "prepare_gigaspeech2", "download_gigast", "prepare_gigast", "download_grid", diff --git a/lhotse/recipes/gigaspeech2.py b/lhotse/recipes/gigaspeech2.py index 514e59883..4e1a9a17c 100644 --- a/lhotse/recipes/gigaspeech2.py +++ b/lhotse/recipes/gigaspeech2.py @@ -14,6 +14,7 @@ from tqdm.auto import tqdm +from lhotse import load_manifest from lhotse.audio import Recording, RecordingSet from lhotse.qa import fix_manifests, validate_recordings_and_supervisions from lhotse.supervision import SupervisionSegment, SupervisionSet From 1d721b4a7c0e55157d7b5765d9e11b610c7e3e33 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Fri, 28 Jun 2024 02:06:09 -0700 Subject: [PATCH 3/8] remove comments --- lhotse/recipes/gigaspeech2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lhotse/recipes/gigaspeech2.py b/lhotse/recipes/gigaspeech2.py index 4e1a9a17c..38f7c8645 100644 --- a/lhotse/recipes/gigaspeech2.py +++ b/lhotse/recipes/gigaspeech2.py @@ -23,8 +23,7 @@ GIGASPEECH2_URL = "https://huggingface.co/datasets/speechcolab/gigaspeech2" GIGASPEECH2_LANGS = ("th", "id", "vi") -# GIGASPEECH2_SPLITS = ("train_raw", "train_refined", "dev", "test") -GIGASPEECH2_SPLITS = ("dev", "test") +GIGASPEECH2_SPLITS = ("train_raw", "train_refined", "dev", "test") def _read_manifests_if_cached( From 83aef17792b586fa40eaa0400abeb1af1438c76c Mon Sep 17 00:00:00 2001 From: yfyeung Date: Fri, 28 Jun 2024 02:25:16 -0700 Subject: [PATCH 4/8] small fix for train_raw & train_refined --- lhotse/recipes/gigaspeech2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/recipes/gigaspeech2.py b/lhotse/recipes/gigaspeech2.py index 38f7c8645..8af3d35dd 100644 --- a/lhotse/recipes/gigaspeech2.py +++ b/lhotse/recipes/gigaspeech2.py @@ -97,7 +97,7 @@ def _prepare_subset( :return: the RecodingSet and SupervisionSet for train and valid. """ lang_dir = Path(lang_dir) - part_dir = lang_dir / part + part_dir = lang_dir / part.replace("_raw", "").replace("_refined", "") tsv_path = lang_dir / f"{part}.tsv" audio_infos = [] From 3b720c4e4ed90c008507ccc55c547ef81bf0cd2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 7 Nov 2024 11:45:34 -0500 Subject: [PATCH 5/8] Support for AIStore ObjectFile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- lhotse/serialization.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lhotse/serialization.py b/lhotse/serialization.py index c11390fb4..76822ae08 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, Iterable, Optional, Type, Union import yaml +from packaging.version import parse as parse_version from lhotse.utils import Pathlike, SmartOpen, is_module_available, is_valid_url from lhotse.workarounds import gzip_open_robust @@ -105,17 +106,23 @@ def get_aistore_client(): raise ValueError( "Set a valid URL as AIS_ENDPOINT environment variable's value to read data from AIStore." ) - from aistore import Client + import aistore endpoint_url = os.environ["AIS_ENDPOINT"] - return Client(endpoint_url) + version = parse_version(aistore.__version__) + return aistore.Client(endpoint_url), version def open_aistore(uri: str, mode: str): assert "r" in mode, "We only support reading from AIStore at this time." - client = get_aistore_client() + client, version = get_aistore_client() object = client.fetch_object_by_url(uri) - return object.get().raw() + request = object.get() + if version >= parse_version("1.9.1"): + # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency + return request.as_file() + else: + return request.raw() def save_to_yaml(data: Any, path: Pathlike) -> None: From 1880fc1a1d3ce9de0184db6845dd8934d5cc95c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=9C=87=E4=B8=9C?= Date: Tue, 19 Nov 2024 21:49:08 +0800 Subject: [PATCH 6/8] minor fix (#1418) --- lhotse/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/parallel.py b/lhotse/parallel.py index f9ab8e55f..dd882ab4e 100644 --- a/lhotse/parallel.py +++ b/lhotse/parallel.py @@ -88,7 +88,7 @@ class ParallelExecutor: >>> class MyRunner: ... def __init__(self): - ... self.name = name + ... pass ... def __call__(self, x): ... return f'processed: {x}' ... From 9c1330a8523a2cc28c5d983a1d59ae4d4b05f117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=9C=87=E4=B8=9C?= <275331498@qq.com> Date: Sat, 23 Nov 2024 08:41:47 +0800 Subject: [PATCH 7/8] change max_frames to max_duration in docs (#1419) * change max_frames to max_duration in docs * minor fix --- docs/datasets.rst | 8 ++++---- lhotse/cut/data.py | 2 +- lhotse/cut/mixed.py | 2 +- lhotse/cut/padding.py | 2 +- lhotse/cut/set.py | 2 +- lhotse/dataset/audio_tagging.py | 2 +- lhotse/dataset/sampling/bucketing.py | 4 ++-- lhotse/dataset/sampling/cut_pairs.py | 8 ++++---- lhotse/dataset/sampling/dynamic.py | 2 +- lhotse/dataset/sampling/simple.py | 12 ++++++------ lhotse/dataset/sampling/weighted_simple.py | 2 +- lhotse/dataset/speech_recognition.py | 2 +- lhotse/dataset/speech_translation.py | 2 +- lhotse/dataset/surt.py | 2 +- 14 files changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/datasets.rst b/docs/datasets.rst index 03df15ea1..d3609e955 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -28,7 +28,7 @@ It allows for interesting collation methods - e.g. **padding the speech with noi The items for mini-batch creation are selected by the ``Sampler``. Lhotse defines ``Sampler`` classes that are initialized with :class:`~lhotse.cut.CutSet`'s, so that they can look up specific properties of an utterance to stratify the sampling. -For example, :class:`~lhotse.dataset.sampling.SimpleCutSampler` has a defined ``max_frames`` attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of frames. +For example, :class:`~lhotse.dataset.sampling.SimpleCutSampler` has a defined ``max_duration`` attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of seconds. Another strategy — used in :class:`~lhotse.dataset.sampling.BucketingSampler` — will first group the cuts of similar durations into buckets, and then randomly select a bucket to draw the whole batch from. For tasks where both input and output of the model are speech utterances, we can use the :class:`~lhotse.dataset.sampling.CutPairsSampler`, which accepts two :class:`~lhotse.cut.CutSet`'s and will match the cuts in them by their IDs. @@ -38,11 +38,11 @@ A typical Lhotse's dataset API usage might look like this: .. code-block:: from torch.utils.data import DataLoader - from lhotse.dataset import SpeechRecognitionDataset, SimpleCutSampler + from lhotse.dataset import K2SpeechRecognitionDataset, SimpleCutSampler cuts = CutSet(...) - dset = SpeechRecognitionDataset(cuts) - sampler = SimpleCutSampler(cuts, max_frames=50000) + dset = K2SpeechRecognitionDataset(cuts) + sampler = SimpleCutSampler(cuts, max_duration=500) # Dataset performs batching by itself, so we have to indicate that # to the DataLoader with batch_size=None dloader = DataLoader(dset, sampler=sampler, batch_size=None, num_workers=1) diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index ad47ca381..a939db5a2 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -723,7 +723,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/mixed.py b/lhotse/cut/mixed.py index 01acf248d..cd83d29e0 100644 --- a/lhotse/cut/mixed.py +++ b/lhotse/cut/mixed.py @@ -622,7 +622,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/padding.py b/lhotse/cut/padding.py index c535bde2b..a95be6062 100644 --- a/lhotse/cut/padding.py +++ b/lhotse/cut/padding.py @@ -236,7 +236,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param duration: The cut's minimal duration after padding. diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 2a7afd16c..5a62ba21c 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -2821,7 +2821,7 @@ def pad( """ Return a new MixedCut, padded with zeros in the recording, and ``pad_feat_value`` in each feature bin. - The user can choose to pad either to a specific `duration`; a specific number of frames `max_frames`; + The user can choose to pad either to a specific `duration`; a specific number of frames `num_frames`; or a specific number of samples `num_samples`. The three arguments are mutually exclusive. :param cut: DataCut to be padded. diff --git a/lhotse/dataset/audio_tagging.py b/lhotse/dataset/audio_tagging.py index 0ca44a687..fbf370fd6 100644 --- a/lhotse/dataset/audio_tagging.py +++ b/lhotse/dataset/audio_tagging.py @@ -78,7 +78,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ self.hdf5_fix.update() diff --git a/lhotse/dataset/sampling/bucketing.py b/lhotse/dataset/sampling/bucketing.py index dd53551cc..b869185b6 100644 --- a/lhotse/dataset/sampling/bucketing.py +++ b/lhotse/dataset/sampling/bucketing.py @@ -30,7 +30,7 @@ class BucketingSampler(CutSampler): ... # BucketingSampler specific args ... sampler_type=SimpleCutSampler, num_buckets=20, ... # Args passed into SimpleCutSampler - ... max_frames=20000 + ... max_duration=200 ... ) Bucketing sampler with 20 buckets, sampling pairs of source-target cuts:: @@ -40,7 +40,7 @@ class BucketingSampler(CutSampler): ... # BucketingSampler specific args ... sampler_type=CutPairsSampler, num_buckets=20, ... # Args passed into CutPairsSampler - ... max_source_frames=20000, max_target_frames=15000 + ... max_source_duration=200, max_target_duration=150 ... ) """ diff --git a/lhotse/dataset/sampling/cut_pairs.py b/lhotse/dataset/sampling/cut_pairs.py index 1582158d2..cd13353d8 100644 --- a/lhotse/dataset/sampling/cut_pairs.py +++ b/lhotse/dataset/sampling/cut_pairs.py @@ -12,10 +12,10 @@ class CutPairsSampler(CutSampler): It expects that both CutSet's strictly consist of Cuts with corresponding IDs. It behaves like an iterable that yields lists of strings (cut IDs). - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_source_duration`, :attr:`max_target_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. - Padding required to collate the batch does not contribute to max frames/samples/duration. + Padding required to collate the batch does not contribute to max source_duration/target_duration. """ def __init__( @@ -229,7 +229,7 @@ def _next_batch(self) -> Tuple[CutSet, CutSet]: self.source_constraints.add(next_source_cut) self.target_constraints.add(next_target_cut) - # Did we exceed the max_source_frames and max_cuts constraints? + # Did we exceed the max_source_duration and max_cuts constraints? if ( not self.source_constraints.exceeded() and not self.target_constraints.exceeded() @@ -249,7 +249,7 @@ def _next_batch(self) -> Tuple[CutSet, CutSet]: # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates one of the max_... constraints" - "we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc." + "we'll return it anyway. Consider increasing max_source_duration/max_cuts/etc." ) source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) diff --git a/lhotse/dataset/sampling/dynamic.py b/lhotse/dataset/sampling/dynamic.py index 2d36b4130..dc5858010 100644 --- a/lhotse/dataset/sampling/dynamic.py +++ b/lhotse/dataset/sampling/dynamic.py @@ -335,7 +335,7 @@ def detuplify( else next_cut_or_tpl ) - # Did we exceed the max_frames and max_cuts constraints? + # Did we exceed the max_duration and max_cuts constraints? if self.constraint.close_to_exceeding(): # Yes. Finish sampling this batch. if self.constraint.exceeded() and len(cuts) == 1: diff --git a/lhotse/dataset/sampling/simple.py b/lhotse/dataset/sampling/simple.py index 66b56dae2..a8ca079c4 100644 --- a/lhotse/dataset/sampling/simple.py +++ b/lhotse/dataset/sampling/simple.py @@ -11,10 +11,10 @@ class SimpleCutSampler(CutSampler): Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs). - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. - Padding required to collate the batch does not contribute to max frames/samples/duration. + Padding required to collate the batch does not contribute to max duration. Example usage:: @@ -197,10 +197,10 @@ def _next_batch(self) -> CutSet: self.diagnostics.discard_single(next_cut) continue - # Track the duration/frames/etc. constraints. + # Track the duration/etc. constraints. self.time_constraint.add(next_cut) - # Did we exceed the max_frames and max_cuts constraints? + # Did we exceed the max_duration and max_cuts constraints? if not self.time_constraint.exceeded(): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut) @@ -215,9 +215,9 @@ def _next_batch(self) -> CutSet: # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " - "the max_frames, max_cuts, or max_duration constraints - " + "the max_duration, or max_cuts constraints - " "we'll return it anyway. " - "Consider increasing max_frames/max_cuts/max_duration." + "Consider increasing max_duration/max_cuts." ) cuts.append(next_cut) diff --git a/lhotse/dataset/sampling/weighted_simple.py b/lhotse/dataset/sampling/weighted_simple.py index 7c3f76034..4a3191b02 100644 --- a/lhotse/dataset/sampling/weighted_simple.py +++ b/lhotse/dataset/sampling/weighted_simple.py @@ -15,7 +15,7 @@ class WeightedSimpleCutSampler(SimpleCutSampler): When performing sampling, it avoids having duplicated cuts in the same batch. The sampler terminates if the number of sampled cuts reach :attr:`num_samples` - When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + When one of :attr:`max_duration`, or :attr:`max_cuts` is specified, the batch size is dynamic. Example usage: diff --git a/lhotse/dataset/speech_recognition.py b/lhotse/dataset/speech_recognition.py index 4c9919f99..4a3520b37 100644 --- a/lhotse/dataset/speech_recognition.py +++ b/lhotse/dataset/speech_recognition.py @@ -94,7 +94,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) diff --git a/lhotse/dataset/speech_translation.py b/lhotse/dataset/speech_translation.py index 672d27069..1def4475b 100644 --- a/lhotse/dataset/speech_translation.py +++ b/lhotse/dataset/speech_translation.py @@ -97,7 +97,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) self.hdf5_fix.update() diff --git a/lhotse/dataset/surt.py b/lhotse/dataset/surt.py index 8eda83b5f..5e424353c 100644 --- a/lhotse/dataset/surt.py +++ b/lhotse/dataset/surt.py @@ -170,7 +170,7 @@ def __init__( def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. + of max_duration and max_cuts. """ validate_for_asr(cuts) From 36ce63e6e3545b74222e7f5feec3e75d59cc5a35 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:11:04 -0800 Subject: [PATCH 8/8] Option to save audio in the original format when exporting to shar (#1422) * Option to save audio in the original format when exporting to shar * Removed duplicate test --- lhotse/audio/recording.py | 17 ++++++ lhotse/audio/source.py | 28 ++++++++++ lhotse/bin/modes/shar.py | 4 +- lhotse/shar/writers/audio.py | 18 +++++-- lhotse/shar/writers/shar.py | 7 ++- lhotse/testing/dummies.py | 40 ++++++++++---- test/shar/test_write.py | 102 ++++++++++++++++++----------------- 7 files changed, 152 insertions(+), 64 deletions(-) diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index ec0f605f2..cd8a41fb2 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -168,6 +168,23 @@ def is_placeholder(self) -> bool: def num_channels(self) -> int: return len(self.channel_ids) + @property + def source_format(self) -> str: + """Infer format of the audio sources. + If all sources have the same format, return it. + If sources have different formats, raise an error. + """ + source_formats = list(set([s.format for s in self.sources])) + + if len(source_formats) == 1: + # if all sources have the same format, return it + return source_formats[0] + else: + # at the moment, we don't resolve different formats + raise NotImplementedError( + "Sources have different formats. Resolving to a single format not implemented." + ) + @staticmethod def from_file( path: Pathlike, diff --git a/lhotse/audio/source.py b/lhotse/audio/source.py index 459881ea2..88bb9743d 100644 --- a/lhotse/audio/source.py +++ b/lhotse/audio/source.py @@ -1,3 +1,5 @@ +import io +import os import warnings from dataclasses import dataclass from io import BytesIO, FileIO @@ -6,6 +8,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +import soundfile as sf import torch from lhotse.audio.backend import read_audio @@ -64,6 +67,10 @@ class AudioSource: def has_video(self) -> bool: return self.video is not None + @property + def format(self) -> str: + return self._get_format() + def load_audio( self, offset: Seconds = 0.0, @@ -316,3 +323,24 @@ def _prepare_for_reading( ) return source + + def _get_format(self) -> str: + """Get format for the audio source. + If using 'file' or 'url' types, the format is inferred from the file extension, as in soundfile. + If using 'memory' type, the format is inferred from the binary data. + """ + if self.type in ("file", "url"): + # Resolve audio format based on the filename + format = os.path.splitext(self.source)[-1][1:] + return format.lower() + elif self.type == "memory": + sf_info = sf.info(io.BytesIO(self.source)) + if sf_info.format == "OGG" and sf_info.subtype == "OPUS": + # soundfile describes opus as ogg container with opus coding + return "opus" + else: + return sf_info.format.lower() + else: + raise NotImplementedError( + f"Getting format not implemented for source type {self.type}" + ) diff --git a/lhotse/bin/modes/shar.py b/lhotse/bin/modes/shar.py index 95b670529..cffdf596b 100644 --- a/lhotse/bin/modes/shar.py +++ b/lhotse/bin/modes/shar.py @@ -27,8 +27,8 @@ def shar(): "-a", "--audio", default="none", - type=click.Choice(["none", "wav", "flac", "mp3", "opus"]), - help="Format in which to export audio (disabled by default, enabling will make a copy of the data)", + type=click.Choice(["none", "wav", "flac", "mp3", "opus", "original"]), + help="Format in which to export audio. Original will save in the same format as the original audio (disabled by default, enabling will make a copy of the data)", ) @click.option( "-f", diff --git a/lhotse/shar/writers/audio.py b/lhotse/shar/writers/audio.py index b3a855f59..22bdf4cf1 100644 --- a/lhotse/shar/writers/audio.py +++ b/lhotse/shar/writers/audio.py @@ -66,6 +66,14 @@ def close(self): def output_paths(self) -> List[str]: return self.tar_writer.output_paths + def resolve_format(self, original_format: str): + if self.format == "original": + # save using the original format of the input audio + return original_format + else: + # save using the format specified at initialization + return self.format + def write_placeholder(self, key: str) -> None: self.tar_writer.write(f"{key}.nodata", BytesIO()) self.tar_writer.write(f"{key}.nometa", BytesIO(), count=False) @@ -76,15 +84,18 @@ def write( value: np.ndarray, sampling_rate: int, manifest: Recording, + original_format: Optional[str] = None, ) -> None: + save_format = self.resolve_format(original_format) + value, manifest, sampling_rate = self._maybe_resample( - value, manifest, sampling_rate + value, manifest, sampling_rate, format=save_format ) # Write binary data stream = BytesIO() save_audio( - dest=stream, src=value, sampling_rate=sampling_rate, format=self.format + dest=stream, src=value, sampling_rate=sampling_rate, format=save_format ) self.tar_writer.write(f"{key}.{self.format}", stream) @@ -103,13 +114,14 @@ def _maybe_resample( audio: Union[torch.Tensor, np.ndarray], manifest: Recording, sampling_rate: int, + format: str, ) -> Tuple[Union[np.ndarray, torch.Tensor], Recording, int]: # Resampling is required for some versions of OPUS encoders. # First resample the manifest which only adjusts the metadata; # then resample the audio array to 48kHz. OPUS_DEFAULT_SAMPLING_RATE = 48000 if ( - self.format == "opus" + format == "opus" and is_torchaudio_available() and not isinstance(get_current_audio_backend(), LibsndfileBackend) and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE diff --git a/lhotse/shar/writers/shar.py b/lhotse/shar/writers/shar.py index 229a073b6..2c1a88442 100644 --- a/lhotse/shar/writers/shar.py +++ b/lhotse/shar/writers/shar.py @@ -135,7 +135,11 @@ def write(self, cut: Cut) -> None: recording.sources[0].channels = cut_channels recording.channel_ids = cut_channels self.writers["recording"].write( - cut.id, data, cut.sampling_rate, manifest=recording + cut.id, + data, + cut.sampling_rate, + manifest=recording, + original_format=cut.recording.source_format, ) cut = fastcopy(cut, recording=recording) else: @@ -224,6 +228,7 @@ def resolve_writer(name: str) -> Tuple[FieldWriter, str]: "flac": (partial(AudioTarWriter, format="flac"), ".tar"), "mp3": (partial(AudioTarWriter, format="mp3"), ".tar"), "opus": (partial(AudioTarWriter, format="opus"), ".tar"), + "original": (partial(AudioTarWriter, format="original"), ".tar"), "lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"), "numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"), "jsonl": (JsonlShardWriter, ".jsonl.gz"), diff --git a/lhotse/testing/dummies.py b/lhotse/testing/dummies.py index aec6a7581..0999906aa 100644 --- a/lhotse/testing/dummies.py +++ b/lhotse/testing/dummies.py @@ -63,6 +63,7 @@ def dummy_recording( duration: float = 1.0, sampling_rate: int = 16000, with_data: bool = False, + source_format: str = "wav", ) -> Recording: num_samples = compute_num_samples(duration, sampling_rate) return Recording( @@ -72,6 +73,7 @@ def dummy_recording( sampling_rate=sampling_rate, num_samples=num_samples, with_data=with_data, + format=source_format, ) ], sampling_rate=sampling_rate, @@ -85,6 +87,7 @@ def dummy_audio_source( sampling_rate: int = 16000, channels: Optional[List[int]] = None, with_data: bool = False, + format: str = "wav", ) -> AudioSource: if channels is None: channels = [0] @@ -95,21 +98,40 @@ def dummy_audio_source( else: import soundfile - # 1kHz sine wave - data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples)) + # generate 1kHz sine wave + f_sine = 1000 + assert ( + f_sine < sampling_rate / 2 + ), f"Sine wave frequency {f_sine} exceeds Nyquist frequency {sampling_rate/2} for sampling rate {sampling_rate}" + data = torch.sin(2 * np.pi * f_sine / sampling_rate * torch.arange(num_samples)) + + # prepare multichannel data if len(channels) > 1: data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1) # ensure each channel has different data for channel selection testing mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)]) data = data * mults + + # prepare source with the selected format binary_data = BytesIO() - soundfile.write( - binary_data, - data.numpy(), - sampling_rate, - format="wav", - closefd=False, - ) + if format == "opus": + # workaround for OPUS: soundfile supports OPUS as a subtype of OGG format + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format="OGG", + subtype="OPUS", + closefd=False, + ) + else: + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format=format, + closefd=False, + ) binary_data.seek(0) return AudioSource( type="memory", channels=channels, source=binary_data.getvalue() diff --git a/test/shar/test_write.py b/test/shar/test_write.py index cee35de1c..7d88bd3d0 100644 --- a/test/shar/test_write.py +++ b/test/shar/test_write.py @@ -66,55 +66,6 @@ def test_tar_writer_pipe(tmp_path: Path): assert f2.read() == b"test" -@pytest.mark.parametrize( - "format", - [ - "wav", - pytest.param( - "flac", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("0.12.1"), - reason="Torchaudio v0.12.1 or greater is required.", - ), - ), - # "mp3", # apparently doesn't work in CI, mp3 encoder is missing - pytest.param( - "opus", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("2.1.0"), - reason="Torchaudio v2.1.0 or greater is required.", - ), - ), - ], -) -def test_audio_tar_writer(tmp_path: Path, format: str): - from lhotse.testing.dummies import dummy_recording - - recording = dummy_recording(0, with_data=True) - audio = recording.load_audio() - - with AudioTarWriter( - str(tmp_path / "test.tar"), shard_size=None, format=format - ) as writer: - writer.write( - key="my-recording", - value=audio, - sampling_rate=recording.sampling_rate, - manifest=recording, - ) - - (path,) = writer.output_paths - - ((deserialized_recording, inner_path),) = list(TarIterator(path)) - - deserialized_audio = deserialized_recording.resample( - recording.sampling_rate - ).load_audio() - - rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) - assert rmse < 0.5 - - @pytest.mark.parametrize( ["format", "backend"], [ @@ -175,6 +126,59 @@ def test_audio_tar_writer(tmp_path: Path, format: str, backend: str): assert rmse < 0.5 +@pytest.mark.parametrize( + ["original_format", "rmse_threshold"], + [("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)], +) +def test_audio_tar_writer_original_format( + tmp_path: Path, original_format: str, rmse_threshold: float +): + """Test using AudioTarWritter to write the audio signal in the exact same format + as it was loaded from the source. + """ + from lhotse.testing.dummies import dummy_recording + + backend = "default" # use the default backend for reading the audio + writer_format = "original" # write the audio in the same format as it was loaded + + recording = dummy_recording(0, with_data=True, source_format=original_format) + audio = recording.load_audio() + + assert ( + recording.source_format == original_format + ), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})" + + with audio_backend(backend): + with AudioTarWriter( + str(tmp_path / "test.tar"), shard_size=None, format=writer_format + ) as writer: + writer.write( + key="my-recording", + value=audio, + sampling_rate=recording.sampling_rate, + manifest=recording, + original_format=recording.source_format, + ) + (path,) = writer.output_paths + ((deserialized_recording, inner_path),) = list(TarIterator(path)) + + # make sure the deserialized audio is in the same format as the original + assert ( + deserialized_recording.source_format == original_format + ), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})" + + # load audio + deserialized_audio = deserialized_recording.resample( + recording.sampling_rate + ).load_audio() + + # check difference between original and deserialized audio + rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) + assert ( + rmse <= rmse_threshold + ), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}" + + def test_shar_writer(tmp_path: Path): # Prepare data cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)