From d1f76513e666bce75915a9f3063c534c91670433 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 29 May 2024 14:20:25 +0200 Subject: [PATCH] FIX NeuralNetBinaryClassifier with torch.compile Fixes #1057 NeuralNetBinaryClassifier was not working with torch.compile because the non-linearity was not correctly inferred. This inference depends on the instance type of the criterion. However, when using torch.compile, the criterion is wrapped, resulting in the isinstance check to miss. Now, we unwrap the criterion before checking the instance type. --- skorch/tests/test_net.py | 33 +++++++++++++++++++++++++++++++++ skorch/utils.py | 2 ++ 2 files changed, 35 insertions(+) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 7eec94e5a..17704e1bd 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -4159,3 +4159,36 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data): # compiled, we rely here on torch keeping this public attribute assert hasattr(net.module_, 'dynamo_ctx') assert hasattr(net.criterion_, 'dynamo_ctx') + + def test_binary_classifier_with_compile(self, data): + # issue 1057 the problem was that compile would wrap the optimizer, + # resulting in _infer_predict_nonlinearity to return the wrong result + # because of a failing isinstance check + from skorch import NeuralNetBinaryClassifier + + X, y = data[0], data[1].astype(np.float32) + + class MyNet(nn.Module): + def __init__(self): + super(MyNet, self).__init__() + self.linear = nn.Linear(20, 10) + self.output = nn.Linear(10, 1) + + def forward(self, input): + out = self.linear(input) + out = nn.functional.relu(out) + out = self.output(out) + return out.squeeze(-1) + + net = NeuralNetBinaryClassifier( + MyNet, + max_epochs=3, + compile=True, + ) + # check that no error is raised + net.fit(X, y) + + y_proba = net.predict_proba(X) + y_pred = net.predict(X) + assert y_proba.shape == (X.shape[0], 2) + assert y_pred.shape == (X.shape[0],) diff --git a/skorch/utils.py b/skorch/utils.py index 57ceaaa6f..de679ec35 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -660,6 +660,8 @@ def _infer_predict_nonlinearity(net): return _identity criterion = getattr(net, net._criteria[0] + '_') + # unwrap optimizer in case of torch.compile being used + criterion = getattr(criterion, '_orig_mod', criterion) if isinstance(criterion, CrossEntropyLoss): return partial(torch.softmax, dim=-1)