Skip to content

Commit 180e037

Browse files
Fix XGLM loss computation (PyTorch and TensorFlow)
1 parent b912f5e commit 180e037

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

src/transformers/models/xglm/modeling_tf_xglm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def call(
969969
if labels is not None:
970970
# shift labels to the left and cut last logit token
971971
labels = tf.concat(
972-
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
972+
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(-100, labels.dtype))],
973973
axis=-1,
974974
)
975975
loss = self.hf_compute_loss(labels, lm_logits)

src/transformers/models/xglm/modeling_xglm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,10 +778,10 @@ def forward(
778778

779779
loss = None
780780
if labels is not None:
781-
# shift labels and add a pad token to the end
781+
# shift labels to the left and cut last logit token
782782
shift_labels = labels.new_zeros(labels.shape)
783783
shift_labels[:, :-1] = labels[:, 1:].clone()
784-
shift_labels[:, -1] = self.config.pad_token_id
784+
shift_labels[:, -1] = -100
785785

786786
loss_fct = CrossEntropyLoss()
787787
loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

tests/models/xglm/test_modeling_tf_xglm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,22 @@ def test_batch_generation(self):
238238
]
239239
self.assertListEqual(expected_output_sentence, batch_out_sentence)
240240
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
241+
242+
def test_loss_with_padding(self):
243+
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
244+
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
245+
246+
tokenizer.padding_side = "right"
247+
248+
sequence = "Sequence"
249+
250+
tokenized_non_padded = tokenizer(sequence, return_tensors="tf")
251+
labels_non_padded = tokenized_non_padded.input_ids
252+
loss_non_padded = model(tokenized_non_padded, labels=labels_non_padded).loss
253+
254+
tokenized_padded = tokenizer(sequence, padding="max_length", max_length=16, return_tensors="tf")
255+
labels_padded = tokenized_padded.input_ids
256+
labels_padded = tf.where(labels_padded == tokenizer.pad_token_id, -100, labels_padded)
257+
loss_padded = model(tokenized_padded, labels=labels_padded).loss
258+
259+
tf.debugging.assert_near(loss_non_padded, loss_padded, atol=1e-3)

tests/models/xglm/test_modeling_xglm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,23 @@ def test_batched_nan_fp16(self):
494494
self.assertFalse(
495495
torch.isnan(outputs.logits[0]).any().item()
496496
) # the first logits could contain NaNs if it fails
497+
498+
def test_loss_with_padding(self):
499+
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
500+
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
501+
model.to(torch_device)
502+
503+
tokenizer.padding_side = "right"
504+
505+
sequence = "Sequence"
506+
507+
tokenized_non_padded = tokenizer(sequence, return_tensors="pt")
508+
labels_non_padded = tokenized_non_padded.input_ids.clone()
509+
loss_non_padded = model(**tokenized_non_padded, labels=labels_non_padded).loss
510+
511+
tokenized_padded = tokenizer(sequence, padding="max_length", max_length=16, return_tensors="pt")
512+
labels_padded = tokenized_padded.input_ids.clone()
513+
labels_padded[labels_padded == tokenizer.pad_token_id] = -100
514+
loss_padded = model(**tokenized_padded, labels=labels_padded).loss
515+
516+
torch.testing.assert_close(loss_non_padded, loss_padded, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)