Skip to content

Commit 0cf0f83

Browse files
authored
add pytorch lightning ddp elastic example (#1671)
* add pytorch lightning ddp elastic example Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * copy updates, refactor module Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update default arg Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * fix requirements, update example Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * fix imagespec Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update formatting Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update deps Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * remove custom image name Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update image spec Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * fix formatting Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update deps Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * update imagespec Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> * add back cuda Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> --------- Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>
1 parent 374093a commit 0cf0f83

File tree

5 files changed

+238
-7
lines changed

5 files changed

+238
-7
lines changed

.github/workflows/monodocs_build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
FLYTESNACKS_LOCAL_PATH: ${{ github.workspace }}/flytesnacks
5252
run: |
5353
conda activate monodocs-env
54-
make -C docs html SPHINXOPTS="-W -vvv"
54+
make -C docs html SPHINXOPTS="-W"

examples/kfpytorch_plugin/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,21 @@ To enable the plugin in the backend, follow instructions outlined in the {std:re
2121

2222
## Run the example on the Flyte cluster
2323

24-
To run the provided example on the Flyte cluster, use the following command:
24+
To run the provided examples on the Flyte cluster, use the following commands:
25+
26+
Distributed pytorch training:
27+
28+
```
29+
pyflyte run --remote pytorch_mnist.py pytorch_training_wf
30+
```
31+
32+
Pytorch lightning training:
2533

2634
```
27-
pyflyte run --remote pytorch_mnist.py \
28-
pytorch_training_wf
35+
pyflyte run --remote pytorch_lightning_mnist_autoencoder.py train_workflow
2936
```
3037

3138
```{auto-examples-toc}
3239
pytorch_mnist
40+
pytorch_lightning_mnist_autoencoder
3341
```
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# %% [markdown]
2+
# # Use PyTorch Lightning to Train an MNIST Autoencoder
3+
#
4+
# This notebook demonstrates how to use Pytorch Lightning with Flyte's `Elastic`
5+
# task config, which is exposed by the `flytekitplugins-kfpytorch` plugin.
6+
#
7+
# First, we import all of the relevant packages.
8+
9+
import os
10+
11+
import lightning as L
12+
from flytekit import ImageSpec, PodTemplate, Resources, task, workflow
13+
from flytekit.extras.accelerators import T4
14+
from flytekit.types.directory import FlyteDirectory
15+
from flytekitplugins.kfpytorch.task import Elastic
16+
from kubernetes.client.models import (
17+
V1Container,
18+
V1EmptyDirVolumeSource,
19+
V1PodSpec,
20+
V1Volume,
21+
V1VolumeMount,
22+
)
23+
from torch import nn, optim
24+
from torch.utils.data import DataLoader
25+
from torchvision.datasets import MNIST
26+
from torchvision.transforms import ToTensor
27+
28+
# %% [markdown]
29+
# ## Image and Pod Template Configuration
30+
#
31+
# For this task, we're going to use a custom image that has all of the
32+
# necessary dependencies installed.
33+
34+
custom_image = ImageSpec(
35+
packages=[
36+
"adlfs==2024.4.1",
37+
"gcsfs==2024.3.1",
38+
"torch==2.2.1",
39+
"torchvision",
40+
"flytekitplugins-kfpytorch",
41+
"kubernetes",
42+
"lightning==2.2.4",
43+
"networkx==3.2.1",
44+
"s3fs==2024.3.1",
45+
],
46+
cuda="12.1.0",
47+
python_version="3.10",
48+
registry="ghcr.io/flyteorg",
49+
)
50+
51+
# %% [markdown]
52+
# :::{important}
53+
# Replace `ghcr.io/flyteorg` with a container registry you've access to publish to.
54+
# To upload the image to the local registry in the demo cluster, indicate the
55+
# registry as `localhost:30000`.
56+
# :::
57+
#
58+
# :::{note}
59+
# You can activate GPU support by either using the base image that includes
60+
# the necessary GPU dependencies or by specifying the `cuda` parameter in
61+
# the {py:class}`~flytekit.image_spec.ImageSpec`, for example:
62+
#
63+
# ```python
64+
# custom_image = ImageSpec(
65+
# packages=[...],
66+
# cuda="12.1.0",
67+
# ...
68+
# )
69+
# ```
70+
# :::
71+
72+
# %% [markdown]
73+
# We're also going to define a custom pod template that mounts a shared memory
74+
# volume to `/dev/shm`. This is necessary for distributed data parallel (DDP)
75+
# training so that state can be shared across workers.
76+
77+
container = V1Container(name=custom_image.name, volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")])
78+
volume = V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
79+
custom_pod_template = PodTemplate(
80+
primary_container_name=custom_image.name,
81+
pod_spec=V1PodSpec(containers=[container], volumes=[volume]),
82+
)
83+
84+
# %% [markdown]
85+
# ## Define a `LightningModule`
86+
#
87+
# Then we create a pytorch lightning module, which defines an autoencoder that
88+
# will learn how to create compressed embeddings of MNIST images.
89+
90+
91+
class MNISTAutoEncoder(L.LightningModule):
92+
def __init__(self, encoder, decoder):
93+
super().__init__()
94+
self.encoder = encoder
95+
self.decoder = decoder
96+
97+
def training_step(self, batch, batch_idx):
98+
x, y = batch
99+
x = x.view(x.size(0), -1)
100+
z = self.encoder(x)
101+
x_hat = self.decoder(z)
102+
loss = nn.functional.mse_loss(x_hat, x)
103+
self.log("train_loss", loss)
104+
return loss
105+
106+
def configure_optimizers(self):
107+
optimizer = optim.Adam(self.parameters(), lr=1e-3)
108+
return optimizer
109+
110+
111+
# %% [markdown]
112+
# ## Define a `LightningDataModule`
113+
#
114+
# Then we define a pytorch lightning data module, which defines how to prepare
115+
# and setup the training data.
116+
117+
118+
class MNISTDataModule(L.LightningDataModule):
119+
def __init__(self, root_dir, batch_size=64, dataloader_num_workers=0):
120+
super().__init__()
121+
self.root_dir = root_dir
122+
self.batch_size = batch_size
123+
self.dataloader_num_workers = dataloader_num_workers
124+
125+
def prepare_data(self):
126+
MNIST(self.root_dir, train=True, download=True)
127+
128+
def setup(self, stage=None):
129+
self.dataset = MNIST(
130+
self.root_dir,
131+
train=True,
132+
download=False,
133+
transform=ToTensor(),
134+
)
135+
136+
def train_dataloader(self):
137+
persistent_workers = self.dataloader_num_workers > 0
138+
return DataLoader(
139+
self.dataset,
140+
batch_size=self.batch_size,
141+
num_workers=self.dataloader_num_workers,
142+
persistent_workers=persistent_workers,
143+
pin_memory=True,
144+
shuffle=True,
145+
)
146+
147+
148+
# %% [markdown]
149+
# ## Creating the pytorch `Elastic` task
150+
#
151+
# With the model architecture defined, we now create a Flyte task that assumes
152+
# a world size of 16: 2 nodes with 8 devices each. We also set the `max_restarts`
153+
# to `3` so that the task can be retried up to 3 times in case it fails for
154+
# whatever reason, and we set `rdzv_configs` to have a generous timeout so that
155+
# the head and worker nodes have enought time to connect to each other.
156+
#
157+
# This task will output a {ref}`FlyteDirectory <folder>`, which will contain the
158+
# model checkpoint that will result from training.
159+
160+
NUM_NODES = 2
161+
NUM_DEVICES = 8
162+
163+
164+
@task(
165+
container_image=custom_image,
166+
task_config=Elastic(
167+
nnodes=NUM_NODES,
168+
nproc_per_node=NUM_DEVICES,
169+
rdzv_configs={"timeout": 36000, "join_timeout": 36000},
170+
max_restarts=3,
171+
),
172+
accelerator=T4,
173+
requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
174+
pod_template=custom_pod_template,
175+
)
176+
def train_model(dataloader_num_workers: int) -> FlyteDirectory:
177+
"""Train an autoencoder model on the MNIST."""
178+
179+
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
180+
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
181+
autoencoder = MNISTAutoEncoder(encoder, decoder)
182+
183+
root_dir = os.getcwd()
184+
data = MNISTDataModule(
185+
root_dir,
186+
batch_size=4,
187+
dataloader_num_workers=dataloader_num_workers,
188+
)
189+
190+
model_dir = os.path.join(root_dir, "model")
191+
trainer = L.Trainer(
192+
default_root_dir=model_dir,
193+
max_epochs=3,
194+
num_nodes=NUM_NODES,
195+
devices=NUM_DEVICES,
196+
accelerator="gpu",
197+
strategy="ddp",
198+
precision="16-mixed",
199+
)
200+
trainer.fit(model=autoencoder, datamodule=data)
201+
return FlyteDirectory(path=str(model_dir))
202+
203+
204+
# %% [markdown]
205+
# Finally, we wrap it all up in a workflow.
206+
207+
208+
@workflow
209+
def train_workflow(dataloader_num_workers: int = 1) -> FlyteDirectory:
210+
return train_model(dataloader_num_workers=dataloader_num_workers)

examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,19 @@
4545
from torchvision import datasets, transforms
4646

4747
# %% [markdown]
48-
# You can activate GPU support by either using the base image that includes the necessary GPU dependencies
49-
# or by initializing the [CUDA parameters](https://github.com/flyteorg/flytekit/blob/master/flytekit/image_spec/image_spec.py#L34-L35)
50-
# within the `ImageSpec`.
48+
# :::{note}
49+
# You can activate GPU support by either using the base image that includes
50+
# the necessary GPU dependencies or by specifying the `cuda` parameter in
51+
# the {py:class}`~flytekit.image_spec.ImageSpec`, for example:
52+
#
53+
# ```python
54+
# custom_image = ImageSpec(
55+
# packages=[...],
56+
# cuda="12.1.0",
57+
# ...
58+
# )
59+
# ````
60+
# :::
5161
#
5262
# Adjust memory, GPU usage and storage settings based on whether you are
5363
# registering against the demo cluster or not.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
flytekit
22
flytekitplugins-kfpytorch
3+
kubernetes
4+
lightning
35
matplotlib
46
torch
57
tensorboardX
68
torchvision
9+
lightning

0 commit comments

Comments
 (0)