Skip to content

Commit

Permalink
Refine regions_with_inaccuracies to account for ARM numerics differences
Browse files Browse the repository at this point in the history
cc @pearu

PiperOrigin-RevId: 612419520
  • Loading branch information
apaszke authored and jax authors committed Mar 4, 2024
1 parent 7514d5c commit 57a8ef7
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import itertools
import math
import operator
import platform
import types
import unittest
from unittest import SkipTest
Expand Down Expand Up @@ -3365,22 +3366,28 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
def testOnComplexPlane(self, name, dtype, kind):
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
is_cpu = jtu.test_device_matches(["cpu"])
is_arm_cpu = platform.machine().startswith('aarch')
is_cuda = jtu.test_device_matches(["cuda"])

# TODO(pearu): eliminate all items in the following lists:
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
regions_with_inaccuracies = dict(
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
exp = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'],
exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf']
+ (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
log = ['q1', 'q2', 'q3', 'q4'],
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
sinh = ['pos', 'neg', 'ninf', 'pinf'],
cosh = ['pos', 'neg', 'ninf', 'pinf'],
sinh = (['pos', 'neg', 'ninf', 'pinf']
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
cosh = (['pos', 'neg', 'ninf', 'pinf']
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
square = ((['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
+ ['ninf', 'pinf']
square = (['pinf']
+ (['ninfj', 'pinfj'] if is_arm_cpu else [])
+ (['ninf'] if not is_arm_cpu else [])
+ (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
sinc = ['q1', 'q2', 'q3', 'q4'],
sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
Expand All @@ -3390,6 +3397,9 @@ def testOnComplexPlane(self, name, dtype, kind):
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [],
)

jnp_op = getattr(jnp, name)
Expand Down

0 comments on commit 57a8ef7

Please sign in to comment.