Skip to content

Commit 4be523e

Browse files
committed
check if torch is installed before checking if cuda is available
1 parent bcfb7d7 commit 4be523e

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tests/cases/torch_train.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import logging
2020

21+
TORCH_AVAILABLE = isinstance(torch, NoSuchModule)
22+
2123

2224
# Example 2D source
2325
def example_2d_source(array_key: ArrayKey):
@@ -52,7 +54,7 @@ def example_train_source(a_key, b_key, c_key):
5254
return (source_a, source_b, source_c) + MergeProvider()
5355

5456

55-
if not isinstance(torch, NoSuchModule):
57+
if not TORCH_AVAILABLE:
5658

5759
class ExampleLinearModel(torch.nn.Module):
5860
def __init__(self):
@@ -68,15 +70,16 @@ def forward(self, a, b):
6870
return d_pred
6971

7072

71-
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
73+
@skipIf(TORCH_AVAILABLE, "torch is not installed")
7274
@pytest.mark.parametrize(
7375
"device",
7476
[
7577
"cpu",
7678
pytest.param(
7779
"cuda:0",
7880
marks=pytest.mark.skipif(
79-
not torch.cuda.is_available(), reason="CUDA not available"
81+
TORCH_AVAILABLE or not torch.cuda.is_available(),
82+
reason="CUDA not available",
8083
),
8184
),
8285
],
@@ -143,7 +146,7 @@ def test_loss_drops(tmpdir, device):
143146
assert loss2 < loss1
144147

145148

146-
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
149+
@skipIf(TORCH_AVAILABLE, "torch is not installed")
147150
@pytest.mark.parametrize(
148151
"device",
149152
[
@@ -152,7 +155,8 @@ def test_loss_drops(tmpdir, device):
152155
"cuda:0",
153156
marks=[
154157
pytest.mark.skipif(
155-
not torch.cuda.is_available(), reason="CUDA not available"
158+
TORCH_AVAILABLE or not torch.cuda.is_available(),
159+
reason="CUDA not available",
156160
),
157161
pytest.mark.xfail(
158162
reason="failing to move model to device when using a subprocess"
@@ -207,7 +211,7 @@ def test_output(device):
207211
assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3))
208212

209213

210-
if not isinstance(torch, NoSuchModule):
214+
if not TORCH_AVAILABLE:
211215

212216
class Example2DModel(torch.nn.Module):
213217
def __init__(self):
@@ -222,7 +226,7 @@ def forward(self, a):
222226
return pred
223227

224228

225-
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
229+
@skipIf(TORCH_AVAILABLE, "torch is not installed")
226230
@pytest.mark.parametrize(
227231
"device",
228232
[
@@ -231,7 +235,8 @@ def forward(self, a):
231235
"cuda:0",
232236
marks=[
233237
pytest.mark.skipif(
234-
not torch.cuda.is_available(), reason="CUDA not available"
238+
TORCH_AVAILABLE or not torch.cuda.is_available(),
239+
reason="CUDA not available",
235240
),
236241
pytest.mark.xfail(
237242
reason="failing to move model to device in multiprocessing context"
@@ -275,7 +280,7 @@ def test_scan(device):
275280
assert pred in batch
276281

277282

278-
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
283+
@skipIf(TORCH_AVAILABLE, "torch is not installed")
279284
@pytest.mark.parametrize(
280285
"device",
281286
[
@@ -284,7 +289,8 @@ def test_scan(device):
284289
"cuda:0",
285290
marks=[
286291
pytest.mark.skipif(
287-
not torch.cuda.is_available(), reason="CUDA not available"
292+
TORCH_AVAILABLE or not torch.cuda.is_available(),
293+
reason="CUDA not available",
288294
),
289295
pytest.mark.xfail(
290296
reason="failing to move model to device in multiprocessing context"

0 commit comments

Comments
 (0)