diff --git a/audiomentations/augmentations/transforms.py b/audiomentations/augmentations/transforms.py index 39a1ac77..c347375e 100644 --- a/audiomentations/augmentations/transforms.py +++ b/audiomentations/augmentations/transforms.py @@ -84,17 +84,19 @@ def apply(self, samples, sample_rate): class TimeMask(BasicTransform): """Mask some time band on the spectrogram. Inspired by https://arxiv.org/pdf/1904.08779.pdf """ - def __init__(self, min_band_part=0.0, max_band_part=0.5, p=0.5): + def __init__(self, min_band_part=0.0, max_band_part=0.5, fade=False, p=0.5): """ :param min_band_part: Minimum length of the silent part as a fraction of the total sound length. Float. :param max_band_part: Maximum length of the silent part as a fraction of the total sound length. Float. + :param fade: Bool, Add linear fade in and fade out of the silent part. :param p: """ super().__init__(p) self.min_band_part = min_band_part self.max_band_part = max_band_part + self.fade = fade def apply(self, samples, sample_rate): new_samples = samples.copy() @@ -103,7 +105,12 @@ def apply(self, samples, sample_rate): int(new_samples.shape[0] * self.max_band_part), ) _t0 = random.randint(0, new_samples.shape[0] - _t) - new_samples[_t0 : _t0 + _t] = 0 + mask = np.zeros(_t) + if self.fade: + fade_length = min(int(sample_rate * 0.01), int(_t * 0.1)) + mask[0:fade_length] = np.linspace(1, 0, num=fade_length) + mask[-fade_length:] = np.linspace(0, 1, num=fade_length) + new_samples[_t0 : _t0 + _t] *= mask return new_samples diff --git a/tests/test_time_mask.py b/tests/test_time_mask.py index 009295db..2ac02cdc 100644 --- a/tests/test_time_mask.py +++ b/tests/test_time_mask.py @@ -20,3 +20,35 @@ def test_dynamic_length(self): std_in = np.mean(np.abs(samples_in)) std_out = np.mean(np.abs(samples_out)) self.assertLess(std_out, std_in) + + def test_dynamic_length_with_fade(self): + sample_len = 1024 + samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32) + sample_rate = 16000 + augmenter = Compose( + [TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)] + ) + + samples_out = augmenter(samples=samples_in, sample_rate=sample_rate) + self.assertEqual(samples_out.dtype, np.float32) + self.assertEqual(len(samples_out), sample_len) + + std_in = np.mean(np.abs(samples_in)) + std_out = np.mean(np.abs(samples_out)) + self.assertLess(std_out, std_in) + + def test_dynamic_length_with_fade_short_signal(self): + sample_len = 100 + samples_in = np.random.normal(0, 1, size=sample_len).astype(np.float32) + sample_rate = 16000 + augmenter = Compose( + [TimeMask(min_band_part=0.2, max_band_part=0.5, fade=True, p=1.0)] + ) + + samples_out = augmenter(samples=samples_in, sample_rate=sample_rate) + self.assertEqual(samples_out.dtype, np.float32) + self.assertEqual(len(samples_out), sample_len) + + std_in = np.mean(np.abs(samples_in)) + std_out = np.mean(np.abs(samples_out)) + self.assertLess(std_out, std_in)