Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align audio array shape #45

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ If the speech XAI functionalities are needed, then follow these steps:
2. install whisperX with `pip install git+https://github.com/m-bain/whisperx.git`
3. install system-wide [ffmpeg](https://ffmpeg.org/download.html). If you have no sudo rights, you can try with `conda install conda-forge::ffmpeg`

### Testing
For detailed instructions on setting up your environment and running tests, please see our [Testing Guidelines](TESTING.md).


### Explain & Benchmark

Expand Down
47 changes: 47 additions & 0 deletions TESTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Testing

To ensure the quality and functionality of our code, we use automated tests.

## Installation
from PyPI:
```bash
pip install pytest
```

## Running Tests
We use pytest for our tests. Below are the commands for running tests in different scopes:


### All Tests
Run all tests for both speech and text with:
```bash
pytest
```

### Specific Test Files
Run tests for text processing only:
```bash
pytest tests/test_text.py
```

Run tests for speech processing only:
```bash
pytest tests/test_speech.py
```

### Specific Test Methods
Run a specific test method by specifying the test file and method name
(replacing the `test_text.py` with the desired test file and `test_method_name` by the desited test method):
```bash
pytest tests/test_text.py::test_method_name
```

### Clear Cache
We use some caching in our tests. If you encounter issues that might be related to cached test results or configurations, you can clear the pytest cache with:
```bash
pytest --cache-clear
```
This command removes all items from the cache, ensuring that your next test run is completely clean.

## Troubleshooting Common Issues
If tests behave unexpectedly or fail after changes, consider clearing the pytest cache or re-running the tests to verify if the issue persists. Always ensure that your environment matches the required configurations as specified in our setup guidelines.
3 changes: 3 additions & 0 deletions ferret/benchmark_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def explain(
"""
Explain the prediction of the model.
Returns the importance of each segment in the audio.

Note: the `target_class` argument specifies the ID of the target
class.
"""
explainer_args = dict()
# TODO UNIFY THE INPUT FORMAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def compute_explanation(

# if word_timestamps is None:
# # Transcribe audio
word_timestamps = audio.transcription
# word_timestamps = audio.transcription

# Compute gradient importance for each target label
# This also handles the multilabel scenario as for FSC
Expand Down
7 changes: 7 additions & 0 deletions ferret/explainers/explanation_speech/loo_speech_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def remove_words(
- silence
- white noise
- pink noise

Note: in all the manipulations, the sample rate remains that of
the input `audio`!
"""

## Load audio as pydub.AudioSegment
Expand Down Expand Up @@ -61,6 +64,8 @@ def compute_explanation(
) -> ExplanationSpeech:
"""
Computes the importance of each word in the audio.

`target` class should be an integer identifying the class ID.
"""

## Get modified audio by leaving a single word out and the words
Expand All @@ -86,6 +91,8 @@ def compute_explanation(
targets = target_class

else:
# If no target class is passed, the explanation is computed for
# the predicted class.
if n_labels > 1:
# Multilabel scenario as for FSC
targets = [
Expand Down
28 changes: 22 additions & 6 deletions ferret/explainers/explanation_speech/utils_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def remove_word(audio, word, removal_type: str = "nothing"):
- white noise
- pink noise

WARNING: if `word["start"] * 1000 - a` is negative, the audio is actually
traversed FROM SOME POINT UNTIL ITS END (like `l[-10:]`
actually takes the last 10 entries of the list `l`). Therefore if
the difference is negative, we effectively use `a=0`.

Args:
audio (pydub.AudioSegment): audio
word: word to remove with its start and end times
Expand All @@ -65,22 +70,33 @@ def remove_word(audio, word, removal_type: str = "nothing"):

a, b = 100, 40

before_word_audio = audio[: word["start"] * 1000 - a]
after_word_audio = audio[word["end"] * 1000 + b :]
word_duration = (word["end"] * 1000 - word["start"] * 1000) + a + b
# Convert from seconds (as returned by WhisperX) to milliseconds (as
# required to index PyDub `AudioSegment` objects).
word_start_ms = word["start"] * 1000
word_end_ms = word["end"] * 1000

# If we risk reading the audio segment from the end (difference is
# negative) set the offset `a` to zero to avoid that.
if word_start_ms - a < 0:
a = 0

before_word_audio = audio[:word_start_ms - a]
after_word_audio = audio[word_end_ms + b :]
word_duration = (word_end_ms - word_start_ms) + a + b

if removal_type == "nothing":
replace_word_audio = AudioSegment.empty()

elif removal_type == "silence":
replace_word_audio = AudioSegment.silent(duration=word_duration)

elif removal_type == "white noise":
sound_path = (os.path.join(os.path.dirname(__file__), "white_noise.mp3"),)
sound_path = os.path.join(os.path.dirname(__file__), "white_noise.mp3")

replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration]

# display(audio_removed)
elif removal_type == "pink noise":
sound_path = (os.path.join(os.path.dirname(__file__), "pink_noise.mp3"),)
sound_path = os.path.join(os.path.dirname(__file__), "pink_noise.mp3")
replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration]

audio_removed = before_word_audio + replace_word_audio + after_word_audio
Expand Down
15 changes: 14 additions & 1 deletion ferret/modeling/speech_model_helpers/model_helper_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,21 @@ def _predict(

## Predict logits
with torch.no_grad():
# Some feature encoders return the input tensor(s) under the
# `input_values` key, other under the `input_features` one.
if 'input_values' in inputs.keys():
input_features = inputs['input_values'].to(self.device)
elif 'input_features' in inputs.keys():
input_features = inputs['input_features'].to(self.device)
else:
raise Exception(
'Input features not found in inputs dict neither under'
' the `input_values` key, nor under the `input_features`'
' one'
)

logits = (
self.model(inputs.input_values.to(self.device))
self.model(input_features)
.logits.detach()
.cpu()
# .numpy()
Expand Down
8 changes: 4 additions & 4 deletions ferret/speechxai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(

if isinstance(audio_path_or_array, str):
self.array, self.current_sr = librosa.load(
audio_path_or_array, sr=None, dtype=np.float32
audio_path_or_array, sr=None, dtype=np.float32, mono=True
)
elif isinstance(audio_path_or_array, np.ndarray):
if current_sr is None:
Expand Down Expand Up @@ -65,8 +65,8 @@ def resample(self, target_sr: int):
Resample the audio to the target sampling rate. In place operation.
"""
self.array = librosa.resample(
self.array, orig_sr=self.current_sr, target_sr=target_sr
)
self.array.ravel(), orig_sr=self.current_sr, target_sr=target_sr
).reshape(-1, 1)
self.current_sr = target_sr

@staticmethod
Expand Down Expand Up @@ -130,7 +130,7 @@ def transcribe_audio(
## Load whisperx model. TODO: we should definitely avoid loading the model for *every* sample to subscribe

device_type = device.type
device_index = device.index
device_index = device.index if device.index is not None else 0

model_whisperx = whisperx.load_model(
model_name_whisper,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ joblib = "^1.3.2"
pytreebank = "^0.2.7"
thermostat-datasets = "^1.1.0"
ipython = "^8.22.2"
pytest = "^7.4.4"
# Speech-XAI additional requirements to allow for `pip install ferret[speech]`.
pydub = { version = "0.25.1", optional = true }
audiomentations = { version = "0.34.1", optional = true }
Expand Down
Binary file added tests/data/sample_audio.wav
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/test_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ def test_gradient_ner(self):
explanation = exp(text, target="I-LOC", target_token="York")
self.assertTrue("york" in [token.lower() for token in explanation.tokens])
self.assertEqual(explanation.target_pos_idx, 6)

if __name__ == '__main__':
unittest.main()
145 changes: 145 additions & 0 deletions tests/test_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pytest
import torch
import os
import numpy as np
import pandas as pd
from pydub import AudioSegment
from ferret import SpeechBenchmark
from ferret.explainers.explanation_speech.loo_speech_explainer import LOOSpeechExplainer
from ferret.explainers.explanation_speech.gradient_speech_explainer import (
GradientSpeechExplainer,
)
from ferret.explainers.explanation_speech.lime_speech_explainer import (
LIMESpeechExplainer,
)
from ferret.explainers.explanation_speech.paraling_speech_explainer import (
ParalinguisticSpeechExplainer,
)
from scipy.io.wavfile import write
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor



# ================================================================
# = Fixtures creation audio sample to use throughout the testing =
# ================================================================
@pytest.fixture(scope="module")
def sample_audio_file():
return os.path.join(os.path.dirname(__file__), 'data', 'sample_audio.wav')


@pytest.fixture(scope="module")
def benchmark():
model = Wav2Vec2ForSequenceClassification.from_pretrained(
"superb/wav2vec2-base-superb-ic"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"superb/wav2vec2-base-superb-ic"
)
return SpeechBenchmark(model, feature_extractor)

# ==========
# = Tests =
# ==========

def test_initialization_benchmark(benchmark):
assert benchmark.model is not None
assert benchmark.feature_extractor is not None
assert isinstance(benchmark, SpeechBenchmark)

def test_explainer_types(benchmark):
for explainer_name, explainer in benchmark.explainers.items():
assert explainer is not None
assert explainer_name in ['LOO', 'Gradient', 'GradientXInput', 'LIME', 'perturb_paraling']
assert isinstance(explainer, (LOOSpeechExplainer, GradientSpeechExplainer, LIMESpeechExplainer, ParalinguisticSpeechExplainer))


def test_audio_transcription(benchmark, sample_audio_file):
audio = AudioSegment.from_wav(sample_audio_file)
sr = audio.frame_rate
transcription = benchmark.transcribe(sample_audio_file, current_sr=sr)

assert transcription[0] is not None
assert transcription[0] == ' Turn up the bedroom heat.'

def test_prediction(benchmark, sample_audio_file):
audio = AudioSegment.from_wav(sample_audio_file)
audio_array = np.array(audio.get_array_of_samples()).astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
predictions = benchmark.predict([audio_array])

assert predictions is not None
assert len(predictions) == 3
action_probs, object_probs, location_probs = benchmark.predict([audio_array])

assert len(action_probs) == 1
assert len(object_probs) == 1
assert len(location_probs) == 1
assert action_probs[0].shape == (6,)
assert object_probs[0].shape == (14,)
assert location_probs[0].shape == (4,)

@pytest.mark.parametrize("methodology", ["LOO", "Gradient", "LIME", "perturb_paraling"])
def test_explain_method(benchmark, sample_audio_file, methodology):
explanations = benchmark.explain(
audio_path_or_array=sample_audio_file,
current_sr=16000,
methodology=methodology,
)

assert explanations is not None

if methodology != "perturb_paraling":
assert hasattr(explanations, 'scores')
assert hasattr(explanations, 'features')
assert len(explanations.scores) > 0
assert len(explanations.features) > 0
else:
assert isinstance(explanations, list)
assert len(explanations) > 0
for explanation in explanations:
assert hasattr(explanation, 'scores')
assert hasattr(explanation, 'features')


def test_explain_features(benchmark, sample_audio_file):
explanations = benchmark.explain(
audio_path_or_array=sample_audio_file,
current_sr=16000,
methodology='LOO',
)

expected_features = ['Turn', 'up', 'the', 'bedroom', 'heat.']
assert explanations.features == expected_features

def test_invalid_audio_file(benchmark):
with pytest.raises(Exception):
benchmark.explain(
audio_path_or_array='non_existent_file.wav',
current_sr=16000,
methodology='LOO',
)

def test_silence_audio(benchmark):
silent_audio = np.zeros(int(16000 * 1)) # 1 second of silent audio at 16kHz
explanations = benchmark.explain(
audio_path_or_array=silent_audio,
current_sr=16000,
methodology='LOO',
)
assert explanations is not None
assert explanations.scores.shape == (3,0)
assert len(explanations.features) == 0

def test_explain_variations(benchmark, sample_audio_file):
perturbation_types = ['time stretching', 'pitch shifting', 'noise']
variations_table = benchmark.explain_variations(
audio_path_or_array=sample_audio_file,
current_sr=16000,
perturbation_types=perturbation_types
)
assert isinstance(variations_table, dict)
assert all(pt in variations_table for pt in perturbation_types)
for pt, df in variations_table.items():
assert isinstance(df, pd.DataFrame)
assert not df.empty
Loading