diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 232ebd205824..f0e4fd36df34 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2083,6 +2083,10 @@ def _get_word_embedding_weight(model, embedding_layer): # value has a shape (num_tokens, dim) then needs to be transposed return tf.Variable(tf.transpose(embeds)) + embeds = getattr(embedding_layer, "weights", None) + if embeds is not None and len(embeds) > 0: + return embeds[0] + # The reason why the attributes don't exist might be # because the model is not built, so retry getting # the argument after building the model @@ -2102,6 +2106,10 @@ def _get_word_embedding_weight(model, embedding_layer): # value has a shape (num_tokens, dim) then needs to be transposed return tf.Variable(tf.transpose(embeds)) + embeds = getattr(embedding_layer, "weights", None) + if embeds is not None and len(embeds) > 0: + return embeds[0] + return None def _resize_token_embeddings(self, new_num_tokens): diff --git a/src/transformers/models/mistral/modeling_tf_mistral.py b/src/transformers/models/mistral/modeling_tf_mistral.py index 5c21dd3c3f53..71056f62226d 100644 --- a/src/transformers/models/mistral/modeling_tf_mistral.py +++ b/src/transformers/models/mistral/modeling_tf_mistral.py @@ -824,8 +824,17 @@ def set_input_embeddings(self, value): def get_output_embeddings(self): return self.lm_head - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + def set_output_embeddings(self, value): + self.lm_head = keras.layers.Dense( + shape_list(value)[0], + use_bias=False, + kernel_initializer=get_initializer(self.config.initializer_range), + name="lm_head", + ) + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = tf.Variable(transposed_value) def set_decoder(self, decoder): self.model = decoder