Skip to content

list index out of range with wids.ShardListDataset #2

@ThisisBillhe

Description

@ThisisBillhe

I am training models with wids.ShardListDataset in a DDP env:

def make_sample(sample):
    # print(sample)
    image = sample[".jpg"]
    label = sample[".json"]["prompt"]

    return transform(image), label

dataset= wids.ShardListDataset(trainset_url, keep=True)
dataset= dataset.add_transform(make_sample)  
sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
  loader = DataLoader(
      dataset, 
      batch_size=local_batch_size, 
      sampler=sampler, 
      num_workers=args.num_workers,
      pin_memory=True,
      drop_last=True
  )

But I got errors after training for several iterations:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/mnt/petrelfs/heyefei/ZipAR-X/autoregressive/train/train_t2i_webdata.py", line 363, in <module>
[rank2]:     main(args)
[rank2]:   File "/mnt/petrelfs/heyefei/ZipAR-X/autoregressive/train/train_t2i_webdata.py", line 214, in main
[rank2]:     for x, caption in loader:
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
[rank2]:     data = self._next_data()
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
[rank2]:     return self._process_data(data)
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
[rank2]:     data.reraise()
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
[rank2]:     raise exception
[rank2]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank2]: Original Traceback (most recent call last):
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
[rank2]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
[rank2]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
[rank2]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/wids/wids.py", line 575, in __getitem__
[rank2]:     sample = shard[inner_idx]
[rank2]:   File "/mnt/petrelfs/heyefei/anaconda3/envs/llamagen/lib/python3.10/site-packages/wids/wids.py", line 265, in __getitem__
[rank2]:     indexes = self.samples[idx]
[rank2]: IndexError: list index out of range

The dataset is downloaded from HF, and the trainset_url is a json file like follows:

{
  "__kind__": "wids-shard-index-v1",
  "wids_version": 1,
  "shardlist": [
    {
      "url": "t2i_2M/data_000000.tar",
      "nsamples": 50000,
      "filesize": 7475189760
    },
    {
      "url": "t2i_2M/data_000001.tar",
      "nsamples": 50000,
      "filesize": 7457955840
    },
    {
      "url": "t2i_2M/data_000002.tar",
      "nsamples": 49985,
      "filesize": 11103385600
    },

Could this error be related to the different number of samples in each tar?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions