Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions convert-to-coreml
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,27 @@ def main():
args = parser.parse_args()

samplerate = 44100
estimator = Estimator(num_instruments=args.num_instruments, checkpoint_path=args.model)
estimator = Estimator(
num_instruments=args.num_instruments,
checkpoint_path=args.model,
use_torch_stft=False,
)
estimator.eval()

# Create sample 'audio' for tracing
wav = torch.zeros(2, int(args.length * samplerate))

# Reproduce the STFT step (which we cannot convert to Core ML, unfortunately)
_, stft_mag = estimator.compute_stft(wav)

print('==> Tracing model')
traced_model = torch.jit.trace(estimator.separator, stft_mag)
out = traced_model(stft_mag)
traced_model = torch.jit.trace(estimator, wav)
out = traced_model(wav)

print('==> Converting to Core ML')
mlmodel = ct.convert(
traced_model,
convert_to='mlprogram',
# TODO: Investigate whether we'd want to make the input shape flexible
# See https://coremltools.readme.io/docs/flexible-inputs
inputs=[ct.TensorType(shape=stft_mag.shape)]
inputs=[ct.TensorType(shape=wav.shape)]
)

output_dir: Path = args.output
Expand Down
7 changes: 6 additions & 1 deletion spleeter_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ def main():
parser.add_argument('-n', '--num-instruments', type=int, default=2, help='The number of stems.')
parser.add_argument('-m', '--model', type=Path, default=ROOT / 'checkpoints' / '2stems' / 'model', help='The path to the model to use.')
parser.add_argument('-o', '--output', type=Path, default=ROOT / 'output' / 'stems', help='The path to the output directory.')
parser.add_argument('--torch-stft', default=True, action=argparse.BooleanOptionalAction, help="Whether to use PyTorch's native STFT.")
parser.add_argument('input', type=Path, help='The path to the input file to process')

args = parser.parse_args()
estimator = Estimator(num_instruments=args.num_instruments, checkpoint_path=args.model)
estimator = Estimator(
num_instruments=args.num_instruments,
checkpoint_path=args.model,
use_torch_stft=args.torch_stft,
)
estimator.eval()

# Load wav audio
Expand Down
46 changes: 39 additions & 7 deletions spleeter_pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,54 @@
from torch import nn

from spleeter_pytorch.separator import Separator
from spleeter_pytorch.util import overlap_and_add

class Estimator(nn.Module):
def __init__(self, num_instruments: int, checkpoint_path: Path):
def __init__(
self,
num_instruments: int,
checkpoint_path: Path,
use_torch_stft: bool=True,
):
super().__init__()

# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.win_length = 4096 # should be a power of two, see https://github.com/tensorflow/tensorflow/blob/6935c8f706dde1906e388b3142906c92cdcc36db/tensorflow/python/ops/signal/spectral_ops.py#L48-L49
self.hop_length = 1024
self.win = nn.Parameter(
torch.hann_window(self.win_length),
requires_grad=False
)

self.separator = Separator(num_instruments=num_instruments, checkpoint_path=checkpoint_path)
self.use_torch_stft = use_torch_stft

def compute_stft(self, wav):
def compute_stft(self, wav: torch.Tensor):
"""
Computes stft feature from wav

Args:
wav (Tensor): B x L
"""

stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
center=True, return_complex=True, pad_mode='constant')
if self.use_torch_stft:
stft = torch.stft(
wav,
n_fft=self.win_length,
hop_length=self.hop_length,
window=self.win,
center=True,
return_complex=True,
pad_mode='constant'
)
else:
L = wav.shape[-1]
framed_wav = wav.unfold(-1, size=self.win_length, step=self.hop_length)
framed_wav *= self.win
stft = torch.fft.rfft(framed_wav, self.win_length)
stft = stft.transpose(1, 2)

# only keep freqs smaller than self.F
stft = stft[:, :self.F, :]
Expand All @@ -45,8 +66,19 @@ def inverse_stft(self, stft):
pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
stft = torch.view_as_complex(stft)
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
window=self.win)
if self.use_torch_stft:
wav = torch.istft(
stft,
self.win_length,
hop_length=self.hop_length,
center=True,
window=self.win
)
else:
stft = stft.transpose(1, 2)
wav: torch.Tensor = torch.fft.irfft(stft, self.win_length)
wav *= self.win
wav = overlap_and_add(wav, self.hop_length)
return wav.detach()

def forward(self, wav):
Expand Down
6 changes: 3 additions & 3 deletions spleeter_pytorch/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class CustomPad(nn.Module):
def __init__(self, padding_setting=(1, 2, 1, 2)):
super(CustomPad, self).__init__()
super().__init__()
self.padding_setting = padding_setting

def forward(self, x):
Expand All @@ -14,7 +14,7 @@ def forward(self, x):

class CustomTransposedPad(nn.Module):
def __init__(self, padding_setting=(1, 2, 1, 2)):
super(CustomTransposedPad, self).__init__()
super().__init__()
self.padding_setting = padding_setting

def forward(self, x):
Expand Down Expand Up @@ -45,7 +45,7 @@ def up_block(in_filters, out_filters, dropout=False):

class UNet(nn.Module):
def __init__(self, in_channels=2):
super(UNet, self).__init__()
super().__init__()
self.down1_conv, self.down1_act = down_block(in_channels, 16)
self.down2_conv, self.down2_act = down_block(16, 32)
self.down3_conv, self.down3_act = down_block(32, 64)
Expand Down
43 changes: 43 additions & 0 deletions spleeter_pytorch/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import numpy as np
import tensorflow as tf
import torch

from pathlib import Path

Expand Down Expand Up @@ -76,3 +78,44 @@ def tf2pytorch(checkpoint_path: Path, num_instruments: int):
conv_idx += 1

return outputs

# Source: https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
# MIT-licensed, Copyright (c) 2018 Kaituo XU

def overlap_and_add(signal: torch.Tensor, frame_step: int):
'''
Reconstructs a signal from a framed representation.
Adds potentially overlapping frames of a signal with shape
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
The resulting tensor has shape `[..., output_size]` where
output_size = (frames - 1) * frame_step + frame_length

Args:
signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.

Returns:
A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
output_size = (frames - 1) * frame_step + frame_length

Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
'''
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]

subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length

subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)

frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
frame = frame.contiguous().view(-1)

result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result