diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 243a29534ac0..89b520d819a9 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1431,8 +1431,12 @@ def resize_token_embeddings(self, new_num_tokens: int) -> tf.Variable: def get_output_embeddings(self): return self.lm_head - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + def set_output_embeddings(self, value): + self.lm_head = keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head") + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = tf.Variable(transposed_value) @unpack_inputs @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)