Skip to content

Commit

Permalink
Adding dtype specific tolerances (ivy-llc#21490)
Browse files Browse the repository at this point in the history
  • Loading branch information
mobley-trent authored Aug 23, 2023
1 parent 9878ab8 commit 8d1bd2a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
26 changes: 22 additions & 4 deletions ivy_tests/test_ivy/helpers/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def value_test(
ret_np_from_gt_flat,
rtol=None,
atol=1e-6,
specific_tolerance_dict=None,
backend: str,
ground_truth_backend="TensorFlow",
):
Expand All @@ -108,6 +109,8 @@ def value_test(
Relative Tolerance Value.
atol
Absolute Tolerance Value.
specific_tolerance_dict
(Optional) Dictionary of specific rtol and atol values according to the dtype.
ground_truth_backend
Ground Truth Backend Framework.
Expand All @@ -117,9 +120,9 @@ def value_test(
"""
assert_same_type_and_shape([ret_np_flat, ret_np_from_gt_flat])

if type(ret_np_flat) != list:
if type(ret_np_flat) != list: # noqa: E721
ret_np_flat = [ret_np_flat]
if type(ret_np_from_gt_flat) != list:
if type(ret_np_from_gt_flat) != list: # noqa: E721
ret_np_from_gt_flat = [ret_np_from_gt_flat]
assert len(
ret_np_flat
Expand All @@ -135,9 +138,23 @@ def value_test(
)
)
# value tests, iterating through each array in the flattened returns
if not rtol:
if specific_tolerance_dict is not None:
for ret_np, ret_np_from_gt in zip(ret_np_flat, ret_np_from_gt_flat):
dtype = str(ret_np_from_gt.dtype)
if specific_tolerance_dict.get(dtype) is not None:
rtol = specific_tolerance_dict.get(dtype)
else:
rtol = TOLERANCE_DICT.get(dtype, 1e-03) if rtol is None else rtol
assert_all_close(
ret_np,
ret_np_from_gt,
backend=backend,
rtol=rtol,
atol=atol,
ground_truth_backend=ground_truth_backend,
)
elif rtol is not None:
for ret_np, ret_np_from_gt in zip(ret_np_flat, ret_np_from_gt_flat):
rtol = TOLERANCE_DICT.get(str(ret_np_from_gt.dtype), 1e-03)
assert_all_close(
ret_np,
ret_np_from_gt,
Expand All @@ -148,6 +165,7 @@ def value_test(
)
else:
for ret_np, ret_np_from_gt in zip(ret_np_flat, ret_np_from_gt_flat):
rtol = TOLERANCE_DICT.get(str(ret_np_from_gt.dtype), 1e-03)
assert_all_close(
ret_np,
ret_np_from_gt,
Expand Down
21 changes: 21 additions & 0 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_function(
fn_name: str,
rtol_: float = None,
atol_: float = 1e-06,
tolerance_dict: dict = None,
test_values: bool = True,
xs_grad_idxs=None,
ret_grad_idxs=None,
Expand Down Expand Up @@ -126,6 +127,8 @@ def test_function(
relative tolerance value.
atol_
absolute tolerance value.
tolerance_dict
(Optional) dictionary of tolerance values for each dtype.
test_values
if True, test for the correctness of the resulting values.
xs_grad_idxs
Expand Down Expand Up @@ -462,6 +465,7 @@ def test_function(
test_flags=test_flags,
rtol_=rtol_,
atol_=atol_,
tolerance_dict=tolerance_dict,
xs_grad_idxs=xs_grad_idxs,
ret_grad_idxs=ret_grad_idxs,
ground_truth_backend=test_flags.ground_truth_backend,
Expand Down Expand Up @@ -497,6 +501,7 @@ def test_function(
ret_np_from_gt_flat=ret_np_from_gt_flat,
rtol=rtol_,
atol=atol_,
specific_tolerance_dict=tolerance_dict,
backend=backend_to_test,
ground_truth_backend=test_flags.ground_truth_backend,
)
Expand All @@ -513,6 +518,7 @@ def test_frontend_function(
gt_fn_tree: str = None,
rtol: float = None,
atol: float = 1e-06,
tolerance_dict: dict = None,
test_values: bool = True,
**all_as_kwargs_np,
):
Expand All @@ -537,6 +543,8 @@ def test_frontend_function(
relative tolerance value.
atol
absolute tolerance value.
tolerance_dict
dictionary of tolerance values for specific dtypes.
test_values
if True, test for the correctness of the resulting values.
all_as_kwargs_np
Expand Down Expand Up @@ -884,6 +892,7 @@ def arrays_to_numpy(x):
ret_np_from_gt_flat=frontend_ret_np_flat,
rtol=rtol,
atol=atol,
specific_tolerance_dict=tolerance_dict,
backend=backend_to_test,
ground_truth_backend=frontend,
)
Expand All @@ -903,6 +912,7 @@ def gradient_test(
test_compile: bool = False,
rtol_: float = None,
atol_: float = 1e-06,
tolerance_dict: dict = None,
xs_grad_idxs=None,
ret_grad_idxs=None,
backend_to_test: str,
Expand Down Expand Up @@ -1007,6 +1017,7 @@ def _gt_grad_fn(all_args):
ret_np_from_gt_flat=grads_np_from_gt_flat,
rtol=rtol_,
atol=atol_,
specific_tolerance_dict=tolerance_dict,
backend=backend_to_test,
ground_truth_backend=ground_truth_backend,
)
Expand All @@ -1026,6 +1037,7 @@ def test_method(
method_with_v: bool = False,
rtol_: float = None,
atol_: float = 1e-06,
tolerance_dict: dict = None,
test_values: Union[bool, str] = True,
test_gradients: bool = False,
xs_grad_idxs=None,
Expand Down Expand Up @@ -1085,6 +1097,8 @@ def test_method(
relative tolerance value.
atol_
absolute tolerance value.
tolerance_dict
dictionary of tolerance values for specific dtypes.
test_values
can be a bool or a string to indicate whether correctness of values should be
tested. If the value is `with_v`, shapes are tested but not values.
Expand Down Expand Up @@ -1368,6 +1382,7 @@ def test_method(
test_compile=test_compile,
rtol_=rtol_,
atol_=atol_,
tolerance_dict=tolerance_dict,
xs_grad_idxs=xs_grad_idxs,
ret_grad_idxs=ret_grad_idxs,
backend_to_test=backend_to_test,
Expand All @@ -1389,6 +1404,7 @@ def test_method(
test_compile=test_compile,
rtol_=rtol_,
atol_=atol_,
tolerance_dict=tolerance_dict,
xs_grad_idxs=xs_grad_idxs,
ret_grad_idxs=ret_grad_idxs,
backend_to_test=backend_to_test,
Expand Down Expand Up @@ -1426,6 +1442,7 @@ def test_method(
ret_np_from_gt_flat=ret_np_from_gt_flat,
rtol=rtol_,
atol=atol_,
specific_tolerance_dict=tolerance_dict,
)


Expand All @@ -1443,6 +1460,7 @@ def test_frontend_method(
on_device,
rtol_: float = None,
atol_: float = 1e-06,
tolerance_dict: dict = None,
test_values: Union[bool, str] = True,
):
"""
Expand Down Expand Up @@ -1474,6 +1492,8 @@ def test_frontend_method(
relative tolerance value.
atol_
absolute tolerance value.
tolerance_dict
dictionary of tolerance values for specific dtypes.
test_values
can be a bool or a string to indicate whether correctness of values should be
tested. If the value is `with_v`, shapes are tested but not values.
Expand Down Expand Up @@ -1736,6 +1756,7 @@ def test_frontend_method(
ret_np_from_gt_flat=frontend_ret_np_flat,
rtol=rtol_,
atol=atol_,
specific_tolerance_dict=tolerance_dict,
backend=backend_to_test,
ground_truth_backend=frontend,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_mean(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_dev
on_device=on_device,
rtol_=1e-1,
atol_=1e-1,
tolerance_dict={"bfloat16": 1e-1},
x=x[0],
axis=axis,
keepdims=keep_dims,
Expand Down

0 comments on commit 8d1bd2a

Please sign in to comment.