Skip to content

Commit

Permalink
Add support for keras.layers.Embedding in _get_word_embedding_weight
Browse files Browse the repository at this point in the history
This resolves:

- An issue of input embeddings in
  `TFT5ModelTest.test_resize_token_embeddings` and
  `TFT5ModelTest.test_save_load_after_resize_token_embeddings` when
  `model.config.tie_word_embeddings` is set to `False`:

  ```
  ValueError: Attempt to convert a value
  (<tf_keras.src.layers.core.embedding.Embedding object at 0x32b0747a0>)
  with an unsupported type
  (<class 'tf_keras.src.layers.core.embedding.Embedding'>) to a Tensor.
  ```

- An issue of input embeddings in
  `TFMistralModelTest.test_resize_token_embeddings` where resizing token
  embeddings caused the following error:

  ```
  ValueError: Attempt to convert a value (None) with an unsupported type
  (<class 'NoneType'>) to a Tensor.
  ```
  • Loading branch information
damianoamatruda committed Jan 28, 2025
1 parent a41faf1 commit fbf492c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,10 @@ def _get_word_embedding_weight(model, embedding_layer):
# value has a shape (num_tokens, dim) then needs to be transposed
return tf.Variable(tf.transpose(embeds))

embeds = getattr(embedding_layer, "weights", None)
if embeds is not None and len(embeds) > 0:
return embeds[0]

# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
Expand All @@ -2102,6 +2106,10 @@ def _get_word_embedding_weight(model, embedding_layer):
# value has a shape (num_tokens, dim) then needs to be transposed
return tf.Variable(tf.transpose(embeds))

embeds = getattr(embedding_layer, "weights", None)
if embeds is not None and len(embeds) > 0:
return embeds[0]

return None

def _resize_token_embeddings(self, new_num_tokens):
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/mistral/modeling_tf_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,8 +824,17 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_output_embeddings(self, value):
self.lm_head = keras.layers.Dense(
shape_list(value)[0],
use_bias=False,
kernel_initializer=get_initializer(self.config.initializer_range),
name="lm_head",
)
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
transposed_value = tf.transpose(value)
self.lm_head.kernel = tf.Variable(transposed_value)

def set_decoder(self, decoder):
self.model = decoder
Expand Down

0 comments on commit fbf492c

Please sign in to comment.