Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 1686b949766b76960534ede1105751591fd91c9f
Author: William Patton <wllmpttn24@gmail.com>
Date:   Tue Dec 19 08:43:11 2023 -0700

    black reformatting

commit 26d2c7cfff3f2702f56a5bb4249a0811f54b45ef
Author: Mohinta2892 <samiamohinta2892@gmail.com>
Date:   Thu Nov 2 19:09:15 2023 +0000

    Revert "black reformatted"

    This reverts commit 66dd69b.

    Only format changed files, since black does not consider formatting history

commit a273fd3813fc16b516c2438ad5af0c4ee3f0686b
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 17:12:26 2023 +0000

    black reformatted

commit bb37769eec33af5921386f283e2579055bb34e6d
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 16:40:32 2023 +0000

    add device arg

    Allow passing cuda device to Predict. Issue #188

commit a3b3588a1406d609ae95370cf2c5339872616011
Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com>
Date:   Thu Nov 2 16:39:09 2023 +0000

    add device arg

    allow passing cuda device to Train
  • Loading branch information
pattonw committed Dec 19, 2023
1 parent 35fcd43 commit 076661f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
9 changes: 6 additions & 3 deletions gunpowder/torch/nodes/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ def __init__(
self.intermediate_layers: dict[ArrayKey, Any] = {}

def start(self):
self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda"
logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}")
self.device = torch.device("cuda" if self.use_cuda else "cpu")
# Issue #188
self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__(
"cuda"
)

logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}")
self.device = torch.device(self.device_string if self.use_cuda else "cpu")
try:
self.model = self.model.to(self.device)
except RuntimeError as e:
Expand Down
13 changes: 11 additions & 2 deletions gunpowder/torch/nodes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class Train(GenericTrain):
spawn_subprocess (``bool``, optional):
Whether to run the ``train_step`` in a separate process. Default is false.
device (``str``, optional):
Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems.
defaults to ``cuda``
"""

def __init__(
Expand All @@ -93,9 +99,10 @@ def __init__(
array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None,
checkpoint_basename: str = "model",
save_every: int = 2000,
log_dir: Optional[str] = None,
log_dir: str = None,
log_every: int = 1,
spawn_subprocess: bool = False,
device: str = "cuda",
):
if not model.training:
logger.warning(
Expand Down Expand Up @@ -125,6 +132,7 @@ def __init__(
self.loss_inputs = loss_inputs
self.checkpoint_basename = checkpoint_basename
self.save_every = save_every
self.dev = device

self.iteration = 0

Expand Down Expand Up @@ -167,7 +175,8 @@ def retain_gradients(self, request, outputs):

def start(self):
self.use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if self.use_cuda else "cpu")
# Issue: #188
self.device = torch.device(self.dev if self.use_cuda else "cpu")

try:
self.model = self.model.to(self.device)
Expand Down

0 comments on commit 076661f

Please sign in to comment.