Skip to content

Commit

Permalink
Rolling back #18980, because it is not backwards compatible and break…
Browse files Browse the repository at this point in the history
…s existing users.

Reverts 91faddd

PiperOrigin-RevId: 591200403
  • Loading branch information
superbobry authored and jax authors committed Dec 15, 2023
1 parent a0458e6 commit 4153112
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 113 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.24

* Changes

* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
Expand All @@ -18,9 +17,6 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`.

## jaxlib 0.4.24

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def func(...): ...

import jax
from jax import tree_util
from jax.experimental import export
from jax.experimental.export import export

from jax.experimental import pjit

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ load("@rules_python//python:defs.bzl", "py_library")

licenses(["notice"])

# Please add new users to :australis_users.
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
Expand All @@ -30,8 +31,7 @@ package(
py_library(
name = "export",
srcs = [
"__init__.py",
"_export.py",
"export.py",
"serialization.py",
"serialization_generated.py",
"shape_poly.py",
Expand Down
18 changes: 0 additions & 18 deletions jax/experimental/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from jax.experimental.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
export,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform,

symbolic_shape,
args_specs,
)
from jax.experimental.export.serialization import (
serialize,
deserialize,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import itertools
import re
from typing import Any, Callable, Optional, TypeVar, Union
import warnings

from absl import logging
import numpy as np
Expand Down Expand Up @@ -56,20 +55,6 @@

DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]

# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]

# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9

_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9


class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
Expand Down Expand Up @@ -130,6 +115,19 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash(self._impl)

# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9

_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9

# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]

# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]

@dataclasses.dataclass(frozen=True)
class Exported:
Expand Down Expand Up @@ -1052,7 +1050,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:

### Calling the exported function

def call(exported: Exported) -> Callable[..., jax.Array]:
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
Expand Down Expand Up @@ -1098,7 +1096,6 @@ def f_imported(*args, **kwargs):
return exported.out_tree.unflatten(res_flat)
return f_imported

call_exported = call

# A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported")
Expand Down Expand Up @@ -1299,30 +1296,3 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())

# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return f(*args, **kwargs)
return wrapped_f

export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version
8 changes: 3 additions & 5 deletions jax/experimental/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
from jax._src import effects
from jax._src import tree_util
from jax._src.lib import xla_client
from jax.experimental.export import export
from jax.experimental.export import serialization_generated as ser_flatbuf
from jax.experimental.export import _export
from jax.experimental import export

import numpy as np

T = TypeVar("T")
Expand Down Expand Up @@ -355,7 +353,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue:


def _serialize_sharding(
builder: flatbuffers.Builder, s: _export.Sharding
builder: flatbuffers.Builder, s: export.Sharding
) -> int:
proto = None
if s is None:
Expand All @@ -372,7 +370,7 @@ def _serialize_sharding(
return ser_flatbuf.ShardingEnd(builder)


def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding:
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
kind = s.Kind()
if kind == ser_flatbuf.ShardingKind.unspecified:
return None
Expand Down
35 changes: 17 additions & 18 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
from jax import sharding
from jax.experimental import maps
from jax.experimental.export import shape_poly
from jax.experimental.export import _export
from jax.experimental import export
from jax.experimental.export import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla

Expand Down Expand Up @@ -515,14 +514,14 @@ def run_fun_tf(self,

def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]:
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings,
out_avals=self.exported.out_avals,
out_shardings=self.exported.out_shardings,
nr_devices=self.exported.nr_devices,
apply_jit=True)
return export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings,
out_avals=self.exported.out_avals,
out_shardings=self.exported.out_shardings,
nr_devices=self.exported.nr_devices,
apply_jit=True)

class GraphSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
Expand Down Expand Up @@ -587,14 +586,14 @@ def get_vjp_fun(self) -> tuple[Callable,
# We reuse the code for native serialization to get the VJP functions,
# except we use unspecified shardings, and we do not apply a jit on the
# VJP. This matches the older behavior of jax2tf for graph serialization.
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree,
in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat),
out_avals=self.outs_avals,
out_shardings=(None,) * len(self.outs_avals),
nr_devices=1, # Does not matter for unspecified shardings
apply_jit=False)
return export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree,
in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat),
out_avals=self.outs_avals,
out_shardings=(None,) * len(self.outs_avals),
nr_devices=1, # Does not matter for unspecified shardings
apply_jit=False)


def dtype_of_val(val: TfVal) -> DType:
Expand Down
5 changes: 2 additions & 3 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

import jax
from jax import lax
from jax.experimental import export
from jax.experimental.export import _export
from jax.experimental.export import export
from jax._src.internal_test_util import export_back_compat_test_util as bctu

from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
Expand Down Expand Up @@ -98,7 +97,7 @@ def test_detect_different_custom_calls(self):

def test_custom_call_coverage(self):
"""Tests that the back compat tests cover all the targets declared stable."""
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf
from jax.experimental import export
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np

Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf
from jax.experimental import export
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from jax import tree_util

from jax.experimental import jax2tf
from jax.experimental import export
from jax.experimental.export import export
from jax._src import config
from jax._src import xla_bridge
import numpy as np
Expand Down
Loading

0 comments on commit 4153112

Please sign in to comment.