diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d1b2535c4c..c949d76abe 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -523,7 +523,7 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) + self.runner(attrs).test_forward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) @@ -531,7 +531,7 @@ def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) + self.runner(attrs).test_backward(data_shape, dtype) class TestEncoderLayer(BaseTester):