Skip to content

Commit

Permalink
FIX NeuralNetBinaryClassifier with torch.compile
Browse files Browse the repository at this point in the history
Fixes #1057

NeuralNetBinaryClassifier was not working with torch.compile because the
non-linearity was not correctly inferred. This inference depends on the
instance type of the criterion. However, when using torch.compile, the
criterion is wrapped, resulting in the isinstance check to miss. Now, we
unwrap the criterion before checking the instance type.
  • Loading branch information
BenjaminBossan committed May 29, 2024
1 parent dd341d3 commit d1f7651
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
33 changes: 33 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,3 +4159,36 @@ def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
# compiled, we rely here on torch keeping this public attribute
assert hasattr(net.module_, 'dynamo_ctx')
assert hasattr(net.criterion_, 'dynamo_ctx')

def test_binary_classifier_with_compile(self, data):
# issue 1057 the problem was that compile would wrap the optimizer,
# resulting in _infer_predict_nonlinearity to return the wrong result
# because of a failing isinstance check
from skorch import NeuralNetBinaryClassifier

X, y = data[0], data[1].astype(np.float32)

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.linear = nn.Linear(20, 10)
self.output = nn.Linear(10, 1)

def forward(self, input):
out = self.linear(input)
out = nn.functional.relu(out)
out = self.output(out)
return out.squeeze(-1)

net = NeuralNetBinaryClassifier(
MyNet,
max_epochs=3,
compile=True,
)
# check that no error is raised
net.fit(X, y)

y_proba = net.predict_proba(X)
y_pred = net.predict(X)
assert y_proba.shape == (X.shape[0], 2)
assert y_pred.shape == (X.shape[0],)
2 changes: 2 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ def _infer_predict_nonlinearity(net):
return _identity

criterion = getattr(net, net._criteria[0] + '_')
# unwrap optimizer in case of torch.compile being used
criterion = getattr(criterion, '_orig_mod', criterion)

if isinstance(criterion, CrossEntropyLoss):
return partial(torch.softmax, dim=-1)
Expand Down

0 comments on commit d1f7651

Please sign in to comment.