Skip to content

Commit 0619223

Browse files
committed
correctly send all batch features to DataCollatorForSpanClassification device
1 parent 9b05fcf commit 0619223

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tibert/bertcoref.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,21 +416,22 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
416416
document.tokens = tokens
417417
labels = [doc.document_labels(self.max_span_size) for doc in documents]
418418

419+
device = torch.device(self.device)
419420
del batch["coref_labels"]
420421
del batch["mention_labels"]
421422
batch = BatchEncoding(
422423
{
423-
k: torch.tensor(v, dtype=torch.int64, device=torch.device(self.device))
424+
k: torch.tensor(v, dtype=torch.int64, device=device)
424425
for k, v in batch.items()
425426
},
426427
encoding=batch.encodings,
427428
)
428429
batch["coref_labels"] = torch.stack(
429430
[coref_labels for coref_labels, _ in labels]
430-
)
431+
).to(device)
431432
batch["mention_labels"] = torch.stack(
432433
[mention_labels for _, mention_labels in labels]
433-
)
434+
).to(device)
434435

435436
return batch
436437

0 commit comments

Comments
 (0)