Skip to content

Commit

Permalink
Fix handling of output embeddings in TFSpeech2Text
Browse files Browse the repository at this point in the history
This resolves an issue of output embeddings in
`TFSpeech2TextModelTest.test_save_load_after_resize_token_embeddings`
where resizing token embeddings caused the following error:

```
ValueError: Attempt to convert a value (None) with an unsupported type
(<class 'NoneType'>) to a Tensor.
```
  • Loading branch information
damianoamatruda committed Jan 28, 2025
1 parent 7aba408 commit a41faf1
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1431,8 +1431,12 @@ def resize_token_embeddings(self, new_num_tokens: int) -> tf.Variable:
def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_output_embeddings(self, value):
self.lm_head = keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
transposed_value = tf.transpose(value)
self.lm_head.kernel = tf.Variable(transposed_value)

@unpack_inputs
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
Expand Down

0 comments on commit a41faf1

Please sign in to comment.