diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index d426d2a8..3e5ba8f1 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -104,7 +104,7 @@ def start(self): if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: - self.model.load_state_dict() + self.model.load_state_dict(checkpoint) def predict(self, batch, request): inputs = self.get_inputs(batch)