Skip to content

Commit

Permalink
Make torch_stft_fb.py encoder onnxable (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Mar 12, 2021
1 parent 778e541 commit 3510292
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions asteroid_filterbanks/torch_stft_fb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import warnings
from typing import List
from typing import Tuple
import torch
import torch.nn.functional as F
from asteroid_filterbanks import STFTFB, Encoder, Decoder
from asteroid_filterbanks import STFTFB
from asteroid_filterbanks.scripting import script_if_tracing
import numpy as np


class TorchSTFTFB(STFTFB):
Expand Down Expand Up @@ -99,7 +98,7 @@ def pre_analysis(self, wav):
"""Centers the frames if `center` is True."""
if not self.center:
return wav
pad_shape = [self.kernel_size // 2, self.kernel_size // 2]
pad_shape = (self.kernel_size // 2, self.kernel_size // 2)
wav = pad_all_shapes(wav, pad_shape=pad_shape, mode=self.pad_mode)
return wav

Expand Down Expand Up @@ -178,12 +177,14 @@ def square_ola(window: torch.Tensor, kernel_size: int, stride: int, n_frame: int


@script_if_tracing
def pad_all_shapes(x: torch.Tensor, pad_shape: List[int], mode: str = "reflect") -> torch.Tensor:
def pad_all_shapes(
x: torch.Tensor, pad_shape: Tuple[int, int], mode: str = "reflect"
) -> torch.Tensor:
if x.ndim == 1:
return F.pad(x[None, None], pad=pad_shape, mode=mode).squeeze(0).squeeze(0)
if x.ndim == 2:
return F.pad(x[None], pad=pad_shape, mode=mode).squeeze(0)
if x.ndim == 3:
return F.pad(x, pad=pad_shape, mode=mode)
pad_shape = [pad_shape[0]] + [0 for _ in range(x.ndim - 1)]
pad_shape = (pad_shape[0],) + (0,) * (x.ndim - 1)
return F.pad(x, pad=pad_shape, mode=mode)

0 comments on commit 3510292

Please sign in to comment.