Skip to content

Commit

Permalink
Pytorch Train: let users specify model inputs as args instead of kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jun 14, 2024
1 parent fcdee74 commit 89a0354
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions gunpowder/torch/nodes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Train(GenericTrain):
The torch optimizer to use.
inputs (``dict``, ``string`` -> :class:`ArrayKey`):
inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`):
Dictionary from the names of input tensors (argument names of the
``forward`` method) in the model to array keys.
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
model,
loss,
optimizer,
inputs: Dict[str, ArrayKey],
inputs: Dict[Union[str, int], ArrayKey],
outputs: Dict[Union[int, str], ArrayKey],
loss_inputs: Dict[Union[int, str], ArrayKey],
gradients: Dict[Union[int, str], ArrayKey] = {},
Expand All @@ -112,11 +112,11 @@ def __init__(

# not yet implemented
gradients = gradients
all_inputs = {
k: v
for k, v in itertools.chain(inputs.items(), loss_inputs.items())
if v not in outputs.values()
}
loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()}
all_inputs = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()}
all_inputs.update(
{k: v for k, v in loss_inputs.items() if v not in outputs.values()}
)

super(Train, self).__init__(
all_inputs,
Expand Down Expand Up @@ -208,16 +208,22 @@ def start(self):

def train_step(self, batch, request):
inputs = self.__collect_provided_inputs(batch)
inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()}
requested_outputs = self.__collect_requested_outputs(request)

# keys are argument names of model forward pass
device_inputs = {
k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()
}
device_input_args = []
for i in range(len(inputs)):
key = f"{i}"
if key in inputs:
device_input_args.append(inputs.pop(key))
else:
break
device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)}

# get outputs. Keys are tuple indices or model attr names as in self.outputs
self.optimizer.zero_grad()
model_outputs = self.model(**device_inputs)
model_outputs = self.model(*device_input_args, **device_input_kwargs)
if isinstance(model_outputs, tuple):
outputs = {i: model_outputs[i] for i in range(len(model_outputs))}
elif isinstance(model_outputs, torch.Tensor):
Expand Down Expand Up @@ -247,8 +253,9 @@ def train_step(self, batch, request):

device_loss_args = []
for i in range(len(device_loss_inputs)):
if i in device_loss_inputs:
device_loss_args.append(device_loss_inputs.pop(i))
key = f"loss_{i}"
if key in device_loss_inputs:
device_loss_args.append(device_loss_inputs.pop(key))
else:
break
device_loss_kwargs = {}
Expand Down Expand Up @@ -327,7 +334,12 @@ def __collect_requested_outputs(self, request):

def __collect_provided_inputs(self, batch):
return self.__collect_provided_arrays(
{k: v for k, v in self.inputs.items() if k not in self.loss_inputs}, batch
{
k: v
for k, v in self.inputs.items()
if (isinstance(k, int) or k not in self.loss_inputs)
},
batch,
)

def __collect_provided_loss_inputs(self, batch):
Expand All @@ -353,8 +365,9 @@ def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=Fals
arrays[array_name] = getattr(batch, array_key)
else:
raise Exception(
"Unknown network array key {}, can't be given to "
"network".format(array_key)
"Unknown network array key {}, can't be given to " "network".format(
array_key
)
)

return arrays

0 comments on commit 89a0354

Please sign in to comment.