diff --git a/skorch/callbacks/logging.py b/skorch/callbacks/logging.py index 8ebef45eb..e07032934 100644 --- a/skorch/callbacks/logging.py +++ b/skorch/callbacks/logging.py @@ -669,7 +669,7 @@ def __getstate__(self): # don't save away the temporary pbar_ object which gets created on # epoch begin anew anyway. This avoids pickling errors with tqdm. state = self.__dict__.copy() - del state['pbar_'] + state.pop('pbar_', None) return state diff --git a/skorch/tests/callbacks/test_logging.py b/skorch/tests/callbacks/test_logging.py index 5b79c6cc8..088f1c1cb 100644 --- a/skorch/tests/callbacks/test_logging.py +++ b/skorch/tests/callbacks/test_logging.py @@ -770,6 +770,18 @@ def test_pickle(self, net_cls, progressbar_cls, data): net = pickle.loads(dump) net.fit(*data) + def test_pickle_without_fit(self, net_cls, progressbar_cls, data): + # pickling should work even if the net hasn't been fit. + # see https://github.com/skorch-dev/skorch/pull/1034. + import pickle + + net = net_cls(callbacks=[ + progressbar_cls(), + ]) + dump = pickle.dumps(net) + + net = pickle.loads(dump) + @pytest.mark.skipif( not tensorboard_installed, reason='tensorboard is not installed')