diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py index 6aaa3b07..cf31986e 100644 --- a/audiocraft/quantization/core_vq.py +++ b/audiocraft/quantization/core_vq.py @@ -205,7 +205,6 @@ def forward(self, x): if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. - self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) @@ -215,6 +214,7 @@ def forward(self, x): ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) return quantize, embed_ind diff --git a/audiocraft/solvers/musicgen2.py b/audiocraft/solvers/musicgen2.py new file mode 100644 index 00000000..b876d4e7 --- /dev/null +++ b/audiocraft/solvers/musicgen2.py @@ -0,0 +1,711 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import time +import typing as tp +import warnings + +import flashy +import math +import omegaconf +import torch +from torch.nn import functional as F + +from . import base, builders +from .compression import CompressionSolver +from .. import metrics as eval_metrics +from .. import models +from ..data.audio_dataset import AudioDataset +from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo +from ..data.audio_utils import normalize_audio +from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition +from ..utils.cache import CachedBatchWriter, CachedBatchLoader +from ..utils.samples.manager import SampleManager +from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once + + +class MusicGenSolver(base.StandardSolver): + """Solver for MusicGen training task. + + Used in: https://arxiv.org/abs/2306.05284 + """ + DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC + + def __init__(self, cfg: omegaconf.DictConfig): + super().__init__(cfg) + # easier access to sampling parameters + self.generation_params = { + 'use_sampling': self.cfg.generate.lm.use_sampling, + 'temp': self.cfg.generate.lm.temp, + 'top_k': self.cfg.generate.lm.top_k, + 'top_p': self.cfg.generate.lm.top_p, + } + self._best_metric_name: tp.Optional[str] = 'ce' + + self._cached_batch_writer = None + self._cached_batch_loader = None + if cfg.cache.path: + if cfg.cache.write: + self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) + if self.cfg.cache.write_num_shards: + self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") + self._best_metric_name = None + else: + self._cached_batch_loader = CachedBatchLoader( + Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, + min_length=self.cfg.optim.updates_per_epoch or 1) + self.dataloaders['original_train'] = self.dataloaders['train'] + self.dataloaders['train'] = self._cached_batch_loader # type: ignore + + @staticmethod + def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, + device: tp.Optional[str] = None, autocast: bool = True, + batch_size: tp.Optional[int] = None, + override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, + **kwargs): + """Mostly a convenience function around magma.train.get_solver_from_sig, + populating all the proper param, deactivating EMA, FSDP, loading the best state, + basically all you need to get a solver ready to "play" with in single GPU mode + and with minimal memory overhead. + + Args: + sig (str): signature to load. + dtype (str or None): potential dtype, as a string, i.e. 'float16'. + device (str or None): potential device, as a string, i.e. 'cuda'. + override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. + """ + from audiocraft import train + our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} + our_override_cfg['autocast'] = autocast + if dtype is not None: + our_override_cfg['dtype'] = dtype + if device is not None: + our_override_cfg['device'] = device + if batch_size is not None: + our_override_cfg['dataset'] = {'batch_size': batch_size} + if override_cfg is None: + override_cfg = {} + override_cfg = omegaconf.OmegaConf.merge( + omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore + solver = train.get_solver_from_sig( + sig, override_cfg=override_cfg, + load_best=True, disable_fsdp=True, + ignore_state_keys=['optimizer', 'ema'], **kwargs) + solver.model.eval() + return solver + + def get_formatter(self, stage_name: str) -> flashy.Formatter: + return flashy.Formatter({ + 'lr': '.2E', + 'ce': '.3f', + 'ppl': '.3f', + 'grad_norm': '.3E', + }, exclude_keys=['ce_q*', 'ppl_q*']) + + @property + def best_metric_name(self) -> tp.Optional[str]: + return self._best_metric_name + + def build_model(self) -> None: + """Instantiate models and optimizer.""" + # we can potentially not use all quantizers with which the EnCodec model was trained + # (e.g. we trained the model with quantizers dropout) + self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( + self.cfg, self.cfg.compression_model_checkpoint, device=self.device) + assert self.compression_model.sample_rate == self.cfg.sample_rate, ( + f"Compression model sample rate is {self.compression_model.sample_rate} but " + f"Solver sample rate is {self.cfg.sample_rate}." + ) + # ensure we have matching configuration between LM and compression model + assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( + "Cardinalities of the LM and compression model don't match: ", + f"LM cardinality is {self.cfg.transformer_lm.card} vs ", + f"compression model cardinality is {self.compression_model.cardinality}" + ) + assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( + "Numbers of codebooks of the LM and compression models don't match: ", + f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", + f"compression model numer of codebooks is {self.compression_model.num_codebooks}" + ) + self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", + self.compression_model.num_codebooks, self.compression_model.cardinality, + self.compression_model.frame_rate) + # instantiate LM model + self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) + if self.cfg.fsdp.use: + assert not self.cfg.autocast, "Cannot use autocast with fsdp" + self.model = self.wrap_with_fsdp(self.model) + self.register_ema('model') + # initialize optimization + self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) + self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) + self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') + self.register_best_state('model') + self.autocast_dtype = { + 'float16': torch.float16, 'bfloat16': torch.bfloat16 + }[self.cfg.autocast_dtype] + self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None + if self.cfg.fsdp.use: + need_scaler = self.cfg.fsdp.param_dtype == 'float16' + else: + need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 + if need_scaler: + if self.cfg.fsdp.use: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + self.scaler = ShardedGradScaler() # type: ignore + else: + self.scaler = torch.cuda.amp.GradScaler() + self.register_stateful('scaler') + + def build_dataloaders(self) -> None: + """Instantiate audio dataloaders for each stage.""" + self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) + + def show(self) -> None: + """Show the compression model and LM model.""" + self.logger.info("Compression model:") + self.log_model_summary(self.compression_model) + self.logger.info("LM model:") + self.log_model_summary(self.model) + + def load_state_dict(self, state: dict) -> None: + if 'condition_provider' in state: + model_state = state['model'] + condition_provider_state = state.pop('condition_provider') + prefix = 'condition_provider.' + for key, value in condition_provider_state.items(): + key = prefix + key + assert key not in model_state + model_state[key] = value + super().load_state_dict(state) + + def load_from_pretrained(self, name: str): + # TODO: support native HF versions of MusicGen. + lm_pkg = models.loaders.load_lm_model_ckpt(name) + state: dict = { + 'best_state': { + 'model': lm_pkg['best_state'], + }, + } + return state + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute focal loss between multi-codebook targets and model's logits. + The focal loss is computed per codebook to provide codebook-level focal loss. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + fl (torch.Tensor): Focal loss averaged over the codebooks + fl_per_codebook (list of torch.Tensor): Focal loss per codebook (detached). + """ + alpha = 0.05 + gamma = 1.5 + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + fl = torch.zeros([], device=targets.device) + fl_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + fl_targets = targets_k[mask_k] + fl_logits = logits_k[mask_k] + ce_loss = F.cross_entropy(fl_logits, fl_targets, reduction='none') + pt = torch.exp(-ce_loss) + focal_loss = alpha * (1 - pt) ** gamma * ce_loss + q_fl = focal_loss.mean() + fl += q_fl + fl_per_codebook.append(q_fl.detach()) + # average focal loss across codebooks + fl = fl / K + return fl, fl_per_codebook + + + def _prepare_tokens_and_attributes( + self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + check_synchronization_points: bool = False + ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: + """Prepare input batchs for language model training. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] + and corresponding metadata as SegmentWithAttributes (with B items). + check_synchronization_points (bool): Whether to check for synchronization points slowing down training. + Returns: + Condition tensors (dict[str, any]): Preprocessed condition attributes. + Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], + with B the batch size, K the number of codebooks, T_s the token timesteps. + Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. + """ + if self.model.training: + warnings.warn( + "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. " + "This is inconsistent with how model were trained in the MusicGen paper. We removed the " + "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. " + "Really sorry about that.") + if self._cached_batch_loader is None or self.current_stage != "train": + audio, infos = batch + audio = audio.to(self.device) + audio_tokens = None + assert audio.size(0) == len(infos), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(infos)})" + ) + else: + audio = None + # In that case the batch will be a tuple coming from the _cached_batch_writer bit below. + infos, = batch # type: ignore + assert all([isinstance(info, AudioInfo) for info in infos]) + assert all([info.audio_tokens is not None for info in infos]) # type: ignore + audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore + audio_tokens = audio_tokens.long() + for info in infos: + if isinstance(info, MusicInfo): + # Careful here, if you want to use this condition_wav (e.b. chroma conditioning), + # then you must be using the chroma cache! otherwise the code will try + # to use this segment and fail (by that I mean you will see NaN everywhere). + info.self_wav = WavCondition( + torch.full([1, info.channels, info.total_frames], float('NaN')), + length=torch.tensor([info.n_frames]), + sample_rate=[info.sample_rate], + path=[info.meta.path], + seek_time=[info.seek_time]) + dataset = get_dataset_from_loader(self.dataloaders['original_train']) + assert isinstance(dataset, MusicDataset), type(dataset) + if dataset.paraphraser is not None and info.description is not None: + # Hackingly reapplying paraphraser when using cache. + info.description = dataset.paraphraser.sample_paraphrase( + info.meta.path, info.description) + # prepare attributes + attributes = [info.to_condition_attributes() for info in infos] + attributes = self.model.cfg_dropout(attributes) + attributes = self.model.att_dropout(attributes) + tokenized = self.model.condition_provider.tokenize(attributes) + + # Now we should be synchronization free. + if self.device == "cuda" and check_synchronization_points: + torch.cuda.set_sync_debug_mode("warn") + + if audio_tokens is None: + with torch.no_grad(): + audio_tokens, scale = self.compression_model.encode(audio) + assert scale is None, "Scaled compression model not supported with LM." + + with self.autocast: + condition_tensors = self.model.condition_provider(tokenized) + + # create a padding mask to hold valid vs invalid positions + padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) + # replace encodec tokens from padded audio with special_token_id + if self.cfg.tokens.padding_with_special_token: + audio_tokens = audio_tokens.clone() + padding_mask = padding_mask.clone() + token_sample_rate = self.compression_model.frame_rate + B, K, T_s = audio_tokens.shape + for i in range(B): + n_samples = infos[i].n_frames + audio_sample_rate = infos[i].sample_rate + # take the last token generated from actual audio frames (non-padded audio) + valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) + audio_tokens[i, :, valid_tokens:] = self.model.special_token_id + padding_mask[i, :, valid_tokens:] = 0 + + if self.device == "cuda" and check_synchronization_points: + torch.cuda.set_sync_debug_mode("default") + + if self._cached_batch_writer is not None and self.current_stage == 'train': + assert self._cached_batch_loader is None + assert audio_tokens is not None + for info, one_audio_tokens in zip(infos, audio_tokens): + assert isinstance(info, AudioInfo) + if isinstance(info, MusicInfo): + assert not info.joint_embed, "joint_embed and cache not supported yet." + info.self_wav = None + assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() + info.audio_tokens = one_audio_tokens.short().cpu() + self._cached_batch_writer.save(infos) + + return condition_tensors, audio_tokens, padding_mask + + def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: + """Perform one training or valid step on a given batch.""" + check_synchronization_points = idx == 1 and self.device == 'cuda' + + condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( + batch, check_synchronization_points) + + self.deadlock_detect.update('tokens_and_conditions') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('warn') + + with self.autocast: + model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore + logits = model_output.logits + mask = padding_mask & model_output.mask + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + loss = ce + self.deadlock_detect.update('loss') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('default') + + if self.is_training: + metrics['lr'] = self.optimizer.param_groups[0]['lr'] + if self.scaler is not None: + loss = self.scaler.scale(loss) + self.deadlock_detect.update('scale') + if self.cfg.fsdp.use: + loss.backward() + flashy.distrib.average_tensors(self.model.buffers()) + elif self.cfg.optim.eager_sync: + with flashy.distrib.eager_sync_model(self.model): + loss.backward() + else: + # this should always be slower but can be useful + # for weird use cases like multiple backwards. + loss.backward() + flashy.distrib.sync_model(self.model) + self.deadlock_detect.update('backward') + + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + if self.cfg.optim.max_norm: + if self.cfg.fsdp.use: + metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore + else: + metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + if self.scaler is None: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + if self.lr_scheduler: + self.lr_scheduler.step() + self.optimizer.zero_grad() + self.deadlock_detect.update('optim') + if self.scaler is not None: + scale = self.scaler.get_scale() + metrics['grad_scale'] = scale + if not loss.isfinite().all(): + raise RuntimeError("Model probably diverged.") + + metrics['ce'] = ce + metrics['ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'ce_q{k + 1}'] = ce_q + metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) + + return metrics + + @torch.no_grad() + def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + gen_duration: float, prompt_duration: tp.Optional[float] = None, + remove_prompt: bool = False, + **generation_params) -> dict: + """Run generate step on a batch of optional audio tensor and corresponding attributes. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): + use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. + gen_duration (float): Target audio duration for the generation. + prompt_duration (float, optional): Duration for the audio prompt to use for continuation. + remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. + generation_params: Additional generation parameters. + Returns: + gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation + and the prompt along with additional information. + """ + bench_start = time.time() + audio, meta = batch + assert audio.size(0) == len(meta), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(meta)})" + ) + # prepare attributes + attributes = [x.to_condition_attributes() for x in meta] + # TODO: Add dropout for chroma? + + # prepare audio prompt + if prompt_duration is None: + prompt_audio = None + else: + assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" + prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) + prompt_audio = audio[..., :prompt_audio_frames] + + # get audio tokens from compression model + if prompt_audio is None or prompt_audio.nelement() == 0: + num_samples = len(attributes) + prompt_tokens = None + else: + num_samples = None + prompt_audio = prompt_audio.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt_audio) + assert scale is None, "Compression model in MusicGen should not require rescaling." + + # generate by sampling from the LM + with self.autocast: + total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) + gen_tokens = self.model.generate( + prompt_tokens, attributes, max_gen_len=total_gen_len, + num_samples=num_samples, **self.generation_params) + + # generate audio from tokens + assert gen_tokens.dim() == 3 + gen_audio = self.compression_model.decode(gen_tokens, None) + + bench_end = time.time() + gen_outputs = { + 'rtf': (bench_end - bench_start) / gen_duration, + 'ref_audio': audio, + 'gen_audio': gen_audio, + 'gen_tokens': gen_tokens, + 'prompt_audio': prompt_audio, + 'prompt_tokens': prompt_tokens, + } + return gen_outputs + + def generate_audio(self) -> dict: + """Audio generation stage.""" + generate_stage_name = f'{self.current_stage}' + sample_manager = SampleManager(self.xp) + self.logger.info(f"Generating samples in {sample_manager.base_folder}") + loader = self.dataloaders['generate'] + updates = len(loader) + lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) + + dataset = get_dataset_from_loader(loader) + dataset_duration = dataset.segment_duration + assert dataset_duration is not None + assert isinstance(dataset, AudioDataset) + target_duration = self.cfg.generate.lm.gen_duration + prompt_duration = self.cfg.generate.lm.prompt_duration + if target_duration is None: + target_duration = dataset_duration + if prompt_duration is None: + prompt_duration = dataset_duration / 4 + assert prompt_duration < dataset_duration, ( + f"Specified prompt duration ({prompt_duration}s) is longer", + f" than reference audio duration ({dataset_duration}s)" + ) + + def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): + hydrated_conditions = [] + for sample in [x.to_condition_attributes() for x in meta]: + cond_dict = {} + for cond_type in sample.__annotations__.keys(): + for cond_key, cond_val in getattr(sample, cond_type).items(): + if cond_key not in self.model.condition_provider.conditioners.keys(): + continue + if is_jsonable(cond_val): + cond_dict[cond_key] = cond_val + elif isinstance(cond_val, WavCondition): + cond_dict[cond_key] = cond_val.path + elif isinstance(cond_val, JointEmbedCondition): + cond_dict[cond_key] = cond_val.text # only support text at inference for now + else: + # if we reached this point, it is not clear how to log the condition + # so we just log the type. + cond_dict[cond_key] = str(type(cond_val)) + continue + hydrated_conditions.append(cond_dict) + return hydrated_conditions + + metrics: dict = {} + average = flashy.averager() + for batch in lp: + audio, meta = batch + # metadata for sample manager + hydrated_conditions = get_hydrated_conditions(meta) + sample_generation_params = { + **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, + **self.generation_params + } + if self.cfg.generate.lm.unprompted_samples: + if self.cfg.generate.lm.gen_gt_samples: + # get the ground truth instead of generation + self.logger.warn( + "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") + gen_unprompted_audio = audio + rtf = 1. + else: + gen_unprompted_outputs = self.run_generate_step( + batch, gen_duration=target_duration, prompt_duration=None, + **self.generation_params) + gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() + rtf = gen_unprompted_outputs['rtf'] + sample_manager.add_samples( + gen_unprompted_audio, self.epoch, hydrated_conditions, + ground_truth_wavs=audio, generation_args=sample_generation_params) + + if self.cfg.generate.lm.prompted_samples: + gen_outputs = self.run_generate_step( + batch, gen_duration=target_duration, prompt_duration=prompt_duration, + **self.generation_params) + gen_audio = gen_outputs['gen_audio'].cpu() + prompt_audio = gen_outputs['prompt_audio'].cpu() + sample_manager.add_samples( + gen_audio, self.epoch, hydrated_conditions, + prompt_wavs=prompt_audio, ground_truth_wavs=audio, + generation_args=sample_generation_params) + + metrics['rtf'] = rtf + metrics = average(metrics) + + flashy.distrib.barrier() + return metrics + + def generate(self) -> dict: + """Generate stage.""" + self.model.eval() + with torch.no_grad(): + return self.generate_audio() + + def run_epoch(self): + if self.cfg.cache.write: + if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: + return + super().run_epoch() + + def train(self): + """Train stage. + """ + if self._cached_batch_writer is not None: + self._cached_batch_writer.start_epoch(self.epoch) + if self._cached_batch_loader is None: + dataset = get_dataset_from_loader(self.dataloaders['train']) + assert isinstance(dataset, AudioDataset) + dataset.current_epoch = self.epoch + else: + self._cached_batch_loader.start_epoch(self.epoch) + return super().train() + + def evaluate_audio_generation(self) -> dict: + """Evaluate audio generation with off-the-shelf metrics.""" + evaluate_stage_name = f'{self.current_stage}_generation' + # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation + fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None + kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None + text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None + chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None + should_run_eval = False + eval_chroma_wavs: tp.Optional[torch.Tensor] = None + if self.cfg.evaluate.metrics.fad: + fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.kld: + kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.text_consistency: + text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) + should_run_eval = True + if self.cfg.evaluate.metrics.chroma_cosine: + chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) + # if we have predefind wavs for chroma we should purge them for computing the cosine metric + has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ + self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() + if has_predefined_eval_chromas: + warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " + 'Resetting eval chromas to None for evaluation.') + eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore + self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore + should_run_eval = True + + def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: + audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) + compressed_audio = self.compression_model.decode(audio_tokens, scale) + return compressed_audio[..., :audio.shape[-1]] + + metrics: dict = {} + if should_run_eval: + loader = self.dataloaders['evaluate'] + updates = len(loader) + lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) + average = flashy.averager() + dataset = get_dataset_from_loader(loader) + assert isinstance(dataset, AudioDataset) + self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") + + for idx, batch in enumerate(lp): + audio, meta = batch + assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) + + target_duration = audio.shape[-1] / self.cfg.sample_rate + if self.cfg.evaluate.fixed_generation_duration: + target_duration = self.cfg.evaluate.fixed_generation_duration + + gen_outputs = self.run_generate_step( + batch, gen_duration=target_duration, + **self.generation_params + ) + y_pred = gen_outputs['gen_audio'].detach() + y_pred = y_pred[..., :audio.shape[-1]] + + normalize_kwargs = dict(self.cfg.generate.audio) + normalize_kwargs.pop('format', None) + y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() + y = audio.cpu() # should already be on CPU but just in case + sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding + sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples + audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] + + if fad is not None: + if self.cfg.metrics.fad.use_gt: + y_pred = get_compressed_audio(y).cpu() + fad.update(y_pred, y, sizes, sample_rates, audio_stems) + if kldiv is not None: + if self.cfg.metrics.kld.use_gt: + y_pred = get_compressed_audio(y).cpu() + kldiv.update(y_pred, y, sizes, sample_rates) + if text_consistency is not None: + texts = [m.description for m in meta] + if self.cfg.metrics.text_consistency.use_gt: + y_pred = y + text_consistency.update(y_pred, texts, sizes, sample_rates) + if chroma_cosine is not None: + if self.cfg.metrics.chroma_cosine.use_gt: + y_pred = get_compressed_audio(y).cpu() + chroma_cosine.update(y_pred, y, sizes, sample_rates) + # restore chroma conditioner's eval chroma wavs + if eval_chroma_wavs is not None: + self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) + + flashy.distrib.barrier() + if fad is not None: + metrics['fad'] = fad.compute() + if kldiv is not None: + kld_metrics = kldiv.compute() + metrics.update(kld_metrics) + if text_consistency is not None: + metrics['text_consistency'] = text_consistency.compute() + if chroma_cosine is not None: + metrics['chroma_cosine'] = chroma_cosine.compute() + metrics = average(metrics) + metrics = flashy.distrib.average_metrics(metrics, len(loader)) + + return metrics + + def evaluate(self) -> dict: + """Evaluate stage.""" + self.model.eval() + with torch.no_grad(): + metrics: dict = {} + if self.cfg.evaluate.metrics.base: + metrics.update(self.common_train_valid('evaluate')) + gen_metrics = self.evaluate_audio_generation() + return {**metrics, **gen_metrics} diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md index 9a6b1e74..92febb1b 100644 --- a/docs/MUSICGEN.md +++ b/docs/MUSICGEN.md @@ -173,7 +173,7 @@ Please find some example grids to train MusicGen at # text-to-music dora grid musicgen.musicgen_base_32khz --dry_run --init # melody-guided music generation -dora grid musicgen.musicgen_melody_base_32khz --dry_run --init +dora grid musicgen.musicgen_melody_32khz --dry_run --init # Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. ```