diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 585ebc34..89c9ac0c 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -85,10 +85,13 @@ def __init__( self.intermediate_layers: dict[ArrayKey, Any] = {} def start(self): - self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda" - logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue #188 + self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__( + "cuda" + ) + logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") + self.device = torch.device(self.device_string if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) except RuntimeError as e: diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 3f688929..e913d8ad 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -79,6 +79,12 @@ class Train(GenericTrain): spawn_subprocess (``bool``, optional): Whether to run the ``train_step`` in a separate process. Default is false. + + device (``str``, optional): + + Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems. + defaults to ``cuda`` + """ def __init__( @@ -93,9 +99,10 @@ def __init__( array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, - log_dir: Optional[str] = None, + log_dir: str = None, log_every: int = 1, spawn_subprocess: bool = False, + device: str = "cuda", ): if not model.training: logger.warning( @@ -125,6 +132,7 @@ def __init__( self.loss_inputs = loss_inputs self.checkpoint_basename = checkpoint_basename self.save_every = save_every + self.dev = device self.iteration = 0 @@ -167,7 +175,8 @@ def retain_gradients(self, request, outputs): def start(self): self.use_cuda = torch.cuda.is_available() - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue: #188 + self.device = torch.device(self.dev if self.use_cuda else "cpu") try: self.model = self.model.to(self.device)