diff --git a/README.md b/README.md index 7a0e694..9c16b89 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,9 @@ center = 'half-hop' # (Optional) Linearly interpolate unvoiced regions below periodicity threshold interp_unvoiced_at = .065 +# (Optional) Select a decoding method. One of ['argmax', 'pyin', 'viterbi']. +decoder = 'viterbi' + # Infer pitch and periodicity pitch, periodicity = penn.from_audio( audio, @@ -85,6 +88,7 @@ pitch, periodicity = penn.from_audio( checkpoint=checkpoint, batch_size=batch_size, center=center, + decoder=decoder, interp_unvoiced_at=interp_unvoiced_at, gpu=gpu) ``` @@ -96,16 +100,17 @@ pitch, periodicity = penn.from_audio( ``` def from_audio( - audio: torch.Tensor, - sample_rate: int = penn.SAMPLE_RATE, - hopsize: float = penn.HOPSIZE_SECONDS, - fmin: float = penn.FMIN, - fmax: float = penn.FMAX, - checkpoint: Optional[Path] = None, - batch_size: Optional[int] = None, - center: str = 'half-window', - interp_unvoiced_at: Optional[float] = None, - gpu: Optional[int] = None + audio: torch.Tensor, + sample_rate: int = penn.SAMPLE_RATE, + hopsize: float = penn.HOPSIZE_SECONDS, + fmin: float = penn.FMIN, + fmax: float = penn.FMAX, + checkpoint: Optional[Path] = None, + batch_size: Optional[int] = None, + center: str = 'half-window', + decoder: str = penn.DECODER, + interp_unvoiced_at: Optional[float] = None, + gpu: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform pitch and periodicity estimation @@ -134,15 +139,16 @@ Returns: ``` def from_file( - file: Path, - hopsize: float = penn.HOPSIZE_SECONDS, - fmin: float = penn.FMIN, - fmax: float = penn.FMAX, - checkpoint: Optional[Path] = None, - batch_size: Optional[int] = None, - center: str = 'half-window', - interp_unvoiced_at: Optional[float] = None, - gpu: Optional[int] = None + file: Path, + hopsize: float = penn.HOPSIZE_SECONDS, + fmin: float = penn.FMIN, + fmax: float = penn.FMAX, + checkpoint: Optional[Path] = None, + batch_size: Optional[int] = None, + center: str = 'half-window', + decoder: str = penn.DECODER, + interp_unvoiced_at: Optional[float] = None, + gpu: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform pitch and periodicity estimation from audio on disk @@ -168,16 +174,17 @@ Returns: ``` def from_file_to_file( - file: Path, - output_prefix: Optional[Path] = None, - hopsize: float = penn.HOPSIZE_SECONDS, - fmin: float = penn.FMIN, - fmax: float = penn.FMAX, - checkpoint: Optional[Path] = None, - batch_size: Optional[int] = None, - center: str = 'half-window', - interp_unvoiced_at: Optional[float] = None, - gpu: Optional[int] = None + file: Path, + output_prefix: Optional[Path] = None, + hopsize: float = penn.HOPSIZE_SECONDS, + fmin: float = penn.FMIN, + fmax: float = penn.FMAX, + checkpoint: Optional[Path] = None, + batch_size: Optional[int] = None, + center: str = 'half-window', + decoder: str = penn.DECODER, + interp_unvoiced_at: Optional[float] = None, + gpu: Optional[int] = None ) -> None: """Perform pitch and periodicity estimation from audio on disk and save @@ -208,6 +215,7 @@ def from_files_to_files( checkpoint: Optional[Path] = None, batch_size: Optional[int] = None, center: str = 'half-window', + decoder: str = penn.DECODER, interp_unvoiced_at: Optional[float] = None, num_workers: int = penn.NUM_WORKERS, gpu: Optional[int] = None @@ -244,7 +252,9 @@ python -m penn [--checkpoint CHECKPOINT] [--batch_size BATCH_SIZE] [--center {half-window,half-hop,zero}] + [--decoder {argmax,pyin,viterbi}] [--interp_unvoiced_at INTERP_UNVOICED_AT] + [--num_workers NUM_WORKERS] [--gpu GPU] required arguments: @@ -271,8 +281,12 @@ optional arguments: The number of frames per batch. Defaults to 2048. --center {half-window,half-hop,zero} Padding options - --interp_unvoiced_at INTERP_UNVOICED_AT + --decoder {argmax,pyin,viterbi} + Posteriorgram decoder + --interp_unvoiced_at INTERP_UNVOICED_AT Specifies voicing threshold for interpolation. Defaults to 0.1625. + --num_workers + Number of CPU threads for async data I/O --gpu GPU The index of the gpu to perform inference on. Defaults to CPU. ``` diff --git a/config/crepe++.py b/config/crepe++.py index 79efec6..b162399 100644 --- a/config/crepe++.py +++ b/config/crepe++.py @@ -6,5 +6,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # The name of the model to use for training MODEL = 'crepe' diff --git a/config/deepf0++.py b/config/deepf0++.py index aede086..ebb6d1a 100644 --- a/config/deepf0++.py +++ b/config/deepf0++.py @@ -6,5 +6,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # The name of the model to use for training MODEL = 'deepf0' diff --git a/config/fcnf0++-ablate-batchsize.py b/config/fcnf0++-ablate-batchsize.py index a003d38..507cb86 100644 --- a/config/fcnf0++-ablate-batchsize.py +++ b/config/fcnf0++-ablate-batchsize.py @@ -8,3 +8,6 @@ # The decoder to use for postprocessing DECODER = 'argmax' + +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False diff --git a/config/fcnf0++-ablate-chunkviterbi-normal.py b/config/fcnf0++-ablate-chunkviterbi-normal.py new file mode 100644 index 0000000..87e9a95 --- /dev/null +++ b/config/fcnf0++-ablate-chunkviterbi-normal.py @@ -0,0 +1,10 @@ +MODULE = 'penn' + +# Configuration name +CONFIG = 'fcnf0++-ablate-chunkviterbi-normal' + +# The decoder to use for postprocessing +DECODER = 'viterbi' + +# Maximum chunk size for chunked Viterbi decoding +VITERBI_MIN_CHUNK_SIZE = 64 diff --git a/config/fcnf0++-ablate-chunkviterbi.py b/config/fcnf0++-ablate-chunkviterbi.py new file mode 100644 index 0000000..e563a5e --- /dev/null +++ b/config/fcnf0++-ablate-chunkviterbi.py @@ -0,0 +1,13 @@ +MODULE = 'penn' + +# Configuration name +CONFIG = 'fcnf0++-ablate-chunkviterbi' + +# The decoder to use for postprocessing +DECODER = 'viterbi' + +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + +# Maximum chunk size for chunked Viterbi decoding +VITERBI_MIN_CHUNK_SIZE = 8 diff --git a/config/fcnf0++-ablate-decoder.py b/config/fcnf0++-ablate-decoder.py index fc2950f..8e2920a 100644 --- a/config/fcnf0++-ablate-decoder.py +++ b/config/fcnf0++-ablate-decoder.py @@ -3,5 +3,8 @@ # Configuration name CONFIG = 'fcnf0++-ablate-decoder' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # The decoder to use for postprocessing DECODER = 'argmax' diff --git a/config/fcnf0++-ablate-earlystop.py b/config/fcnf0++-ablate-earlystop.py index 34ad2f8..14e67be 100644 --- a/config/fcnf0++-ablate-earlystop.py +++ b/config/fcnf0++-ablate-earlystop.py @@ -9,5 +9,8 @@ # Whether to stop training when validation loss stops improving EARLY_STOPPING = True +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Number of steps between logging to Tensorboard LOG_INTERVAL = 500 # steps diff --git a/config/fcnf0++-ablate-inputnorm.py b/config/fcnf0++-ablate-inputnorm.py index 135cfba..c97e8be 100644 --- a/config/fcnf0++-ablate-inputnorm.py +++ b/config/fcnf0++-ablate-inputnorm.py @@ -6,5 +6,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Whether to normalize input audio to mean zero and variance one NORMALIZE_INPUT = True diff --git a/config/fcnf0++-ablate-layernorm.py b/config/fcnf0++-ablate-layernorm.py index 25c27f4..493ccd3 100644 --- a/config/fcnf0++-ablate-layernorm.py +++ b/config/fcnf0++-ablate-layernorm.py @@ -6,5 +6,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Type of model normalization NORMALIZATION = 'batch' diff --git a/config/fcnf0++-ablate-loss.py b/config/fcnf0++-ablate-loss.py index 9375956..679437e 100644 --- a/config/fcnf0++-ablate-loss.py +++ b/config/fcnf0++-ablate-loss.py @@ -6,5 +6,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Loss function LOSS = 'binary_cross_entropy' diff --git a/config/fcnf0++-ablate-quantization.py b/config/fcnf0++-ablate-quantization.py index 82fd558..13c25c2 100644 --- a/config/fcnf0++-ablate-quantization.py +++ b/config/fcnf0++-ablate-quantization.py @@ -9,5 +9,8 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Number of pitch bins to predict PITCH_BINS = 486 diff --git a/config/fcnf0++-ablate-unvoiced.py b/config/fcnf0++-ablate-unvoiced.py index e372a8d..563eec0 100644 --- a/config/fcnf0++-ablate-unvoiced.py +++ b/config/fcnf0++-ablate-unvoiced.py @@ -6,6 +6,9 @@ # The decoder to use for postprocessing DECODER = 'argmax' +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Whether to only use voiced start frames VOICED_ONLY = True diff --git a/config/fcnf0++-mdb.py b/config/fcnf0++-mdb.py index e54f54c..40676b2 100644 --- a/config/fcnf0++-mdb.py +++ b/config/fcnf0++-mdb.py @@ -5,3 +5,6 @@ # The decoder to use for postprocessing DECODER = 'argmax' + +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False diff --git a/config/fcnf0++-ptdb.py b/config/fcnf0++-ptdb.py index 0431c26..0882ed6 100644 --- a/config/fcnf0++-ptdb.py +++ b/config/fcnf0++-ptdb.py @@ -5,3 +5,6 @@ # The decoder to use for postprocessing DECODER = 'argmax' + +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False diff --git a/config/fcnf0.py b/config/fcnf0.py index 8954cf6..5f75d97 100644 --- a/config/fcnf0.py +++ b/config/fcnf0.py @@ -18,6 +18,9 @@ # Minimum representable frequency FMIN = 30. # Hz +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Number of steps between logging to Tensorboard LOG_INTERVAL = 500 # steps diff --git a/config/pyin.py b/config/pyin.py index 5429fff..827edbc 100644 --- a/config/pyin.py +++ b/config/pyin.py @@ -4,7 +4,7 @@ CONFIG = 'pyin' # The decoder to use for postprocessing -DECODER = 'argmax' +DECODER = 'pyin' # Distance between adjacent frames HOPSIZE = 160 # samples diff --git a/config/torchcrepe.py b/config/torchcrepe.py index 8b743a8..fe824b4 100644 --- a/config/torchcrepe.py +++ b/config/torchcrepe.py @@ -32,6 +32,9 @@ # Distance between adjacent frames HOPSIZE = 160 # samples +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = False + # Number of steps between logging to Tensorboard LOG_INTERVAL = 500 # steps diff --git a/penn/__main__.py b/penn/__main__.py index b287f54..1ddf30e 100644 --- a/penn/__main__.py +++ b/penn/__main__.py @@ -64,10 +64,20 @@ def parse_args(): choices=['half-window', 'half-hop', 'zero'], default='half-window', help='Padding options') + parser.add_argument( + '--decoder', + choices=['argmax', 'pyin', 'viterbi'], + default=penn.DECODER, + help='Posteriorgram decoder') parser.add_argument( '--interp_unvoiced_at', type=float, help='Specifies voicing threshold for interpolation') + parser.add_argument( + '--num_workers', + type=int, + default=0, + help='Number of CPU threads for async data I/O') parser.add_argument( '--gpu', type=int, diff --git a/penn/config/defaults.py b/penn/config/defaults.py index eb33231..87c720c 100644 --- a/penn/config/defaults.py +++ b/penn/config/defaults.py @@ -28,12 +28,6 @@ # Distance between adjacent frames HOPSIZE = 80 # samples -# The size of the window used for locally normal pitch decoding -LOCAL_PITCH_WINDOW_SIZE = 19 - -# Pitch velocity constraint for viterbi decoding -MAX_OCTAVES_PER_SECOND = 35.92 - # Whether to normalize input audio to mean zero and variance one NORMALIZE_INPUT = False @@ -53,6 +47,27 @@ WINDOW_SIZE = 1024 # samples +############################################################################### +# Decoder parameters +############################################################################### + + +# The decoder to use for postprocessing. One of ['argmax', 'pyin', 'viterbi']. +DECODER = 'viterbi' + +# Whether to perform local expected value decoding of pitch +LOCAL_EXPECTED_VALUE = True + +# The size of the window used for local expected value pitch decoding +LOCAL_PITCH_WINDOW_SIZE = 19 + +# Pitch velocity constraint for viterbi decoding +MAX_OCTAVES_PER_SECOND = 6. + +# Maximum chunk size for chunked Viterbi decoding +VITERBI_MIN_CHUNK_SIZE = None + + ############################################################################### # Directories ############################################################################### @@ -115,9 +130,6 @@ ############################################################################### -# The decoder to use for postprocessing -DECODER = 'local_expected_value' - # The dropout rate. Set to None to turn off dropout. DROPOUT = None diff --git a/penn/core.py b/penn/core.py index 9c9fffd..6cd9c56 100644 --- a/penn/core.py +++ b/penn/core.py @@ -28,6 +28,7 @@ def from_audio( checkpoint: Optional[Path] = None, batch_size: Optional[int] = None, center: str = 'half-window', + decoder: str = penn.DECODER, interp_unvoiced_at: Optional[float] = None, gpu: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -42,6 +43,7 @@ def from_audio( checkpoint: The checkpoint file batch_size: The number of frames per batch center: Padding options. One of ['half-window', 'half-hop', 'zero']. + decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. interp_unvoiced_at: Specifies voicing threshold for interpolation gpu: The index of the gpu to run inference on @@ -51,7 +53,14 @@ def from_audio( periodicity: torch.tensor( shape=(1, int(samples // penn.seconds_to_sample(hopsize)))) """ - pitch, periodicity = [], [] + device = 'cpu' if gpu is None else f'cuda:{gpu}' + + # Storage for batching + if batch_size is not None: + if decoder == 'argmax': + pitch, periodicity = [], [] + else: + logits = [] # Preprocess audio for frames in preprocess( @@ -64,19 +73,50 @@ def from_audio( # Copy to device with torchutil.time.context('copy-to'): - frames = frames.to('cpu' if gpu is None else f'cuda:{gpu}') + frames = frames.to(device) # Infer - logits = infer(frames, checkpoint).detach() + inferred = infer(frames, checkpoint).detach() + + if batch_size is None: + + # Postprocess full file + with torchutil.time.context('postprocess'): + _, pitch, periodicity = postprocess( + inferred, + fmin, + fmax, + decoder) - # Postprocess - with torchutil.time.context('postprocess'): - result = postprocess(logits, fmin, fmax) - pitch.append(result[1]) - periodicity.append(result[2]) + elif decoder == 'argmax': + + # Postprocess partial file + with torchutil.time.context('postprocess'): + result = postprocess(inferred, fmin, fmax, decoder) + pitch.append(result[1]) + periodicity.append(result[2]) + + else: + + # Save logits off GPU for later decoding + logits.append(inferred.cpu()) + + if batch_size is not None: + + if decoder == 'argmax': + + # Concatenate results + pitch = torch.cat(pitch, 1) + periodicity = torch.cat(periodicity, 1) + + else: - # Concatenate results - pitch, periodicity = torch.cat(pitch, 1), torch.cat(periodicity, 1) + # Postprocess full file + _, pitch, periodicity = postprocess( + torch.cat(logits, 0).to(device), + fmin, + fmax, + decoder) # Maybe interpolate unvoiced regions if interp_unvoiced_at is not None: @@ -96,6 +136,7 @@ def from_file( checkpoint: Optional[Path] = None, batch_size: Optional[int] = None, center: str = 'half-window', + decoder: str = penn.DECODER, interp_unvoiced_at: Optional[float] = None, gpu: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -109,6 +150,7 @@ def from_file( checkpoint: The checkpoint file batch_size: The number of frames per batch center: Padding options. One of ['half-window', 'half-hop', 'zero']. + decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. interp_unvoiced_at: Specifies voicing threshold for interpolation gpu: The index of the gpu to run inference on @@ -130,6 +172,7 @@ def from_file( checkpoint, batch_size, center, + decoder, interp_unvoiced_at, gpu) @@ -143,6 +186,7 @@ def from_file_to_file( checkpoint: Optional[Path] = None, batch_size: Optional[int] = None, center: str = 'half-window', + decoder: str = penn.DECODER, interp_unvoiced_at: Optional[float] = None, gpu: Optional[int] = None ) -> None: @@ -157,6 +201,7 @@ def from_file_to_file( checkpoint: The checkpoint file batch_size: The number of frames per batch center: Padding options. One of ['half-window', 'half-hop', 'zero']. + decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. interp_unvoiced_at: Specifies voicing threshold for interpolation gpu: The index of the gpu to run inference on """ @@ -169,6 +214,7 @@ def from_file_to_file( checkpoint, batch_size, center, + decoder, interp_unvoiced_at, gpu) @@ -197,8 +243,9 @@ def from_files_to_files( checkpoint: Optional[Path] = None, batch_size: Optional[int] = None, center: str = 'half-window', + decoder: str = penn.DECODER, interp_unvoiced_at: Optional[float] = None, - num_workers: int = penn.NUM_WORKERS, + num_workers: int = 0, gpu: Optional[int] = None ) -> None: """Perform pitch and periodicity estimation from files on disk and save @@ -212,6 +259,7 @@ def from_files_to_files( checkpoint: The checkpoint file batch_size: The number of frames per batch center: Padding options. One of ['half-window', 'half-hop', 'zero']. + decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. interp_unvoiced_at: Specifies voicing threshold for interpolation num_workers: Number of CPU threads for async data I/O gpu: The index of the gpu to run inference on @@ -240,6 +288,7 @@ def from_files_to_files( checkpoint, batch_size, center, + decoder, interp_unvoiced_at, gpu) @@ -271,14 +320,22 @@ def from_files_to_files( try: + device = 'cpu' if gpu is None else f'cuda:{gpu}' + # Track residual to fill batch residual_files = [] residual_frames = torch.zeros((0, 1, 1024)) residual_lengths = torch.zeros((0,), dtype=torch.long) - # Iterate over data - pitch, periodicity = torch.zeros((1, 0)), torch.zeros((1, 0)) + # Storage for batching within files + if batch_size is not None: + if decoder == 'argmax': + pitch, periodicity = torch.zeros((1, 0)), torch.zeros((1, 0)) + else: + logits = torch.zeros((1, 0, 0)) + # Iterate over data + num_inferred_unsaved = 0 for frames, lengths, input_files in loader: # Prepend residual @@ -292,40 +349,77 @@ def from_files_to_files( # Copy to device size = len(frames) if batch_size is None else batch_size - batch_frames = frames[i:i + size].to( - 'cpu' if gpu is None else f'cuda:{gpu}') + batch_frames = frames[i:i + size].to(device) # Infer - logits = infer(batch_frames, checkpoint).detach() + inferred = infer(batch_frames, checkpoint).detach() + i += len(batch_frames) + num_inferred_unsaved += len(batch_frames) - # Postprocess - results = postprocess(logits, fmin, fmax) + if batch_size is None: - # Append to residual - pitch = torch.cat((pitch, results[1].cpu()), dim=1) - periodicity = torch.cat( - (periodicity, results[2].cpu()), - dim=1) + # Postprocess full file + _, pitch, periodicity = postprocess( + inferred, + fmin, + fmax, + decoder) + break - i += len(batch_frames) + elif decoder == 'argmax': + + # Postprocess partial file + results = postprocess(inferred, fmin, fmax, decoder) + pitch = torch.cat((pitch, results[1].cpu()), dim=1) + periodicity = torch.cat( + (periodicity, results[2].cpu()), + dim=1) + + else: + + # Save logits for later decoding + # NOTE - This differs from from_audio and does not + # handle large files that do not fit on GPU. + # However, it saves a GPU -> CPU -> GPU copy. + logits = torch.cat((logits, inferred), dim=0) # Save to disk j, k = 0, 0 for length, file in zip(lengths, input_files): # Slice and save in another process - if j + length <= pitch.shape[-1]: + if j + length <= num_inferred_unsaved: + + if batch_size is not None: + + if decoder == 'argmax': + + # Slice results + save_pitch = pitch[:, j:j + length] + save_periodicity = periodicity[:, j:j + length] + + else: + + # Postprocess full file + _, save_pitch, save_periodicity = postprocess( + logits[j:j + length], + fmin, + fmax, + decoder) + + # Async save futures.append( pool.apply_async( save_worker, args=( output_prefixes[file], - pitch[:, j:j + length], - periodicity[:, j:j + length], + save_pitch, + save_periodicity, interp_unvoiced_at))) while len(futures) > 100: futures = [f for f in futures if not f.ready()] time.sleep(.1) + j += length k += 1 progress.update() @@ -333,8 +427,10 @@ def from_files_to_files( break # Setup residual for next iteration + num_inferred_unsaved -= j pitch = pitch[:, j:] periodicity = periodicity[:, j:] + logits = logits[j:] residual_files = input_files[k:] residual_lengths = lengths[k:] residual_frames = frames[i:] @@ -343,32 +439,57 @@ def from_files_to_files( if residual_frames.numel(): # Copy to device - batch_frames = residual_frames.to( - 'cpu' if gpu is None else f'cuda:{gpu}') + batch_frames = residual_frames.to(device) # Infer - logits = infer(batch_frames, checkpoint).detach() + inferred = infer(batch_frames, checkpoint).detach() + num_inferred_unsaved += len(batch_frames) + + if decoder == 'argmax': + + # Postprocess partial file + results = postprocess(inferred, fmin, fmax, decoder) + pitch = torch.cat((pitch, results[1].cpu()), dim=1) + periodicity = torch.cat( + (periodicity, results[2].cpu()), + dim=1) - # Postprocess - results = postprocess(logits, fmin, fmax) + else: - # Append to residual - pitch = torch.cat((pitch, results[1].cpu()), dim=1) - periodicity = torch.cat((periodicity, results[2].cpu()), dim=1) + # Save logits for later decoding + # NOTE - This differs from from_audio and does not + # handle large files that do not fit on GPU. + # However, it saves a GPU -> CPU -> GPU copy. + logits = torch.cat((logits, inferred), dim=0) # Save i = 0 for length, file in zip(residual_lengths, residual_files): + if decoder == 'argmax': + + # Slice results + save_pitch = pitch[:, i:i + length] + save_periodicity = periodicity[:, i:i + length] + + else: + + # Postprocess full file + _, save_pitch, save_periodicity = postprocess( + logits[i:i + length], + fmin, + fmax, + decoder) + # Slice and save in another process - if i + length <= pitch.shape[-1]: + if i + length <= num_inferred_unsaved: futures.append( pool.apply_async( save_worker, args=( output_prefixes[file], - pitch[:, i:i + length], - periodicity[:, i:i + length], + save_pitch, + save_periodicity, interp_unvoiced_at))) while len(futures) > 100: futures = [f for f in futures if not f.ready()] @@ -443,8 +564,23 @@ def infer(frames, checkpoint=None): return logits -def postprocess(logits, fmin=penn.FMIN, fmax=penn.FMAX): +def postprocess(logits, fmin=penn.FMIN, fmax=penn.FMAX, decoder=penn.DECODER): """Convert model output to pitch and periodicity""" + # Cache decoder + if ( + not hasattr(postprocess, 'decoder') or + postprocess.decoder_name != decoder + ): + if decoder == 'argmax': + postprocess.decoder = penn.decode.Argmax() + elif decoder == 'pyin': + postprocess.decoder = penn.decode.PYIN() + elif decoder == 'viterbi': + postprocess.decoder = penn.decode.Viterbi() + else: + raise ValueError(f'Decoder method {decoder} is not defined') + postprocess.decoder_name = decoder + # Turn off gradients with torch.inference_mode(): @@ -459,14 +595,7 @@ def postprocess(logits, fmin=penn.FMIN, fmax=penn.FMAX): logits[:, maxidx:] = -float('inf') # Decode pitch from logits - if penn.DECODER == 'argmax': - bins, pitch = penn.decode.argmax(logits) - elif penn.DECODER.startswith('viterbi'): - bins, pitch = penn.decode.viterbi(logits) - elif penn.DECODER == 'local_expected_value': - bins, pitch = penn.decode.local_expected_value(logits) - else: - raise ValueError(f'Decoder method {penn.DECODER} is not defined') + bins, pitch = postprocess.decoder(logits) # Decode periodicity from logits if penn.PERIODICITY == 'entropy': diff --git a/penn/decode.py b/penn/decode.py index bc787eb..2cd4fb1 100644 --- a/penn/decode.py +++ b/penn/decode.py @@ -1,114 +1,164 @@ +import abc +import functools + import numpy as np +import torbi import torch import penn ############################################################################### -# Decode pitch contour from logits of pitch posteriorgram +# Base pitch posteriorgram decoder +############################################################################### + + +class Decoder(abc.ABC): + """Base decoder""" + + def __init__(self, local_expected_value=True): + self.local_expected_value = local_expected_value + + @abc.abstractmethod + def __call__(self, logits): + """Perform decoding""" + pass + + +############################################################################### +# Derived pitch posteriorgram decoders ############################################################################### -def argmax(logits): +class Argmax(Decoder): """Decode pitch using argmax""" - # Get pitch bins - bins = logits.argmax(dim=1) - # Convert to hz - pitch = penn.convert.bins_to_frequency(bins) + def __init__(self, local_expected_value=penn.LOCAL_EXPECTED_VALUE): + super().__init__(local_expected_value) - return bins, pitch + def __call__(self, logits): + # Get pitch bins + bins = logits.argmax(dim=1) + # Convert to frequency in Hz + if self.local_expected_value: -def viterbi(logits): - """Decode pitch using viterbi decoding (from librosa)""" - import librosa + # Decode using an assumption of normality around the argmax path + pitch = local_expected_value_from_bins(bins, logits) - # Normalize and convert to numpy - if penn.METHOD == 'pyin': + else: + + # Linearly interpolate unvoiced regions + pitch = penn.convert.bins_to_frequency(bins) + + return bins, pitch + + +class PYIN(Decoder): + """Decode pitch via peak picking + Viterbi. Used by PYIN.""" + + def __init__(self, local_expected_value=False): + super().__init__(local_expected_value) + + def __call__(self, logits): + """PYIN decoding""" periodicity = penn.periodicity.sum(logits).T unvoiced = ( (1 - periodicity) / penn.PITCH_BINS).repeat(penn.PITCH_BINS, 1) distributions = torch.cat( (torch.exp(logits.permute(2, 1, 0)), unvoiced[None]), - dim=1).numpy() - else: + dim=1) - # Viterbi REQUIRES a categorical distribution, even if the loss was BCE - distributions = torch.nn.functional.softmax(logits, dim=1) - distributions = distributions.permute(2, 1, 0) - distributions = distributions.to( - device=torch.device('cpu'), - dtype=torch.float32 - ).numpy() - - # Cache viterbi probabilities - if not hasattr(viterbi, 'transition'): - # Get number of bins per frame - bins_per_octave = penn.OCTAVE / penn.CENTS_PER_BIN - max_octaves_per_frame = \ - penn.MAX_OCTAVES_PER_SECOND * penn.HOPSIZE / penn.SAMPLE_RATE - max_bins_per_frame = max_octaves_per_frame * bins_per_octave + 1 - - # Construct the within voicing transition probabilities - viterbi.transition = librosa.sequence.transition_local( - penn.PITCH_BINS, - max_bins_per_frame, - window='triangle', - wrap=False) - - if penn.METHOD == 'pyin': - - # Add unvoiced probabilities - viterbi.transition = np.kron( - librosa.sequence.transition_loop(2, .99), - viterbi.transition) - - # Uniform initial probabilities - viterbi.initial = np.zeros(2 * penn.PITCH_BINS) - viterbi.initial[penn.PITCH_BINS:] = 1 / penn.PITCH_BINS + # Viterbi decoding + gpu = ( + None if distributions.device.type == 'cpu' + else distributions.device.index) + bins = torbi.from_probabilities( + observation=distributions[0].T.unsqueeze(dim=0), + transition=self.transition, + initial=self.initial, + gpu=gpu) + + # Convert to frequency in Hz + if self.local_expected_value: + + # Decode using an assumption of normality around the viterbi path + pitch = local_expected_value_from_bins(bins.T, logits).T else: - # Uniform initial probabilities - viterbi.initial = np.full(penn.PITCH_BINS, 1 / penn.PITCH_BINS) + # Argmax decoding + pitch = penn.convert.bins_to_frequency(bins) - # Viterbi decoding - bins = librosa.sequence.viterbi( - distributions, - viterbi.transition, - p_init=viterbi.initial) - bins = torch.from_numpy(bins.astype(np.int32)) + # Linearly interpolate unvoiced regions + pitch[bins >= penn.PITCH_BINS] = 0 + pitch = penn.data.preprocess.interpolate_unvoiced(pitch.numpy())[0] + pitch = torch.from_numpy(pitch).to(logits.device) + bins = bins.to(logits.device) - # Convert to frequency in Hz - if penn.DECODER.endswith('normal'): + return bins.T, pitch.T - # Decode using an assumption of normality around to the viterbi path - pitch = local_expected_value_from_bins( - bins.T.to(logits.device), - logits).T + @functools.cached_property + def initial(self): + """Create initial probability matrix for PYIN""" + initial = torch.zeros(2 * penn.PITCH_BINS) + initial[penn.PITCH_BINS:] = 1 / penn.PITCH_BINS + return initial - else: + @functools.cached_property + def transition(self): + """Create the Viterbi transition matrix for PYIN""" + transition = triangular_transition_matrix() - # Argmax decoding - pitch = penn.convert.bins_to_frequency(bins) + # Add unvoiced probabilities + transition = torch.kron( + torch.tensor([[.99, .01], [.01, .99]]), + transition) + return transition - if penn.METHOD == 'pyin': - # Linearly interpolate unvoiced regions - pitch[bins >= penn.PITCH_BINS] = 0 - pitch = penn.data.preprocess.interpolate_unvoiced(pitch.numpy())[0] - pitch = torch.from_numpy(pitch) +class Viterbi(Decoder): - return bins.T, pitch.T + def __init__(self, local_expected_value=True): + super().__init__(local_expected_value) + def __call__(self, logits): + """Decode pitch using viterbi decoding (from librosa)""" + distributions = torch.nn.functional.softmax(logits, dim=1) + distributions = distributions.permute(2, 1, 0) # F x C x 1 -> 1 x C x F -def local_expected_value(logits, window=penn.LOCAL_PITCH_WINDOW_SIZE): - """Decode pitch using a normal assumption around the argmax""" - # Get center bins - bins = logits.argmax(dim=1) + # Viterbi decoding + gpu = ( + None if distributions.device.type == 'cpu' + else distributions.device.index) + bins = torbi.from_probabilities( + observation=distributions[0].T.unsqueeze(dim=0), + transition=self.transition, + initial=self.initial, + gpu=gpu) - return bins, local_expected_value_from_bins(bins, logits, window) + # Convert to frequency in Hz + if self.local_expected_value: + + # Decode using an assumption of normality around the viterbi path + pitch = local_expected_value_from_bins(bins.T, logits).T + + else: + + # Argmax decoding + pitch = penn.convert.bins_to_frequency(bins) + + return bins.T, pitch.T + + @functools.cached_property + def initial(self): + """Create uniform initial probabilities""" + return torch.full((penn.PITCH_BINS,), 1 / penn.PITCH_BINS) + + @functools.cached_property + def transition(self): + """Create Viterbi transition probability matrix""" + return triangular_transition_matrix() ############################################################################### @@ -157,3 +207,17 @@ def local_expected_value_from_bins( # Decode using local expected value return expected_value(torch.gather(padded, 1, indices), cents) + + +def triangular_transition_matrix(): + """Create a triangular distribution transition matrix""" + xx, yy = torch.meshgrid( + torch.arange(penn.PITCH_BINS), + torch.arange(penn.PITCH_BINS), + indexing='ij') + bins_per_octave = penn.OCTAVE / penn.CENTS_PER_BIN + max_octaves_per_frame = \ + penn.MAX_OCTAVES_PER_SECOND * penn.HOPSIZE / penn.SAMPLE_RATE + max_bins_per_frame = max_octaves_per_frame * bins_per_octave + 1 + transition = torch.clip(max_bins_per_frame - (xx - yy).abs(), 0) + return transition / transition.sum(dim=1, keepdims=True) diff --git a/setup.py b/setup.py index bdbb025..642ff32 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='penn', description='Pitch Estimating Neural Networks (PENN)', - version='0.0.14', + version='0.1.0', author='Max Morrison, Caedon Hsieh, Nathan Pruyne, and Bryan Pardo', author_email='interactiveaudiolab@gmail.com', url='https://github.com/interactiveaudiolab/penn', @@ -28,6 +28,7 @@ install_requires=[ 'huggingface_hub', # 0.11.1 'numpy', # 1.23.4 + 'torbi', # 0.0.1 'torch', # 1.12.1+cu113 'torchaudio', # 0.12.1+cu113 'torchutil', # 0.0.7