Skip to content

Commit cd08c87

Browse files
committed
Fixed tests in CI by skipping if MPS unavailable
1 parent 03103f1 commit cd08c87

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/unit/test_stft.py

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def test_inverse_output_shape(self):
150150
# Check if the output tensor has the expected shape
151151
self.assertEqual(output_tensor.shape, expected_shape)
152152

153+
@unittest.skipIf(not torch.backends.mps.is_available(), "MPS not available")
153154
def test_stft_with_mps_device(self):
154155
mps_device = torch.device("mps")
155156
self.stft.device = mps_device
@@ -158,6 +159,7 @@ def test_stft_with_mps_device(self):
158159
self.assertIsNotNone(stft_result)
159160
self.assertIsInstance(stft_result, torch.Tensor)
160161

162+
@unittest.skipIf(not torch.backends.mps.is_available(), "MPS not available")
161163
def test_inverse_with_mps_device(self):
162164
mps_device = torch.device("mps")
163165
self.stft.device = mps_device

0 commit comments

Comments
 (0)