Skip to content

Commit 0afd759

Browse files
committed
Move code to python
1 parent 19d1ebd commit 0afd759

File tree

5 files changed

+92
-98
lines changed

5 files changed

+92
-98
lines changed

docs/examples/data/torchvision/README.rst

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ repository.
6969
+# Copy the dataset to $SLURM_TMPDIR so it is close to the GPUs for
7070
+# faster training
7171
+srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
72-
+ time -p bash data.sh "/network/datasets/inat" ${_DATA_PREP_WORKERS}
72+
+ time -p bash data.py "/network/datasets/inat" ${_DATA_PREP_WORKERS}
7373
7474
7575
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
@@ -293,67 +293,64 @@ repository.
293293
main()
294294
295295
296-
**data.sh**
296+
**data.py**
297297

298-
.. code:: bash
298+
.. code:: python
299299
300-
#!/bin/bash
301-
set -o errexit
300+
"""Make sure the data is available"""
301+
import os
302+
import shutil
303+
import sys
304+
import time
305+
from multiprocessing import Pool
306+
from pathlib import Path
302307
303-
function ln_files {
304-
# Clone the dataset structure of `src` to `dest` with symlinks and using
305-
# `workers` numbre of workers (defaults to 4)
306-
local src=$1
307-
local dest=$2
308-
local workers=${3:-4}
308+
from torchvision.datasets import INaturalist
309309
310-
(cd "${src}" && find -L * -type f) | while read f
311-
do
312-
mkdir --parents "${dest}/$(dirname "$f")"
313-
# echo source first so it is matched to the ln's '-T' argument
314-
readlink --canonicalize "${src}/$f"
315-
# echo output last so ln understands it's the output file
316-
echo "${dest}/$f"
317-
done | xargs -n2 -P${workers} ln --symbolic --force -T
318-
}
319310
320-
_SRC=$1
321-
_WORKERS=$2
322-
# Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
323-
# environment variable will only be resolved on the worker node (i.e. not
324-
# referencing the $SLURM_TMPDIR of the master node)
325-
_DEST=$SLURM_TMPDIR/data
311+
def link_file(src:str, dest:str):
312+
Path(src).symlink_to(dest)
326313
327-
ln_files "${_SRC}" "${_DEST}" ${_WORKERS}
328314
329-
# Reorganise the files if needed
330-
(
331-
cd "${_DEST}"
332-
# Torchvision expects these names
333-
mv train.tar.gz 2021_train.tgz
334-
mv val.tar.gz 2021_valid.tgz
335-
)
315+
def link_files(src:str, dest:str, workers=4):
316+
src = Path(src)
317+
dest = Path(dest)
318+
os.makedirs(dest, exist_ok=True)
319+
with Pool(processes=workers) as pool:
320+
for path, dnames, fnames in os.walk(str(src)):
321+
rel_path = Path(path).relative_to(src)
322+
fnames = map(lambda _f: rel_path / _f, fnames)
323+
dnames = map(lambda _d: rel_path / _d, dnames)
324+
for d in dnames:
325+
os.makedirs(str(dest / d), exist_ok=True)
326+
pool.starmap(
327+
link_file,
328+
[(src / _f, dest / _f) for _f in fnames]
329+
)
336330
337-
# Extract and prepare the data
338-
python3 data.py "${_DEST}"
339331
332+
if __name__ == "__main__":
333+
src = Path(sys.argv[1])
334+
workers = int(sys.argv[2])
335+
# Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
336+
# environment variable will only be resolved on the worker node (i.e. not
337+
# referencing the $SLURM_TMPDIR of the master node)
338+
dest = Path(os.environ["SLURM_TMPDIR"]) / "dest"
340339
341-
**data.py**
340+
start_time = time.time()
342341
343-
.. code:: python
342+
link_files(src, dest, workers)
344343
345-
"""Make sure the data is available"""
346-
import sys
347-
import time
344+
# Torchvision expects these names
345+
shutil.move(dest / "train.tar.gz", dest / "2021_train.tgz")
346+
shutil.move(dest / "val.tar.gz", dest / "2021_valid.tgz")
348347
349-
from torchvision.datasets import INaturalist
348+
INaturalist(root=dest, version="2021_train", download=True)
349+
INaturalist(root=dest, version="2021_valid", download=True)
350350
351+
seconds_spent = time.time() - start_time
351352
352-
start_time = time.time()
353-
INaturalist(root=sys.argv[1], version="2021_train", download=True)
354-
INaturalist(root=sys.argv[1], version="2021_valid", download=True)
355-
seconds_spent = time.time() - start_time
356-
print(f"Prepared data in {seconds_spent/60:.2f}m")
353+
print(f"Prepared data in {seconds_spent/60:.2f}m")
357354
358355
359356
**Running this example**

docs/examples/data/torchvision/_index.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ repository.
2727
:language: diff
2828

2929

30-
**data.sh**
31-
32-
.. literalinclude:: examples/data/torchvision/data.sh
33-
:language: bash
34-
35-
3630
**data.py**
3731

3832
.. literalinclude:: examples/data/torchvision/data.py
Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,54 @@
11
"""Make sure the data is available"""
2+
import os
3+
import shutil
24
import sys
35
import time
6+
from multiprocessing import Pool
7+
from pathlib import Path
48

59
from torchvision.datasets import INaturalist
610

711

8-
start_time = time.time()
9-
INaturalist(root=sys.argv[1], version="2021_train", download=True)
10-
INaturalist(root=sys.argv[1], version="2021_valid", download=True)
11-
seconds_spent = time.time() - start_time
12-
print(f"Prepared data in {seconds_spent/60:.2f}m")
12+
def link_file(src:str, dest:str):
13+
Path(src).symlink_to(dest)
14+
15+
16+
def link_files(src:str, dest:str, workers=4):
17+
src = Path(src)
18+
dest = Path(dest)
19+
os.makedirs(dest, exist_ok=True)
20+
with Pool(processes=workers) as pool:
21+
for path, dnames, fnames in os.walk(str(src)):
22+
rel_path = Path(path).relative_to(src)
23+
fnames = map(lambda _f: rel_path / _f, fnames)
24+
dnames = map(lambda _d: rel_path / _d, dnames)
25+
for d in dnames:
26+
os.makedirs(str(dest / d), exist_ok=True)
27+
pool.starmap(
28+
link_file,
29+
[(src / _f, dest / _f) for _f in fnames]
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
src = Path(sys.argv[1])
35+
workers = int(sys.argv[2])
36+
# Referencing $SLURM_TMPDIR here instead of job.sh makes sure that the
37+
# environment variable will only be resolved on the worker node (i.e. not
38+
# referencing the $SLURM_TMPDIR of the master node)
39+
dest = Path(os.environ["SLURM_TMPDIR"]) / "dest"
40+
41+
start_time = time.time()
42+
43+
link_files(src, dest, workers)
44+
45+
# Torchvision expects these names
46+
shutil.move(dest / "train.tar.gz", dest / "2021_train.tgz")
47+
shutil.move(dest / "val.tar.gz", dest / "2021_valid.tgz")
48+
49+
INaturalist(root=dest, version="2021_train", download=True)
50+
INaturalist(root=dest, version="2021_valid", download=True)
51+
52+
seconds_spent = time.time() - start_time
53+
54+
print(f"Prepared data in {seconds_spent/60:.2f}m")

docs/examples/data/torchvision/data.sh

Lines changed: 0 additions & 39 deletions
This file was deleted.

docs/examples/data/torchvision/job.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ mkdir -p "$SLURM_TMPDIR/data"
3939
# Copy the dataset to $SLURM_TMPDIR so it is close to the GPUs for
4040
# faster training
4141
srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
42-
time -p bash data.sh "/network/datasets/inat" ${_DATA_PREP_WORKERS}
42+
time -p bash data.py "/network/datasets/inat" ${_DATA_PREP_WORKERS}
4343

4444

4545
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0

0 commit comments

Comments
 (0)