Skip to content

Commit 4c8bb6f

Browse files
feat: add first working implementation
1 parent 9c6ab6e commit 4c8bb6f

File tree

3 files changed

+155
-21
lines changed

3 files changed

+155
-21
lines changed

README.md

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

2-
# CQT - PyTorch
2+
# CQT - PyTorch
33

4-
An invertible and differentiable implementation of the Constant-Q Transform (CQT), in PyTorch.
4+
An invertible and differentiable implementation of the Constant-Q Transform (CQT) using Non-stationary Gabor Transform (NSGT), in PyTorch.
55

66
```bash
7-
pip install cqt-pytorch
7+
pip install cqt-pytorch
88
```
99
[![PyPI - Python Version](https://img.shields.io/pypi/v/cqt-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/cqt-pytorch/)
1010

@@ -15,7 +15,33 @@ pip install cqt-pytorch
1515
from cqt_pytorch import CQT
1616

1717
transform = CQT(
18-
...
18+
num_octaves = 7,
19+
num_bins_per_octave = 65,
20+
sample_rate = 48000,
21+
block_length = 2 ** 18
1922
)
2023

24+
# (Random) audio waveform tensor x
25+
x = torch.randn(1, 2, 2**18) # [1, 1, 262144] = [batch_size, channels, timesteps]
26+
z = transform.encode(x) # [1, 2, 455, 2796] = [batch_size, channels, frequencies, time]
27+
y = transform.decode(z) # [1, 1, 262144]
28+
```
29+
30+
## TODO
31+
[ ] Understand why/if inverse window is necessary.
32+
[ ] Allow variable audio lengths by chunking.
33+
34+
## Appreciation
35+
Special thanks to [Eloi Moliner](https://github.com/eloimoliner) for taking the time to help me understand how CQT works. Check out his own implementation with interesting features at [eloimoliner/CQT_pytorch](https://github.com/eloimoliner/CQT_pytorch).
36+
37+
## Citations
38+
39+
```bibtex
40+
@article{1210.0084,
41+
Author = {Nicki Holighaus and Monika Dörfler and Gino Angelo Velasco and Thomas Grill},
42+
Title = {A framework for invertible, real-time constant-Q transforms},
43+
Year = {2012},
44+
Eprint = {arXiv:1210.0084},
45+
Doi = {10.1109/TASL.2012.2234114},
46+
}
2147
```

cqt_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .cqt import CQT

cqt_pytorch/cqt.py

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,142 @@
1-
from typing import Optional, TypeVar
1+
from math import floor
22

33
import torch
4-
from typing_extensions import TypeGuard
4+
import torch.nn.functional as F
5+
from torch import Tensor, nn
56

6-
T = TypeVar("T")
77

8+
def get_center_frequencies(
9+
num_octaves: int, num_bins_per_octave: int, sample_rate: int # C # B # Xi_s
10+
) -> Tensor: # Xi_k for k in [1, 2*K+1]
11+
"""Compute log scaled center frequencies tensor"""
12+
frequency_nyquist = sample_rate / 2
13+
frequency_min = frequency_nyquist / (2**num_octaves)
14+
num_bins = num_octaves * num_bins_per_octave # K
15+
# Exponential increase from min to Nyquist
16+
frequencies = frequency_min * (2 ** (torch.arange(num_bins) / num_bins_per_octave))
17+
frequencies_all = torch.cat(
18+
[
19+
frequencies,
20+
torch.tensor([frequency_nyquist]),
21+
# sample_rate - torch.flip(frequencies, dims=[0]) # not necessary
22+
],
23+
dim=0,
24+
)
25+
return frequencies_all
826

9-
"""
10-
Utils
11-
"""
1227

28+
def get_bandwidths(
29+
num_octaves: int, # C
30+
num_bins_per_octave: int, # B
31+
sample_rate: int, # Xi_s
32+
frequencies: Tensor, # Xi_k for k in [1, 2*K+1]
33+
) -> Tensor: # Omega_k for k in [1, 2*K+1]
34+
"""Compute bandwidths tensor from center frequencies"""
35+
num_bins = num_octaves * num_bins_per_octave # K
36+
q_factor = 1.0 / (
37+
2 ** (1.0 / num_bins_per_octave) - 2 ** (-1.0 / num_bins_per_octave)
38+
)
39+
bandwidths = frequencies[1 : num_bins + 1] / q_factor
40+
bandwidths_symmetric = (
41+
torch.flip(frequencies[1 : num_bins + 1], dims=[0]) / q_factor
42+
)
43+
bandwidths_all = torch.cat(
44+
[
45+
bandwidths,
46+
torch.tensor([sample_rate - 2 * frequencies[num_bins]]),
47+
bandwidths_symmetric,
48+
],
49+
dim=0,
50+
)
51+
return bandwidths_all
1352

14-
def exists(val: Optional[T]) -> TypeGuard[T]:
15-
return val is not None
1653

54+
def get_windows_range_indices(lengths: Tensor, positions: Tensor) -> Tensor:
55+
"""Compute windowing tensor of indices"""
56+
num_bins = lengths.shape[0] // 2
57+
max_length = lengths.max()
58+
ranges = []
59+
for i in range(num_bins):
60+
start = positions[i] - max_length
61+
ranges += [torch.arange(start=start, end=start + max_length)] # type: ignore
62+
return torch.stack(ranges, dim=0).long()
1763

18-
"""
19-
CQT
20-
"""
2164

22-
class CQT(nn.Module):
65+
def get_windows(lengths: Tensor) -> Tensor:
66+
"""Compute tensor of stacked (centered) windows"""
67+
num_bins = lengths.shape[0] // 2
68+
max_length = lengths.max()
69+
windows = []
70+
for length in lengths[:num_bins]:
71+
# Pad windows left and right to center them
72+
pad_left = floor(max_length / 2 - length / 2)
73+
pad_right = int(max_length - length - pad_left)
74+
windows += [F.pad(torch.hann_window(int(length)), pad=(pad_left, pad_right))]
75+
return torch.stack(windows, dim=0)
76+
77+
78+
def get_windows_inverse(windows: Tensor, lengths: Tensor) -> Tensor:
79+
num_bins = windows.shape[0]
80+
return torch.einsum("k m, k -> k m", windows**2, lengths[:num_bins])
2381

82+
83+
class CQT(nn.Module):
2484
def __init__(
2585
self,
86+
num_octaves: int,
87+
num_bins_per_octave: int,
88+
sample_rate: int,
89+
block_length: int,
2690
):
27-
super().__init__()
91+
super().__init__()
92+
self.block_length = block_length
93+
94+
frequencies = get_center_frequencies(
95+
num_octaves=num_octaves,
96+
num_bins_per_octave=num_bins_per_octave,
97+
sample_rate=sample_rate,
98+
)
99+
100+
bandwidths = get_bandwidths(
101+
num_octaves=num_octaves,
102+
num_bins_per_octave=num_bins_per_octave,
103+
sample_rate=sample_rate,
104+
frequencies=frequencies,
105+
)
106+
107+
window_lengths = torch.round(bandwidths * block_length / sample_rate)
108+
109+
self.register_buffer(
110+
"windows_range_indices",
111+
get_windows_range_indices(
112+
lengths=window_lengths,
113+
positions=torch.round(frequencies * block_length / sample_rate),
114+
),
115+
)
28116

117+
self.register_buffer("windows", get_windows(lengths=window_lengths))
29118

30-
def encode(self, x: Tensor) -> Tensor:
31-
pass
119+
self.register_buffer(
120+
"windows_inverse",
121+
get_windows_inverse(windows=self.windows, lengths=window_lengths), # type: ignore # noqa
122+
)
32123

124+
def encode(self, waveform: Tensor) -> Tensor:
125+
frequencies = torch.fft.fft(waveform)
126+
crops = frequencies[:, :, self.windows_range_indices]
127+
crops_windowed = torch.einsum("... t k, t k -> ... t k", crops, self.windows)
128+
transform = torch.fft.ifft(crops_windowed)
129+
return transform
33130

34-
def decode(self, x: Tensor) -> Tensor:
35-
pass
131+
def decode(self, transform: Tensor) -> Tensor:
132+
b, c, length = *transform.shape[0:2], self.block_length
133+
crops_windowed = torch.fft.fft(transform)
134+
crops_unwindowed = crops_windowed # TODO crops_unwindowed = torch.einsum('... t k, t k -> ... t k', transformed, self.windows_inverse) # noqa
135+
frequencies = torch.zeros(b, c, length).to(transform)
136+
frequencies.scatter_add_(
137+
dim=-1,
138+
index=self.windows_range_indices.view(-1).expand(b, c, -1) % l, # type: ignore # noqa
139+
src=crops_unwindowed.view(b, c, -1),
140+
)
141+
waveform = torch.fft.ifft(frequencies)
142+
return waveform

0 commit comments

Comments
 (0)