Skip to content

Commit 514e2f8

Browse files
pranavvp16vfdev-5
andauthored
Fix failing mps tests (#3145)
* Fix failing mps tests * Fix failing tests * tests * remove unnecessary changes * Fix flake8 errors --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 8fb3ae2 commit 514e2f8

File tree

4 files changed

+9
-18
lines changed

4 files changed

+9
-18
lines changed

tests/ignite/distributed/comp_models/test_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from ignite.distributed.comp_models.base import _SerialModel, _torch_version_le_112, ComputationModel
55

66

7-
@pytest.mark.skipif(
8-
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
9-
)
107
def test_serial_model():
118
_SerialModel.create_from_backend()
129
model = _SerialModel.create_from_context()
@@ -19,6 +16,8 @@ def test_serial_model():
1916
assert model.get_node_rank() == 0
2017
if torch.cuda.is_available():
2118
assert model.device().type == "cuda"
19+
elif _torch_version_le_112 and torch.backends.mps.is_available():
20+
assert model.device().type == "mps"
2221
else:
2322
assert model.device().type == "cpu"
2423
assert model.backend() is None

tests/ignite/distributed/test_auto.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import ignite.distributed as idist
1414
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
15-
from ignite.distributed.comp_models.base import _torch_version_le_112
1615

1716

1817
class DummyDS(Dataset):
@@ -180,16 +179,13 @@ def _test_auto_model_optimizer(ws, device):
180179
assert optimizer.backward_passes_per_step == backward_passes_per_step
181180

182181

183-
@pytest.mark.skipif(
184-
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
185-
)
186182
def test_auto_methods_no_dist():
187183
_test_auto_dataloader(1, 1, batch_size=1)
188184
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)
189185
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")
190186
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="DistributedSampler")
191-
192-
_test_auto_model_optimizer(1, "cuda" if torch.cuda.is_available() else "cpu")
187+
device = idist.device()
188+
_test_auto_model_optimizer(1, device)
193189

194190

195191
@pytest.mark.distributed

tests/ignite/distributed/test_launcher.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from packaging.version import Version
99

1010
import ignite.distributed as idist
11-
from ignite.distributed.comp_models.base import _torch_version_le_112
1211
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
1312

1413

@@ -258,11 +257,8 @@ def test_idist_parallel_n_procs_native(init_method, backend, get_fixed_dirname,
258257

259258

260259
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
261-
@pytest.mark.skipif(
262-
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
263-
)
264260
def test_idist_parallel_no_dist():
265-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261+
device = idist.device()
266262
with idist.Parallel(backend=None) as parallel:
267263
parallel.run(_test_func, ws=1, device=device, backend=None, true_init_method=None)
268264

tests/ignite/distributed/utils/test_serial.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32

43
import ignite.distributed as idist
@@ -15,13 +14,12 @@
1514
)
1615

1716

18-
@pytest.mark.skipif(
19-
_torch_version_le_112 and torch.backends.mps.is_available(), reason="Temporary skip if MPS is available"
20-
)
2117
def test_no_distrib(capsys):
2218
assert idist.backend() is None
2319
if torch.cuda.is_available():
2420
assert idist.device().type == "cuda"
21+
elif _torch_version_le_112 and torch.backends.mps.is_available():
22+
assert idist.device().type == "mps"
2523
else:
2624
assert idist.device().type == "cpu"
2725
assert idist.get_rank() == 0
@@ -43,6 +41,8 @@ def test_no_distrib(capsys):
4341
assert "ignite.distributed.utils INFO: backend: None" in out[-1]
4442
if torch.cuda.is_available():
4543
assert "ignite.distributed.utils INFO: device: cuda" in out[-1]
44+
elif _torch_version_le_112 and torch.backends.mps.is_available():
45+
assert "ignite.distributed.utils INFO: device: mps" in out[-1]
4646
else:
4747
assert "ignite.distributed.utils INFO: device: cpu" in out[-1]
4848
assert "ignite.distributed.utils INFO: rank: 0" in out[-1]

0 commit comments

Comments
 (0)