Skip to content

Commit

Permalink
Remove struct and leverage ml_dtypes in helper tests (onnx#6631)
Browse files Browse the repository at this point in the history
Remove the use of struct and leverage ml_dtypes to clean up helper
tests. This should unblock onnx#6621

---------

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: seungwoo-ji <seungwoo.ji@nuvilab.com>
  • Loading branch information
justinchuby authored and seungwoo-ji-03 committed Feb 17, 2025
1 parent ffe12b7 commit a59a8ca
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 131 deletions.
137 changes: 41 additions & 96 deletions onnx/test/helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import itertools
import math
import random
import struct
import unittest
from typing import Any

import ml_dtypes
import numpy as np
import parameterized
import pytest
Expand All @@ -29,8 +29,6 @@
helper,
numpy_helper,
)
from onnx.reference.op_run import to_array_extended
from onnx.reference.ops.op_cast import Cast_19 as Cast


class TestHelperAttributeFunctions(unittest.TestCase):
Expand Down Expand Up @@ -451,34 +449,6 @@ def test_make_bfloat16_tensor(self) -> None:
],
dtype=np.float32,
)
np_results = np.array(
[
[
struct.unpack("!f", bytes.fromhex("3F800000"))[0], # 1.0
struct.unpack("!f", bytes.fromhex("40000000"))[0],
], # 2.0
[
struct.unpack("!f", bytes.fromhex("40400000"))[0], # 3.0
struct.unpack("!f", bytes.fromhex("40800000"))[0],
], # 4.0
[
struct.unpack("!f", bytes.fromhex("3DCC0000"))[
0
], # round-to-nearest-even rounds down (0x8000)
struct.unpack("!f", bytes.fromhex("3DCC0000"))[0],
], # round-to-nearest-even rounds up (0x8000)
[
struct.unpack("!f", bytes.fromhex("3DCC0000"))[
0
], # round-to-nearest-even rounds down (0x7fff)
struct.unpack("!f", bytes.fromhex("3DCD0000"))[0],
], # round-to-nearest-even rounds up (0xCCCD)
[
struct.unpack("!f", bytes.fromhex("7FC00000"))[0], # NaN
struct.unpack("!f", bytes.fromhex("7F800000"))[0],
], # inf
]
)

tensor = helper.make_tensor(
name="test",
Expand All @@ -488,17 +458,17 @@ def test_make_bfloat16_tensor(self) -> None:
)
self.assertEqual(tensor.name, "test")
np.testing.assert_equal(
Cast.eval(np_results, to=TensorProto.BFLOAT16), # type: ignore[arg-type]
numpy_helper.to_array(tensor),
numpy_helper.to_array(tensor).view(np.uint16),
np_array.astype(ml_dtypes.bfloat16).view(np.uint16),
)

def test_make_float8e4m3fn_tensor(self) -> None:
y = helper.make_tensor(
"zero_point", TensorProto.FLOAT8E4M3FN, [5], [0, 0.5, 1, 50000, 10.1]
)
ynp = numpy_helper.to_array(y)
expected = np.array([0, 0.5, 1, 448, 10], dtype=np.float32)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E4M3FN), ynp) # type: ignore[arg-type]
expected = np.array([0, 0.5, 1, 448, 10], dtype=ml_dtypes.float8_e4m3fn)
np.testing.assert_equal(ynp.view(np.uint8), expected.view(np.uint8))

def test_make_float8e4m3fnuz_tensor(self) -> None:
y = helper.make_tensor(
Expand All @@ -508,16 +478,16 @@ def test_make_float8e4m3fnuz_tensor(self) -> None:
[0, 0.5, 1, 50000, 10.1, -0.00001, 0.00001],
)
ynp = numpy_helper.to_array(y)
expected = np.array([0, 0.5, 1, 240, 10, 0, 0], dtype=np.float32)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E4M3FNUZ), ynp) # type: ignore[arg-type]
expected = np.array([0, 0.5, 1, 240, 10, 0, 0], dtype=ml_dtypes.float8_e4m3fnuz)
np.testing.assert_equal(ynp.view(np.uint8), expected.view(np.uint8))

def test_make_float8e5m2_tensor(self) -> None:
y = helper.make_tensor(
"zero_point", TensorProto.FLOAT8E5M2, [5], [0, 0.5, 1, 50000, 96]
)
ynp = numpy_helper.to_array(y)
expected = np.array([0, 0.5, 1, 49152, 96], dtype=np.float32)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E5M2), ynp) # type: ignore[arg-type]
expected = np.array([0, 0.5, 1, 49152, 96], dtype=ml_dtypes.float8_e5m2)
np.testing.assert_equal(ynp.view(np.uint8), expected.view(np.uint8))

def test_make_float8e5m2fnuz_tensor(self) -> None:
y = helper.make_tensor(
Expand All @@ -527,69 +497,31 @@ def test_make_float8e5m2fnuz_tensor(self) -> None:
[0, 0.5, 1, 50000, 96, -0.0000001, 0.0000001],
)
ynp = numpy_helper.to_array(y)
expected = np.array([0, 0.5, 1, 49152, 96, 0, 0], dtype=np.float32)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E5M2FNUZ), ynp) # type: ignore[arg-type]
expected = np.array(
[0, 0.5, 1, 49152, 96, 0, 0], dtype=ml_dtypes.float8_e5m2fnuz
)
np.testing.assert_equal(ynp.view(np.uint8), expected.view(np.uint8))

@unittest.skipIf(
version_utils.numpy_older_than("1.26.0"),
"The test requires numpy 1.26.0 or later",
)
def test_make_bfloat16_tensor_raw(self) -> None:
# numpy doesn't support bf16, so we have to compute the correct result manually
np_array = np.array(
array = np.array(
[
[1.0, 2.0],
[3.0, 4.0],
[0.099853515625, 0.099365234375],
[0.0998535081744, 0.1],
[np.nan, np.inf],
],
dtype=np.float32,
)
np_results = np.array(
[
[
struct.unpack("!f", bytes.fromhex("3F800000"))[0], # 1.0
struct.unpack("!f", bytes.fromhex("40000000"))[0],
], # 2.0
[
struct.unpack("!f", bytes.fromhex("40400000"))[0], # 3.0
struct.unpack("!f", bytes.fromhex("40800000"))[0],
], # 4.0
[
struct.unpack("!f", bytes.fromhex("3DCC0000"))[0], # truncated
struct.unpack("!f", bytes.fromhex("3DCB0000"))[0],
], # truncated
[
struct.unpack("!f", bytes.fromhex("3DCC0000"))[0], # truncated
struct.unpack("!f", bytes.fromhex("3DCC0000"))[0],
], # truncated
[
struct.unpack("!f", bytes.fromhex("7FC00000"))[0], # NaN
struct.unpack("!f", bytes.fromhex("7F800000"))[0],
], # inf
]
)

# write out 16-bit of fp32 to create bf16 using truncation, no rounding
def truncate(x):
return x >> 16

values_as_ints = np_array.astype(np.float32).view(np.uint32).flatten()
packed_values = truncate(values_as_ints).astype(np.uint16).tobytes()
dtype=ml_dtypes.bfloat16,
).view(np.uint16)

tensor = helper.make_tensor(
name="test",
data_type=TensorProto.BFLOAT16,
dims=np_array.shape,
vals=packed_values,
dims=array.shape,
vals=array.tobytes(),
raw=True,
)
self.assertEqual(tensor.name, "test")
np.testing.assert_allclose(
Cast.eval(np_results, to=TensorProto.BFLOAT16), # type: ignore[arg-type]
numpy_helper.to_array(tensor),
rtol=1e-4, # truncate is not nearest even rounding
)
np.testing.assert_allclose(numpy_helper.to_array(tensor).view(np.uint16), array)

def test_make_float8e4m3fn_tensor_raw(self) -> None:
expected = np.array([0, 0.5, 1, 448, 10], dtype=np.float32)
Expand All @@ -605,7 +537,9 @@ def test_make_float8e4m3fn_tensor_raw(self) -> None:
raw=True,
)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E4M3FN), ynp) # type: ignore[arg-type]
np.testing.assert_equal(
ynp.view(np.uint8), expected.astype(ml_dtypes.float8_e4m3fn).view(np.uint8)
)

def test_make_float8e4m3fnuz_tensor_raw(self) -> None:
expected = np.array([0, 0.5, 1, 240, 10], dtype=np.float32)
Expand All @@ -621,7 +555,10 @@ def test_make_float8e4m3fnuz_tensor_raw(self) -> None:
raw=True,
)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E4M3FNUZ), ynp) # type: ignore[arg-type]
np.testing.assert_equal(
ynp.view(np.uint8),
expected.astype(ml_dtypes.float8_e4m3fnuz).view(np.uint8),
)

def test_make_float8e5m2_tensor_raw(self) -> None:
expected = np.array([0, 0.5, 1, 49152, 10], dtype=np.float32)
Expand All @@ -637,7 +574,9 @@ def test_make_float8e5m2_tensor_raw(self) -> None:
raw=True,
)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E5M2), ynp) # type: ignore[arg-type]
np.testing.assert_equal(
ynp.view(np.uint8), expected.astype(ml_dtypes.float8_e5m2).view(np.uint8)
)

def test_make_float8e5m2fnuz_tensor_raw(self) -> None:
expected = np.array([0, 0.5, 1, 49152, 10], dtype=np.float32)
Expand All @@ -654,7 +593,10 @@ def test_make_float8e5m2fnuz_tensor_raw(self) -> None:
raw=True,
)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT8E5M2FNUZ), ynp) # type: ignore[arg-type]
np.testing.assert_equal(
ynp.view(np.uint8),
expected.astype(ml_dtypes.float8_e5m2fnuz).view(np.uint8),
)

@parameterized.parameterized.expand(
itertools.product(
Expand Down Expand Up @@ -682,7 +624,7 @@ def test_make_4bit_tensor(self, dtype, dims) -> None:
np.testing.assert_equal(actual_data_size, expected_data_size)

# Check the expected data values.
ynp = to_array_extended(y)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(data, ynp)

@parameterized.parameterized.expand(
Expand Down Expand Up @@ -736,6 +678,7 @@ def test_make_4bit_raw_tensor(self, dtype, dims) -> None:
def test_make_float4e2m1_raw_tensor(self) -> None:
data = np.array([0, 0.5, 1, 240, 10, -2], dtype=np.float32)
packed_data = helper.pack_float32_to_float4e2m1(data)
expected = data.astype(ml_dtypes.float4_e2m1fn).view(np.uint8)
y = helper.make_tensor(
"packed_fp4e2m1",
TensorProto.FLOAT4E2M1,
Expand All @@ -744,7 +687,7 @@ def test_make_float4e2m1_raw_tensor(self) -> None:
raw=True,
)
ynp = numpy_helper.to_array(y)
np.testing.assert_equal(Cast.eval(data, to=TensorProto.FLOAT4E2M1), ynp) # type: ignore[arg-type]
np.testing.assert_equal(ynp, expected)

def test_make_float4e2m1_tensor(self) -> None:
y = helper.make_tensor(
Expand All @@ -754,8 +697,10 @@ def test_make_float4e2m1_tensor(self) -> None:
[0, 0.5, 1, 50000, -0.6, -100, -5],
)
ynp = numpy_helper.to_array(y)
expected = np.array([0, 0.5, 1, 6, -0.5, -6, -4], dtype=np.float32)
np.testing.assert_equal(Cast.eval(expected, to=TensorProto.FLOAT4E2M1), ynp) # type: ignore[arg-type]
expected = np.array(
[0, 0.5, 1, 6, -0.5, -6, -4], dtype=ml_dtypes.float4_e2m1fn
).view(np.uint8)
np.testing.assert_equal(ynp, expected)

def test_make_sparse_tensor(self) -> None:
values = [1.1, 2.2, 3.3, 4.4, 5.5]
Expand Down
65 changes: 30 additions & 35 deletions onnx/test/numpy_helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
from typing import Any

import ml_dtypes
import numpy as np
import parameterized

Expand All @@ -15,17 +16,7 @@


def bfloat16_to_float32(ival: int) -> Any:
if ival == 0x7FC0:
return np.float32(np.nan)

expo = ival >> 7
prec = ival - (expo << 7)
sign = expo & 256
powe = expo & 255
fval = float(prec * 2 ** (-7) + 1) * 2.0 ** (powe - 127)
if sign:
fval = -fval
return np.float32(fval)
return np.int16(ival).view(ml_dtypes.bfloat16).astype(np.float32)


def float8e4m3_to_float32(ival: int) -> Any:
Expand Down Expand Up @@ -625,30 +616,34 @@ def _to_array_from_array(self, value: int, check_dtype: bool = True):
self.assertEqual(tp.SerializeToString(), again.SerializeToString())
self.assertEqual(tp.data_type, helper.np_dtype_to_tensor_dtype(back.dtype))

@parameterized.parameterized.expand([(att,) for att in dir(onnx.TensorProto)])
def test_to_array_from_array(self, att):
if att in {
"INT4",
"UINT4",
"STRING",
"UNDEFINED",
"DEFAULT",
"NAME_FIELD_NUMBER",
"FLOAT4E2M1",
}:
return
if att[0] < "A" or att[0] > "Z":
return
value = getattr(onnx.TensorProto, att)
if not isinstance(value, int):
return

self._to_array_from_array(value)

def test_to_array_from_array_subtype(self):
self._to_array_from_array(onnx.TensorProto.INT4)
self._to_array_from_array(onnx.TensorProto.UINT4)
self._to_array_from_array(onnx.TensorProto.FLOAT4E2M1)
@parameterized.parameterized.expand(
[
("FLOAT", onnx.TensorProto.FLOAT),
("UINT8", onnx.TensorProto.UINT8),
("INT8", onnx.TensorProto.INT8),
("UINT16", onnx.TensorProto.UINT16),
("INT16", onnx.TensorProto.INT16),
("INT32", onnx.TensorProto.INT32),
("INT64", onnx.TensorProto.INT64),
("BOOL", onnx.TensorProto.BOOL),
("FLOAT16", onnx.TensorProto.FLOAT16),
("DOUBLE", onnx.TensorProto.DOUBLE),
("UINT32", onnx.TensorProto.UINT32),
("UINT64", onnx.TensorProto.UINT64),
("COMPLEX64", onnx.TensorProto.COMPLEX64),
("COMPLEX128", onnx.TensorProto.COMPLEX128),
("BFLOAT16", onnx.TensorProto.BFLOAT16),
("FLOAT8E4M3FN", onnx.TensorProto.FLOAT8E4M3FN),
("FLOAT8E4M3FNUZ", onnx.TensorProto.FLOAT8E4M3FNUZ),
("FLOAT8E5M2", onnx.TensorProto.FLOAT8E5M2),
("FLOAT8E5M2FNUZ", onnx.TensorProto.FLOAT8E5M2FNUZ),
("UINT4", onnx.TensorProto.UINT4),
("INT4", onnx.TensorProto.INT4),
("FLOAT4E2M1", onnx.TensorProto.FLOAT4E2M1),
]
)
def test_to_array_from_array(self, _: str, data_type: onnx.TensorProto.DataType):
self._to_array_from_array(data_type)

def test_to_array_from_array_string(self):
self._to_array_from_array(onnx.TensorProto.STRING, False)
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pytest-xdist
setuptools
twine
wheel
ml_dtypes
# Dependencies for linting. Versions match those in setup.py.
lintrunner>=0.10.7
lintrunner-adapters>=0.12.3
Expand Down
1 change: 1 addition & 0 deletions requirements-release.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build
ipython
ml_dtypes
nbval
numpy==1.24.3; python_version<"3.9"
numpy==2.0.0; python_version >= "3.9" and python_version <= "3.12"
Expand Down

0 comments on commit a59a8ca

Please sign in to comment.