Skip to content

Commit

Permalink
add support for args as inputs to predict.py
Browse files Browse the repository at this point in the history
Its often not so straightforward to know the key word argument name for the forward function of your model. Especially if you use something like `torch.nn.Sequential`
  • Loading branch information
pattonw committed May 30, 2024
1 parent b5ceb57 commit 832e475
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions gunpowder/torch/nodes/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class Predict(GenericPredict):
The model to use for prediction.
inputs (``dict``, ``string`` -> :class:`ArrayKey`):
inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Dictionary from the names of input tensors (argument names of the
``forward`` method) in the model to array keys.
Dictionary from the position (for args) and names (for kwargs) of input
tensors (argument names of the ``forward`` method) in the model to array keys.
outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Expand Down Expand Up @@ -58,7 +58,7 @@ class Predict(GenericPredict):
def __init__(
self,
model,
inputs: Dict[str, ArrayKey],
inputs: Dict[Union[str, int], ArrayKey],
outputs: Dict[Union[str, int], ArrayKey],
array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None,
checkpoint: Optional[str] = None,
Expand Down Expand Up @@ -111,18 +111,24 @@ def start(self):
self.register_hooks()

def predict(self, batch, request):
inputs = self.get_inputs(batch)
input_args, input_kwargs = self.get_inputs(batch)
with torch.no_grad():
out = self.model.forward(**inputs)
out = self.model.forward(*input_args, **input_kwargs)
outputs = self.get_outputs(out, request)
self.update_batch(batch, request, outputs)

def get_inputs(self, batch):
model_inputs = {
model_args = [
torch.as_tensor(batch[self.inputs[ii]].data, device=self.device)
for ii in range(len(self.inputs))
if ii in self.inputs
]
model_kwargs = {
key: torch.as_tensor(batch[value].data, device=self.device)
for key, value in self.inputs.items()
if isinstance(key, str)
}
return model_inputs
return model_args, model_kwargs

def register_hooks(self):
for key in self.outputs:
Expand Down

0 comments on commit 832e475

Please sign in to comment.