Skip to content

Commit

Permalink
Merge pull request #24 from askskro/master
Browse files Browse the repository at this point in the history
Implement smooth fade-out and fade-in in TimeMask
  • Loading branch information
iver56 authored Nov 1, 2019
2 parents c682995 + c83640f commit f939ec3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
11 changes: 9 additions & 2 deletions audiomentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,19 @@ def apply(self, samples, sample_rate):
class TimeMask(BasicTransform):
"""Mask some time band on the spectrogram. Inspired by https://arxiv.org/pdf/1904.08779.pdf """

def __init__(self, min_band_part=0.0, max_band_part=0.5, p=0.5):
def __init__(self, min_band_part=0.0, max_band_part=0.5, fade=False, p=0.5):
"""
:param min_band_part: Minimum length of the silent part as a fraction of the
total sound length. Float.
:param max_band_part: Maximum length of the silent part as a fraction of the
total sound length. Float.
:param fade: Bool, Add linear fade in and fade out of the silent part.
:param p:
"""
super().__init__(p)
self.min_band_part = min_band_part
self.max_band_part = max_band_part
self.fade = fade

def apply(self, samples, sample_rate):
new_samples = samples.copy()
Expand All @@ -103,7 +105,12 @@ def apply(self, samples, sample_rate):
int(new_samples.shape[0] * self.max_band_part),
)
_t0 = random.randint(0, new_samples.shape[0] - _t)
new_samples[_t0 : _t0 + _t] = 0
mask = np.zeros(_t)
if self.fade:
fade_length = min(int(sample_rate * 0.01), int(_t * 0.1))
mask[0:fade_length] = np.linspace(1, 0, num=fade_length)
mask[-fade_length:] = np.linspace(0, 1, num=fade_length)
new_samples[_t0 : _t0 + _t] *= mask
return new_samples


Expand Down
32 changes: 32 additions & 0 deletions tests/test_time_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,35 @@ def test_dynamic_length(self):
std_in = np.mean(np.abs(samples_in))
std_out = np.mean(np.abs(samples_out))
self.assertLess(std_out, std_in)

def test_dynamic_length_with_fade(self):
sample_len = 1024
samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32)
sample_rate = 16000
augmenter = Compose(
[TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)]
)

samples_out = augmenter(samples=samples_in, sample_rate=sample_rate)
self.assertEqual(samples_out.dtype, np.float32)
self.assertEqual(len(samples_out), sample_len)

std_in = np.mean(np.abs(samples_in))
std_out = np.mean(np.abs(samples_out))
self.assertLess(std_out, std_in)

def test_dynamic_length_with_fade_short_signal(self):
sample_len = 100
samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32)
sample_rate = 16000
augmenter = Compose(
[TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)]
)

samples_out = augmenter(samples=samples_in, sample_rate=sample_rate)
self.assertEqual(samples_out.dtype, np.float32)
self.assertEqual(len(samples_out), sample_len)

std_in = np.mean(np.abs(samples_in))
std_out = np.mean(np.abs(samples_out))
self.assertLess(std_out, std_in)

0 comments on commit f939ec3

Please sign in to comment.