Skip to content

Commit

Permalink
Merge pull request #22547 from sameer-dudeja:dev-fix-broken-link
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654724736
  • Loading branch information
jax authors committed Jul 22, 2024
2 parents 433f66a + 993a1e7 commit 8ec0cc2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
HloSharding = xla_client.HloSharding

# The minimum and maximum supported calling convention version.
# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions
# See https://jax.readthedocs.io/en/latest/export/export.html#export-calling-convention-version
minimum_supported_calling_convention_version = 9
maximum_supported_calling_convention_version = 9

Expand Down Expand Up @@ -166,14 +166,14 @@ class Exported:
add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'.
See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export.
ordered_effects: the ordered effects present in the serialized module.
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention
for the calling convention in presence of ordered effects.
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
calling_convention_version: a version number for the calling
convention of the exported module.
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions.
See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used.
Expand All @@ -192,7 +192,7 @@ class Exported:
for each primal output. It returns a tuple with the cotangents
corresponding to the flattened primal inputs.
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention).
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention).
"""
fun_name: str
in_tree: tree_util.PyTreeDef
Expand Down Expand Up @@ -399,7 +399,7 @@ def export_back_compat(
Note: this function exists only for internal usage by jax2tf and for
backwards compatibility with jax.experimental.export. Use
`jax.export` instead.
See https://jax.readthedocs.io/en/latest/export.html
See https://jax.readthedocs.io/en/latest/export/export.html
Args:
fun_jax: the function to lower and serialize.
Expand All @@ -409,7 +409,7 @@ def export_back_compat(
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained
at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
Expand Down Expand Up @@ -484,7 +484,7 @@ def export(
the exported code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained at
https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention.
lowering_platforms: DEPRECATED, use `platforms`.
disabled_checks: the safety checks to disable. See documentation for
of `jax.export.DisabledSafetyCheck`.
Expand Down Expand Up @@ -709,7 +709,7 @@ def _wrap_main_func(
) -> ir.Module:
"""Wraps the lowered module with a new "main" handling dimension arguments.
See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
See calling convention documentation https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention.
Args:
module: the HLO module as obtained from lowering.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class LoweringParameters:
# Signals that we are lowering for exporting.

for_export: bool = False
# See usage in https://jax.readthedocs.io/en/latest/export.html#ensuring-forward-and-backward-compatibility
# See usage in https://jax.readthedocs.io/en/latest/export/export.html#ensuring-forward-and-backward-compatibility
# We have this here to ensure it is reflected in the cache keys
export_ignore_forward_compatibility: bool = False

Expand Down

0 comments on commit 8ec0cc2

Please sign in to comment.