Skip to content

Commit 8df8015

Browse files
Fix XGLM loss computation (PyTorch and TensorFlow)
1 parent 72d1a4c commit 8df8015

File tree

4 files changed

+49
-11
lines changed

4 files changed

+49
-11
lines changed

src/transformers/models/xglm/modeling_tf_xglm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,11 +968,9 @@ def call(
968968
loss = None
969969
if labels is not None:
970970
# shift labels to the left and cut last logit token
971-
labels = tf.concat(
972-
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
973-
axis=-1,
974-
)
975-
loss = self.hf_compute_loss(labels, lm_logits)
971+
shifted_logits = lm_logits[:, :-1]
972+
labels = labels[:, 1:]
973+
loss = self.hf_compute_loss(labels, shifted_logits)
976974

977975
if not return_dict:
978976
output = (lm_logits,) + outputs[1:]

src/transformers/models/xglm/modeling_xglm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,14 @@ def forward(
778778

779779
loss = None
780780
if labels is not None:
781-
# shift labels and add a pad token to the end
782-
shift_labels = labels.new_zeros(labels.shape)
783-
shift_labels[:, :-1] = labels[:, 1:].clone()
784-
shift_labels[:, -1] = self.config.pad_token_id
785-
781+
# move labels to correct device to enable model parallelism
782+
labels = labels.to(logits.device)
783+
# Shift so that tokens < n predict n
784+
shift_logits = logits[..., :-1, :].contiguous()
785+
shift_labels = labels[..., 1:].contiguous()
786+
# Flatten the tokens
786787
loss_fct = CrossEntropyLoss()
787-
loss = loss_fct(logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
788+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
788789

789790
if not return_dict:
790791
output = (logits,) + outputs[1:]

tests/models/xglm/test_modeling_tf_xglm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,22 @@ def test_batch_generation(self):
238238
]
239239
self.assertListEqual(expected_output_sentence, batch_out_sentence)
240240
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
241+
242+
def test_loss_with_padding(self):
243+
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
244+
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
245+
246+
tokenizer.padding_side = "right"
247+
248+
sequence = "Sequence"
249+
250+
tokenized_non_padded = tokenizer(sequence, return_tensors="tf")
251+
labels_non_padded = tokenized_non_padded.input_ids
252+
loss_non_padded = model(tokenized_non_padded, labels=labels_non_padded).loss
253+
254+
tokenized_padded = tokenizer(sequence, padding="max_length", max_length=16, return_tensors="tf")
255+
labels_padded = tokenized_padded.input_ids
256+
labels_padded = tf.where(labels_padded == tokenizer.pad_token_id, -100, labels_padded)
257+
loss_padded = model(tokenized_padded, labels=labels_padded).loss
258+
259+
tf.debugging.assert_near(loss_non_padded, loss_padded, atol=1e-3)

tests/models/xglm/test_modeling_xglm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,23 @@ def test_batched_nan_fp16(self):
494494
self.assertFalse(
495495
torch.isnan(outputs.logits[0]).any().item()
496496
) # the first logits could contain NaNs if it fails
497+
498+
def test_loss_with_padding(self):
499+
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
500+
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
501+
model.to(torch_device)
502+
503+
tokenizer.padding_side = "right"
504+
505+
sequence = "Sequence"
506+
507+
tokenized_non_padded = tokenizer(sequence, return_tensors="pt")
508+
labels_non_padded = tokenized_non_padded.input_ids.clone()
509+
loss_non_padded = model(**tokenized_non_padded, labels=labels_non_padded).loss
510+
511+
tokenized_padded = tokenizer(sequence, padding="max_length", max_length=16, return_tensors="pt")
512+
labels_padded = tokenized_padded.input_ids.clone()
513+
labels_padded[labels_padded == tokenizer.pad_token_id] = -100
514+
loss_padded = model(**tokenized_padded, labels=labels_padded).loss
515+
516+
self.assertTrue(torch.allclose(loss_non_padded, loss_padded, atol=1e-3))

0 commit comments

Comments
 (0)