Skip to content

Commit

Permalink
docs: fix typos, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
0xlws committed Oct 13, 2023
1 parent 9af6690 commit 92b9916
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions audiocraft/models/multibanddiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class DiffusionProcess:
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
"""
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
"""
"""
self.model = model
self.schedule = noise_schedule

Expand All @@ -40,8 +38,8 @@ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
"""Perform one diffusion process to generate one of the bands.
Args:
condition (tensor): The embeddings form the compression model.
initial_noise (tensor): The initial noise to start the process/
condition (torch.Tensor): The embeddings from the compression model.
initial_noise (torch.Tensor): The initial noise to start the process.
"""
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
condition=condition)
Expand Down Expand Up @@ -87,7 +85,6 @@ def get_mbd_24khz(bw: float = 3.0,
Args:
bw (float): Bandwidth of the compression model.
pretrained (bool): Whether to use / download if necessary the models.
device (torch.device or str, optional): Device on which the models are loaded.
n_q (int, optional): Number of quantizers to use within the compression model.
"""
Expand All @@ -114,10 +111,10 @@ def get_mbd_24khz(bw: float = 3.0,

@torch.no_grad()
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
"""Get the conditioning (i.e. latent representations of the compression model) from a waveform.
Args:
wav (torch.Tensor): The audio that we want to extract the conditioning from
sample_rate (int): sample rate of the audio"""
wav (torch.Tensor): The audio that we want to extract the conditioning from.
sample_rate (int): Sample rate of the audio."""
if sample_rate != self.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
codes, scale = self.codec_model.encode(wav)
Expand All @@ -127,20 +124,20 @@ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:

@torch.no_grad()
def get_emb(self, codes: torch.Tensor):
"""Get latent representation from the discrete codes
Argrs:
codes (torch.Tensor): discrete tokens"""
"""Get latent representation from the discrete codes.
Args:
codes (torch.Tensor): Discrete tokens."""
emb = self.codec_model.decode_latent(codes)
return emb

def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
step_list: tp.Optional[tp.List[int]] = None):
"""Generate Wavform audio from the latent embeddings of the compression model
"""Generate waveform audio from the latent embeddings of the compression model.
Args:
emb (torch.Tensor): Conditioning embeddinds
size (none torch.Size): size of the output
if None this is computed from the typical upsampling of the model
step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
emb (torch.Tensor): Conditioning embeddings
size (None, torch.Size): Size of the output
if None this is computed from the typical upsampling of the model.
step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
"""
if size is None:
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
Expand All @@ -152,12 +149,12 @@ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
return out

def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
"""match the eq to the encodec output by matching the standard deviation of some frequency bands
"""Match the eq to the encodec output by matching the standard deviation of some frequency bands.
Args:
wav (torch.Tensor): audio to equalize
ref (torch.Tensor):refenrence audio from which we match the spectrogram.
n_bands (int): number of bands of the eq
strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
wav (torch.Tensor): Audio to equalize.
ref (torch.Tensor): Reference audio from which we match the spectrogram.
n_bands (int): Number of bands of the eq.
strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
"""
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
bands = split(wav)
Expand All @@ -168,10 +165,10 @@ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictn
return out

def regenerate(self, wav: torch.Tensor, sample_rate: int):
"""Regenerate a wavform through compression and diffusion regeneration.
"""Regenerate a waveform through compression and diffusion regeneration.
Args:
wav (torch.Tensor): Original 'ground truth' audio
sample_rate (int): sample rate of the input (and output) wav
wav (torch.Tensor): Original 'ground truth' audio.
sample_rate (int): Sample rate of the input (and output) wav.
"""
if sample_rate != self.codec_model.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
Expand All @@ -185,8 +182,8 @@ def regenerate(self, wav: torch.Tensor, sample_rate: int):
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
"""Generate Waveform audio with diffusion from the discrete codes.
Args:
tokens (torch.Tensor): discrete codes
n_bands (int): bands for the eq matching.
tokens (torch.Tensor): Discrete codes.
n_bands (int): Bands for the eq matching.
"""
wav_encodec = self.codec_model.decode(tokens)
condition = self.get_emb(tokens)
Expand Down

0 comments on commit 92b9916

Please sign in to comment.