Skip to content

Commit

Permalink
parameterize tests for cuda devices
Browse files Browse the repository at this point in the history
currently failing a few of them, some are expected failures.
  • Loading branch information
pattonw committed Dec 19, 2023
1 parent 076661f commit b6c425f
Showing 1 changed file with 72 additions and 5 deletions.
77 changes: 72 additions & 5 deletions tests/cases/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,20 @@ def forward(self, a, b):
return d_pred


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda:0",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
def test_loss_drops(tmpdir):
def test_loss_drops(tmpdir, device):
checkpoint_basename = str(tmpdir / "model")

a_key = ArrayKey("A")
Expand All @@ -80,7 +92,7 @@ def test_loss_drops(tmpdir):

model = ExampleLinearModel()
loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.999)

source = example_train_source(a_key, b_key, c_key)
train = Train(
Expand All @@ -98,6 +110,7 @@ def test_loss_drops(tmpdir):
checkpoint_basename=checkpoint_basename,
save_every=100,
spawn_subprocess=False,
device=device,
)
pipeline = source + train

Expand Down Expand Up @@ -130,8 +143,25 @@ def test_loss_drops(tmpdir):
assert loss2 < loss1


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda:0",
marks=[
pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
pytest.mark.xfail(
reason="failing to move model to device when using a subprocess"
),
],
),
],
)
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
def test_output():
def test_output(device):
logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO)

a_key = ArrayKey("A")
Expand All @@ -153,6 +183,7 @@ def test_output():
d_pred: ArraySpec(nonspatial=True),
},
spawn_subprocess=True,
device=device,
)
pipeline = source + predict

Expand Down Expand Up @@ -191,8 +222,25 @@ def forward(self, a):
return pred


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda:0",
marks=[
pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
pytest.mark.xfail(
reason="failing to move model to device in multiprocessing context"
),
],
),
],
)
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
def test_scan():
def test_scan(device):
logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO)

a_key = ArrayKey("A")
Expand All @@ -210,6 +258,7 @@ def test_scan():
inputs={"a": a_key},
outputs={0: pred},
array_specs={pred: ArraySpec()},
device=device,
)
pipeline = source + predict + Scan(reference_request, num_workers=2)

Expand All @@ -226,8 +275,25 @@ def test_scan():
assert pred in batch


@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda:0",
marks=[
pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
pytest.mark.xfail(
reason="failing to move model to device in multiprocessing context"
),
],
),
],
)
@skipIf(isinstance(torch, NoSuchModule), "torch is not installed")
def test_precache():
def test_precache(device):
logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO)

a_key = ArrayKey("A")
Expand All @@ -245,6 +311,7 @@ def test_precache():
inputs={"a": a_key},
outputs={0: pred},
array_specs={pred: ArraySpec()},
device=device,
)
pipeline = source + predict + PreCache(cache_size=3, num_workers=2)

Expand Down

0 comments on commit b6c425f

Please sign in to comment.