Skip to content

Commit

Permalink
properly skip torch tests if torch not installed
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 2, 2023
1 parent 865fe53 commit 2078036
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/cases/torch_train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from .helper_sources import ArraySource
from gunpowder import (
BatchProvider,
BatchRequest,
ArraySpec,
Roi,
Coordinate,
ArrayKeys,
ArrayKey,
Array,
Batch,
Scan,
PreCache,
MergeProvider,
build,
)
from gunpowder.ext import torch, NoSuchModule
from gunpowder.torch import Train, Predict
from unittest import skipIf, expectedFailure
from unittest import skipIf
import numpy as np
import pytest

import logging

Expand Down Expand Up @@ -55,8 +52,7 @@ def example_train_source(a_key, b_key, c_key):
return (source_a, source_b, source_c) + MergeProvider()


if torch is not NoSuchModule:

if not isinstance(torch, NoSuchModule):
class ExampleLinearModel(torch.nn.Module):
def __init__(self):
super(ExampleLinearModel, self).__init__()
Expand Down Expand Up @@ -179,7 +175,7 @@ def test_output():
assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3))


if torch is not NoSuchModule:
if not isinstance(torch, NoSuchModule):

class Example2DModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -229,6 +225,7 @@ def test_scan():
assert pred in batch


@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
def test_precache():
logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO)

Expand Down

0 comments on commit 2078036

Please sign in to comment.