Skip to content

Commit

Permalink
limiting dependency on torchaudio for writing files
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Oct 9, 2023
1 parent be7a4ff commit d45769c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).


## [1.0.1] - TBD

Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg.

## [1.0.0] - 2023-09-07

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Expand Down
38 changes: 29 additions & 9 deletions audiocraft/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import torchaudio as ta

import av
import subprocess as sp

from .audio_utils import f32_pcm, i16_pcm, normalize_audio
from .audio_utils import f32_pcm, normalize_audio


_av_initialized = False
Expand Down Expand Up @@ -150,10 +151,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
return wav, sr


def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
assert wav.dim() == 2, wav.shape
command = [
'ffmpeg',
'-loglevel', 'error',
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
'-i', '-'] + flags + [str(out_path)]
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
sp.run(command, input=input_, check=True)


def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
Expand All @@ -164,8 +177,9 @@ def audio_write(stem_name: tp.Union[str, Path],
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
format (str): Either "wav" or "mp3".
format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
Expand Down Expand Up @@ -193,14 +207,20 @@ def audio_write(stem_name: tp.Union[str, Path],
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
elif format == 'ogg':
suffix = '.ogg'
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
if ogg_rate is not None:
flags += ['-b:a', f'{ogg_rate}k']
elif format == 'flac':
suffix = '.flac'
flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
Expand All @@ -209,7 +229,7 @@ def audio_write(stem_name: tp.Union[str, Path],
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
_piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
Expand Down
15 changes: 6 additions & 9 deletions tests/common_utils/wav_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
import typing as tp

import torch
import torchaudio

from audiocraft.data.audio import audio_write


def get_white_noise(chs: int = 1, num_frames: int = 1):
Expand All @@ -22,11 +22,8 @@ def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):


def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
assert wav.dim() == 2, wav.shape
fp = Path(path)
kwargs: tp.Dict[str, tp.Any] = {}
if fp.suffix == '.wav':
kwargs['encoding'] = 'PCM_S'
kwargs['bits_per_sample'] = 16
elif fp.suffix == '.mp3':
kwargs['compression'] = 320
torchaudio.save(str(fp), wav, sample_rate, **kwargs)
assert fp.suffix in ['.mp3', '.ogg', '.wav', '.flac'], fp
audio_write(fp.parent / fp.stem, wav, sample_rate, fp.suffix[1:],
normalize=False, strategy='clip', peak_clip_headroom_db=0)

0 comments on commit d45769c

Please sign in to comment.