From 8c5bf17624fb8eca1f23af3c6e45cafa1ddabaef Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Fri, 6 Sep 2024 12:01:25 +0100 Subject: [PATCH] #4 Implement AudioToTextPipeline (#34) * Add speech pipeline to get text from audio * Deprecate remove columns from pipeline * Modify speech pipeline to better interface with SONAR * Refactor to decompose Speech pipelines into SpeechToEmbedding and just use EmbeddingToText afterwards * Add interface documentation to docs for speech to embedding pipeline * Add unit tests for speech to embeddings pipeline * Add optional dependencies for huggingface using tag * Fix syntax errors * Remove commas between configs * Improve audio pipeline naming conventions to use Audio * Keep tensors in memory to avoid writing to temp files * Fix linting issues * Fix black linter issues * Create factory for audio to embedding pipeline * Fix linting issues * Implement functionality to cast audio column to config sampling rate * update tests to not raise for list[str] * Fix linting issues * Fix linting issues * Add more specific typing * Reformat audio.py file with black * Type types in unit tests * Add dtype to audio pipeline * Switch from any to Any * Improve typing for tests * Fix linting issues * Add comptability for python 3.10 using Union * Fix linting issues * Add more robust tests * Add batching to audio pipeline * Fix linting issues * Add docs for huggingfae method * Do not move whole tensor just in case too big * Move batch inputs to device incrementally * Fix linting issues * Fix linting issues * Remove tensor cast * Cast to every tensor in batch rather than the batch object * Fix child pipeline config initialization * Fix linting issues --- huggingface_pipelines/audio.py | 330 ++++++++++++++++++ pyproject.toml | 1 + .../unit_tests/huggingface_pipelines/audio.py | 214 ++++++++++++ 3 files changed, 545 insertions(+) create mode 100644 huggingface_pipelines/audio.py create mode 100644 tests/unit_tests/huggingface_pipelines/audio.py diff --git a/huggingface_pipelines/audio.py b/huggingface_pipelines/audio.py new file mode 100644 index 0000000..3341653 --- /dev/null +++ b/huggingface_pipelines/audio.py @@ -0,0 +1,330 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, List + +import numpy as np +import torch +from datasets import Audio # type: ignore +from numpy.typing import DTypeLike # type: ignore + +from sonar.inference_pipelines.speech import SpeechToEmbeddingModelPipeline + +from .dataset import DatasetConfig # type: ignore +from .pipeline import Pipeline, PipelineConfig # type: ignore + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class AudioDatasetConfig(DatasetConfig): + """ + Configuration for audio datasets. + This class inherits from DatasetConfig and includes + audio-specific attributes and processing. + + Attributes: + sampling_rate (int): The target sampling rate for audio data. + audio_column (str): The column that contains the audio data. + + Example: + dataset_config = AudioDatasetConfig( + dataset_name="librispeech_asr", + dataset_split="train.clean.100", + output_dir="/path/to/output", + config="clean", + trust_remote_code=True, + sampling_rate=16000, + audio_column="audio" + ) + """ + + sampling_rate: int = 16000 + audio_column: str = "audio" + + def load_dataset(self): + """ + Loads and optionally shards the dataset based on the configuration settings. + This method extends the base load_dataset method to include audio-specific processing. + + Returns: + datasets.Dataset: The loaded, potentially sharded, and audio-processed dataset. + + Raises: + ValueError: If the dataset cannot be loaded with the given configuration. + ImportError: If the 'datasets' library is not installed. + """ + dataset = super().load_dataset() + return self.process_audio_column(dataset) + + def process_audio_column(self, dataset): + """ + Processes the audio column of the dataset. + This modifies the original column. + + Args: + dataset (datasets.Dataset): The loaded dataset. + + Returns: + datasets.Dataset: The dataset with processed audio column. + """ + if self.audio_column in dataset.column_names: + dataset = dataset.cast_column( + self.audio_column, Audio(sampling_rate=self.sampling_rate) + ) + else: + raise ValueError( + f"Error: {self.audio_column} column not found in the dataset. Skipping audio processing." + ) + + return dataset + + +@dataclass +class HFAudioToEmbeddingPipelineConfig(PipelineConfig): + """ + Configuration class for HFAudioToEmbeddingPipeline. + + Attributes: + encoder_model (str): The name or path of the encoder model to use. + fbank_dtype (torch.dtype): The dtype for the fbank features. Defaults to torch.float32. + n_parallel (int): Number of parallel processes for audio processing. Defaults to 4. + pad_idx (int): The index used for padding. Defaults to 0. + dtype (np.dtype) The data type of output numpy embeddings. + Example: + + pipeline_config = HFAudioToEmbeddingPipelineConfig( + encoder_model="sonar_speech_encoder_large", + fbank_dtype=torch.float16, + n_parallel=4, + pad_idx=0, + device="cuda", + batch_size=32, + columns=["audio", "audio2"], + output_path="/path/to/output", + output_column_suffix="embedding" + ) + + """ + + encoder_model: str = "sonar_speech_encoder" + fbank_dtype: torch.dtype = torch.float32 + n_parallel: int = 4 + pad_idx: int = 0 + dtype: DTypeLike = np.float32 + + +class HFAudioToEmbeddingPipeline(Pipeline): + """ + A pipeline for converting audio to embeddings using a HuggingFace model. + + This pipeline processes batches of audio data, converting them to embeddings + using a specified encoder model. It handles temporary file creation for audio + processing and ensures consistent embedding shapes across the batch. + + Attributes: + config (HFAudioToEmbeddingPipelineConfig): The configuration for this pipeline. + model (SpeechToEmbeddingModelPipeline): The underlying model used for embedding generation. + + Example: + + pipeline_config = HFAudioToEmbeddingPipelineConfig( + encoder_model="sonar_speech_encoder", + device="cuda", + batch_size=16, + n_parallel=4, + pad_idx = 0, + dtype = np.float32 + ) + + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + """ + + config: HFAudioToEmbeddingPipelineConfig + + def __init__(self, config: HFAudioToEmbeddingPipelineConfig): + """ + Initialize the HFAudioToEmbeddingPipeline. + + Args: + config (HFAudioToEmbeddingPipelineConfig): The configuration for this pipeline. + """ + super().__init__(config) + self.model = SpeechToEmbeddingModelPipeline( + encoder=self.config.encoder_model, + device=torch.device(self.config.device), + fbank_dtype=self.config.fbank_dtype, + ) + + def collect_valid_audio_inputs( + self, audio_data_list: List[Dict[str, Any]] + ) -> List[torch.Tensor]: + """ + Collect and process valid audio inputs from a list of audio data dictionaries. + + This method processes a list of audio data dictionaries, extracting valid audio inputs + and converting them to PyTorch tensors. It handles multi-channel audio by taking the + mean across channels and ensures that the output tensors are 2D with shape (1, num_samples). + + Args: + audio_data_list (List[Dict[str, Any]]): A list of dictionaries containing audio data. + Each dictionary is expected to have 'array' and 'sampling_rate' keys. + + Returns: + List[torch.Tensor]: A list of valid audio inputs as PyTorch tensors. + + Raises: + ValueError: If the input is not a list, if any audio data has an invalid format, + or if the resulting tensor has an unexpected shape. + + """ + audio_inputs = [] + + # Ensure audio_data_list is always a list + if not isinstance(audio_data_list, list): + raise ValueError("Audio data must be in list format.") + + for audio_data in audio_data_list: + if ( + isinstance(audio_data, dict) + and "array" in audio_data + and "sampling_rate" in audio_data + ): + # Handle multi-channel audio by taking the mean across channels + audio_array = audio_data["array"] + if audio_array.ndim > 1: + audio_array = np.mean(audio_array, axis=0) + + # Convert numpy array to torch tensor + audio_tensor = torch.from_numpy(audio_array).float() + + # Ensure the tensor is 2D with shape (1, num_samples) + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + elif audio_tensor.dim() > 2: + raise ValueError( + f"Unexpected audio tensor shape: {audio_tensor.shape}" + ) + + audio_inputs.append(audio_tensor) + else: + logger.error( + f"Invalid audio data format in batch {audio_data_list}: {audio_data}" + ) + raise ValueError( + f"Invalid audio data format in column {audio_data_list}: {audio_data}" + ) + + return audio_inputs + + def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Process a batch of audio data, converting it to embeddings. + + This method handles the conversion of audio data to temporary WAV files, + generates embeddings using the model, and ensures consistent embedding + shapes across the batch. + + Args: + batch (Dict[str, Any]): A dictionary containing the batch data. + Expected to have an 'audio' key with a list of audio data dictionaries. + + Returns: + Dict[str, Any]: The input batch dictionary with an additional key + '{column}_{suffix}' containing the generated embeddings. + + Raises: + Exception: If there's an error during batch processing or embedding generation. + """ + + try: + for column in self.config.columns: + if column not in batch: + logger.warning(f"Column {column} not found in batch. Skipping.") + continue + + audio_inputs = self.collect_valid_audio_inputs(batch[column]) + + if not audio_inputs: + + raise ValueError(f"No valid audio inputs found in column {column}/") + + try: + + audio_embeddings: List[np.ndarray] = [] + + for i in range(0, len(audio_inputs), self.config.batch_size): + + batch_inputs = [ + tensor.to(self.config.device) + for tensor in audio_inputs[i : i + self.config.batch_size] + ] + + batch_embeddings = self.model.predict( + input=batch_inputs, + batch_size=self.config.batch_size, + n_parallel=self.config.n_parallel, + pad_idx=self.config.pad_idx, + ) + + final_embeddings: np.ndarray = ( + batch_embeddings.detach() + .cpu() + .numpy() + .astype(self.config.dtype) + ) + + audio_embeddings.extend(final_embeddings) + + batch[f"{column}_{self.config.output_column_suffix}"] = np.array( + audio_embeddings + ) + + except Exception as e: + logger.error( + f"Error in model.predict for column {column}: {str(e)}" + ) + raise ValueError( + f"Error in model.predict for column {column}: {str(e)}" + ) + + except Exception as e: + logger.error(f"Error processing batch: {str(e)}") + logger.error(f"Batch content: {batch}") + raise ValueError(f"Error processing batch: {str(e)}") + + return batch + + +class AudioToEmbeddingPipelineFactory: + """ + Factory class for creating AudioToEmbedding pipelines. + + This factory creates HFAudioToEmbeddingPipeline instances based on the provided configuration. + + Example: + factory = AudioToEmbeddingPipelineFactory() + config = { + "encoder_model": "sonar_speech_encoder_large", + "fbank_dtype": torch.float16, + "n_parallel": 4, + "pad_idx": 0, + "audio_column": "audio", + "device": "cuda", + "batch_size": 32, + "columns": ["audio"], + "output_path": "/path/to/output", + "output_column_suffix": "embedding" + } + pipeline = factory.create_pipeline(config) + """ + + def create_pipeline(self, config: Dict[str, Any]) -> Pipeline: + """ + Create an AudioToEmbedding pipeline based on the provided configuration. + + Returns: + Pipeline: An instance of HFAudioToEmbeddingPipeline. + """ + pipeline_config = HFAudioToEmbeddingPipelineConfig(**config) + return HFAudioToEmbeddingPipeline(pipeline_config) diff --git a/pyproject.toml b/pyproject.toml index c65329b..2078947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dev = [ "pylint>=2.8.0", ] + hg = [ "transformers>=4.44.0", "datasets>=2.20.0", diff --git a/tests/unit_tests/huggingface_pipelines/audio.py b/tests/unit_tests/huggingface_pipelines/audio.py new file mode 100644 index 0000000..e90cca6 --- /dev/null +++ b/tests/unit_tests/huggingface_pipelines/audio.py @@ -0,0 +1,214 @@ +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from huggingface_pipelines.audio import ( # type: ignore + HFAudioToEmbeddingPipeline, + HFAudioToEmbeddingPipelineConfig, + SpeechToEmbeddingModelPipeline, +) + + +@pytest.fixture +def pipeline_config(): + return HFAudioToEmbeddingPipelineConfig( + encoder_model="sonar_speech_encoder_eng", + device="cpu", + batch_size=2, + columns=["audio"], + output_path="test", + output_column_suffix="embedding", + ) + + +@pytest.fixture +def sample_audio_data(): + return {"array": np.random.rand(16000), "sampling_rate": 16000} + + +@pytest.fixture +def complex_audio_data(): + return { + "short_audio": {"array": np.random.rand(8000), "sampling_rate": 16000}, + "long_audio": {"array": np.random.rand(32000), "sampling_rate": 16000}, + "multi_channel": {"array": np.random.rand(2, 16000), "sampling_rate": 16000}, + "high_sample_rate": {"array": np.random.rand(48000), "sampling_rate": 48000}, + "low_sample_rate": {"array": np.random.rand(8000), "sampling_rate": 8000}, + "float64_audio": { + "array": np.random.rand(16000).astype(np.float64), + "sampling_rate": 16000, + }, + "int16_audio": { + "array": (np.random.rand(16000) * 32767).astype(np.int16), + "sampling_rate": 16000, + }, + } + + +def test_pipeline_initialization(pipeline_config, mock_speech_to_embedding_model): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + assert pipeline.config == pipeline_config + assert isinstance(pipeline.model, SpeechToEmbeddingModelPipeline) + + +def test_process_batch_valid_input( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"audio": [sample_audio_data, sample_audio_data]} + result = pipeline.process_batch(batch) + assert "audio_embedding" in result + assert isinstance(result["audio_embedding"], np.ndarray) + assert result["audio_embedding"].shape == ( + 2, + 3, + ) # 2 samples, 3 embedding dimensions + + +def test_process_batch_empty_input(pipeline_config, mock_speech_to_embedding_model): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch: Dict[str, Any] = {"audio": []} + with pytest.raises(ValueError, match="No valid audio inputs found in column audio"): + pipeline.process_batch(batch) + + +def test_process_batch_invalid_audio_data( + pipeline_config, mock_speech_to_embedding_model +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"audio": [{"invalid": "data"}]} + with pytest.raises(ValueError, match="Invalid audio data format in column"): + pipeline.process_batch(batch) + + +def test_process_batch_mixed_valid_invalid_data( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"audio": [sample_audio_data, {"invalid": "data"}, sample_audio_data]} + with pytest.raises(ValueError, match="Invalid audio data format in column"): + pipeline.process_batch(batch) + + +@patch("huggingface_pipelines.speech.SpeechToEmbeddingModelPipeline") +def test_error_handling_in_model_predict( + mock_predict, pipeline_config, sample_audio_data +): + mock_predict.return_value.predict.side_effect = Exception("Model prediction error") + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"audio": [sample_audio_data]} + with pytest.raises( + ValueError, + match="Error in model.predict for column audio: Model prediction error", + ): + pipeline.process_batch(batch) + + +def test_process_large_batch( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + large_batch = {"audio": [sample_audio_data] * 100} # 100 audio samples + result = pipeline.process_batch(large_batch) + assert "audio_embedding" in result + assert isinstance(result["audio_embedding"], np.ndarray) + assert result["audio_embedding"].shape == ( + 100, + 3, + ) # 100 samples, 3 embedding dimensions + + +def test_collect_valid_audio_inputs(pipeline_config, sample_audio_data): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + audio_data_list = [sample_audio_data, sample_audio_data] + result = pipeline.collect_valid_audio_inputs(audio_data_list) + assert len(result) == 2 + assert all(isinstance(tensor, torch.Tensor) for tensor in result) + assert all(tensor.shape == (1, 16000) for tensor in result) + + +def test_collect_valid_audio_inputs_invalid_data(pipeline_config): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + invalid_data = [{"invalid": "data"}] + with pytest.raises(ValueError, match="Invalid audio data format in column"): + pipeline.collect_valid_audio_inputs(invalid_data) + + +def test_collect_valid_audio_inputs_multi_channel(pipeline_config): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + multi_channel_data = {"array": np.random.rand(2, 16000), "sampling_rate": 16000} + result = pipeline.collect_valid_audio_inputs([multi_channel_data]) + assert len(result) == 1 + assert result[0].shape == (1, 16000) + + +def test_process_complex_audio_data( + pipeline_config, mock_speech_to_embedding_model, complex_audio_data +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"audio": list(complex_audio_data.values())} + result = pipeline.process_batch(batch) + assert "audio_embedding" in result + assert isinstance(result["audio_embedding"], np.ndarray) + assert result["audio_embedding"].shape == (len(complex_audio_data), 3) + + +def test_collect_valid_audio_inputs_complex(pipeline_config, complex_audio_data): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + result = pipeline.collect_valid_audio_inputs(list(complex_audio_data.values())) + assert len(result) == len(complex_audio_data) + assert all(isinstance(tensor, torch.Tensor) for tensor in result) + assert all(tensor.dim() == 2 and tensor.size(0) == 1 for tensor in result) + + +def test_process_batch_with_missing_column( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + pipeline = HFAudioToEmbeddingPipeline(pipeline_config) + batch = {"wrong_column": [sample_audio_data]} + result = pipeline.process_batch(batch) + assert "audio_embedding" not in result + + +def test_process_batch_with_multiple_columns( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + config = HFAudioToEmbeddingPipelineConfig( + **{**pipeline_config.__dict__, "columns": ["audio1", "audio2"]} + ) + pipeline = HFAudioToEmbeddingPipeline(config) + batch = {"audio1": [sample_audio_data], "audio2": [sample_audio_data]} + result = pipeline.process_batch(batch) + assert "audio1_embedding" in result + assert "audio2_embedding" in result + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_handling( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data, device +): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + config = HFAudioToEmbeddingPipelineConfig( + **{**pipeline_config.__dict__, "device": device} + ) + pipeline = HFAudioToEmbeddingPipeline(config) + batch = {"audio": [sample_audio_data]} + result = pipeline.process_batch(batch) + assert "audio_embedding" in result + + +def test_batch_size_handling( + pipeline_config, mock_speech_to_embedding_model, sample_audio_data +): + config = HFAudioToEmbeddingPipelineConfig( + **{**pipeline_config.__dict__, "batch_size": 1} + ) + pipeline = HFAudioToEmbeddingPipeline(config) + batch = {"audio": [sample_audio_data] * 5} + result = pipeline.process_batch(batch) + assert "audio_embedding" in result + assert result["audio_embedding"].shape == (5, 3)