Skip to content

Commit e07585a

Browse files
Fix handling of input and output embeddings in TFXGLM
1 parent 38f7b42 commit e07585a

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/transformers/modeling_tf_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,12 @@ def _get_word_embedding_weight(model, embedding_layer):
20692069
return embedding_layer
20702070
# Otherwise, try to get them from the layer's attributes
20712071

2072+
embeds = getattr(embedding_layer, "kernel", None)
2073+
if embeds is not None:
2074+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
2075+
# value has a shape (num_tokens, dim) then needs to be transposed
2076+
return tf.Variable(tf.transpose(embeds))
2077+
20722078
embeds = getattr(embedding_layer, "weight", None)
20732079
if embeds is not None:
20742080
return embeds
@@ -2082,6 +2088,12 @@ def _get_word_embedding_weight(model, embedding_layer):
20822088
# the argument after building the model
20832089
model.build_in_name_scope()
20842090

2091+
embeds = getattr(embedding_layer, "kernel", None)
2092+
if embeds is not None:
2093+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
2094+
# value has a shape (num_tokens, dim) then needs to be transposed
2095+
return tf.Variable(tf.transpose(embeds))
2096+
20852097
embeds = getattr(embedding_layer, "weight", None)
20862098
if embeds is not None:
20872099
return embeds

src/transformers/models/xglm/modeling_tf_xglm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,9 @@ def __init__(
490490
def get_input_embeddings(self) -> TFSharedEmbeddings:
491491
return self.embed_tokens
492492

493-
def set_input_embeddings(self, value: TFSharedEmbeddings) -> None:
494-
self.embed_tokens = value
493+
def set_input_embeddings(self, value) -> None:
494+
self.embed_tokens.vocab_size = value.shape[0]
495+
self.embed_tokens.weight = value
495496

496497
def _prepare_decoder_attention_mask(
497498
self,
@@ -888,8 +889,17 @@ def __init__(
888889
def get_output_embeddings(self):
889890
return self.lm_head
890891

891-
def set_output_embeddings(self, new_embeddings):
892-
self.lm_head = new_embeddings
892+
def set_output_embeddings(self, value):
893+
self.lm_head = keras.layers.Dense(
894+
shape_list(value)[0],
895+
use_bias=False,
896+
kernel_initializer=get_initializer(self.config.init_std),
897+
name="lm_head",
898+
)
899+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
900+
# value has a shape (num_tokens, dim) then needs to be transposed
901+
transposed_value = tf.transpose(value)
902+
self.lm_head.kernel = transposed_value
893903

894904
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
895905
# only last token for inputs_ids if past is defined in kwargs

0 commit comments

Comments
 (0)