Skip to content

Commit aa7edca

Browse files
author
Damien Sileo
committed
fixed issue with local_rank on colab
1 parent a6c7489 commit aa7edca

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

src/tasknet/models.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,8 @@ def get_single_train_dataloader(self, task_name, train_dataset):
452452
"""
453453
if self.train_dataset is None:
454454
raise ValueError("Trainer: training requires a train_dataset.")
455-
train_sampler = (
456-
RandomSampler(train_dataset)
457-
if self.args.local_rank == -1
458-
else DistributedSampler(train_dataset)
459-
)
460-
455+
train_sampler = (RandomSampler(train_dataset) if torch.cuda.device_count()<2 or self.args.local_rank == -1 else DistributedSampler(train_dataset))
456+
461457
data_loader = DataLoaderWithTaskname(
462458
task_name=task_name,
463459
data_loader=DataLoader(

0 commit comments

Comments
 (0)