@@ -69,7 +69,7 @@ repository.
69
69
+# Copy the dataset to $SLURM_TMPDIR so it is close to the GPUs for
70
70
+# faster training
71
71
+srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
72
- + time -p bash data.sh "/network/datasets/inat" ${_DATA_PREP_WORKERS}
72
+ + time -p bash data.py "/network/datasets/inat" ${_DATA_PREP_WORKERS}
73
73
74
74
75
75
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
@@ -293,67 +293,64 @@ repository.
293
293
main()
294
294
295
295
296
- **data.sh **
296
+ **data.py **
297
297
298
- .. code :: bash
298
+ .. code :: python
299
299
300
- #! /bin/bash
301
- set -o errexit
300
+ """ Make sure the data is available"""
301
+ import os
302
+ import shutil
303
+ import sys
304
+ import time
305
+ from multiprocessing import Pool
306
+ from pathlib import Path
302
307
303
- function ln_files {
304
- # Clone the dataset structure of `src` to `dest` with symlinks and using
305
- # `workers` numbre of workers (defaults to 4)
306
- local src=$1
307
- local dest=$2
308
- local workers=${3:- 4}
308
+ from torchvision.datasets import INaturalist
309
309
310
- (cd " ${src} " && find -L * -type f) | while read f
311
- do
312
- mkdir --parents " ${dest} /$( dirname " $f " ) "
313
- # echo source first so it is matched to the ln's '-T' argument
314
- readlink --canonicalize " ${src} /$f "
315
- # echo output last so ln understands it's the output file
316
- echo " ${dest} /$f "
317
- done | xargs -n2 -P${workers} ln --symbolic --force -T
318
- }
319
310
320
- _SRC=$1
321
- _WORKERS=$2
322
- # Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
323
- # environment variable will only be resolved on the worker node (i.e. not
324
- # referencing the $SLURM_TMPDIR of the master node)
325
- _DEST=$SLURM_TMPDIR /data
311
+ def link_file (src :str , dest :str ):
312
+ Path(src).symlink_to(dest)
326
313
327
- ln_files " ${_SRC} " " ${_DEST} " ${_WORKERS}
328
314
329
- # Reorganise the files if needed
330
- (
331
- cd " ${_DEST} "
332
- # Torchvision expects these names
333
- mv train.tar.gz 2021_train.tgz
334
- mv val.tar.gz 2021_valid.tgz
335
- )
315
+ def link_files (src :str , dest :str , workers = 4 ):
316
+ src = Path(src)
317
+ dest = Path(dest)
318
+ os.makedirs(dest, exist_ok = True )
319
+ with Pool(processes = workers) as pool:
320
+ for path, dnames, fnames in os.walk(str (src)):
321
+ rel_path = Path(path).relative_to(src)
322
+ fnames = map (lambda _f : rel_path / _f, fnames)
323
+ dnames = map (lambda _d : rel_path / _d, dnames)
324
+ for d in dnames:
325
+ os.makedirs(str (dest / d), exist_ok = True )
326
+ pool.starmap(
327
+ link_file,
328
+ [(src / _f, dest / _f) for _f in fnames]
329
+ )
336
330
337
- # Extract and prepare the data
338
- python3 data.py " ${_DEST} "
339
331
332
+ if __name__ == " __main__" :
333
+ src = Path(sys.argv[1 ])
334
+ workers = int (sys.argv[2 ])
335
+ # Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
336
+ # environment variable will only be resolved on the worker node (i.e. not
337
+ # referencing the $SLURM_TMPDIR of the master node)
338
+ dest = Path(os.environ[" SLURM_TMPDIR" ]) / " dest"
340
339
341
- ** data.py **
340
+ start_time = time.time()
342
341
343
- .. code :: python
342
+ link_files(src, dest, workers)
344
343
345
- """ Make sure the data is available """
346
- import sys
347
- import time
344
+ # Torchvision expects these names
345
+ shutil.move(dest / " train.tar.gz " , dest / " 2021_train.tgz " )
346
+ shutil.move(dest / " val.tar.gz " , dest / " 2021_valid.tgz " )
348
347
349
- from torchvision.datasets import INaturalist
348
+ INaturalist(root = dest, version = " 2021_train" , download = True )
349
+ INaturalist(root = dest, version = " 2021_valid" , download = True )
350
350
351
+ seconds_spent = time.time() - start_time
351
352
352
- start_time = time.time()
353
- INaturalist(root = sys.argv[1 ], version = " 2021_train" , download = True )
354
- INaturalist(root = sys.argv[1 ], version = " 2021_valid" , download = True )
355
- seconds_spent = time.time() - start_time
356
- print (f " Prepared data in { seconds_spent/ 60 :.2f } m " )
353
+ print (f " Prepared data in { seconds_spent/ 60 :.2f } m " )
357
354
358
355
359
356
**Running this example **
0 commit comments