1
- from typing import Optional , TypeVar
1
+ from math import floor
2
2
3
3
import torch
4
- from typing_extensions import TypeGuard
4
+ import torch .nn .functional as F
5
+ from torch import Tensor , nn
5
6
6
- T = TypeVar ("T" )
7
7
8
+ def get_center_frequencies (
9
+ num_octaves : int , num_bins_per_octave : int , sample_rate : int # C # B # Xi_s
10
+ ) -> Tensor : # Xi_k for k in [1, 2*K+1]
11
+ """Compute log scaled center frequencies tensor"""
12
+ frequency_nyquist = sample_rate / 2
13
+ frequency_min = frequency_nyquist / (2 ** num_octaves )
14
+ num_bins = num_octaves * num_bins_per_octave # K
15
+ # Exponential increase from min to Nyquist
16
+ frequencies = frequency_min * (2 ** (torch .arange (num_bins ) / num_bins_per_octave ))
17
+ frequencies_all = torch .cat (
18
+ [
19
+ frequencies ,
20
+ torch .tensor ([frequency_nyquist ]),
21
+ # sample_rate - torch.flip(frequencies, dims=[0]) # not necessary
22
+ ],
23
+ dim = 0 ,
24
+ )
25
+ return frequencies_all
8
26
9
- """
10
- Utils
11
- """
12
27
28
+ def get_bandwidths (
29
+ num_octaves : int , # C
30
+ num_bins_per_octave : int , # B
31
+ sample_rate : int , # Xi_s
32
+ frequencies : Tensor , # Xi_k for k in [1, 2*K+1]
33
+ ) -> Tensor : # Omega_k for k in [1, 2*K+1]
34
+ """Compute bandwidths tensor from center frequencies"""
35
+ num_bins = num_octaves * num_bins_per_octave # K
36
+ q_factor = 1.0 / (
37
+ 2 ** (1.0 / num_bins_per_octave ) - 2 ** (- 1.0 / num_bins_per_octave )
38
+ )
39
+ bandwidths = frequencies [1 : num_bins + 1 ] / q_factor
40
+ bandwidths_symmetric = (
41
+ torch .flip (frequencies [1 : num_bins + 1 ], dims = [0 ]) / q_factor
42
+ )
43
+ bandwidths_all = torch .cat (
44
+ [
45
+ bandwidths ,
46
+ torch .tensor ([sample_rate - 2 * frequencies [num_bins ]]),
47
+ bandwidths_symmetric ,
48
+ ],
49
+ dim = 0 ,
50
+ )
51
+ return bandwidths_all
13
52
14
- def exists (val : Optional [T ]) -> TypeGuard [T ]:
15
- return val is not None
16
53
54
+ def get_windows_range_indices (lengths : Tensor , positions : Tensor ) -> Tensor :
55
+ """Compute windowing tensor of indices"""
56
+ num_bins = lengths .shape [0 ] // 2
57
+ max_length = lengths .max ()
58
+ ranges = []
59
+ for i in range (num_bins ):
60
+ start = positions [i ] - max_length
61
+ ranges += [torch .arange (start = start , end = start + max_length )] # type: ignore
62
+ return torch .stack (ranges , dim = 0 ).long ()
17
63
18
- """
19
- CQT
20
- """
21
64
22
- class CQT (nn .Module ):
65
+ def get_windows (lengths : Tensor ) -> Tensor :
66
+ """Compute tensor of stacked (centered) windows"""
67
+ num_bins = lengths .shape [0 ] // 2
68
+ max_length = lengths .max ()
69
+ windows = []
70
+ for length in lengths [:num_bins ]:
71
+ # Pad windows left and right to center them
72
+ pad_left = floor (max_length / 2 - length / 2 )
73
+ pad_right = int (max_length - length - pad_left )
74
+ windows += [F .pad (torch .hann_window (int (length )), pad = (pad_left , pad_right ))]
75
+ return torch .stack (windows , dim = 0 )
76
+
77
+
78
+ def get_windows_inverse (windows : Tensor , lengths : Tensor ) -> Tensor :
79
+ num_bins = windows .shape [0 ]
80
+ return torch .einsum ("k m, k -> k m" , windows ** 2 , lengths [:num_bins ])
23
81
82
+
83
+ class CQT (nn .Module ):
24
84
def __init__ (
25
85
self ,
86
+ num_octaves : int ,
87
+ num_bins_per_octave : int ,
88
+ sample_rate : int ,
89
+ block_length : int ,
26
90
):
27
- super ().__init__ ()
91
+ super ().__init__ ()
92
+ self .block_length = block_length
93
+
94
+ frequencies = get_center_frequencies (
95
+ num_octaves = num_octaves ,
96
+ num_bins_per_octave = num_bins_per_octave ,
97
+ sample_rate = sample_rate ,
98
+ )
99
+
100
+ bandwidths = get_bandwidths (
101
+ num_octaves = num_octaves ,
102
+ num_bins_per_octave = num_bins_per_octave ,
103
+ sample_rate = sample_rate ,
104
+ frequencies = frequencies ,
105
+ )
106
+
107
+ window_lengths = torch .round (bandwidths * block_length / sample_rate )
108
+
109
+ self .register_buffer (
110
+ "windows_range_indices" ,
111
+ get_windows_range_indices (
112
+ lengths = window_lengths ,
113
+ positions = torch .round (frequencies * block_length / sample_rate ),
114
+ ),
115
+ )
28
116
117
+ self .register_buffer ("windows" , get_windows (lengths = window_lengths ))
29
118
30
- def encode (self , x : Tensor ) -> Tensor :
31
- pass
119
+ self .register_buffer (
120
+ "windows_inverse" ,
121
+ get_windows_inverse (windows = self .windows , lengths = window_lengths ), # type: ignore # noqa
122
+ )
32
123
124
+ def encode (self , waveform : Tensor ) -> Tensor :
125
+ frequencies = torch .fft .fft (waveform )
126
+ crops = frequencies [:, :, self .windows_range_indices ]
127
+ crops_windowed = torch .einsum ("... t k, t k -> ... t k" , crops , self .windows )
128
+ transform = torch .fft .ifft (crops_windowed )
129
+ return transform
33
130
34
- def decode (self , x : Tensor ) -> Tensor :
35
- pass
131
+ def decode (self , transform : Tensor ) -> Tensor :
132
+ b , c , length = * transform .shape [0 :2 ], self .block_length
133
+ crops_windowed = torch .fft .fft (transform )
134
+ crops_unwindowed = crops_windowed # TODO crops_unwindowed = torch.einsum('... t k, t k -> ... t k', transformed, self.windows_inverse) # noqa
135
+ frequencies = torch .zeros (b , c , length ).to (transform )
136
+ frequencies .scatter_add_ (
137
+ dim = - 1 ,
138
+ index = self .windows_range_indices .view (- 1 ).expand (b , c , - 1 ) % l , # type: ignore # noqa
139
+ src = crops_unwindowed .view (b , c , - 1 ),
140
+ )
141
+ waveform = torch .fft .ifft (frequencies )
142
+ return waveform
0 commit comments