Skip to content

Commit

Permalink
Many bug fixes for v1.1.0 (#335)
Browse files Browse the repository at this point in the history
* wip

* fix bug in sample storage

* fix bug in fsdp with pytorch 2.0.1

* fix kld

* revert back passt change

* fix n_quantizers

* changes

* adding warning

* extra warnings and tests

* missing changelog

* missing changes between audiogen and musicgen
  • Loading branch information
adefossez authored Oct 26, 2023
1 parent 5d8752d commit f73b7ae
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 20 deletions.
18 changes: 17 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,26 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).


## [1.0.1] - TBD
## [1.1.0a] - TBD

Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.

Fixed DAC support with non default number of codebooks.

Fixed bug when `two_step_cfg` was overriden when calling `generate()`.

Fixed samples being always prompted with audio, rather than having both prompted and unprompted.

**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
We removed it, so you might need to retrain models.

**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).

**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
retrained a model with this pattern, so hopefully this won't impact you!


## [1.0.0] - 2023-09-07

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.0.0'
__version__ = '1.1.0a1'
4 changes: 4 additions & 0 deletions audiocraft/models/audiogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
# Just to be safe, let's put everything in eval mode.
self.compression_model.eval()
self.lm.eval()

if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
Expand Down
5 changes: 2 additions & 3 deletions audiocraft/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
CoarseFirstPattern,
)
from ..modules.conditioners import (
BaseConditioner,
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'coarse_first': CoarseFirstPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
Expand All @@ -196,7 +196,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
'dimension': 32,
'ratios': ratios,
}
print(seanet_kwargs)
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/models/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:

def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
codes = self.model.encode(x, self.n_quantizers)[1]
return codes, None
return codes[:, :self.n_quantizers], None

def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
assert scale is None
Expand Down
8 changes: 5 additions & 3 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ def _sample_next_token(self,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
cfg_coef: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Expand All @@ -335,7 +336,8 @@ def _sample_next_token(self,
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
if self.two_step_cfg and cfg_conditions != {}:
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
Expand Down Expand Up @@ -493,7 +495,7 @@ def generate(self,
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef)
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
Expand Down
4 changes: 4 additions & 0 deletions audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
# Just to be safe, let's put everything in eval mode.
self.compression_model.eval()
self.lm.eval()

if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
Expand Down
11 changes: 8 additions & 3 deletions audiocraft/modules/codebooks_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,14 @@ def get_pattern(self, timesteps: int) -> Pattern:
return Pattern(out, n_q=self.n_q, timesteps=timesteps)


class VALLEPattern(CodebooksPatternProvider):
"""Almost VALL-E style pattern.
We further allow some delays for the codebooks other than the first one.
class CoarseFirstPattern(CodebooksPatternProvider):
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
potentially with delays.
..Warning:: You must always generate the full training duration at test time, for instance,
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
location. This is due to the non causality of the remaining codebooks with respect to
the first ones.
Args:
n_q (int): Number of codebooks.
Expand Down
2 changes: 2 additions & 0 deletions audiocraft/modules/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,8 @@ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
import laion_clap # type: ignore
except ImportError:
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
"Please retrain all models.")
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
Expand Down
4 changes: 2 additions & 2 deletions audiocraft/optim/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def _name_without_fsdp_prefix(name: str) -> str:
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
return '.'.join(new_parts)

def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore
state = dict(super().state_dict())
def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore
state = dict(super().state_dict(*args, **kwargs))
for key, value in list(state.items()):
if is_sharded_tensor(value):
del state[key]
Expand Down
10 changes: 8 additions & 2 deletions audiocraft/solvers/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
import time
import typing as tp
import warnings

import flashy
import math
Expand Down Expand Up @@ -226,7 +227,6 @@ def _compute_cross_entropy(
ce = ce / K
return ce, ce_per_codebook

@torch.no_grad()
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
Expand All @@ -243,6 +243,12 @@ def _prepare_tokens_and_attributes(
with B the batch size, K the number of codebooks, T_s the token timesteps.
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
"""
if self.model.training:
warnings.warn(
"Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
"This is inconsistent with how model were trained in the MusicGen paper. We removed the "
"`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
"Really sorry about that.")
if self._cached_batch_loader is None or self.current_stage != "train":
audio, infos = batch
audio = audio.to(self.device)
Expand Down Expand Up @@ -533,7 +539,7 @@ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
batch, gen_duration=target_duration, prompt_duration=None,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class EmbeddingCache:
specify the index corresponding to the current embedding in the object that can represent batch metadata.
If not specified, will return the full embedding unmodified.
"""
def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device],
def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
self.cache_path = Path(cache_path)
Expand Down
2 changes: 1 addition & 1 deletion config/conditioner/clapemb2music.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ conditioners:
checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
sample_rate: 44100
sample_rate: 48000
max_audio_length: 10
audio_stride: 1
dim: 512
Expand Down
2 changes: 1 addition & 1 deletion config/model/lm/audiogen_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ codebooks_pattern:
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
valle:
coarse_first:
delays: [0, 0, 0]

transformer_lm:
Expand Down
2 changes: 1 addition & 1 deletion config/model/lm/musicgen_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ codebooks_pattern:
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
valle:
coarse_first:
delays: [0, 0, 0]

transformer_lm:
Expand Down
4 changes: 4 additions & 0 deletions docs/MUSICGEN.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ We provide a dummy dataset containing just a few examples for illustrative purpo

Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section.


**Warning:** As of version 1.1.0, a few breaking changes were introduced. Check the [CHANGELOG.md](../CHANGELOG.md)
file for more information. You might need to retrain some of your models.

### Example configurations and grids

We provide configurations to reproduce the released models and our research.
Expand Down
7 changes: 7 additions & 0 deletions tests/models/test_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,10 @@ def test_generate_long(self):
wav = mg.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 32000 * 4]

def test_generate_two_step_cfg(self):
mg = self.get_musicgen()
mg.set_generation_params(duration=2.0, extend_stride=2., two_step_cfg=True)
wav = mg.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 64000]

0 comments on commit f73b7ae

Please sign in to comment.