Skip to content

Commit 03103f1

Browse files
committed
Refactored STFT class and added unit tests
1 parent 36f468d commit 03103f1

File tree

3 files changed

+289
-44
lines changed

3 files changed

+289
-44
lines changed

audio_separator/separator/stft.py

+111-44
Original file line numberDiff line numberDiff line change
@@ -9,50 +9,117 @@ def __init__(self, logger, n_fft, hop_length, dim_f, device):
99
self.logger = logger
1010
self.n_fft = n_fft
1111
self.hop_length = hop_length
12-
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
1312
self.dim_f = dim_f
1413
self.device = device
14+
# Create a Hann window tensor for use in the STFT.
15+
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
1516

16-
def __call__(self, x):
17-
x_is_mps = not x.device.type in ["cuda", "cpu"]
18-
if x_is_mps:
19-
x = x.cpu()
20-
21-
initial_shape = x.shape
22-
window = self.window.to(x.device)
23-
batch_dims = x.shape[:-2]
24-
c, t = x.shape[-2:]
25-
x = x.reshape([-1, t])
26-
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False)
27-
x = x.permute([0, 3, 1, 2])
28-
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
29-
30-
if x_is_mps:
31-
x = x.to(self.device)
32-
33-
# self.logger.debug(f"STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}")
34-
return x[..., : self.dim_f, :]
35-
36-
def inverse(self, x):
37-
x_is_mps = not x.device.type in ["cuda", "cpu"]
38-
if x_is_mps:
39-
x = x.cpu()
40-
41-
initial_shape = x.shape
42-
window = self.window.to(x.device)
43-
batch_dims = x.shape[:-3]
44-
c, f, t = x.shape[-3:]
45-
n = self.n_fft // 2 + 1
46-
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
47-
x = torch.cat([x, f_pad], -2)
48-
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
49-
x = x.permute([0, 2, 3, 1])
50-
x = x[..., 0] + x[..., 1] * 1.0j
51-
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
52-
x = x.reshape([*batch_dims, 2, -1])
53-
54-
if x_is_mps:
55-
x = x.to(self.device)
56-
57-
# self.logger.debug(f"Inverse STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}")
58-
return x
17+
def __call__(self, input_tensor):
18+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
19+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
20+
21+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
22+
if is_non_standard_device:
23+
input_tensor = input_tensor.cpu()
24+
25+
# Transfer the pre-defined window tensor to the same device as the input tensor.
26+
stft_window = self.hann_window.to(input_tensor.device)
27+
28+
# Extract batch dimensions (all dimensions except the last two which are channel and time).
29+
batch_dimensions = input_tensor.shape[:-2]
30+
31+
# Extract channel and time dimensions (last two dimensions of the tensor).
32+
channel_dim, time_dim = input_tensor.shape[-2:]
33+
34+
# Reshape the tensor to merge batch and channel dimensions for STFT processing.
35+
reshaped_tensor = input_tensor.reshape([-1, time_dim])
36+
37+
# Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
38+
stft_output = torch.stft(
39+
reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False
40+
)
41+
42+
# Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
43+
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
44+
45+
# Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
46+
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape(
47+
[*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
48+
)
49+
50+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
51+
if is_non_standard_device:
52+
final_output = final_output.to(self.device)
53+
54+
# Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
55+
return final_output[..., : self.dim_f, :]
56+
57+
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
58+
"""
59+
Adds zero padding to the frequency dimension of the input tensor.
60+
"""
61+
# Create a padding tensor for the frequency dimension
62+
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
63+
64+
# Concatenate the padding to the input tensor along the frequency dimension.
65+
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
66+
67+
return padded_tensor
68+
69+
def calculate_inverse_dimensions(self, input_tensor):
70+
# Extract batch dimensions and frequency-time dimensions.
71+
batch_dimensions = input_tensor.shape[:-3]
72+
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
73+
74+
# Calculate the number of frequency bins for the inverse STFT.
75+
num_freq_bins = self.n_fft // 2 + 1
76+
77+
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
78+
79+
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
80+
"""
81+
Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
82+
and creating a complex tensor from the real and imaginary parts.
83+
"""
84+
# Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
85+
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
86+
87+
# Flatten batch dimensions and rearrange for ISTFT.
88+
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
89+
90+
# Rearrange the dimensions of the tensor to bring the frequency dimension forward.
91+
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
92+
93+
# Combine real and imaginary parts into a complex tensor.
94+
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
95+
96+
return complex_tensor
97+
98+
def inverse(self, input_tensor):
99+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
100+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
101+
102+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
103+
if is_non_standard_device:
104+
input_tensor = input_tensor.cpu()
105+
106+
# Transfer the pre-defined Hann window tensor to the same device as the input tensor.
107+
stft_window = self.hann_window.to(input_tensor.device)
108+
109+
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
110+
111+
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
112+
113+
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
114+
115+
# Perform the Inverse Short-Time Fourier Transform (ISTFT).
116+
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
117+
118+
# Reshape ISTFT result to restore original batch and channel dimensions.
119+
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
120+
121+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
122+
if is_non_standard_device:
123+
final_output = final_output.to(self.device)
124+
125+
return final_output

pytest.ini

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Used by PyDub, which uses a pure-python fallback when needed already, not an issue
22
[pytest]
33
filterwarnings =
4+
ignore:stft with return_complex=False is deprecated:UserWarning
45
ignore:'audioop' is deprecated:DeprecationWarning

tests/unit/test_stft.py

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from unittest.mock import Mock, patch
5+
from audio_separator.separator.stft import STFT
6+
7+
# Short-Time Fourier Transform (STFT) Process Overview:
8+
#
9+
# STFT transforms a time-domain signal into a frequency-domain representation.
10+
# This transformation is achieved by dividing the signal into short frames (or segments) and applying the Fourier Transform to each frame.
11+
#
12+
# n_fft: The number of points used in the Fourier Transform, which determines the resolution of the frequency domain representation.
13+
# Essentially, it dictates how many frequency bins we get in our STFT.
14+
#
15+
# hop_length: The number of samples by which we shift each frame of the signal.
16+
# It affects the overlap between consecutive frames. If the hop_length is less than n_fft, we get overlapping frames.
17+
#
18+
# Windowing: Each frame of the signal is multiplied by a window function (e.g. Hann window) before applying the Fourier Transform.
19+
# This is done to minimize discontinuities at the borders of each frame.
20+
21+
22+
class TestSTFT(unittest.TestCase):
23+
def setUp(self):
24+
self.n_fft = 2048
25+
self.hop_length = 512
26+
self.dim_f = 1025
27+
self.device = torch.device("cpu")
28+
self.stft = STFT(logger=Mock(), n_fft=self.n_fft, hop_length=self.hop_length, dim_f=self.dim_f, device=self.device)
29+
30+
def create_mock_tensor(self, shape, device=None):
31+
tensor = torch.rand(shape)
32+
if device:
33+
tensor = tensor.to(device)
34+
return tensor
35+
36+
def test_stft_initialization(self):
37+
self.assertEqual(self.stft.n_fft, self.n_fft)
38+
self.assertEqual(self.stft.hop_length, self.hop_length)
39+
self.assertEqual(self.stft.dim_f, self.dim_f)
40+
self.assertEqual(self.stft.device.type, "cpu")
41+
self.assertIsInstance(self.stft.hann_window, torch.Tensor)
42+
43+
def test_stft_call(self):
44+
input_tensor = self.create_mock_tensor((1, 16000))
45+
46+
# Apply STFT
47+
stft_result = self.stft(input_tensor)
48+
49+
# Test conditions
50+
self.assertIsNotNone(stft_result)
51+
self.assertIsInstance(stft_result, torch.Tensor)
52+
53+
# Calculate the expected shape based on input parameters:
54+
55+
# Frequency Dimension (dim_f): This corresponds to the number of frequency bins in the STFT output.
56+
# In the case of a real-valued input signal (like audio), the Fourier Transform produces a symmetric output.
57+
# Hence, for an n_fft of 2048, we would typically get 2049 frequency bins (from 0 Hz to the Nyquist frequency).
58+
# However, we often don't need the full symmetric spectrum.
59+
# So, dim_f is used to specify how many frequency bins we are interested in.
60+
# In this test, it's set to 1025, which is about half of n_fft + 1 (as the Fourier Transform of a real-valued signal is symmetric).
61+
62+
# Time Dimension: This corresponds to how many frames (or segments) the input signal has been divided into.
63+
# It depends on the length of the input signal and the hop_length.
64+
# The formula for calculating the number of frames is derived from how we stride the window across the signal:
65+
# Length of Input Signal: Let's denote it as L. In this test, the input tensor has a shape of [1, 16000], so L is 16000 (ignoring the batch dimension for simplicity).
66+
# Number of Frames: The number of frames depends on how we stride the window across the signal. For each frame, we move the window by hop_length samples.
67+
# Therefore, the number of frames N_frames can be roughly estimated by dividing the length of the signal by the hop_length.
68+
# However, since the window overlaps the signal, we add an extra frame to account for the last segment of the signal. This gives us N_frames = (L // hop_length) + 1.
69+
70+
# Putting It All Together
71+
# expected_shape thus becomes (dim_f, N_frames), which is (1025, (16000 // 512) + 1) in this test case.
72+
73+
expected_shape = (self.dim_f, (input_tensor.shape[1] // self.hop_length) + 1)
74+
75+
self.assertEqual(stft_result.shape[-2:], expected_shape)
76+
77+
def test_calculate_inverse_dimensions(self):
78+
# Create a sample input tensor
79+
sample_input = torch.randn(1, 2, 500, 32) # Batch, Channel, Frequency, Time dimensions
80+
batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)
81+
82+
# Expected values
83+
expected_num_freq_bins = self.n_fft // 2 + 1
84+
85+
# Assertions
86+
self.assertEqual(batch_dims, sample_input.shape[:-3])
87+
self.assertEqual(channel_dim, 2)
88+
self.assertEqual(freq_dim, 500)
89+
self.assertEqual(time_dim, 32)
90+
self.assertEqual(num_freq_bins, expected_num_freq_bins)
91+
92+
def test_pad_frequency_dimension(self):
93+
# Create a sample input tensor
94+
sample_input = torch.randn(1, 2, 500, 32) # Batch, Channel, Frequency, Time dimensions
95+
batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)
96+
97+
# Apply padding
98+
padded_output = self.stft.pad_frequency_dimension(sample_input, batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins)
99+
100+
# Expected frequency dimension after padding
101+
expected_freq_dim = num_freq_bins
102+
103+
# Assertions
104+
self.assertEqual(padded_output.shape[-2], expected_freq_dim)
105+
106+
def test_prepare_for_istft(self):
107+
# Create a sample input tensor
108+
sample_input = torch.randn(1, 2, 500, 32) # Batch, Channel, Frequency, Time dimensions
109+
batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)
110+
padded_output = self.stft.pad_frequency_dimension(sample_input, batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins)
111+
112+
# Apply prepare_for_istft
113+
complex_tensor = self.stft.prepare_for_istft(padded_output, batch_dims, channel_dim, num_freq_bins, time_dim)
114+
115+
# Calculate the expected flattened batch size (flattening batch and channel dimensions)
116+
expected_flattened_batch_size = batch_dims[0] * (channel_dim // 2)
117+
118+
# Expected shape of the complex tensor
119+
expected_shape = (expected_flattened_batch_size, num_freq_bins, time_dim)
120+
121+
# Assertions
122+
self.assertEqual(complex_tensor.shape, expected_shape)
123+
124+
def test_inverse_device_handling(self):
125+
# Create a mock tensor with the correct input shape
126+
input_tensor = torch.rand(1, 2, 1025, 32) # shape matching output of STFT
127+
128+
# Initialize STFT
129+
stft = STFT(logger=MockLogger(), n_fft=2048, hop_length=512, dim_f=1025, device="cpu")
130+
131+
# Apply inverse STFT
132+
output_tensor = stft.inverse(input_tensor)
133+
134+
# Check if the output tensor is on the CPU
135+
self.assertEqual(output_tensor.device.type, "cpu")
136+
137+
def test_inverse_output_shape(self):
138+
# Create a mock tensor
139+
input_tensor = torch.rand(1, 2, 1025, 32) # shape matching output of STFT
140+
141+
# Initialize STFT
142+
stft = STFT(logger=MockLogger(), n_fft=2048, hop_length=512, dim_f=1025, device="cpu")
143+
144+
# Apply inverse STFT
145+
output_tensor = stft.inverse(input_tensor)
146+
147+
# Expected output shape: (Batch size, Channel dimension, Time dimension)
148+
expected_shape = (1, 2, 7936) # Calculated based on STFT parameters
149+
150+
# Check if the output tensor has the expected shape
151+
self.assertEqual(output_tensor.shape, expected_shape)
152+
153+
def test_stft_with_mps_device(self):
154+
mps_device = torch.device("mps")
155+
self.stft.device = mps_device
156+
input_tensor = self.create_mock_tensor((1, 16000), device=mps_device)
157+
stft_result = self.stft(input_tensor)
158+
self.assertIsNotNone(stft_result)
159+
self.assertIsInstance(stft_result, torch.Tensor)
160+
161+
def test_inverse_with_mps_device(self):
162+
mps_device = torch.device("mps")
163+
self.stft.device = mps_device
164+
input_tensor = self.create_mock_tensor((1, 2, 1025, 32), device=mps_device)
165+
istft_result = self.stft.inverse(input_tensor)
166+
self.assertIsNotNone(istft_result)
167+
self.assertIsInstance(istft_result, torch.Tensor)
168+
169+
170+
# Mock logger to use in tests
171+
class MockLogger:
172+
def debug(self, message):
173+
pass
174+
175+
176+
if __name__ == "__main__":
177+
unittest.main()

0 commit comments

Comments
 (0)