Skip to content

Commit 76b2e4e

Browse files
Add support for keras.layers.Embedding in _get_word_embedding_weight
This resolves: - An issue of input embeddings in `TFT5ModelTest.test_resize_token_embeddings` and `TFT5ModelTest.test_save_load_after_resize_token_embeddings` when `model.config.tie_word_embeddings` is set to `False`: ``` ValueError: Attempt to convert a value (<tf_keras.src.layers.core.embedding.Embedding object at 0x32b0747a0>) with an unsupported type (<class 'tf_keras.src.layers.core.embedding.Embedding'>) to a Tensor. ``` - An issue of input embeddings in `TFMistralModelTest.test_resize_token_embeddings` where resizing token embeddings caused the following error: ``` ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor. ```
1 parent 3762df4 commit 76b2e4e

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/transformers/modeling_tf_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,6 +2083,10 @@ def _get_word_embedding_weight(model, embedding_layer):
20832083
# value has a shape (num_tokens, dim) then needs to be transposed
20842084
return tf.Variable(tf.transpose(embeds))
20852085

2086+
embeds = getattr(embedding_layer, "weights", None)
2087+
if embeds is not None and len(embeds) > 0:
2088+
return embeds[0]
2089+
20862090
# The reason why the attributes don't exist might be
20872091
# because the model is not built, so retry getting
20882092
# the argument after building the model
@@ -2102,6 +2106,10 @@ def _get_word_embedding_weight(model, embedding_layer):
21022106
# value has a shape (num_tokens, dim) then needs to be transposed
21032107
return tf.Variable(tf.transpose(embeds))
21042108

2109+
embeds = getattr(embedding_layer, "weights", None)
2110+
if embeds is not None and len(embeds) > 0:
2111+
return embeds[0]
2112+
21052113
return None
21062114

21072115
def _resize_token_embeddings(self, new_num_tokens):

src/transformers/models/mistral/modeling_tf_mistral.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,8 +824,17 @@ def set_input_embeddings(self, value):
824824
def get_output_embeddings(self):
825825
return self.lm_head
826826

827-
def set_output_embeddings(self, new_embeddings):
828-
self.lm_head = new_embeddings
827+
def set_output_embeddings(self, value):
828+
self.lm_head = keras.layers.Dense(
829+
shape_list(value)[0],
830+
use_bias=False,
831+
kernel_initializer=get_initializer(self.config.initializer_range),
832+
name="lm_head",
833+
)
834+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
835+
# value has a shape (num_tokens, dim) then needs to be transposed
836+
transposed_value = tf.transpose(value)
837+
self.lm_head.kernel = tf.Variable(transposed_value)
829838

830839
def set_decoder(self, decoder):
831840
self.model = decoder

0 commit comments

Comments
 (0)