From 92b9916aaf1ce0d888d2bb2c080d7ced74d9491f Mon Sep 17 00:00:00 2001 From: 0xlws Date: Fri, 13 Oct 2023 11:57:27 +0200 Subject: [PATCH] docs: fix typos, formatting --- audiocraft/models/multibanddiffusion.py | 49 ++++++++++++------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/audiocraft/models/multibanddiffusion.py b/audiocraft/models/multibanddiffusion.py index c68a4811..451b5862 100644 --- a/audiocraft/models/multibanddiffusion.py +++ b/audiocraft/models/multibanddiffusion.py @@ -30,8 +30,6 @@ class DiffusionProcess: noise_schedule (NoiseSchedule): Noise schedule for diffusion process. """ def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: - """ - """ self.model = model self.schedule = noise_schedule @@ -40,8 +38,8 @@ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, """Perform one diffusion process to generate one of the bands. Args: - condition (tensor): The embeddings form the compression model. - initial_noise (tensor): The initial noise to start the process/ + condition (torch.Tensor): The embeddings from the compression model. + initial_noise (torch.Tensor): The initial noise to start the process. """ return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, condition=condition) @@ -87,7 +85,6 @@ def get_mbd_24khz(bw: float = 3.0, Args: bw (float): Bandwidth of the compression model. - pretrained (bool): Whether to use / download if necessary the models. device (torch.device or str, optional): Device on which the models are loaded. n_q (int, optional): Number of quantizers to use within the compression model. """ @@ -114,10 +111,10 @@ def get_mbd_24khz(bw: float = 3.0, @torch.no_grad() def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: - """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform. + """Get the conditioning (i.e. latent representations of the compression model) from a waveform. Args: - wav (torch.Tensor): The audio that we want to extract the conditioning from - sample_rate (int): sample rate of the audio""" + wav (torch.Tensor): The audio that we want to extract the conditioning from. + sample_rate (int): Sample rate of the audio.""" if sample_rate != self.sample_rate: wav = julius.resample_frac(wav, sample_rate, self.sample_rate) codes, scale = self.codec_model.encode(wav) @@ -127,20 +124,20 @@ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: @torch.no_grad() def get_emb(self, codes: torch.Tensor): - """Get latent representation from the discrete codes - Argrs: - codes (torch.Tensor): discrete tokens""" + """Get latent representation from the discrete codes. + Args: + codes (torch.Tensor): Discrete tokens.""" emb = self.codec_model.decode_latent(codes) return emb def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, step_list: tp.Optional[tp.List[int]] = None): - """Generate Wavform audio from the latent embeddings of the compression model + """Generate waveform audio from the latent embeddings of the compression model. Args: - emb (torch.Tensor): Conditioning embeddinds - size (none torch.Size): size of the output - if None this is computed from the typical upsampling of the model - step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step. + emb (torch.Tensor): Conditioning embeddings + size (None, torch.Size): Size of the output + if None this is computed from the typical upsampling of the model. + step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step. """ if size is None: upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) @@ -152,12 +149,12 @@ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, return out def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): - """match the eq to the encodec output by matching the standard deviation of some frequency bands + """Match the eq to the encodec output by matching the standard deviation of some frequency bands. Args: - wav (torch.Tensor): audio to equalize - ref (torch.Tensor):refenrence audio from which we match the spectrogram. - n_bands (int): number of bands of the eq - strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching. + wav (torch.Tensor): Audio to equalize. + ref (torch.Tensor): Reference audio from which we match the spectrogram. + n_bands (int): Number of bands of the eq. + strictness (float): How strict the matching. 0 is no matching, 1 is exact matching. """ split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) bands = split(wav) @@ -168,10 +165,10 @@ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictn return out def regenerate(self, wav: torch.Tensor, sample_rate: int): - """Regenerate a wavform through compression and diffusion regeneration. + """Regenerate a waveform through compression and diffusion regeneration. Args: - wav (torch.Tensor): Original 'ground truth' audio - sample_rate (int): sample rate of the input (and output) wav + wav (torch.Tensor): Original 'ground truth' audio. + sample_rate (int): Sample rate of the input (and output) wav. """ if sample_rate != self.codec_model.sample_rate: wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) @@ -185,8 +182,8 @@ def regenerate(self, wav: torch.Tensor, sample_rate: int): def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): """Generate Waveform audio with diffusion from the discrete codes. Args: - tokens (torch.Tensor): discrete codes - n_bands (int): bands for the eq matching. + tokens (torch.Tensor): Discrete codes. + n_bands (int): Bands for the eq matching. """ wav_encodec = self.codec_model.decode(tokens) condition = self.get_emb(tokens)