diff --git a/whisperx/asr.py b/whisperx/asr.py index b23a54dc..ba6220bd 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -251,7 +251,10 @@ def data(audio, segments): def detect_language(self, audio: np.ndarray): - segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0) + if audio.shape[0] < N_SAMPLES: + print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + segment = log_mel_spectrogram(audio[: N_SAMPLES], + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0]