Skip to content

Commit

Permalink
Merge pull request deepchem#3704 from arunppsg/td-dataset
Browse files Browse the repository at this point in the history
Added distributed loading of PyTorch DiskDataset
  • Loading branch information
rbharath authored Dec 19, 2023
2 parents a8aee64 + bb454b8 commit 8130f62
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions deepchem/data/pytorch_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import torch.distributed as dist

from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset
from typing import Optional
Expand Down Expand Up @@ -87,16 +88,25 @@ def __init__(self,
self.deterministic = deterministic
self.batch_size = batch_size

def __len__(self):
return len(self.disk_dataset)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
n_shards = self.disk_dataset.get_number_shards()
if worker_info is None:
first_shard = 0
last_shard = n_shards
process_id = 0
num_processes = 1
else:
first_shard = worker_info.id * n_shards // worker_info.num_workers
last_shard = (worker_info.id +
1) * n_shards // worker_info.num_workers
process_id = worker_info.id
num_processes = worker_info.num_workers

if dist.is_initialized():
process_id += dist.get_rank() * num_processes
num_processes *= dist.get_world_size()

first_shard = (process_id * n_shards) // num_processes
last_shard = ((process_id + 1) * n_shards) // num_processes
if first_shard == last_shard:
return

Expand Down

0 comments on commit 8130f62

Please sign in to comment.