Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch 2.4.0 to CI #1063

Merged
merged 2 commits into from
Sep 20, 2024
Merged

Add PyTorch 2.4.0 to CI #1063

merged 2 commits into from
Sep 20, 2024

Conversation

BenjaminBossan
Copy link
Collaborator

Also:

Also:

- Remove 2.0.1
- Updgrade 2.3.0 to 2.3.1
- Use index https://download.pytorch.org/whl/torch, as torch_stable does
  not have 2.4.0 (yet?)
@BenjaminBossan
Copy link
Collaborator Author

Okay, so I investigated the failing tests with PyTorch 2.4 a bit further. Right now, we get a warning:

FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

The reason why this warning leads to an error is pure coincidence: In the given tests, we have a filter to catch warnings for different reasons, which is what is triggered by this new FutureWarning:

with warnings.catch_warnings():
# ensure that there is *no* warning, especially not about setting
# the device because it is None
warnings.simplefilter("error")
net_loaded.load_params(f_params=f_name)

Anyway, it's good that we got an early indicator that this will break in the future. However, fixing the problem is not trivial. Here is why:

Since PyTorch 1.13, there is an option in torch.load called weights_only. If set to True, only a select few types can be loaded, with the intent of making torch.load more secure. As it uses pickle under the hood, there is the danger of malicious code being executed when loading torch checkpoints, so this is a sensible decision (note that alternatives like safetensors and gguf don't suffer from this).

Right now, the default is to set weights_only=False but as the warning indicates, this will be flipped to weights_only=True in the future. Here are some ideas to address this.

1. Defaulting to weights_only=True

My first idea was to fix the warning by switching to weights_only=True in skorch, where we use torch.load for deserialization. However, this makes a bunch of tests fail because they use types that are not considered to be secure.

As an example, each time we define a custom nn.Module, it is considered unsafe and we would have to call torch.serialization.add_safe_globals([MyModule]) if the test involves serialization. But that's not enough: Even builtin types like set and PyTorch types like nn.Linear are not considered secure, so all of these would have to be added too.

The latter could be done once inside of conftest and it would be fine but I really don't want to scatter the code with torch.serialization.add_safe_globals each time a custom class is defined. Moreover, if we make this switch, it means that a bunch of user code would start failing. Yes, this is bound to happen when PyTorch makes the switch, but still it's not a nice experience.

What's also annoying is that PyTorch reports these insecure types only one at a time, so we have to add them, run the tests again, get a new error, add the new type, etc.

2. Setting weights_only=False

We could hard-code this and thus ensure that all the code that used to work will continue working. Neither tests, nor user code would require adjusting. This also wouldn't be more insecure as the status quo, but it defeats the whole idea of making PyTorch more secure.

If we take this route, we should allow users to override this by exposing the argument.

3. Not setting anything in torch.load

I.e. just leaving the code as is and using whatever default is used by the installed PyTorch version. The failing test would still fail, but it could be fixed by excepting this FutureWarning. User code would work as normal. When the new PyTorch version with flipped defaults releases, users have to start dealing with this, same as other PyTorch users. Similarly, we will have to deal with this for skorch, same as discussed in solution 1.

For now, I have reported this internally to PyTorch devs, let's see what comes out of it.

Input by others would be appreciated @ottonemo @thomasjpfan

@thomasjpfan
Copy link
Member

In the long term, I'll want a way to allow weights_only=True even if it takes some time to get right with torch.serialization.add_safe_globals. For skorch, I propose:

  1. Use weights_only=False as the default
  2. Add torch_load_kwargs to NeuralNet.__init__ to allow user to override torch.load kwargs and set weights_only=True.

Concretely, whenever we call torch.load:

default_load_kwargs = {"weights_only": True}

torch_load_kwargs = {**default_load_kwargs, **self.torch_load_kwargs}
torch.load(..., **torch_load_kwargs)

@BenjaminBossan
Copy link
Collaborator Author

BenjaminBossan commented Aug 12, 2024

Thanks for the input, this sounds reasonable. It's not pretty, but since we cannot directly pass arguments to __setstate__, I don't see a better way.

As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?

I found that there is also a context manager torch.serializaton.safe_globals. For test-specific classes, we can use that, for the rest like set we can use add_safe_globals in conftest.py. Edit: This was only added recently, so it's not available for older releases.

Edit: Planned release is v2.6.0.

@ottonemo
Copy link
Member

As to the default: WDYT about using "auto" and then switching to whatever the default is for the given PyTorch version?

I like this. It would mean that we expose a way of handling model loading security to the user while keeping pytorch's defaults. Since this is a long-standing security issue I'd say we should at least follow the pytorch default as soon as they deem the ecosystem to be ready for it.

We could simply use the pytorch release version as a default indicator (might be better than using inspect?)

I assume that the need for a class variable for the load kwarg comes from the fact that we support pickling skorch models?

I found that there is also a context manager torch.serializaton.safe_globals. For test-specific classes, we can use that, for the rest like set we can use add_safe_globals in conftest.py.

I was going to say that it might be beneficial to have the tests look as close to user code where possible so that we have approximately the same issues (in terms of functionality but also in terms of 'design') as our users do. The context manager + whitelisting generic classes is probably a good middle-ground.

@thomasjpfan
Copy link
Member

I'm happy with an "auto" option.

BenjaminBossan added a commit that referenced this pull request Aug 22, 2024
See discussion in #1063

Starting from PyTorch 2.4, there is a warning when torch.load is called
without setting the weights_only argument. This is because in the
future, the default will switch from False to True, which can result in
a lot of errors when trying to load torch files (which are pickle files
and thus insecure).

In this PR, we add a possibility for the user to influence the kwargs
passed to torch.load so that they can control that behavior. If not
further indicated by the user, we will use the same defaults as the
installed torch version. Therefore, users will only encounter this issue
via skorch if they would have encountered it via torch anyway.

Since it's not 100% certain if the default will switch in torch 2.6.0,
we may have to adjust the version check in the future.

Besides directly testing the kwargs being passed on, a test was also
added that net.load_params does not give any warnings. This is already
indirectly tested through some accelerate tests that are currently
failing with torch 2.4, but it's better to have an explicit test.

After this is merged, the CI should pass when using torch 2.4.0.
ottonemo pushed a commit that referenced this pull request Sep 19, 2024
* Fix warning from torch.load starting in torch 2.4

See discussion in #1063

Starting from PyTorch 2.4, there is a warning when torch.load is called
without setting the weights_only argument. This is because in the
future, the default will switch from False to True, which can result in
a lot of errors when trying to load torch files (which are pickle files
and thus insecure).

In this PR, we add a possibility for the user to influence the kwargs
passed to torch.load so that they can control that behavior. If not
further indicated by the user, we will use the same defaults as the
installed torch version. Therefore, users will only encounter this issue
via skorch if they would have encountered it via torch anyway.

Since it's not 100% certain if the default will switch in torch 2.6.0,
we may have to adjust the version check in the future.

Besides directly testing the kwargs being passed on, a test was also
added that net.load_params does not give any warnings. This is already
indirectly tested through some accelerate tests that are currently
failing with torch 2.4, but it's better to have an explicit test.

After this is merged, the CI should pass when using torch 2.4.0.

* Reviewer feedback: return kwargs directly

* Reviewer feedback: One more test w/o monkeypatch

Instead, rely on the installed torch version and skip if it doesn't fit.

* Reviewer feedback: rename function, fix typo
@ottonemo ottonemo merged commit 9ff9cfa into master Sep 20, 2024
16 checks passed
@BenjaminBossan BenjaminBossan deleted the pytorch-2.4 branch September 20, 2024 10:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants