Skip to content

Commit

Permalink
Merge pull request #22199 from gnecula:export_test_skip
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648705137
  • Loading branch information
jax authors committed Jul 2, 2024
2 parents 92ebb53 + cfa3c91 commit 2e0c100
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@

import numpy as np

# ruff: noqa: F401
try:
import flatbuffers
CAN_SERIALIZE = True
except (ModuleNotFoundError, ImportError):
CAN_SERIALIZE = False

config.parse_flags_with_absl()

_exit_stack = contextlib.ExitStack()
Expand Down Expand Up @@ -139,12 +146,15 @@ def _testing_multi_platform_fun_expected(x,


def get_exported(fun: Callable, vjp_order=0,
**export_kwargs):
**export_kwargs) -> Callable[[...], export.Exported]:
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
if CAN_SERIALIZE:
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
else:
return exp
return serde_exported


Expand Down Expand Up @@ -234,6 +244,8 @@ def test_export_error_no_jit(self):
@jtu.ignore_warning(category=DeprecationWarning,
message="The jax.experimental.export module is deprecated")
def test_export_experimental_back_compat(self):
if not CAN_SERIALIZE:
self.skipTest("serialization disabled")
from jax.experimental import export
# Can export a lambda, without jit
exp = export.export(lambda x: jnp.sin(x))(.1)
Expand Down Expand Up @@ -1328,8 +1340,9 @@ def f(x):
exp = export.export(pjit.pjit(f, in_shardings=shardings))(input)
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)

_ = exp.serialize(vjp_order=1)
_ = exp_rev.serialize(vjp_order=1)
if CAN_SERIALIZE:
_ = exp.serialize(vjp_order=1)
_ = exp_rev.serialize(vjp_order=1)

g = jax.grad(exp_rev.call)(input_rev)
g_rev = jax.grad(exp.call)(input)
Expand Down Expand Up @@ -1725,6 +1738,9 @@ def f_jax(x):
)
])
def test_ordered_effects_error(self, *, name: str, expect_error: str):
if not CAN_SERIALIZE:
# These errors arise during serialization
self.skipTest("serialization is disabled")
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x):
return 10. + _testing_multi_platform_func(
Expand Down

0 comments on commit 2e0c100

Please sign in to comment.