Skip to content

Commit

Permalink
Merge pull request #9 from ramppdev/fix/global-warnings-suppression
Browse files Browse the repository at this point in the history
Fix Global Warnings Suppression
  • Loading branch information
KentoNishi authored Sep 25, 2024
2 parents 04cf7af + 81d529b commit a9e5a91
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
8 changes: 6 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import torch
import numpy as np
import torch
from scipy.io import wavfile

from torch_pitch_shift import *

# read an audio file
SAMPLE_RATE, sample = wavfile.read("./wavs/test.wav")

# convert to tensor of shape (batch_size, channels, samples)
dtype = sample.dtype

sample = torch.tensor(
[np.swapaxes(sample, 0, 1)], # (samples, channels) --> (channels, samples)
np.expand_dims(
np.swapaxes(sample, 0, 1), 0
), # (samples, channels) --> (channels, samples)
dtype=torch.float32,
device="cuda" if torch.cuda.is_available() else "cpu",
)
Expand Down
14 changes: 10 additions & 4 deletions torch_pitch_shift/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from collections import Counter
from fractions import Fraction
from functools import reduce
Expand All @@ -13,7 +12,6 @@
from primePy import primes
from torch.nn.functional import pad

warnings.simplefilter("ignore")

# https://stackoverflow.com/a/46623112/9325832
def _combinations_without_repetition(r, iterable=None, values=None, counts=None):
Expand Down Expand Up @@ -116,6 +114,7 @@ def pitch_shift(
bins_per_octave: Optional[int] = 12,
n_fft: Optional[int] = 0,
hop_length: Optional[int] = 0,
window: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Shift the pitch of a batch of waveforms by a given amount.
Expand All @@ -135,6 +134,8 @@ def pitch_shift(
Size of FFT. Default is `sample_rate // 64`.
hop_length: int [optional]
Size of hop length. Default is `n_fft // 32`.
window: torch.Tensor [optional]
A window tensor for the STFT. Default is a tensor of ones.
Returns
-------
Expand All @@ -146,19 +147,24 @@ def pitch_shift(
n_fft = sample_rate // 64
if not hop_length:
hop_length = n_fft // 32
if window is None:
window = torch.ones(n_fft)
window = window.to(input.device)
batch_size, channels, samples = input.shape
if not isinstance(shift, Fraction):
shift = 2.0 ** (float(shift) / bins_per_octave)
resampler = T.Resample(sample_rate, int(sample_rate / shift)).to(input.device)
output = input
output = output.reshape(batch_size * channels, samples)
v011 = version.parse(torchaudio.__version__) >= version.parse("0.11.0")
output = torch.stft(output, n_fft, hop_length, return_complex=v011)[None, ...]
output = torch.stft(output, n_fft, hop_length, return_complex=v011, window=window)[
None, ...
]
stretcher = T.TimeStretch(
fixed_rate=float(1 / shift), n_freq=output.shape[2], hop_length=hop_length
).to(input.device)
output = stretcher(output)
output = torch.istft(output[0], n_fft, hop_length)
output = torch.istft(output[0], n_fft, hop_length, window=window)
output = resampler(output)
del resampler, stretcher
if output.shape[1] >= input.shape[2]:
Expand Down

0 comments on commit a9e5a91

Please sign in to comment.