From 832e475cc6da6f3e3901a0120e038ef8da5f9835 Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 30 May 2024 09:57:53 -0700 Subject: [PATCH] add support for args as inputs to predict.py Its often not so straightforward to know the key word argument name for the forward function of your model. Especially if you use something like `torch.nn.Sequential` --- gunpowder/torch/nodes/predict.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 89c9ac0c..3bc58344 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -18,10 +18,10 @@ class Predict(GenericPredict): The model to use for prediction. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): - Dictionary from the names of input tensors (argument names of the - ``forward`` method) in the model to array keys. + Dictionary from the position (for args) and names (for kwargs) of input + tensors (argument names of the ``forward`` method) in the model to array keys. outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): @@ -58,7 +58,7 @@ class Predict(GenericPredict): def __init__( self, model, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[str, int], ArrayKey], array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint: Optional[str] = None, @@ -111,18 +111,24 @@ def start(self): self.register_hooks() def predict(self, batch, request): - inputs = self.get_inputs(batch) + input_args, input_kwargs = self.get_inputs(batch) with torch.no_grad(): - out = self.model.forward(**inputs) + out = self.model.forward(*input_args, **input_kwargs) outputs = self.get_outputs(out, request) self.update_batch(batch, request, outputs) def get_inputs(self, batch): - model_inputs = { + model_args = [ + torch.as_tensor(batch[self.inputs[ii]].data, device=self.device) + for ii in range(len(self.inputs)) + if ii in self.inputs + ] + model_kwargs = { key: torch.as_tensor(batch[value].data, device=self.device) for key, value in self.inputs.items() + if isinstance(key, str) } - return model_inputs + return model_args, model_kwargs def register_hooks(self): for key in self.outputs: