Skip to content

Commit

Permalink
[release/ray-lightning] adjust the release test of ray lightning master
Browse files Browse the repository at this point in the history
First of all, sorry i messed up with the previous pr when sync with the master (#27374). This PR is the duplicate of previous pr until we update the changes (change: adding the version check for the ray_lightning for the compatibility). Also, apology for the massive review requests on the previous PR.
  • Loading branch information
JiahaoYao authored Aug 3, 2022
1 parent 20119c7 commit 1c1cca2
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions release/ml_user_tests/ray-lightning/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
from torchvision import transforms
import pytorch_lightning as pl

from ray_lightning import RayPlugin
from importlib_metadata import version
from packaging.version import parse as v_parse

rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")

if rlt_use_master:
# ray_lightning >= 0.3.0
from ray_lightning import RayStrategy
else:
# ray_lightning < 0.3.0
from ray_lightning import RayPlugin


class LitAutoEncoder(pl.LightningModule):
Expand Down Expand Up @@ -47,10 +57,16 @@ def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer(
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
max_steps=max_steps,
)
if rlt_use_master:
trainer = pl.Trainer(
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
max_steps=max_steps,
)
else:
trainer = pl.Trainer(
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))


Expand Down

0 comments on commit 1c1cca2

Please sign in to comment.