diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index d54d3f9e..b4f056c7 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -3064,6 +3064,32 @@ def test_torch_load_kwargs_forwarded_to_torch_load( del call_kwargs['map_location'] # we're not interested in that assert call_kwargs == expected_kwargs + def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params + # but without monkeypatching get_torch_load_kwargs. There is no corresponding + # test for >= 2.6.0 since it's not clear yet if the switch will be made in that + # version. + # See discussion in 1063. + from skorch._version import Version + + if Version(torch.__version__) >= Version('2.6.0'): + pytest.skip("Test only for torch < v2.6.0") + + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": False} + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + def test_custom_module_params_passed_to_optimizer( self, net_custom_module_cls, module_cls): # custom module parameters should automatically be passed to the optimizer