Skip to content

Commit d5519a0

Browse files
Fix handling of output embeddings in TFSpeech2Text
This resolves an issue of output embeddings in `TFSpeech2TextModelTest.test_save_load_after_resize_token_embeddings` where resizing token embeddings caused the following error: ``` ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor. ```
1 parent 9902b11 commit d5519a0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,8 +1431,12 @@ def resize_token_embeddings(self, new_num_tokens: int) -> tf.Variable:
14311431
def get_output_embeddings(self):
14321432
return self.lm_head
14331433

1434-
def set_output_embeddings(self, new_embeddings):
1435-
self.lm_head = new_embeddings
1434+
def set_output_embeddings(self, value):
1435+
self.lm_head = keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
1436+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
1437+
# value has a shape (num_tokens, dim) then needs to be transposed
1438+
transposed_value = tf.transpose(value)
1439+
self.lm_head.kernel = tf.Variable(transposed_value)
14361440

14371441
@unpack_inputs
14381442
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)

0 commit comments

Comments
 (0)