Skip to content

Commit 7aba408

Browse files
Fix handling of output embeddings in TFGPTJ
This resolves an issue of output embeddings in `TFGPTJModelTest.test_save_load_after_resize_token_embeddings` where resizing token embeddings caused the following error: ``` ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor. ```
1 parent 8efa519 commit 7aba408

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/transformers/models/gptj/modeling_tf_gptj.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,14 @@ def __init__(self, config, *inputs, **kwargs):
756756
def get_output_embeddings(self):
757757
return self.lm_head
758758

759-
def set_output_embeddings(self, new_embeddings):
760-
self.lm_head = new_embeddings
759+
def set_output_embeddings(self, value):
760+
self.lm_head = keras.layers.Dense(
761+
shape_list(value)[0], kernel_initializer=get_initializer(self.config.initializer_range), name="lm_head"
762+
)
763+
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
764+
# value has a shape (num_tokens, dim) then needs to be transposed
765+
transposed_value = tf.transpose(value)
766+
self.lm_head.kernel = tf.Variable(transposed_value)
761767

762768
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
763769
token_type_ids = kwargs.get("token_type_ids", None)

0 commit comments

Comments
 (0)