diff --git a/huggingface_pipelines/audio.py b/huggingface_pipelines/audio.py index 8df4931..8eb0bac 100644 --- a/huggingface_pipelines/audio.py +++ b/huggingface_pipelines/audio.py @@ -140,6 +140,8 @@ class HFAudioToEmbeddingPipeline(Pipeline): pipeline = HFAudioToEmbeddingPipeline(pipeline_config) """ + config: HFAudioToEmbeddingPipelineConfig + def __init__(self, config: HFAudioToEmbeddingPipelineConfig): """ Initialize the HFAudioToEmbeddingPipeline. @@ -238,14 +240,16 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: try: for column in self.config.columns: if column not in batch: - logger.warning(f"Column {column} not found in batch. Skipping.") + logger.warning( + f"Column {column} not found in batch. Skipping.") continue audio_inputs = self.collect_valid_audio_inputs(batch[column]) if not audio_inputs: - raise ValueError(f"No valid audio inputs found in column {column}/") + raise ValueError( + f"No valid audio inputs found in column {column}/") try: @@ -255,7 +259,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: batch_inputs = [ tensor.to(self.config.device) - for tensor in audio_inputs[i : i + self.config.batch_size] + for tensor in audio_inputs[i: i + self.config.batch_size] ] batch_embeddings = self.model.predict(