Skip to content

Commit

Permalink
Option to save audio in the original format when exporting to shar (l…
Browse files Browse the repository at this point in the history
…hotse-speech#1422)

* Option to save audio in the original format when exporting to shar

* Removed duplicate test
  • Loading branch information
anteju authored Nov 23, 2024
1 parent 9c1330a commit 36ce63e
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 64 deletions.
17 changes: 17 additions & 0 deletions lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def is_placeholder(self) -> bool:
def num_channels(self) -> int:
return len(self.channel_ids)

@property
def source_format(self) -> str:
"""Infer format of the audio sources.
If all sources have the same format, return it.
If sources have different formats, raise an error.
"""
source_formats = list(set([s.format for s in self.sources]))

if len(source_formats) == 1:
# if all sources have the same format, return it
return source_formats[0]
else:
# at the moment, we don't resolve different formats
raise NotImplementedError(
"Sources have different formats. Resolving to a single format not implemented."
)

@staticmethod
def from_file(
path: Pathlike,
Expand Down
28 changes: 28 additions & 0 deletions lhotse/audio/source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io
import os
import warnings
from dataclasses import dataclass
from io import BytesIO, FileIO
Expand All @@ -6,6 +8,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
import soundfile as sf
import torch

from lhotse.audio.backend import read_audio
Expand Down Expand Up @@ -64,6 +67,10 @@ class AudioSource:
def has_video(self) -> bool:
return self.video is not None

@property
def format(self) -> str:
return self._get_format()

def load_audio(
self,
offset: Seconds = 0.0,
Expand Down Expand Up @@ -316,3 +323,24 @@ def _prepare_for_reading(
)

return source

def _get_format(self) -> str:
"""Get format for the audio source.
If using 'file' or 'url' types, the format is inferred from the file extension, as in soundfile.
If using 'memory' type, the format is inferred from the binary data.
"""
if self.type in ("file", "url"):
# Resolve audio format based on the filename
format = os.path.splitext(self.source)[-1][1:]
return format.lower()
elif self.type == "memory":
sf_info = sf.info(io.BytesIO(self.source))
if sf_info.format == "OGG" and sf_info.subtype == "OPUS":
# soundfile describes opus as ogg container with opus coding
return "opus"
else:
return sf_info.format.lower()
else:
raise NotImplementedError(
f"Getting format not implemented for source type {self.type}"
)
4 changes: 2 additions & 2 deletions lhotse/bin/modes/shar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def shar():
"-a",
"--audio",
default="none",
type=click.Choice(["none", "wav", "flac", "mp3", "opus"]),
help="Format in which to export audio (disabled by default, enabling will make a copy of the data)",
type=click.Choice(["none", "wav", "flac", "mp3", "opus", "original"]),
help="Format in which to export audio. Original will save in the same format as the original audio (disabled by default, enabling will make a copy of the data)",
)
@click.option(
"-f",
Expand Down
18 changes: 15 additions & 3 deletions lhotse/shar/writers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def close(self):
def output_paths(self) -> List[str]:
return self.tar_writer.output_paths

def resolve_format(self, original_format: str):
if self.format == "original":
# save using the original format of the input audio
return original_format
else:
# save using the format specified at initialization
return self.format

def write_placeholder(self, key: str) -> None:
self.tar_writer.write(f"{key}.nodata", BytesIO())
self.tar_writer.write(f"{key}.nometa", BytesIO(), count=False)
Expand All @@ -76,15 +84,18 @@ def write(
value: np.ndarray,
sampling_rate: int,
manifest: Recording,
original_format: Optional[str] = None,
) -> None:
save_format = self.resolve_format(original_format)

value, manifest, sampling_rate = self._maybe_resample(
value, manifest, sampling_rate
value, manifest, sampling_rate, format=save_format
)

# Write binary data
stream = BytesIO()
save_audio(
dest=stream, src=value, sampling_rate=sampling_rate, format=self.format
dest=stream, src=value, sampling_rate=sampling_rate, format=save_format
)
self.tar_writer.write(f"{key}.{self.format}", stream)

Expand All @@ -103,13 +114,14 @@ def _maybe_resample(
audio: Union[torch.Tensor, np.ndarray],
manifest: Recording,
sampling_rate: int,
format: str,
) -> Tuple[Union[np.ndarray, torch.Tensor], Recording, int]:
# Resampling is required for some versions of OPUS encoders.
# First resample the manifest which only adjusts the metadata;
# then resample the audio array to 48kHz.
OPUS_DEFAULT_SAMPLING_RATE = 48000
if (
self.format == "opus"
format == "opus"
and is_torchaudio_available()
and not isinstance(get_current_audio_backend(), LibsndfileBackend)
and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE
Expand Down
7 changes: 6 additions & 1 deletion lhotse/shar/writers/shar.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def write(self, cut: Cut) -> None:
recording.sources[0].channels = cut_channels
recording.channel_ids = cut_channels
self.writers["recording"].write(
cut.id, data, cut.sampling_rate, manifest=recording
cut.id,
data,
cut.sampling_rate,
manifest=recording,
original_format=cut.recording.source_format,
)
cut = fastcopy(cut, recording=recording)
else:
Expand Down Expand Up @@ -224,6 +228,7 @@ def resolve_writer(name: str) -> Tuple[FieldWriter, str]:
"flac": (partial(AudioTarWriter, format="flac"), ".tar"),
"mp3": (partial(AudioTarWriter, format="mp3"), ".tar"),
"opus": (partial(AudioTarWriter, format="opus"), ".tar"),
"original": (partial(AudioTarWriter, format="original"), ".tar"),
"lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"),
"numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"),
"jsonl": (JsonlShardWriter, ".jsonl.gz"),
Expand Down
40 changes: 31 additions & 9 deletions lhotse/testing/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def dummy_recording(
duration: float = 1.0,
sampling_rate: int = 16000,
with_data: bool = False,
source_format: str = "wav",
) -> Recording:
num_samples = compute_num_samples(duration, sampling_rate)
return Recording(
Expand All @@ -72,6 +73,7 @@ def dummy_recording(
sampling_rate=sampling_rate,
num_samples=num_samples,
with_data=with_data,
format=source_format,
)
],
sampling_rate=sampling_rate,
Expand All @@ -85,6 +87,7 @@ def dummy_audio_source(
sampling_rate: int = 16000,
channels: Optional[List[int]] = None,
with_data: bool = False,
format: str = "wav",
) -> AudioSource:
if channels is None:
channels = [0]
Expand All @@ -95,21 +98,40 @@ def dummy_audio_source(
else:
import soundfile

# 1kHz sine wave
data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples))
# generate 1kHz sine wave
f_sine = 1000
assert (
f_sine < sampling_rate / 2
), f"Sine wave frequency {f_sine} exceeds Nyquist frequency {sampling_rate/2} for sampling rate {sampling_rate}"
data = torch.sin(2 * np.pi * f_sine / sampling_rate * torch.arange(num_samples))

# prepare multichannel data
if len(channels) > 1:
data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1)
# ensure each channel has different data for channel selection testing
mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)])
data = data * mults

# prepare source with the selected format
binary_data = BytesIO()
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format="wav",
closefd=False,
)
if format == "opus":
# workaround for OPUS: soundfile supports OPUS as a subtype of OGG format
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format="OGG",
subtype="OPUS",
closefd=False,
)
else:
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format=format,
closefd=False,
)
binary_data.seek(0)
return AudioSource(
type="memory", channels=channels, source=binary_data.getvalue()
Expand Down
102 changes: 53 additions & 49 deletions test/shar/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,55 +66,6 @@ def test_tar_writer_pipe(tmp_path: Path):
assert f2.read() == b"test"


@pytest.mark.parametrize(
"format",
[
"wav",
pytest.param(
"flac",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("0.12.1"),
reason="Torchaudio v0.12.1 or greater is required.",
),
),
# "mp3", # apparently doesn't work in CI, mp3 encoder is missing
pytest.param(
"opus",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("2.1.0"),
reason="Torchaudio v2.1.0 or greater is required.",
),
),
],
)
def test_audio_tar_writer(tmp_path: Path, format: str):
from lhotse.testing.dummies import dummy_recording

recording = dummy_recording(0, with_data=True)
audio = recording.load_audio()

with AudioTarWriter(
str(tmp_path / "test.tar"), shard_size=None, format=format
) as writer:
writer.write(
key="my-recording",
value=audio,
sampling_rate=recording.sampling_rate,
manifest=recording,
)

(path,) = writer.output_paths

((deserialized_recording, inner_path),) = list(TarIterator(path))

deserialized_audio = deserialized_recording.resample(
recording.sampling_rate
).load_audio()

rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
assert rmse < 0.5


@pytest.mark.parametrize(
["format", "backend"],
[
Expand Down Expand Up @@ -175,6 +126,59 @@ def test_audio_tar_writer(tmp_path: Path, format: str, backend: str):
assert rmse < 0.5


@pytest.mark.parametrize(
["original_format", "rmse_threshold"],
[("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)],
)
def test_audio_tar_writer_original_format(
tmp_path: Path, original_format: str, rmse_threshold: float
):
"""Test using AudioTarWritter to write the audio signal in the exact same format
as it was loaded from the source.
"""
from lhotse.testing.dummies import dummy_recording

backend = "default" # use the default backend for reading the audio
writer_format = "original" # write the audio in the same format as it was loaded

recording = dummy_recording(0, with_data=True, source_format=original_format)
audio = recording.load_audio()

assert (
recording.source_format == original_format
), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})"

with audio_backend(backend):
with AudioTarWriter(
str(tmp_path / "test.tar"), shard_size=None, format=writer_format
) as writer:
writer.write(
key="my-recording",
value=audio,
sampling_rate=recording.sampling_rate,
manifest=recording,
original_format=recording.source_format,
)
(path,) = writer.output_paths
((deserialized_recording, inner_path),) = list(TarIterator(path))

# make sure the deserialized audio is in the same format as the original
assert (
deserialized_recording.source_format == original_format
), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})"

# load audio
deserialized_audio = deserialized_recording.resample(
recording.sampling_rate
).load_audio()

# check difference between original and deserialized audio
rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
assert (
rmse <= rmse_threshold
), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}"


def test_shar_writer(tmp_path: Path):
# Prepare data
cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)
Expand Down

0 comments on commit 36ce63e

Please sign in to comment.