Skip to content

Commit

Permalink
Reviewer feedback: One more test w/o monkeypatch
Browse files Browse the repository at this point in the history
Instead, rely on the installed torch version and skip if it doesn't fit.
  • Loading branch information
BenjaminBossan committed Sep 2, 2024
1 parent 9acfb84 commit ab9c536
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3064,6 +3064,32 @@ def test_torch_load_kwargs_forwarded_to_torch_load(
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params
# but without monkeypatching get_torch_load_kwargs. There is no corresponding
# test for >= 2.6.0 since it's not clear yet if the switch will be made in that
# version.
# See discussion in 1063.
from skorch._version import Version

if Version(torch.__version__) >= Version('2.6.0'):
pytest.skip("Test only for torch < v2.6.0")

net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_custom_module_params_passed_to_optimizer(
self, net_custom_module_cls, module_cls):
# custom module parameters should automatically be passed to the optimizer
Expand Down

0 comments on commit ab9c536

Please sign in to comment.