diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 89c9ac0c..3bc58344 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -18,10 +18,10 @@ class Predict(GenericPredict): The model to use for prediction. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): - Dictionary from the names of input tensors (argument names of the - ``forward`` method) in the model to array keys. + Dictionary from the position (for args) and names (for kwargs) of input + tensors (argument names of the ``forward`` method) in the model to array keys. outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): @@ -58,7 +58,7 @@ class Predict(GenericPredict): def __init__( self, model, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[str, int], ArrayKey], array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint: Optional[str] = None, @@ -111,18 +111,24 @@ def start(self): self.register_hooks() def predict(self, batch, request): - inputs = self.get_inputs(batch) + input_args, input_kwargs = self.get_inputs(batch) with torch.no_grad(): - out = self.model.forward(**inputs) + out = self.model.forward(*input_args, **input_kwargs) outputs = self.get_outputs(out, request) self.update_batch(batch, request, outputs) def get_inputs(self, batch): - model_inputs = { + model_args = [ + torch.as_tensor(batch[self.inputs[ii]].data, device=self.device) + for ii in range(len(self.inputs)) + if ii in self.inputs + ] + model_kwargs = { key: torch.as_tensor(batch[value].data, device=self.device) for key, value in self.inputs.items() + if isinstance(key, str) } - return model_inputs + return model_args, model_kwargs def register_hooks(self): for key in self.outputs: