Skip to content

Commit

Permalink
Improve PyTorch Performance with Multiple Workers and Shuffling (#207)
Browse files Browse the repository at this point in the history
* removed row shuffling, kept only batch shuffling.

* added persistent_workers attribute to PyTorch reader unit tests.

* updated PyTorch dense reader example.
  • Loading branch information
georgeSkoumas authored Apr 18, 2023
1 parent 03a8602 commit 858177a
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 34 deletions.
166 changes: 141 additions & 25 deletions examples/readers/pytorch_data_api_tiledb_dense.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,110 @@
"name": "#%%\n"
}
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0.3%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
"Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"19.9%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"data_home = os.path.join(os.path.pardir, \"data\")\n",
"data = torchvision.datasets.MNIST(root=data_home, train=False, download=True)"
Expand Down Expand Up @@ -143,7 +246,16 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:448: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
" warnings.warn(\n"
]
}
],
"source": [
"data_dir = os.path.join(data_home, 'readers', 'pytorch', 'dense')\n",
"os.makedirs(data_dir, exist_ok=True)\n",
Expand Down Expand Up @@ -212,14 +324,6 @@
")\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/george/PycharmProjects/TileDB-ML/.venv/lib/python3.9/site-packages/tiledb/ctx.py:410: UserWarning: tiledb.default_ctx and scope_ctx will not function correctly due to bug in IPython contextvar support. You must supply a Ctx object to each function for custom configuration options. Please consider upgrading to ipykernel >= 6!Please see https://github.com/TileDB-Inc/TileDB-Py/issues/667 for more information.\n",
" warnings.warn(\n"
]
}
],
"source": [
Expand Down Expand Up @@ -254,7 +358,7 @@
"outputs": [
{
"data": {
"text/plain": "<matplotlib.image.AxesImage at 0x12de28bb0>"
"text/plain": "<matplotlib.image.AxesImage at 0x123788ca0>"
},
"execution_count": 7,
"metadata": {},
Expand Down Expand Up @@ -305,20 +409,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 Batch: 0 Loss: 2.299262\n",
"Train Epoch: 1 Batch: 100 Loss: 2.262452\n",
"Train Epoch: 1 Batch: 200 Loss: 2.162849\n",
"Train Epoch: 1 Batch: 300 Loss: 1.927302\n",
"Train Epoch: 1 Batch: 400 Loss: 1.646087\n",
"Train Epoch: 2 Batch: 0 Loss: 1.446454\n",
"Train Epoch: 2 Batch: 100 Loss: 1.314963\n",
"Train Epoch: 2 Batch: 200 Loss: 1.376722\n",
"Train Epoch: 2 Batch: 300 Loss: 1.400400\n",
"Train Epoch: 2 Batch: 400 Loss: 1.291488\n"
"Train Epoch: 1 Batch: 0 Loss: 2.304748\n",
"Train Epoch: 1 Batch: 100 Loss: 2.277155\n",
"Train Epoch: 1 Batch: 200 Loss: 2.203359\n",
"Train Epoch: 1 Batch: 300 Loss: 1.895098\n",
"Train Epoch: 1 Batch: 400 Loss: 1.497304\n",
"Train Epoch: 2 Batch: 0 Loss: 1.435658\n",
"Train Epoch: 2 Batch: 100 Loss: 1.305221\n",
"Train Epoch: 2 Batch: 200 Loss: 0.990590\n",
"Train Epoch: 2 Batch: 300 Loss: 1.103210\n",
"Train Epoch: 2 Batch: 400 Loss: 0.903957\n"
]
}
],
"source": [
"import multiprocessing\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
Expand Down Expand Up @@ -349,11 +454,15 @@
" img = np.clip(img,0,1)\n",
" return img\n",
"\n",
"\n",
"ctx = tiledb.Ctx({'sm.memory_budget': 1024**2})\n",
"with tiledb.open(training_images, ctx=ctx) as x, tiledb.open(training_labels, ctx=ctx) as y:\n",
" # Because of this issue (https://github.com/pytorch/pytorch/issues/59451#issuecomment-854883855) we avoid using multiple workers on Jupyter.\n",
" train_loader = PyTorchTileDBDataLoader(\n",
" ArrayParams(x, fn=do_random_noise), ArrayParams(y), batch_size=128,\n",
" ArrayParams(x, fn=do_random_noise),\n",
" ArrayParams(y),\n",
" batch_size=128,\n",
" num_workers=0,\n",
" shuffle_buffer_size=256,\n",
" )\n",
"\n",
" net = Net(shape=(28, 28))\n",
Expand All @@ -374,11 +483,18 @@
" print('Train Epoch: {} Batch: {} Loss: {:.6f}'.format(\n",
" epoch, batch_idx, loss.item()))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -392,7 +508,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
"version": "3.9.9"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions tests/readers/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ def test_dataloader(
):
def test(*all_array_params):
try:
persistent_workers = num_workers > 0

dataloader = PyTorchTileDBDataLoader(
*all_array_params,
shuffle_buffer_size=shuffle_buffer_size,
batch_size=batch_size,
num_workers=num_workers,
persistent_workers=persistent_workers,
)

except NotImplementedError:
assert num_workers and (
torchdata.__version__ < "0.4"
Expand Down
9 changes: 0 additions & 9 deletions tiledb/ml/readers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,8 @@ def PyTorchTileDBDataLoader(

# shuffle the unbatched rows if shuffle_buffer_size > 0
if shuffle_buffer_size:
# load the rows to be shuffled
# don't batch them (batch_size=None) or collate them (collate_fn=_identity)
row_loader = DataLoader(
datapipe, num_workers=num_workers, batch_size=None, collate_fn=_identity
)
# create a new datapipe for these rows
datapipe = DeferredIterableIterDataPipe(iter, row_loader)
# shuffle the datapipe items
datapipe = datapipe.shuffle(buffer_size=shuffle_buffer_size)
# run the shuffling on this process, not on workers
kwargs["num_workers"] = 0

# construct an appropriate collate function
collator = Collator.from_schemas(*schemas)
Expand Down

0 comments on commit 858177a

Please sign in to comment.