diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 676b2c71..d5be0f9a 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -29,7 +29,7 @@ class Train(GenericTrain): The torch optimizer to use. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary from the names of input tensors (argument names of the ``forward`` method) in the model to array keys. @@ -92,7 +92,7 @@ def __init__( model, loss, optimizer, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[int, str], ArrayKey], loss_inputs: Dict[Union[int, str], ArrayKey], gradients: Dict[Union[int, str], ArrayKey] = {}, @@ -112,11 +112,11 @@ def __init__( # not yet implemented gradients = gradients - all_inputs = { - k: v - for k, v in itertools.chain(inputs.items(), loss_inputs.items()) - if v not in outputs.values() - } + loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} + all_inputs = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} + all_inputs.update( + {k: v for k, v in loss_inputs.items() if v not in outputs.values()} + ) super(Train, self).__init__( all_inputs, @@ -208,16 +208,22 @@ def start(self): def train_step(self, batch, request): inputs = self.__collect_provided_inputs(batch) + inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()} requested_outputs = self.__collect_requested_outputs(request) # keys are argument names of model forward pass - device_inputs = { - k: torch.as_tensor(v, device=self.device) for k, v in inputs.items() - } + device_input_args = [] + for i in range(len(inputs)): + key = f"{i}" + if key in inputs: + device_input_args.append(inputs.pop(key)) + else: + break + device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)} # get outputs. Keys are tuple indices or model attr names as in self.outputs self.optimizer.zero_grad() - model_outputs = self.model(**device_inputs) + model_outputs = self.model(*device_input_args, **device_input_kwargs) if isinstance(model_outputs, tuple): outputs = {i: model_outputs[i] for i in range(len(model_outputs))} elif isinstance(model_outputs, torch.Tensor): @@ -247,8 +253,9 @@ def train_step(self, batch, request): device_loss_args = [] for i in range(len(device_loss_inputs)): - if i in device_loss_inputs: - device_loss_args.append(device_loss_inputs.pop(i)) + key = f"loss_{i}" + if key in device_loss_inputs: + device_loss_args.append(device_loss_inputs.pop(key)) else: break device_loss_kwargs = {} @@ -327,7 +334,12 @@ def __collect_requested_outputs(self, request): def __collect_provided_inputs(self, batch): return self.__collect_provided_arrays( - {k: v for k, v in self.inputs.items() if k not in self.loss_inputs}, batch + { + k: v + for k, v in self.inputs.items() + if (isinstance(k, int) or k not in self.loss_inputs) + }, + batch, ) def __collect_provided_loss_inputs(self, batch): @@ -353,8 +365,9 @@ def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=Fals arrays[array_name] = getattr(batch, array_key) else: raise Exception( - "Unknown network array key {}, can't be given to " - "network".format(array_key) + "Unknown network array key {}, can't be given to " "network".format( + array_key + ) ) return arrays