diff --git a/release/ml_user_tests/ray-lightning/simple_example.py b/release/ml_user_tests/ray-lightning/simple_example.py index 8672729d0776..bbb2f691ae86 100644 --- a/release/ml_user_tests/ray-lightning/simple_example.py +++ b/release/ml_user_tests/ray-lightning/simple_example.py @@ -9,7 +9,17 @@ from torchvision import transforms import pytorch_lightning as pl -from ray_lightning import RayPlugin +from importlib_metadata import version +from packaging.version import parse as v_parse + +rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0") + +if rlt_use_master: + # ray_lightning >= 0.3.0 + from ray_lightning import RayStrategy +else: + # ray_lightning < 0.3.0 + from ray_lightning import RayPlugin class LitAutoEncoder(pl.LightningModule): @@ -47,10 +57,16 @@ def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10): train, val = random_split(dataset, [55000, 5000]) autoencoder = LitAutoEncoder() - trainer = pl.Trainer( - plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)], - max_steps=max_steps, - ) + if rlt_use_master: + trainer = pl.Trainer( + strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu), + max_steps=max_steps, + ) + else: + trainer = pl.Trainer( + plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)], + max_steps=max_steps, + ) trainer.fit(autoencoder, DataLoader(train), DataLoader(val))