Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,15 @@ def forward(
):
r"""
Args:
attention_mask (`torch.Tensor`)`, *optional*):
Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`), *optional*):
attention mask used in the encoder stack (after the convolutional layers).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -765,7 +771,7 @@ def forward(
feature_attention_mask.sum(-1)
)
batch_size, _, max_mel_seq_len = input_features.shape
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)
Expand Down
Loading