18
18
19
19
import logging
20
20
21
+ TORCH_AVAILABLE = isinstance (torch , NoSuchModule )
22
+
21
23
22
24
# Example 2D source
23
25
def example_2d_source (array_key : ArrayKey ):
@@ -52,7 +54,7 @@ def example_train_source(a_key, b_key, c_key):
52
54
return (source_a , source_b , source_c ) + MergeProvider ()
53
55
54
56
55
- if not isinstance ( torch , NoSuchModule ) :
57
+ if not TORCH_AVAILABLE :
56
58
57
59
class ExampleLinearModel (torch .nn .Module ):
58
60
def __init__ (self ):
@@ -68,15 +70,16 @@ def forward(self, a, b):
68
70
return d_pred
69
71
70
72
71
- @skipIf (isinstance ( torch , NoSuchModule ) , "torch is not installed" )
73
+ @skipIf (TORCH_AVAILABLE , "torch is not installed" )
72
74
@pytest .mark .parametrize (
73
75
"device" ,
74
76
[
75
77
"cpu" ,
76
78
pytest .param (
77
79
"cuda:0" ,
78
80
marks = pytest .mark .skipif (
79
- not torch .cuda .is_available (), reason = "CUDA not available"
81
+ TORCH_AVAILABLE or not torch .cuda .is_available (),
82
+ reason = "CUDA not available" ,
80
83
),
81
84
),
82
85
],
@@ -143,7 +146,7 @@ def test_loss_drops(tmpdir, device):
143
146
assert loss2 < loss1
144
147
145
148
146
- @skipIf (isinstance ( torch , NoSuchModule ) , "torch is not installed" )
149
+ @skipIf (TORCH_AVAILABLE , "torch is not installed" )
147
150
@pytest .mark .parametrize (
148
151
"device" ,
149
152
[
@@ -152,7 +155,8 @@ def test_loss_drops(tmpdir, device):
152
155
"cuda:0" ,
153
156
marks = [
154
157
pytest .mark .skipif (
155
- not torch .cuda .is_available (), reason = "CUDA not available"
158
+ TORCH_AVAILABLE or not torch .cuda .is_available (),
159
+ reason = "CUDA not available" ,
156
160
),
157
161
pytest .mark .xfail (
158
162
reason = "failing to move model to device when using a subprocess"
@@ -207,7 +211,7 @@ def test_output(device):
207
211
assert np .isclose (batch2 [d_pred ].data , 2 * (1 + 4 * 2 + 9 * 3 ))
208
212
209
213
210
- if not isinstance ( torch , NoSuchModule ) :
214
+ if not TORCH_AVAILABLE :
211
215
212
216
class Example2DModel (torch .nn .Module ):
213
217
def __init__ (self ):
@@ -222,7 +226,7 @@ def forward(self, a):
222
226
return pred
223
227
224
228
225
- @skipIf (isinstance ( torch , NoSuchModule ) , "torch is not installed" )
229
+ @skipIf (TORCH_AVAILABLE , "torch is not installed" )
226
230
@pytest .mark .parametrize (
227
231
"device" ,
228
232
[
@@ -231,7 +235,8 @@ def forward(self, a):
231
235
"cuda:0" ,
232
236
marks = [
233
237
pytest .mark .skipif (
234
- not torch .cuda .is_available (), reason = "CUDA not available"
238
+ TORCH_AVAILABLE or not torch .cuda .is_available (),
239
+ reason = "CUDA not available" ,
235
240
),
236
241
pytest .mark .xfail (
237
242
reason = "failing to move model to device in multiprocessing context"
@@ -275,7 +280,7 @@ def test_scan(device):
275
280
assert pred in batch
276
281
277
282
278
- @skipIf (isinstance ( torch , NoSuchModule ) , "torch is not installed" )
283
+ @skipIf (TORCH_AVAILABLE , "torch is not installed" )
279
284
@pytest .mark .parametrize (
280
285
"device" ,
281
286
[
@@ -284,7 +289,8 @@ def test_scan(device):
284
289
"cuda:0" ,
285
290
marks = [
286
291
pytest .mark .skipif (
287
- not torch .cuda .is_available (), reason = "CUDA not available"
292
+ TORCH_AVAILABLE or not torch .cuda .is_available (),
293
+ reason = "CUDA not available" ,
288
294
),
289
295
pytest .mark .xfail (
290
296
reason = "failing to move model to device in multiprocessing context"
0 commit comments