Skip to content

Commit

Permalink
Merge pull request #463 from MannLabs/fix-torch-conflict-in-tests
Browse files Browse the repository at this point in the history
fix tests by adapting to changes in pytorch 2.6
  • Loading branch information
anna-charlotte authored Jan 30, 2025
2 parents 817495f + ca0ddae commit b0af1aa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion alphadia/workflow/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,9 @@ def load_classifier_store(self, path: None | str = None):

if classifier_hash not in self.classifier_store:
classifier = deepcopy(self.classifier_base)
classifier.from_state_dict(torch.load(os.path.join(path, file)))
classifier.from_state_dict(
torch.load(os.path.join(path, file), weights_only=False)
)
self.classifier_store[classifier_hash].append(classifier)

def get_classifier(self, available_columns: list, version: int = -1):
Expand Down
4 changes: 3 additions & 1 deletion tests/unit_tests/test_fdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def test_feed_forward_save():

new_classifier = fdrx.BinaryClassifierLegacy()
new_classifier.from_state_dict(
torch.load(os.path.join(tempfolder, "test_feed_forward_save.pth"))
torch.load(
os.path.join(tempfolder, "test_feed_forward_save.pth"), weights_only=False
)
)

y_pred = new_classifier.predict(x) # noqa: F841 # TODO fix this test
Expand Down

0 comments on commit b0af1aa

Please sign in to comment.