-
-
Notifications
You must be signed in to change notification settings - Fork 624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed failing tests for mps device #3143
Conversation
Apply Sweep Rules to your PR?
|
@@ -22,6 +22,8 @@ def test_no_distrib(capsys): | |||
assert idist.backend() is None | |||
if torch.cuda.is_available(): | |||
assert idist.device().type == "cuda" | |||
elif torch.backends.mps.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to put a guard _torch_version_le_112
here as we have also tests for older pytorch version where mps backend does not exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, to run the test, you need to remove @pytest.mark.skipif
@@ -43,6 +45,8 @@ def test_no_distrib(capsys): | |||
assert "ignite.distributed.utils INFO: backend: None" in out[-1] | |||
if torch.cuda.is_available(): | |||
assert "ignite.distributed.utils INFO: device: cuda" in out[-1] | |||
elif torch.backends.mps.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
def forward(self, x, bias=None): | ||
if bias is None: | ||
bias = 0.0 | ||
def forward(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert this change
@@ -69,8 +66,8 @@ def get_first_element(output): | |||
optimizer = SGD(model.parameters(), 0.1) | |||
|
|||
if trace: | |||
example_inputs = (torch.randn(1), torch.randn(1)) if with_model_fn else torch.randn(1) | |||
model = torch.jit.trace(model, example_inputs) | |||
example_input = torch.randn(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, let's revert this. Probably, you need merge origin/master
to your branch.
Also, it is a good practice to work on git branches and not on master
(pranavvp16:master)
Draft PR for failing mps tests.