Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit bb625d4

Browse files
author
Sara Adkins
committed
update tests
1 parent 0f1a839 commit bb625d4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/sparseml/transformers/compression/test_fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TestQuantizationMatches(unittest.TestCase):
4949
dataset = "ultrachat-200k"
5050
output = "tiny_llama_out"
5151
max_seq_length = 512
52-
weight_dtype = torch.bfloat16
52+
weight_dtype = torch.float16
5353
num_eval = 64
5454

5555
@classmethod
@@ -127,7 +127,7 @@ def test_quantization_reload(self):
127127
n_scale, n_zp, n_weight = reloaded_weights[name]
128128
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
129129
assert torch.equal(o_scale, n_scale)
130-
assert o_zp.dtype == n_zp.dtype == self.weight_dtype
130+
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
131131
assert torch.equal(o_zp, n_zp)
132132

133133
# we don't expect an exact match here because o_weight still has the
@@ -138,7 +138,7 @@ def test_quantization_reload(self):
138138
n_scale, n_zp = reloaded_inputs[name]
139139
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
140140
assert torch.equal(o_scale, n_scale)
141-
assert o_zp.dtype == n_zp.dtype == self.weight_dtype
141+
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
142142
assert torch.equal(o_zp, n_zp)
143143

144144
def _get_dataloader(self, data_args, tokenizer):

0 commit comments

Comments
 (0)