Skip to content

Commit c83640f

Browse files
committed
Move faded time mask transformation from separate function to parameter in regular TimeMask
1 parent f7d3277 commit c83640f

File tree

4 files changed

+41
-70
lines changed

4 files changed

+41
-70
lines changed

audiomentations/augmentations/transforms.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,19 @@ def apply(self, samples, sample_rate):
8484
class TimeMask(BasicTransform):
8585
"""Mask some time band on the spectrogram. Inspired by https://arxiv.org/pdf/1904.08779.pdf """
8686

87-
def __init__(self, min_band_part=0.0, max_band_part=0.5, p=0.5):
87+
def __init__(self, min_band_part=0.0, max_band_part=0.5, fade=False, p=0.5):
8888
"""
8989
:param min_band_part: Minimum length of the silent part as a fraction of the
9090
total sound length. Float.
9191
:param max_band_part: Maximum length of the silent part as a fraction of the
9292
total sound length. Float.
93+
:param fade: Bool, Add linear fade in and fade out of the silent part.
9394
:param p:
9495
"""
9596
super().__init__(p)
9697
self.min_band_part = min_band_part
9798
self.max_band_part = max_band_part
99+
self.fade = fade
98100

99101
def apply(self, samples, sample_rate):
100102
new_samples = samples.copy()
@@ -103,41 +105,12 @@ def apply(self, samples, sample_rate):
103105
int(new_samples.shape[0] * self.max_band_part),
104106
)
105107
_t0 = random.randint(0, new_samples.shape[0] - _t)
106-
new_samples[_t0 : _t0 + _t] = 0
107-
return new_samples
108-
109-
110-
class SmoothFadeTimeMask(BasicTransform):
111-
"""Mask some time band on the spectrogram with fade in and fade out.
112-
113-
Same transformation as TimeMask but with linear smoothing"""
114-
115-
def __init__(self, min_band_part=0.0, max_band_part=0.5, p=0.5):
116-
"""
117-
:param min_band_part: Minimum length of the silent part as a fraction of the
118-
total sound length. Float.
119-
:param max_band_part: Maximum length of the silent part as a fraction of the
120-
total sound length. Float.
121-
:param p:
122-
"""
123-
super().__init__(p)
124-
self.min_band_part = min_band_part
125-
self.max_band_part = max_band_part
126-
127-
def apply(self, samples, sample_rate):
128-
new_samples = samples.copy()
129-
_t = random.randint(
130-
int(new_samples.shape[0] * self.min_band_part),
131-
int(new_samples.shape[0] * self.max_band_part),
132-
)
133-
_t0 = random.randint(0, new_samples.shape[0] - _t)
134-
# fade length is 10 ms or 10% of silent part if silent part is less than 10 ms
135-
fade_length = min(int(sample_rate * 0.01), int(_t * 0.1))
136-
linear_fade_in = np.linspace(0, 1, num=fade_length)
137-
linear_fade_out = np.linspace(1, 0, num=fade_length)
138-
new_samples[_t0 : _t0 + fade_length] *= linear_fade_out
139-
new_samples[_t0 + _t - fade_length : _t0 + _t] *= linear_fade_in
140-
new_samples[_t0 + fade_length : _t0 + _t - fade_length] = 0
108+
mask = np.zeros(_t)
109+
if self.fade:
110+
fade_length = min(int(sample_rate * 0.01), int(_t * 0.1))
111+
mask[0:fade_length] = np.linspace(1, 0, num=fade_length)
112+
mask[-fade_length:] = np.linspace(0, 1, num=fade_length)
113+
new_samples[_t0 : _t0 + _t] *= mask
141114
return new_samples
142115

143116

demo/demo.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
AddImpulseResponse,
1414
FrequencyMask,
1515
TimeMask,
16-
SmoothFadeTimeMask,
1716
AddGaussianSNR,
1817
Resample,
1918
ClippingDistortion,
@@ -75,15 +74,6 @@ def load_wav_file(sound_file_path):
7574
augmented_samples = augmenter(samples=samples, sample_rate=SAMPLE_RATE)
7675
wavfile.write(output_file_path, rate=SAMPLE_RATE, data=augmented_samples)
7776

78-
# SmoothFadeTimeMask
79-
augmenter = Compose([SmoothFadeTimeMask(p=1.0)])
80-
for i in range(5):
81-
output_file_path = os.path.join(
82-
output_dir, "SmoothFadeTimeMask_{:03d}.wav".format(i)
83-
)
84-
augmented_samples = augmenter(samples=samples, sample_rate=SAMPLE_RATE)
85-
wavfile.write(output_file_path, rate=SAMPLE_RATE, data=augmented_samples)
86-
8777
# AddGaussianSNR
8878
augmenter = Compose([AddGaussianSNR(p=1.0)])
8979
for i in range(5):

tests/test_smooth_fade_time_mask.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/test_time_mask.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,35 @@ def test_dynamic_length(self):
2020
std_in = np.mean(np.abs(samples_in))
2121
std_out = np.mean(np.abs(samples_out))
2222
self.assertLess(std_out, std_in)
23+
24+
def test_dynamic_length_with_fade(self):
25+
sample_len = 1024
26+
samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32)
27+
sample_rate = 16000
28+
augmenter = Compose(
29+
[TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)]
30+
)
31+
32+
samples_out = augmenter(samples=samples_in, sample_rate=sample_rate)
33+
self.assertEqual(samples_out.dtype, np.float32)
34+
self.assertEqual(len(samples_out), sample_len)
35+
36+
std_in = np.mean(np.abs(samples_in))
37+
std_out = np.mean(np.abs(samples_out))
38+
self.assertLess(std_out, std_in)
39+
40+
def test_dynamic_length_with_fade_short_signal(self):
41+
sample_len = 100
42+
samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32)
43+
sample_rate = 16000
44+
augmenter = Compose(
45+
[TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)]
46+
)
47+
48+
samples_out = augmenter(samples=samples_in, sample_rate=sample_rate)
49+
self.assertEqual(samples_out.dtype, np.float32)
50+
self.assertEqual(len(samples_out), sample_len)
51+
52+
std_in = np.mean(np.abs(samples_in))
53+
std_out = np.mean(np.abs(samples_out))
54+
self.assertLess(std_out, std_in)

0 commit comments

Comments
 (0)