From 993a1e74ba06811e80c964deecbe7622de990340 Mon Sep 17 00:00:00 2001 From: Sameer Dudeja Date: Sun, 21 Jul 2024 11:37:01 +0530 Subject: [PATCH] Fix broken export links --- jax/_src/export/_export.py | 16 ++++++++-------- jax/_src/interpreters/mlir.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 0af4b06aeed8..5c83d40089df 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -64,7 +64,7 @@ HloSharding = xla_client.HloSharding # The minimum and maximum supported calling convention version. -# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions +# See https://jax.readthedocs.io/en/latest/export/export.html#export-calling-convention-version minimum_supported_calling_convention_version = 9 maximum_supported_calling_convention_version = 9 @@ -166,14 +166,14 @@ class Exported: add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention + This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. calling_convention_version: a version number for the calling convention of the exported module. - See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions. + See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. @@ -192,7 +192,7 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention). + See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention). """ fun_name: str in_tree: tree_util.PyTreeDef @@ -399,7 +399,7 @@ def export_back_compat( Note: this function exists only for internal usage by jax2tf and for backwards compatibility with jax.experimental.export. Use `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export.html + See https://jax.readthedocs.io/en/latest/export/export.html Args: fun_jax: the function to lower and serialize. @@ -409,7 +409,7 @@ def export_back_compat( the lowered code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained - at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. disabled_checks: the safety checks to disable. See docstring of `DisabledSafetyCheck`. @@ -484,7 +484,7 @@ def export( the exported code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained at - https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. lowering_platforms: DEPRECATED, use `platforms`. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -709,7 +709,7 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + See calling convention documentation https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. Args: module: the HLO module as obtained from lowering. diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f09848d96997..c247fdd5b49f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -567,7 +567,7 @@ class LoweringParameters: # Signals that we are lowering for exporting. for_export: bool = False - # See usage in https://jax.readthedocs.io/en/latest/export.html#ensuring-forward-and-backward-compatibility + # See usage in https://jax.readthedocs.io/en/latest/export/export.html#ensuring-forward-and-backward-compatibility # We have this here to ensure it is reflected in the cache keys export_ignore_forward_compatibility: bool = False