From 0fa7d4511b0fca3f96f59fa37f7dcaed555990d6 Mon Sep 17 00:00:00 2001 From: Weiqi Gao Date: Fri, 22 Sep 2023 20:53:25 +0800 Subject: [PATCH] APIs - 4.1.0a1 (#499) * Allow manually run workflow * Add HTDemucs to AnyModel * Set minimum Python version to 3.8 * Add type hints for save_audio * Update segment machenism * Update help message that htdemucs is the default * Add segment test * Fix linter * API * Update separate.py to use api * Version 4.1.0a1 * Update api.md * Make parameter `repo` clearer * Add doc for loading existing audio in quick start * Fix typo in docs * Add flac output * Fix linter * Fix conflicts * Fix wrong indent * import htdemucs * import htdemucs * rename variable * Fix variable name * Fix mypy linting * Fix mypy linting again * flac output & max_allowed_segment only for HTDemucs * Fix codes forgot to change * Update release.md * Use in-place operation to save memory * fix linter * Make API simpler * Optimise code and fix separate api returns different wave * Fixes according to review Use _NotProvided to allow passing `None` Add returing "break" to stop separation Make callback thread-safe * Allow changing remote_root when listing models * revert changes * return file path for list_models * fix Typing * Minor fixes with linter Directly calls audio.save_audio Changing parameter `sr` (samplerate) to optional argument Fixes linter * Fix lock; use `KeyboardInterrupt` to abort * Fix Linter * Update doc * List model argument --- .gitignore | 1 + Makefile | 2 + README.md | 2 + demucs/__init__.py | 2 +- demucs/api.py | 392 +++++++++++++++++++++++++++++++++++++++++++ demucs/apply.py | 80 ++++++++- demucs/audio.py | 2 +- demucs/pretrained.py | 4 +- demucs/repo.py | 18 ++ demucs/separate.py | 191 ++++++++++----------- demucs/utils.py | 13 +- docs/api.md | 204 ++++++++++++++++++++++ docs/release.md | 14 ++ 13 files changed, 817 insertions(+), 108 deletions(-) create mode 100644 demucs/api.py create mode 100644 docs/api.md diff --git a/.gitignore b/.gitignore index 9501beb1..6f73669b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ Session.vim /trash /misc /mdx +.mypy_cache \ No newline at end of file diff --git a/Makefile b/Makefile index 5b58bf45..0474d587 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,8 @@ test_eval: python3 -m demucs -n demucs_unittest --flac --int24 test.mp3 python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3 python3 -m demucs -n demucs_unittest --segment 8 test.mp3 + python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3 + python3 -m demucs --list-models tests/musdb: test -e tests || mkdir tests diff --git a/README.md b/README.md index e22f0ddd..a93e294d 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,8 @@ import shlex demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"')) ``` +To use more complicated APIs, see [API docs](docs/api.md) + ## Training Demucs If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md). diff --git a/demucs/__init__.py b/demucs/__init__.py index 5a8a6f70..ef5cd6f3 100644 --- a/demucs/__init__.py +++ b/demucs/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -__version__ = "4.0.1a3" +__version__ = "4.1.0a1" diff --git a/demucs/api.py b/demucs/api.py new file mode 100644 index 00000000..fc254fb2 --- /dev/null +++ b/demucs/api.py @@ -0,0 +1,392 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""API methods for demucs + +Classes +------- +`demucs.api.Separator`: The base separator class + +Functions +--------- +`demucs.api.save_audio`: Save an audio +`demucs.api.list_models`: Get models list + +Examples +-------- +See the end of this module (if __name__ == "__main__") +""" + +import subprocess + +import torch as th +import torchaudio as ta + +from dora.log import fatal +from pathlib import Path +from typing import Optional, Callable, Dict, Tuple, Union + +from .apply import apply_model, _replace_dict +from .audio import AudioFile, convert_audio, save_audio +from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo + + +class LoadAudioError(Exception): + pass + + +class LoadModelError(Exception): + pass + + +class _NotProvided: + pass + + +NotProvided = _NotProvided() + + +class Separator: + def __init__( + self, + model: str = "htdemucs", + repo: Optional[Path] = None, + device: str = "cuda" if th.cuda.is_available() else "cpu", + shifts: int = 1, + overlap: float = 0.25, + split: bool = True, + segment: Optional[int] = None, + jobs: int = 0, + progress: bool = False, + callback: Optional[Callable[[dict], None]] = None, + callback_arg: Optional[dict] = None, + ): + """ + `class Separator` + ================= + + Parameters + ---------- + model: Pretrained model name or signature. Default is htdemucs. + repo: Folder containing all pre-trained models for use. + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + self._name = model + self._repo = repo + self._load_model() + self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split, + segment=segment, jobs=jobs, progress=progress, callback=callback, + callback_arg=callback_arg) + + def update_parameter( + self, + device: Union[str, _NotProvided] = NotProvided, + shifts: Union[int, _NotProvided] = NotProvided, + overlap: Union[float, _NotProvided] = NotProvided, + split: Union[bool, _NotProvided] = NotProvided, + segment: Optional[Union[int, _NotProvided]] = NotProvided, + jobs: Union[int, _NotProvided] = NotProvided, + progress: Union[bool, _NotProvided] = NotProvided, + callback: Optional[ + Union[Callable[[dict], None], _NotProvided] + ] = NotProvided, + callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided, + ): + """ + Update the parameters of separation. + + Parameters + ---------- + segment: Length (in seconds) of each segment (only available if `split` is `True`). If \ + not specified, will use the command line option. + shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \ + apply the oppositve shift to the output. This is repeated `shifts` time and all \ + predictions are averaged. This effectively makes the model time equivariant and \ + improves SDR by up to 0.2 points. If not specified, will use the command line option. + split: If True, the input will be broken down into small chunks (length set by `segment`) \ + and predictions will be performed individually on each and concatenated. Useful for \ + model with large memory footprint like Tasnet. If not specified, will use the command \ + line option. + overlap: The overlap between the splits. If not specified, will use the command line \ + option. + device (torch.device, str, or None): If provided, device on which to execute the \ + computation, otherwise `wav.device` is assumed. When `device` is different from \ + `wav.device`, only local computations will be on `device`, while the entire tracks \ + will be stored on `wav.device`. If not specified, will use the command line option. + jobs: Number of jobs. This can increase memory usage but will be much faster when \ + multiple cores are available. If not specified, will use the command line option. + callback: A function will be called when the separation of a chunk starts or finished. \ + The argument passed to the function will be a dict. For more information, please see \ + the Callback section. + callback_arg: A dict containing private parameters to be passed to callback function. For \ + more information, please see the Callback section. + progress: If true, show a progress bar. + + Callback + -------- + The function will be called with only one positional parameter whose type is `dict`. The + `callback_arg` will be combined with information of current separation progress. The + progress information will override the values in `callback_arg` if same key has been used. + To abort the separation, raise `KeyboardInterrupt`. + + Progress information contains several keys (These keys will always exist): + - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. + - `shift_idx`: The index of shifts. Starts from 0. + - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't + mean that it is at the 441000 second of the audio, but the "frame" of the tensor. + - `state`: Could be `"start"` or `"end"`. + - `audio_length`: Length of the audio (in "frame" of the tensor). + - `models`: Count of submodels in the model. + """ + if not isinstance(device, _NotProvided): + self._device = device + if not isinstance(shifts, _NotProvided): + self._shifts = shifts + if not isinstance(overlap, _NotProvided): + self._overlap = overlap + if not isinstance(split, _NotProvided): + self._split = split + if not isinstance(segment, _NotProvided): + self._segment = segment + if not isinstance(jobs, _NotProvided): + self._jobs = jobs + if not isinstance(progress, _NotProvided): + self._progress = progress + if not isinstance(callback, _NotProvided) and (callback is None or callable(callback)): + self._callback = callback + if not isinstance(callback_arg, _NotProvided): + self._callback_arg = callback_arg + + def _load_model(self): + self._model = get_model(name=self._name, repo=self._repo) + if self._model is None: + raise LoadModelError("Failed to load model") + self._audio_channels = self._model.audio_channels + self._samplerate = self._model.samplerate + + def _load_audio(self, track: Path): + errors = {} + wav = None + + try: + wav = AudioFile(track).read(streams=0, samplerate=self._samplerate, + channels=self._audio_channels) + except FileNotFoundError: + errors["ffmpeg"] = "FFmpeg is not installed." + except subprocess.CalledProcessError: + errors["ffmpeg"] = "FFmpeg could not read the file." + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors["torchaudio"] = err.args[0] + else: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + + if wav is None: + raise LoadAudioError( + "\n".join( + "When trying to load using {}, got the following error: {}".format( + backend, error + ) + for backend, error in errors.items() + ) + ) + return wav + + def separate_tensor( + self, wav: th.Tensor, sr: Optional[int] = None + ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: + """ + Separate a loaded tensor. + + Parameters + ---------- + wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \ + while the second is the waveform of each channel. Type should be float32. \ + e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. + sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \ + model. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + + Notes + ----- + Use this function with cautiousness. This function does not provide data verifying. + """ + if sr is not None and sr != self.samplerate: + wav = convert_audio(wav, sr, self._samplerate, self._audio_channels) + ref = wav.mean(0) + wav -= ref.mean() + wav /= ref.std() + out = apply_model( + self._model, + wav[None], + segment=self._segment, + shifts=self._shifts, + split=self._split, + overlap=self._overlap, + device=self._device, + num_workers=self._jobs, + callback=self._callback, + callback_arg=_replace_dict( + self._callback_arg, ("audio_length", wav.shape[1]) + ), + progress=self._progress, + ) + if out is None: + raise KeyboardInterrupt + out *= ref.std() + out += ref.mean() + wav *= ref.std() + wav += ref.mean() + return (wav, dict(zip(self._model.sources, out[0]))) + + def separate_audio_file(self, file: Path): + """ + Separate an audio file. The method will automatically read the file. + + Parameters + ---------- + wav: Path of the file to be separated. + + Returns + ------- + A tuple, whose first element is the original wave and second element is a dict, whose keys + are the name of stems and values are separated waves. The original wave will have already + been resampled. + """ + return self.separate_tensor(self._load_audio(file), self.samplerate) + + @property + def samplerate(self): + return self._samplerate + + @property + def audio_channels(self): + return self._audio_channels + + @property + def model(self): + return self._model + + +def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]: + """ + List the available models. Please remember that not all the returned models can be + successfully loaded. + + Parameters + ---------- + repo: The repo whose models are to be listed. + + Returns + ------- + A dict with two keys ("single" for single models and "bag" for bag of models). The values are + lists whose components are strs. + """ + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + return {"single": model_repo.list_model(), "bag": bag_repo.list_model()} + + +if __name__ == "__main__": + # Test API functions + # two-stem not supported + + from .separate import get_parser + + args = get_parser().parse_args() + separator = Separator( + model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + overlap=args.overlap, + split=args.split, + segment=args.segment, + jobs=args.jobs, + callback=print + ) + out = args.out / args.name + out.mkdir(parents=True, exist_ok=True) + for file in args.tracks: + separated = separator.separate_audio_file(file)[1] + if args.mp3: + ext = "mp3" + elif args.flac: + ext = "flac" + else: + ext = "wav" + kwargs = { + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, + } + for stem, source in separated.items(): + stem = out / args.filename.format( + track=Path(file).name.rsplit(".", 1)[0], + trackext=Path(file).name.rsplit(".", 1)[-1], + stem=stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(source, str(stem), **kwargs) diff --git a/demucs/apply.py b/demucs/apply.py index 54ad9256..180db7fe 100644 --- a/demucs/apply.py +++ b/demucs/apply.py @@ -8,7 +8,9 @@ inteprolation between chunks, as well as the "shift trick". """ from concurrent.futures import ThreadPoolExecutor +import copy import random +from threading import Lock import typing as tp import torch as th @@ -49,7 +51,8 @@ def __init__(self, models: tp.List[Model], assert other.samplerate == first.samplerate assert other.audio_channels == first.audio_channels if segment is not None: - other.segment = segment + if not isinstance(other, HTDemucs) and segment > other.segment: + other.segment = segment self.audio_channels = first.audio_channels self.samplerate = first.samplerate @@ -129,13 +132,25 @@ def tensor_chunk(tensor_or_chunk): return TensorChunk(tensor_or_chunk) +def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict: + if _dict is None: + _dict = {} + else: + _dict = copy.copy(_dict) + for key, value in subs: + _dict[key] = value + return _dict + + def apply_model(model: tp.Union[BagOfModels, Model], mix: tp.Union[th.Tensor, TensorChunk], shifts: int = 1, split: bool = True, overlap: float = 0.25, transition_power: float = 1., progress: bool = False, device=None, num_workers: int = 0, segment: tp.Optional[float] = None, - pool=None) -> th.Tensor: + pool=None, lock=None, + callback: tp.Optional[tp.Callable[[dict], None]] = None, + callback_arg: tp.Optional[dict] = None) -> tp.Optional[th.Tensor]: """ Apply model to a given mixture. @@ -165,6 +180,11 @@ def apply_model(model: tp.Union[BagOfModels, Model], pool = ThreadPoolExecutor(num_workers) else: pool = DummyPoolExecutor() + if lock is None: + lock = Lock() + callback_arg = _replace_dict( + callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items() + ) kwargs: tp.Dict[str, tp.Any] = { 'shifts': shifts, 'split': split, @@ -174,31 +194,49 @@ def apply_model(model: tp.Union[BagOfModels, Model], 'device': device, 'pool': pool, 'segment': segment, + 'lock': lock, } out: tp.Union[float, th.Tensor] + res: tp.Union[float, th.Tensor, None] if isinstance(model, BagOfModels): # Special treatment for bag of model. # We explicitely apply multiple times `apply_model` so that the random shifts # are different for each model. estimates: tp.Union[float, th.Tensor] = 0. totals = [0.] * len(model.sources) + callback_arg["models"] = len(model.models) + kwargs["callback"] = ( + ( + lambda d, i=callback_arg["model_idx_in_bag"]: callback( + _replace_dict(d, ("model_idx_in_bag", i)) + ) + ) + if callable(callback) + else None + ) for sub_model, model_weights in zip(model.models, model.weights): original_model_device = next(iter(sub_model.parameters())).device sub_model.to(device) - out = apply_model(sub_model, mix, **kwargs) + res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg) + if res is None: + return res + out = res sub_model.to(original_model_device) for k, inst_weight in enumerate(model_weights): out[:, k, :, :] *= inst_weight totals[k] += inst_weight estimates += out del out + callback_arg["model_idx_in_bag"] += 1 assert isinstance(estimates, th.Tensor) for k in range(estimates.shape[1]): estimates[:, k, :, :] /= totals[k] return estimates + if "models" not in callback_arg: + callback_arg["models"] = 1 model.to(device) model.eval() assert transition_power >= 1, "transition_power < 1 leads to weird behavior." @@ -210,10 +248,18 @@ def apply_model(model: tp.Union[BagOfModels, Model], assert isinstance(mix, TensorChunk) padded_mix = mix.padded(length + 2 * max_shift) out = 0. - for _ in range(shifts): + for shift_idx in range(shifts): offset = random.randint(0, max_shift) shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) - shifted_out = apply_model(model, shifted, **kwargs) + kwargs["callback"] = ( + (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))) + if callable(callback) + else None + ) + res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg) + if res is None: + return res + shifted_out = res out += shifted_out[..., max_shift - offset:] out /= shifts assert isinstance(out, th.Tensor) @@ -241,13 +287,19 @@ def apply_model(model: tp.Union[BagOfModels, Model], futures = [] for offset in offsets: chunk = TensorChunk(mix, offset, segment_length) - future = pool.submit(apply_model, model, chunk, **kwargs) + future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg, + callback=(lambda d, i=offset: + callback(_replace_dict(d, ("segment_offset", i)))) + if callable(callback) else None) futures.append((future, offset)) offset += segment_length if progress: futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') for future, offset in futures: - chunk_out = future.result() + chunk_out = future.result() # type: tp.Union[None, th.Tensor] + if chunk_out is None: + pool.shutdown(wait=False, cancel_futures=True) + return chunk_out chunk_length = chunk_out.shape[-1] out[..., offset:offset + segment_length] += ( weight[:chunk_length] * chunk_out).to(mix.device) @@ -267,7 +319,21 @@ def apply_model(model: tp.Union[BagOfModels, Model], mix = tensor_chunk(mix) assert isinstance(mix, TensorChunk) padded_mix = mix.padded(valid_length).to(device) + with lock: + try: + callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore + except KeyboardInterrupt: + raise + except Exception: + pass with th.no_grad(): out = model(padded_mix) + with lock: + try: + callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore + except KeyboardInterrupt: + raise + except Exception: + pass assert isinstance(out, th.Tensor) return center_trim(out, length) diff --git a/demucs/audio.py b/demucs/audio.py index 3f50e9b0..31b29b3c 100644 --- a/demucs/audio.py +++ b/demucs/audio.py @@ -166,7 +166,7 @@ def convert_audio_channels(wav, channels=2): return wav -def convert_audio(wav, from_samplerate, to_samplerate, channels): +def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor: """Convert audio from a given samplerate to a target one and target number of channels.""" wav = convert_audio_channels(wav, channels) return julius.resample_frac(wav, from_samplerate, to_samplerate) diff --git a/demucs/pretrained.py b/demucs/pretrained.py index 65851cb6..80ae49cb 100644 --- a/demucs/pretrained.py +++ b/demucs/pretrained.py @@ -32,7 +32,7 @@ def demucs_unittest(): def add_model_flags(parser): group = parser.add_mutually_exclusive_group(required=False) group.add_argument("-s", "--sig", help="Locally trained XP signature.") - group.add_argument("-n", "--name", default=None, + group.add_argument("-n", "--name", default="htdemucs", help="Pretrained model name or signature. Default is htdemucs.") parser.add_argument("--repo", type=Path, help="Folder containing all pre-trained models for use with -n.") @@ -45,6 +45,8 @@ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: line = line.strip() if line.startswith('#'): continue + elif len(line) == 0: + continue elif line.startswith('root:'): root = line.split(':', 1)[1].strip() else: diff --git a/demucs/repo.py b/demucs/repo.py index 811254ce..5e20ff51 100644 --- a/demucs/repo.py +++ b/demucs/repo.py @@ -49,6 +49,9 @@ def has_model(self, sig: str) -> bool: def get_model(self, sig: str) -> Model: raise NotImplementedError() + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + raise NotImplementedError() + class RemoteRepo(ModelOnlyRepo): def __init__(self, models: tp.Dict[str, str]): @@ -66,6 +69,9 @@ def get_model(self, sig: str) -> Model: url, map_location='cpu', check_hash=True) # type: ignore return load_model(pkg) + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models # type: ignore + class LocalRepo(ModelOnlyRepo): def __init__(self, root: Path): @@ -100,6 +106,9 @@ def get_model(self, sig: str) -> Model: check_checksum(file, self._checksums[sig]) return load_model(file) + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._models + class BagOnlyRepo: """Handles only YAML files containing bag of models, leaving the actual @@ -132,6 +141,9 @@ def get_model(self, name: str) -> BagOfModels: segment = bag.get('segment') return BagOfModels(models, weights, segment) + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + return self._bags + class AnyModelRepo: def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): @@ -146,3 +158,9 @@ def get_model(self, name_or_sig: str) -> AnyModel: return self.model_repo.get_model(name_or_sig) else: return self.bag_repo.get_model(name_or_sig) + + def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: + models = self.model_repo.list_model() + for key, value in self.bag_repo.list_model().items(): + models[key] = value + return models diff --git a/demucs/separate.py b/demucs/separate.py index 8747970b..d5102ede 100644 --- a/demucs/separate.py +++ b/demucs/separate.py @@ -7,54 +7,24 @@ import argparse import sys from pathlib import Path -import subprocess from dora.log import fatal import torch as th -import torchaudio as ta - -from .apply import apply_model, BagOfModels -from .audio import AudioFile, convert_audio, save_audio -from .htdemucs import HTDemucs -from .pretrained import get_model_from_args, add_model_flags, ModelLoadingError +from .api import Separator, save_audio, list_models -def load_track(track, audio_channels, samplerate): - errors = {} - wav = None - - try: - wav = AudioFile(track).read( - streams=0, - samplerate=samplerate, - channels=audio_channels) - except FileNotFoundError: - errors['ffmpeg'] = 'FFmpeg is not installed.' - except subprocess.CalledProcessError: - errors['ffmpeg'] = 'FFmpeg could not read the file.' - - if wav is None: - try: - wav, sr = ta.load(str(track)) - except RuntimeError as err: - errors['torchaudio'] = err.args[0] - else: - wav = convert_audio(wav, sr, samplerate, audio_channels) - - if wav is None: - print(f"Could not load file {track}. " - "Maybe it is not a supported file format? ") - for backend, error in errors.items(): - print(f"When trying to load using {backend}, got the following error: {error}") - sys.exit(1) - return wav +from .apply import BagOfModels +from .htdemucs import HTDemucs +from .pretrained import add_model_flags, ModelLoadingError def get_parser(): parser = argparse.ArgumentParser("demucs.separate", description="Separate the sources for the given tracks") - parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks') + parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks') add_model_flags(parser) + parser.add_argument("--list-models", action="store_true", help="List available models " + "from current repo and exit") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("-o", "--out", @@ -96,12 +66,16 @@ def get_parser(): parser.add_argument("--two-stems", dest="stem", metavar="STEM", help="Only separate audio into {STEM} and no_{STEM}. ") - group = parser.add_mutually_exclusive_group() - group.add_argument("--int24", action="store_true", - help="Save wav output as 24 bits wav.") - group.add_argument("--float32", action="store_true", - help="Save wav output as float32 (2x bigger).") - parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp"], + parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"], + default="add", help='Decide how to get "no_{STEM}". "none" will not save ' + '"no_{STEM}". "add" will add all the other stems. "minus" will use the ' + "original track minus the selected stem.") + depth_group = parser.add_mutually_exclusive_group() + depth_group.add_argument("--int24", action="store_true", + help="Save wav output as 24 bits wav.") + depth_group.add_argument("--float32", action="store_true", + help="Save wav output as float32 (2x bigger).") + parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"], help="Strategy for avoiding clipping: rescaling entire signal " "if necessary (rescale) or hard clipping (clamp).") format_group = parser.add_mutually_exclusive_group() @@ -128,53 +102,64 @@ def get_parser(): def main(opts=None): parser = get_parser() args = parser.parse_args(opts) + if args.list_models: + models = list_models(args.repo) + print("Bag of models:", end="\n ") + print("\n ".join(models["bag"])) + print("Single models:", end="\n ") + print("\n ".join(models["single"])) + sys.exit(0) + if len(args.tracks) == 0: + print("error: the following arguments are required: tracks", file=sys.stderr) + sys.exit(1) try: - model = get_model_from_args(args) + separator = Separator(model=args.name, + repo=args.repo, + device=args.device, + shifts=args.shifts, + split=args.split, + overlap=args.overlap, + progress=True, + jobs=args.jobs, + segment=args.segment) except ModelLoadingError as error: fatal(error.args[0]) max_allowed_segment = float('inf') - if isinstance(model, HTDemucs): - max_allowed_segment = float(model.segment) - elif isinstance(model, BagOfModels): - max_allowed_segment = model.max_allowed_segment + if isinstance(separator.model, HTDemucs): + max_allowed_segment = float(separator.model.segment) + elif isinstance(separator.model, BagOfModels): + max_allowed_segment = separator.model.max_allowed_segment if args.segment is not None and args.segment > max_allowed_segment: fatal("Cannot use a Transformer model with a longer segment " f"than it was trained for. Maximum segment is: {max_allowed_segment}") - if isinstance(model, BagOfModels): - print(f"Selected model is a bag of {len(model.models)} models. " - "You will see that many progress bars per track.") - - model.cpu() - model.eval() + if isinstance(separator.model, BagOfModels): + print( + f"Selected model is a bag of {len(separator.model.models)} models. " + "You will see that many progress bars per track." + ) - if args.stem is not None and args.stem not in model.sources: + if args.stem is not None and args.stem not in separator.model.sources: fatal( - 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( - stem=args.stem, sources=', '.join(model.sources))) + 'error: stem "{stem}" is not in selected model. ' + "STEM must be one of {sources}.".format( + stem=args.stem, sources=", ".join(separator.model.sources) + ) + ) out = args.out / args.name out.mkdir(parents=True, exist_ok=True) print(f"Separated tracks will be stored in {out.resolve()}") for track in args.tracks: if not track.exists(): - print( - f"File {track} does not exist. If the path contains spaces, " - "please try again after surrounding the entire path with quotes \"\".", - file=sys.stderr) + print(f"File {track} does not exist. If the path contains spaces, " + 'please try again after surrounding the entire path with quotes "".', + file=sys.stderr) continue print(f"Separating track {track}") - wav = load_track(track, model.audio_channels, model.samplerate) - ref = wav.mean(0) - wav -= ref.mean() - wav /= ref.std() - sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts, - split=args.split, overlap=args.overlap, progress=True, - num_workers=args.jobs, segment=args.segment)[0] - sources *= ref.std() - sources += ref.mean() + origin, res = separator.separate_audio_file(track) if args.mp3: ext = "mp3" @@ -183,36 +168,54 @@ def main(opts=None): else: ext = "wav" kwargs = { - 'samplerate': model.samplerate, - 'bitrate': args.mp3_bitrate, - 'preset': args.mp3_preset, - 'clip': args.clip_mode, - 'as_float': args.float32, - 'bits_per_sample': 24 if args.int24 else 16, + "samplerate": separator.samplerate, + "bitrate": args.mp3_bitrate, + "preset": args.mp3_preset, + "clip": args.clip_mode, + "as_float": args.float32, + "bits_per_sample": 24 if args.int24 else 16, } if args.stem is None: - for source, name in zip(sources, model.sources): - stem = out / args.filename.format(track=track.name.rsplit(".", 1)[0], - trackext=track.name.rsplit(".", 1)[-1], - stem=name, ext=ext) + for name, source in res.items(): + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=name, + ext=ext, + ) stem.parent.mkdir(parents=True, exist_ok=True) save_audio(source, str(stem), **kwargs) else: - sources = list(sources) - stem = out / args.filename.format(track=track.name.rsplit(".", 1)[0], - trackext=track.name.rsplit(".", 1)[-1], - stem=args.stem, ext=ext) - stem.parent.mkdir(parents=True, exist_ok=True) - save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs) - # Warning : after poping the stem, selected stem is no longer in the list 'sources' - other_stem = th.zeros_like(sources[0]) - for i in sources: - other_stem += i - stem = out / args.filename.format(track=track.name.rsplit(".", 1)[0], - trackext=track.name.rsplit(".", 1)[-1], - stem="no_"+args.stem, ext=ext) + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="minus_" + args.stem, + ext=ext, + ) + if args.other_method == "minus": + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(origin - res[args.stem], str(stem), **kwargs) + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem=args.stem, + ext=ext, + ) stem.parent.mkdir(parents=True, exist_ok=True) - save_audio(other_stem, str(stem), **kwargs) + save_audio(res.pop(args.stem), str(stem), **kwargs) + # Warning : after poping the stem, selected stem is no longer in the dict 'res' + if args.other_method == "add": + other_stem = th.zeros_like(next(iter(res.values()))) + for i in res.values(): + other_stem += i + stem = out / args.filename.format( + track=track.name.rsplit(".", 1)[0], + trackext=track.name.rsplit(".", 1)[-1], + stem="no_" + args.stem, + ext=ext, + ) + stem.parent.mkdir(parents=True, exist_ok=True) + save_audio(other_stem, str(stem), **kwargs) if __name__ == "__main__": diff --git a/demucs/utils.py b/demucs/utils.py index 96a2cc11..c80fc129 100755 --- a/demucs/utils.py +++ b/demucs/utils.py @@ -120,19 +120,24 @@ def random_subset(dataset, max_samples: int, seed: int = 42): class DummyPoolExecutor: class DummyResult: - def __init__(self, func, *args, **kwargs): + def __init__(self, func, _dict, *args, **kwargs): self.func = func + self._dict = _dict self.args = args self.kwargs = kwargs def result(self): - return self.func(*self.args, **self.kwargs) + if self._dict["run"]: + return self.func(*self.args, **self.kwargs) def __init__(self, workers=0): - pass + self._dict = {"run": True} def submit(self, func, *args, **kwargs): - return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs) + + def shutdown(self, *_, **__): + self._dict["run"] = False def __enter__(self): return self diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..e6d9e873 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,204 @@ +# Demucs APIs + +## Quick start + +Notes: Type hints have been added to all API functions. It is recommended to check them before passing parameters to a function as some arguments only support limited types (e.g. parameter `repo` of method `load_model` only support type `pathlib.Path`). + +1. The first step is to import api module: + +```python +import demucs.api +``` + +2. Then initialize the `Separator`. Parameters which will be served as default values for methods can be passed. Model should be specified. + +```python +# Initialize with default parameters: +separator = demucs.api.Separator() + +# Use another model and segment: +separator = demucs.api.Separator(model="mdx_extra", segment=12) + +# You can also use other parameters defined +``` + +3. Separate it! + +```python +# Separating an audio file +origin, separated = separator.separate_audio_file("file.mp3") + +# Separating a loaded audio +origin, separated = separator.separate_tensor(audio) + +# If you encounter an error like CUDA out of memory, you can use this to change parameters like `segment`: +separator.update_parameter(segment=smaller_segment) +``` + +4. Save audio + +```python +# Remember to create the destination folder before calling `save_audio` +# Or you are likely to recieve `FileNotFoundError` +for file, sources in separated: + for stem, source in sources.items(): + demucs.api.save_audio(source, f"{stem}_{file}", samplerate=separator.samplerate) +``` + +## API References + +The types of each parameter and return value is not listed in this document. To know the exact type of them, please read the type hints in api.py (most modern code editors support infering types based on type hints). + +### `class Separator` + +The base separator class + +##### Parameters + +model: Pretrained model name or signature. Default is htdemucs. + +repo: Folder containing all pre-trained models for use. + +segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. + +shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. + +split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. + +overlap: The overlap between the splits. If not specified, will use the command line option. + +device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. + +jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. + +callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. + +callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. + +progress: If true, show a progress bar. + +##### Notes for callback + +The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise `KeyboardInterrupt`. + +Progress information contains several keys (These keys will always exist): +- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. +- `shift_idx`: The index of shifts. Starts from 0. +- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. +- `state`: Could be `"start"` or `"end"`. +- `audio_length`: Length of the audio (in "frame" of the tensor). +- `models`: Count of submodels in the model. + +#### `property samplerate` + +A read-only property saving sample rate of the model requires. Will raise a warning if the model is not loaded and return the default value. + +#### `property audio_channels` + +A read-only property saving audio channels of the model requires. Will raise a warning if the model is not loaded and return the default value. + +#### `property model` + +A read-only property saving the model. + +#### `method update_parameter()` + +Update the parameters of separation. + +##### Parameters + +segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. + +shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. + +split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. + +overlap: The overlap between the splits. If not specified, will use the command line option. + +device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. + +jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. + +callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. + +callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. + +progress: If true, show a progress bar. + +##### Notes for callback + +The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise `KeyboardInterrupt`. + +Progress information contains several keys (These keys will always exist): +- `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. +- `shift_idx`: The index of shifts. Starts from 0. +- `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. +- `state`: Could be `"start"` or `"end"`. +- `audio_length`: Length of the audio (in "frame" of the tensor). +- `models`: Count of submodels in the model. + +#### `method separate_tensor()` + +Separate an audio. + +##### Parameters + +wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, while the second is the waveform of each channel. e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. + +sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the model. + +##### Returns + +A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. + +##### Notes + +Use this function with cautiousness. This function does not provide data verifying. + +#### `method separate_audio_file()` + +Separate an audio file. The method will automatically read the file. + +##### Parameters + +wav: Path of the file to be separated. + +##### Returns + +A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. + +### `function save_audio()` + +Save audio file. + +##### Parameters + +wav: Audio to be saved + +path: The file path to be saved. Ending must be one of `.mp3` and `.wav`. + +samplerate: File sample rate. + +bitrate: If the suffix of `path` is `.mp3`, it will be used to specify the bitrate of mp3. + +clip: Clipping preventing strategy. + +bits_per_sample: If the suffix of `path` is `.wav`, it will be used to specify the bit depth of wav. + +as_float: If it is True and the suffix of `path` is `.wav`, then `bits_per_sample` will be set to 32 and will write the wave file with float format. + +##### Returns + +None + +### `function list_models()` + +List the available models. Please remember that not all the returned models can be successfully loaded. + +##### Parameters + +repo: The repo whose models are to be listed. + +##### Returns + +A dict with two keys ("single" for single models and "bag" for bag of models). The values are lists whose components are strs. \ No newline at end of file diff --git a/docs/release.md b/docs/release.md index 54099c01..1c8dd537 100644 --- a/docs/release.md +++ b/docs/release.md @@ -1,5 +1,19 @@ # Release notes for Demucs +## V4.1.0a1, TBD + +Get models list + +Check segment of HTDemucs inside BagOfModels + +Added api.py to be called from another program + +Use api in separate.py + +Added `--other-method`: method to get `no_{STEM}`, add up all the other stems (add), original track substract the specific stem (minus), and discard (none) + +Added type `HTDemucs` to type alias `AnyModel`. + ## V4.0.1a1, TBD **From this version, Python 3.7 is no longer supported. This is not a problem since the latest PyTorch 2.0.0 no longer support it either.**