diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 2e1739d55909..54c68811b732 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -637,11 +637,11 @@ def _share_fdo_profiles( backend: xc.Client, global_client: lib._jax.DistributedRuntimeClient, min_process_id -) -> bytes | None: +) -> bytes: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value fdo_profile = compile_options.executable_build_options.fdo_profile - if fdo_profile is None or len(fdo_profile) == 0: + if len(fdo_profile) == 0: return fdo_profile compile_options.executable_build_options.fdo_profile = b"" diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index a62466049147..f5324f5434c5 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -25,6 +25,7 @@ from jax._src import traceback_util from jax._src.lib import pytree +from jax._src.lib import version as jaxlib_version from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 @@ -123,7 +124,9 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef: See Also: - :func:`jax.tree_util.treedef_children` """ - return pytree.tuple(default_registry, list(treedefs)) + if jaxlib_version < (0, 8, 0): + return pytree.tuple(default_registry, list(treedefs)) # type: ignore + return pytree.treedef_tuple(default_registry, list(treedefs)) @export diff --git a/jaxlib/BUILD b/jaxlib/BUILD index fd3ac392bfa6..0dd9b323d6fa 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -323,8 +323,17 @@ nanobind_pywrap_extension( nanobind_pywrap_extension( name = "_jax", srcs = ["jax.cc"], + additional_stubgen_deps = [ + "//third_party/py/numpy", + "//jaxlib/mlir:ir", + ], + enable_stub_generation = True, pytype_deps = py_deps(["numpy"]), pytype_srcs = glob(["_jax/*.pyi"]), + stub_replacement_patterns = { + "jax.jaxlib._jax.Array$": "Array: Any", + "jax.jaxlib._jax.ArrayImpl$": "ArrayImpl: Any", + }, visibility = jax_visibility("jaxlib/_jax"), deps = [ ":config", diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 1681ab62a7bd..b624291a4015 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -1,133 +1,167 @@ -# Copyright 2021 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from __future__ import annotations - -import builtins -from collections.abc import Callable, Iterator, Mapping, Sequence, Set +from collections.abc import Callable, Iterator, Mapping, Sequence import enum import inspect +import traceback import types -from typing import Any, ClassVar, TypeVar, overload +from typing import Annotated, Any, TypeAlias, overload + +import numpy +from numpy.typing import NDArray +import typing_extensions + +from . import ( + config as config, + ffi as ffi, + guard_lib as guard_lib, + hlo_sharding_util as hlo_sharding_util, + ifrt_programs as ifrt_programs, + jax_jit as jax_jit, + mlir as mlir, + pmap_lib as pmap_lib, + pytree as pytree, +) +from .pmap_lib import PmapFunction as PmapFunction +from .pytree import (PyTreeDef as PyTreeDef, PyTreeRegistry as _PyTreeRegistry) -import numpy as np +class JaxRuntimeError(RuntimeError): + """Runtime errors thrown by the JAX runtime. -from . import config as config -from . import ffi as ffi -from . import guard_lib as guard_lib -from . import ifrt_programs as ifrt_programs -from . import jax_jit as jax_jit -from . import mlir as mlir -from . import pmap_lib as pmap_lib -from . import profiler as profiler -from . import pytree as pytree -from . import transfer_guard_lib as transfer_guard_lib + While the JAX runtime may raise other exceptions as well, most exceptions + thrown by the runtime are instances of this class. + """ -custom_call_targets = Any -hlo_sharding_util = Any +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID = 0 -_LiteralSlice = Any -_Status = Any -_Dtype = Any + PRED = 1 -ifrt_version_number: int + S4 = 21 -_T = TypeVar("_T") + S8 = 2 -class JaxRuntimeError(RuntimeError): - pass + S16 = 3 -class PrimitiveType(enum.IntEnum): - PRIMITIVE_TYPE_INVALID = ... - PRED = ... - S2 = ... - S4 = ... - S8 = ... - S16 = ... - S32 = ... - S64 = ... - U2 = ... - U4 = ... - U8 = ... - U16 = ... - U32 = ... - U64 = ... - F4E2M1FN = ... - F8E3M4 = ... - F8E4M3 = ... - F8E4M3FN = ... - F8E4M3B11FNUZ = ... - F8E4M3FNUZ = ... - F8E5M2 = ... - F8E5M2FNUZ = ... - F8E8M0FNU = ... - BF16 = ... - F16 = ... - F32 = ... - F64 = ... - C64 = ... - C128 = ... - TUPLE = ... - OPAQUE_TYPE = ... - TOKEN = ... - -# === BEGIN xla_compiler.cc + S32 = 4 -class ArrayCopySemantics(enum.IntEnum): - ALWAYS_COPY = ... - REUSE_INPUT = ... - DONATE_INPUT = ... + S64 = 5 + + U4 = 22 + + U8 = 6 + + U16 = 7 + + U32 = 8 + + U64 = 9 + + F16 = 10 + + F4E2M1FN = 32 + + F8E3M4 = 29 + + F8E4M3 = 28 + + F8E4M3FN = 20 + + F8E4M3B11FNUZ = 23 + + F8E4M3FNUZ = 25 + + F8E5M2 = 19 + + F8E5M2FNUZ = 24 + + F8E8M0FNU = 33 + + BF16 = 16 + + F32 = 11 + + F64 = 12 + + C64 = 15 + + C128 = 18 + + TUPLE = 13 + + OPAQUE_TYPE = 14 + + TOKEN = 17 class Layout: @overload - def __init__(self, minor_to_major: tuple[int, ...]): ... + def __init__(self, arg: Sequence[int], /) -> None: ... @overload def __init__( - self, - minor_to_major: tuple[int, ...], - tiling: tuple[tuple[int, ...], ...], - element_size_in_bits: int, - ): ... + self, arg0: Sequence[int], arg1: Sequence[tuple[int, ...]], arg2: int, / + ) -> None: ... def minor_to_major(self) -> tuple[int, ...]: ... - def tiling(self) -> Sequence[tuple[int, ...]]: ... def element_size_in_bits(self) -> int: ... - def to_string(self) -> str: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... + def tiling(self) -> list[tuple[int, ...]]: ... + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... + def __str__(self) -> str: ... def __hash__(self) -> int: ... + def to_string(self) -> str: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... class Shape: - def __init__(self, s: str): ... + def __init__(self, arg: str, /) -> None: ... + @staticmethod + def tuple_shape(arg: Sequence[Shape], /) -> Shape: + """Constructs a tuple shape.""" + + @overload @staticmethod - def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... + def array_shape( + type: PrimitiveType, + dims: Sequence[int], + layout: Sequence[int] | None = None, + dynamic_dimensions: Sequence[bool] | None = None, + ) -> Shape: + """Constructs an array shape.""" + + @overload @staticmethod def array_shape( - type: np.dtype | PrimitiveType, - dims_seq: Any = ..., - layout_seq: Any = ..., - dynamic_dimensions: list[bool] | None = ..., + type: numpy.dtype, + dims: Sequence[int], + layout: Sequence[int] | None = None, + dynamic_dimensions: Sequence[bool] | None = None, ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... + @overload @staticmethod - def scalar_shape(type: np.dtype | PrimitiveType) -> Shape: ... + def scalar_shape(type: PrimitiveType) -> Shape: + """Constructs a scalar shape.""" + + @overload + @staticmethod + def scalar_shape(type: numpy.dtype) -> Shape: ... def dimensions(self) -> tuple[int, ...]: ... def layout(self) -> Layout: ... def xla_element_type(self) -> PrimitiveType: ... - def element_type(self) -> np.dtype: ... - def numpy_dtype(self) -> np.dtype: ... + def element_type(self) -> numpy.dtype: ... + def numpy_dtype(self) -> numpy.dtype: ... def is_tuple(self) -> bool: ... def is_array(self) -> bool: ... def is_token(self) -> bool: ... @@ -139,30 +173,33 @@ class Shape: def to_serialized_proto(self) -> bytes: ... def tuple_shapes(self) -> list[Shape]: ... def leaf_count(self) -> int: ... - def with_major_to_minor_layout_if_absent(self) -> Shape: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... class ProgramShape: - def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... + def __init__(self, arg0: Sequence[Shape], arg1: Shape, /) -> None: ... def parameter_shapes(self) -> list[Shape]: ... def result_shape(self) -> Shape: ... def __repr__(self) -> str: ... class Literal: - def __init__(self, shape: Shape) -> None: ... + def __init__(self, arg: Shape, /) -> None: ... def __repr__(self) -> str: ... def __array__( - self, dtype: np.dtype | None = None, copy: bool | None = None - ) -> np.ndarray: ... + self, dtype: object | None = None, copy: bool | None = None + ) -> NDArray: ... def shape(self) -> Shape: ... class XlaComputation: - def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... + def __init__(self, arg: bytes, /) -> None: ... def get_hlo_module(self) -> HloModule: ... def program_shape(self) -> ProgramShape: ... + def name(self) -> str: ... def as_serialized_hlo_module_proto(self) -> bytes: ... def as_hlo_text(self, print_large_constants: bool = False) -> str: ... def as_hlo_dot_graph(self) -> str: ... @@ -177,363 +214,728 @@ class HloPrintOptions: def canonical() -> HloPrintOptions: ... @staticmethod def fingerprint() -> HloPrintOptions: ... - print_large_constants: bool - print_metadata: bool - print_backend_config: bool - print_result_shape: bool - print_operand_shape: bool - print_operand_names: bool - print_ids: bool - print_extra_attributes: bool - print_program_shape: bool - print_percent: bool - print_control_dependencies: bool - compact_operands: bool - include_layout_in_shapes: bool - canonicalize_instruction_names: bool - canonicalize_computations: bool - indent_amount: int - is_in_nested_computation: bool + @property + def print_large_constants(self) -> bool: ... + @print_large_constants.setter + def print_large_constants(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_metadata(self) -> bool: ... + @print_metadata.setter + def print_metadata(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_backend_config(self) -> bool: ... + @print_backend_config.setter + def print_backend_config(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_result_shape(self) -> bool: ... + @print_result_shape.setter + def print_result_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_operand_shape(self) -> bool: ... + @print_operand_shape.setter + def print_operand_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_operand_names(self) -> bool: ... + @print_operand_names.setter + def print_operand_names(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_ids(self) -> bool: ... + @print_ids.setter + def print_ids(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_extra_attributes(self) -> bool: ... + @print_extra_attributes.setter + def print_extra_attributes(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_program_shape(self) -> bool: ... + @print_program_shape.setter + def print_program_shape(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_percent(self) -> bool: ... + @print_percent.setter + def print_percent(self, arg: bool, /) -> HloPrintOptions: ... + @property + def print_control_dependencies(self) -> bool: ... + @print_control_dependencies.setter + def print_control_dependencies(self, arg: bool, /) -> HloPrintOptions: ... + @property + def compact_operands(self) -> bool: ... + @compact_operands.setter + def compact_operands(self, arg: bool, /) -> HloPrintOptions: ... + @property + def include_layout_in_shapes(self) -> bool: ... + @include_layout_in_shapes.setter + def include_layout_in_shapes(self, arg: bool, /) -> HloPrintOptions: ... + @property + def canonicalize_instruction_names(self) -> bool: ... + @canonicalize_instruction_names.setter + def canonicalize_instruction_names(self, arg: bool, /) -> HloPrintOptions: ... + @property + def canonicalize_computations(self) -> bool: ... + @canonicalize_computations.setter + def canonicalize_computations(self, arg: bool, /) -> HloPrintOptions: ... + @property + def indent_amount(self) -> int: ... + @indent_amount.setter + def indent_amount(self, arg: int, /) -> HloPrintOptions: ... + @property + def is_in_nested_computation(self) -> int: ... + @is_in_nested_computation.setter + def is_in_nested_computation(self, arg: bool, /) -> HloPrintOptions: ... class HloComputation: - def render_html(self) -> None: ... + @property + def name(self) -> str: ... + def render_html(self, arg: str, /) -> None: ... class HloModule: - spmd_output_sharding: OpSharding | None - spmd_parameters_shardings: list[OpSharding] | None @property def name(self) -> str: ... def to_string(self, options: HloPrintOptions = ...) -> str: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - @staticmethod - def from_serialized_hlo_module_proto( - serialized_hlo_module_proto: bytes, - ) -> HloModule: ... + def from_serialized_hlo_module_proto(self) -> HloModule: ... def computations(self) -> list[HloComputation]: ... + @property + def spmd_output_sharding(self) -> OpSharding | None: ... + @property + def spmd_parameters_shardings(self) -> list[OpSharding] | None: ... - -def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... -def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... -def hlo_module_cost_analysis( - client: Client, module: HloModule -) -> dict[str, float]: ... +def hlo_module_to_dot_graph(arg: HloModule, /) -> str: ... +def hlo_module_cost_analysis(arg0: Client, arg1: HloModule, /) -> dict: ... +def hlo_module_from_text(arg: str, /) -> HloModule: ... class DeviceAssignment: @staticmethod - def create(array: np.ndarray) -> DeviceAssignment: ... + def create( + arg: Annotated[NDArray[numpy.int32], dict(shape=(None, None))], / + ) -> DeviceAssignment: ... def replica_count(self) -> int: ... def computation_count(self) -> int: ... def __repr__(self) -> str: ... def serialize(self) -> bytes: ... class CompileOptions: - @staticmethod - def ParseFromString(s: bytes) -> CompileOptions: ... def __init__(self) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... def SerializeAsString(self) -> bytes: ... - argument_layouts: list[Shape] | None - parameter_is_tupled_arguments: bool - executable_build_options: ExecutableBuildOptions - tuple_arguments: bool - num_replicas: int - num_partitions: int - profile_version: int - device_assignment: DeviceAssignment | None - compile_portable_executable: bool - env_option_overrides: list[tuple[str, str]] + @staticmethod + def ParseFromString(arg: bytes, /) -> CompileOptions: ... + @property + def argument_layouts(self) -> list[Shape] | None: ... + @argument_layouts.setter + def argument_layouts(self, arg: Sequence[Shape], /) -> None: ... + @property + def parameter_is_tupled_arguments(self) -> bool: ... + @parameter_is_tupled_arguments.setter + def parameter_is_tupled_arguments(self, arg: bool, /) -> None: ... + @property + def compile_portable_executable(self) -> bool: ... + @compile_portable_executable.setter + def compile_portable_executable(self, arg: bool, /) -> None: ... + @property + def executable_build_options(self) -> ExecutableBuildOptions: ... + @property + def env_option_overrides( + self, + ) -> list[tuple[str, str | bool | int | float]]: ... + @env_option_overrides.setter + def env_option_overrides( + self, arg: Sequence[tuple[str, str | bool | int | float]], / + ) -> None: ... + @property + def tuple_arguments(self) -> bool: ... + @tuple_arguments.setter + def tuple_arguments(self, arg: bool, /) -> None: ... + @property + def num_replicas(self) -> int: ... + @num_replicas.setter + def num_replicas(self, arg: int, /) -> None: ... + @property + def num_partitions(self) -> int: ... + @num_partitions.setter + def num_partitions(self, arg: int, /) -> None: ... + @property + def profile_version(self) -> int: ... + @profile_version.setter + def profile_version(self, arg: int, /) -> None: ... + @property + def device_assignment(self) -> DeviceAssignment | None: ... + @device_assignment.setter + def device_assignment(self, arg: DeviceAssignment, /) -> None: ... def register_custom_call_target( - fn_name: str, - capsule: Any, + fn_name: object, + fn: object, platform: str, - api_version: int = ..., - traits: int = ..., -) -> _Status: ... -def register_custom_call_partitioner( - name: str, - prop_user_sharding: Callable, - partition: Callable, - infer_sharding_from_operands: Callable, - can_side_effecting_have_replicated_sharding: bool = ..., - c_api: Any | None = ..., + api_version: int = 0, + traits: int = 0, ) -> None: ... -def encode_inspect_sharding_callback(handler: Any) -> bytes: ... -def register_custom_call_as_batch_partitionable( - target_name: str, - c_api: Any | None = ..., -) -> None: ... -def register_custom_type_id(type_name: str, type_id: Any) -> None: ... +def custom_call_targets(platform: str) -> dict: ... + +class AutotuneCacheMode(enum.Enum): + UNSPECIFIED = 0 + + UPDATE = 1 -class AutotuneCacheMode(enum.IntEnum): - UNSPECIFIED = ... - UPDATE = ... - READ = ... + READ = 2 + +def register_custom_type_id(type_name: str, type_id: object) -> None: ... class DebugOptions: def __repr__(self) -> str: ... - xla_cpu_enable_fast_math: bool - xla_cpu_fast_math_honor_infs: bool - xla_cpu_fast_math_honor_nans: bool - xla_cpu_fast_math_honor_division: bool - xla_cpu_fast_math_honor_functions: bool - xla_gpu_enable_fast_min_max: bool - xla_backend_optimization_level: int - xla_cpu_enable_xprof_traceme: bool - xla_llvm_disable_expensive_passes: bool - xla_test_all_input_layouts: bool - xla_disable_hlo_passes: str - xla_enable_hlo_passes_only: str - xla_force_host_platform_device_count: int - xla_dump_to: str - xla_dump_hlo_module_re: str - xla_dump_hlo_pass_re: str - xla_dump_hlo_as_text: bool - xla_dump_hlo_as_proto: bool - xla_dump_hlo_as_dot: bool - xla_dump_hlo_as_url: bool - xla_dump_hlo_as_html: bool - xla_dump_fusion_visualization: bool - xla_dump_hlo_snapshots: bool - xla_dump_max_hlo_modules: bool - xla_dump_module_metadata: bool - xla_dump_compress_protos: bool - xla_dump_hlo_as_long_text: bool - xla_dump_disable_metadata: bool - xla_dump_hlo_pipeline_re: str - xla_gpu_cuda_data_dir: str - xla_detailed_logging: bool - xla_enable_dumping: bool - xla_gpu_dump_autotune_results_to: str - xla_gpu_load_autotune_results_from: str - xla_gpu_dump_autotune_logs_to: str - xla_gpu_kernel_cache_file: str - xla_gpu_enable_llvm_module_compilation_parallelism: bool - xla_gpu_per_fusion_autotune_cache_dir: str - xla_gpu_experimental_autotune_cache_mode: AutotuneCacheMode - -class CompiledMemoryStats: - generated_code_size_in_bytes: int - argument_size_in_bytes: int - output_size_in_bytes: int - alias_size_in_bytes: int - temp_size_in_bytes: int - host_generated_code_size_in_bytes: int - host_argument_size_in_bytes: int - host_output_size_in_bytes: int - host_alias_size_in_bytes: int - host_temp_size_in_bytes: int - serialized_buffer_assignment_proto: bytes - def __str__(self) -> str: ... + @property + def xla_backend_optimization_level(self) -> int: ... + @xla_backend_optimization_level.setter + def xla_backend_optimization_level(self, arg: int, /) -> None: ... + @property + def xla_cpu_enable_fast_math(self) -> bool: ... + @xla_cpu_enable_fast_math.setter + def xla_cpu_enable_fast_math(self, arg: bool, /) -> None: ... + @property + def xla_cpu_enable_xprof_traceme(self) -> bool: ... + @xla_cpu_enable_xprof_traceme.setter + def xla_cpu_enable_xprof_traceme(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_infs(self) -> bool: ... + @xla_cpu_fast_math_honor_infs.setter + def xla_cpu_fast_math_honor_infs(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_nans(self) -> bool: ... + @xla_cpu_fast_math_honor_nans.setter + def xla_cpu_fast_math_honor_nans(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_division(self) -> bool: ... + @xla_cpu_fast_math_honor_division.setter + def xla_cpu_fast_math_honor_division(self, arg: bool, /) -> None: ... + @property + def xla_cpu_fast_math_honor_functions(self) -> bool: ... + @xla_cpu_fast_math_honor_functions.setter + def xla_cpu_fast_math_honor_functions(self, arg: bool, /) -> None: ... + @property + def xla_detailed_logging(self) -> bool: ... + @xla_detailed_logging.setter + def xla_detailed_logging(self, arg: bool, /) -> None: ... + @property + def xla_enable_dumping(self) -> bool: ... + @xla_enable_dumping.setter + def xla_enable_dumping(self, arg: bool, /) -> None: ... + @property + def xla_gpu_enable_fast_min_max(self) -> bool: ... + @xla_gpu_enable_fast_min_max.setter + def xla_gpu_enable_fast_min_max(self, arg: bool, /) -> None: ... + @property + def xla_gpu_dump_autotune_results_to(self) -> str: ... + @xla_gpu_dump_autotune_results_to.setter + def xla_gpu_dump_autotune_results_to(self, arg: str, /) -> None: ... + @property + def xla_gpu_load_autotune_results_from(self) -> str: ... + @xla_gpu_load_autotune_results_from.setter + def xla_gpu_load_autotune_results_from(self, arg: str, /) -> None: ... + @property + def xla_gpu_cuda_data_dir(self) -> str: ... + @xla_gpu_cuda_data_dir.setter + def xla_gpu_cuda_data_dir(self, arg: str, /) -> None: ... + @property + def xla_llvm_disable_expensive_passes(self) -> bool: ... + @xla_llvm_disable_expensive_passes.setter + def xla_llvm_disable_expensive_passes(self, arg: bool, /) -> None: ... + @property + def xla_disable_hlo_passes(self) -> str: ... + @xla_disable_hlo_passes.setter + def xla_disable_hlo_passes(self, arg: str, /) -> None: ... + @property + def xla_enable_hlo_passes_only(self) -> str: ... + @xla_enable_hlo_passes_only.setter + def xla_enable_hlo_passes_only(self, arg: str, /) -> None: ... + @property + def xla_test_all_input_layouts(self) -> bool: ... + @xla_test_all_input_layouts.setter + def xla_test_all_input_layouts(self, arg: bool, /) -> None: ... + @property + def xla_force_host_platform_device_count(self) -> int: ... + @xla_force_host_platform_device_count.setter + def xla_force_host_platform_device_count(self, arg: int, /) -> None: ... + @property + def xla_dump_to(self) -> str: ... + @xla_dump_to.setter + def xla_dump_to(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_module_re(self) -> str: ... + @xla_dump_hlo_module_re.setter + def xla_dump_hlo_module_re(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_pass_re(self) -> str: ... + @xla_dump_hlo_pass_re.setter + def xla_dump_hlo_pass_re(self, arg: str, /) -> None: ... + @property + def xla_dump_hlo_as_text(self) -> bool: ... + @xla_dump_hlo_as_text.setter + def xla_dump_hlo_as_text(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_proto(self) -> bool: ... + @xla_dump_hlo_as_proto.setter + def xla_dump_hlo_as_proto(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_dot(self) -> bool: ... + @xla_dump_hlo_as_dot.setter + def xla_dump_hlo_as_dot(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_url(self) -> bool: ... + @xla_dump_hlo_as_url.setter + def xla_dump_hlo_as_url(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_html(self) -> bool: ... + @xla_dump_hlo_as_html.setter + def xla_dump_hlo_as_html(self, arg: bool, /) -> None: ... + @property + def xla_dump_fusion_visualization(self) -> bool: ... + @xla_dump_fusion_visualization.setter + def xla_dump_fusion_visualization(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_snapshots(self) -> bool: ... + @xla_dump_hlo_snapshots.setter + def xla_dump_hlo_snapshots(self, arg: bool, /) -> None: ... + @property + def xla_dump_max_hlo_modules(self) -> int: ... + @xla_dump_max_hlo_modules.setter + def xla_dump_max_hlo_modules(self, arg: int, /) -> None: ... + @property + def xla_dump_module_metadata(self) -> bool: ... + @xla_dump_module_metadata.setter + def xla_dump_module_metadata(self, arg: bool, /) -> None: ... + @property + def xla_dump_compress_protos(self) -> bool: ... + @xla_dump_compress_protos.setter + def xla_dump_compress_protos(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_as_long_text(self) -> bool: ... + @xla_dump_hlo_as_long_text.setter + def xla_dump_hlo_as_long_text(self, arg: bool, /) -> None: ... + @property + def xla_dump_disable_metadata(self) -> bool: ... + @xla_dump_disable_metadata.setter + def xla_dump_disable_metadata(self, arg: bool, /) -> None: ... + @property + def xla_dump_hlo_pipeline_re(self) -> str: ... + @xla_dump_hlo_pipeline_re.setter + def xla_dump_hlo_pipeline_re(self, arg: str, /) -> None: ... + @property + def xla_gpu_dump_autotune_logs_to(self) -> str: ... + @xla_gpu_dump_autotune_logs_to.setter + def xla_gpu_dump_autotune_logs_to(self, arg: str, /) -> None: ... + @property + def xla_gpu_kernel_cache_file(self) -> str: ... + @xla_gpu_kernel_cache_file.setter + def xla_gpu_kernel_cache_file(self, arg: str, /) -> None: ... + @property + def xla_gpu_enable_llvm_module_compilation_parallelism(self) -> bool: ... + @xla_gpu_enable_llvm_module_compilation_parallelism.setter + def xla_gpu_enable_llvm_module_compilation_parallelism( + self, arg: bool, / + ) -> None: ... + @property + def xla_gpu_per_fusion_autotune_cache_dir(self) -> str: ... + @xla_gpu_per_fusion_autotune_cache_dir.setter + def xla_gpu_per_fusion_autotune_cache_dir(self, arg: str, /) -> None: ... + @property + def xla_gpu_experimental_autotune_cache_mode(self) -> AutotuneCacheMode: ... + @xla_gpu_experimental_autotune_cache_mode.setter + def xla_gpu_experimental_autotune_cache_mode( + self, arg: AutotuneCacheMode, / + ) -> None: ... class ExecutableBuildOptions: def __init__(self) -> None: ... def __repr__(self) -> str: ... - result_layout: Shape | None - fdo_profile: bytes | None - num_replicas: int - num_partitions: int - debug_options: DebugOptions - device_assignment: DeviceAssignment | None - use_spmd_partitioning: bool - use_auto_spmd_partitioning: bool - auto_spmd_partitioning_mesh_shape: list[int] - auto_spmd_partitioning_mesh_ids: list[int] - use_shardy_partitioner: bool + @property + def fdo_profile(self) -> bytes: ... + @fdo_profile.setter + def fdo_profile(self, arg: bytes, /) -> None: ... + @property + def result_layout(self) -> Shape | None: ... + @result_layout.setter + def result_layout(self, arg: Shape, /) -> ExecutableBuildOptions: ... + @property + def num_replicas(self) -> int: ... + @num_replicas.setter + def num_replicas(self, arg: int, /) -> ExecutableBuildOptions: ... + @property + def num_partitions(self) -> int: ... + @num_partitions.setter + def num_partitions(self, arg: int, /) -> ExecutableBuildOptions: ... + @property + def debug_options(self) -> DebugOptions: ... + @property + def device_assignment(self) -> DeviceAssignment | None: ... + @device_assignment.setter + def device_assignment( + self, arg: DeviceAssignment, / + ) -> ExecutableBuildOptions: ... def compilation_environments_from_serialized_proto( - self, serialized_proto: bytes + self, arg: bytes, / + ) -> None: ... + @property + def exec_time_optimization_effort(self) -> float: ... + @exec_time_optimization_effort.setter + def exec_time_optimization_effort( + self, arg: float, / + ) -> ExecutableBuildOptions: ... + @property + def memory_fitting_effort(self) -> float: ... + @memory_fitting_effort.setter + def memory_fitting_effort(self, arg: float, /) -> ExecutableBuildOptions: ... + @property + def optimization_level(self) -> int: ... + @optimization_level.setter + def optimization_level(self, arg: int, /) -> None: ... + @property + def memory_fitting_level(self) -> int: ... + @memory_fitting_level.setter + def memory_fitting_level(self, arg: int, /) -> None: ... + @property + def use_spmd_partitioning(self) -> bool: ... + @use_spmd_partitioning.setter + def use_spmd_partitioning(self, arg: bool, /) -> ExecutableBuildOptions: ... + @property + def use_auto_spmd_partitioning(self) -> bool: ... + @use_auto_spmd_partitioning.setter + def use_auto_spmd_partitioning( + self, arg: bool, / + ) -> ExecutableBuildOptions: ... + @property + def auto_spmd_partitioning_mesh_shape(self) -> list[int]: ... + @auto_spmd_partitioning_mesh_shape.setter + def auto_spmd_partitioning_mesh_shape( + self, arg: Sequence[int], / + ) -> ExecutableBuildOptions: ... + @property + def auto_spmd_partitioning_mesh_ids(self) -> list[int]: ... + @auto_spmd_partitioning_mesh_ids.setter + def auto_spmd_partitioning_mesh_ids( + self, arg: Sequence[int], / + ) -> ExecutableBuildOptions: ... + @property + def allow_spmd_sharding_propagation_to_parameters(self) -> list[bool]: ... + @allow_spmd_sharding_propagation_to_parameters.setter + def allow_spmd_sharding_propagation_to_parameters( + self, arg: Sequence[bool], / + ) -> None: ... + @property + def allow_spmd_sharding_propagation_to_output(self) -> list[bool]: ... + @allow_spmd_sharding_propagation_to_output.setter + def allow_spmd_sharding_propagation_to_output( + self, arg: Sequence[bool], / ) -> None: ... + @property + def use_shardy_partitioner(self) -> bool: ... + @use_shardy_partitioner.setter + def use_shardy_partitioner(self, arg: bool, /) -> ExecutableBuildOptions: ... class OpSharding_Type(enum.IntEnum): - REPLICATED = ... - MAXIMAL = ... - TUPLE = ... - OTHER = ... - MANUAL = ... - UNREDUCED = ... - UNKNOWN = ... - -class OpSharding_ShardGroupType(enum.IntEnum): - AS = ... - LIKE = ... + REPLICATED = 0 + + MAXIMAL = 1 + + MANUAL = 4 + + UNREDUCED = 6 + + TUPLE = 2 + + OTHER = 3 + + UNKNOWN = 5 + +class OpSharding_ShardGroupType(enum.Enum): + AS = 0 + + LIKE = 1 class OpSharding: - Type: type[OpSharding_Type] - type: OpSharding_Type - replicate_on_last_tile_dim: bool - last_tile_dims: Sequence[OpSharding_Type] - tile_assignment_dimensions: Sequence[int] - tile_assignment_devices: Sequence[int] - iota_reshape_dims: Sequence[int] - iota_transpose_perm: Sequence[int] - tuple_shardings: Sequence[OpSharding] - is_shard_group: bool - shard_group_id: int - ShardGroupType: builtins.type[OpSharding_ShardGroupType] - shard_group_type: OpSharding_ShardGroupType - def ParseFromString(self, s: bytes) -> None: ... + def __init__(self) -> None: ... + + Type: TypeAlias = OpSharding_Type + + ShardGroupType: TypeAlias = OpSharding_ShardGroupType + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def type(self) -> OpSharding_Type: ... + @type.setter + def type(self, arg: OpSharding_Type, /) -> None: ... + @property + def replicate_on_last_tile_dim(self) -> bool: ... + @replicate_on_last_tile_dim.setter + def replicate_on_last_tile_dim(self, arg: bool, /) -> None: ... + @property + def is_shard_group(self) -> bool: ... + @is_shard_group.setter + def is_shard_group(self, arg: bool, /) -> None: ... + @property + def shard_group_id(self) -> int: ... + @shard_group_id.setter + def shard_group_id(self, arg: int, /) -> None: ... + @property + def shard_group_type(self) -> OpSharding_ShardGroupType: ... + @shard_group_type.setter + def shard_group_type(self, arg: OpSharding_ShardGroupType, /) -> None: ... + def __repr__(self) -> str: ... + def ParseFromString(self, arg: bytes, /) -> None: ... def SerializeToString(self) -> bytes: ... def clone(self) -> OpSharding: ... + @property + def tile_assignment_dimensions(self) -> list[int]: ... + @tile_assignment_dimensions.setter + def tile_assignment_dimensions(self, arg: Sequence[int], /) -> None: ... + @property + def tile_assignment_devices(self) -> list[int]: ... + @tile_assignment_devices.setter + def tile_assignment_devices(self, arg: Sequence[int], /) -> None: ... + @property + def iota_reshape_dims(self) -> list[int]: ... + @iota_reshape_dims.setter + def iota_reshape_dims(self, arg: Sequence[int], /) -> None: ... + @property + def iota_transpose_perm(self) -> list[int]: ... + @iota_transpose_perm.setter + def iota_transpose_perm(self, arg: Sequence[int], /) -> None: ... + @property + def tuple_shardings(self) -> list[OpSharding]: ... + @tuple_shardings.setter + def tuple_shardings(self, arg: Sequence[OpSharding], /) -> None: ... + @property + def last_tile_dims(self) -> list[int]: ... + @last_tile_dims.setter + def last_tile_dims(self, arg: Sequence[int], /) -> None: ... class HloSharding: @staticmethod - def from_proto(proto: OpSharding) -> HloSharding: ... + def from_proto(arg: OpSharding, /) -> HloSharding: ... @staticmethod - def from_string(sharding: str) -> HloSharding: ... + def from_string(arg: str, /) -> HloSharding: ... @staticmethod def tuple_sharding( - shape: Shape, shardings: Sequence[HloSharding] - ) -> HloSharding: ... + arg0: Shape, arg1: Sequence[HloSharding], / + ) -> HloSharding: + """Constructs a tuple sharding.""" + @staticmethod def iota_tile( dims: Sequence[int], - reshape_dims: Sequence[int], - transpose_perm: Sequence[int], - subgroup_types: Sequence[OpSharding_Type], + reshape_dims: Sequence[int] = [], + transpose_perm: Sequence[int] = [], + subgroup_types: Sequence[OpSharding_Type] = [], ) -> HloSharding: ... @staticmethod - def replicate() -> HloSharding: ... - @staticmethod def manual() -> HloSharding: ... @staticmethod + def replicate() -> HloSharding: ... + @staticmethod def unreduced() -> HloSharding: ... @staticmethod def unknown() -> HloSharding: ... @staticmethod def subgroup_with_device_ordering( - tile_assignment: np.ndarray, subgroup_types: Sequence[OpSharding_Type] + tile_assignment: Annotated[NDArray[numpy.int64], dict(order='C')], + subgroup_types: Sequence[OpSharding_Type] = [], ) -> HloSharding: ... - def __eq__(self, other: Any) -> bool: ... + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - def tile(self, shape: Shape) -> Shape: ... def is_replicated(self) -> bool: ... def is_manual(self) -> bool: ... def is_unreduced(self) -> bool: ... def is_unknown(self) -> bool: ... def is_tiled(self) -> bool: ... def is_maximal(self) -> bool: ... + def tile(self, arg: Shape, /) -> Shape: ... def tuple_elements(self) -> list[HloSharding]: ... def num_devices(self) -> int: ... def num_dimensions(self) -> int: ... def is_tile_assignment_iota(self) -> bool: ... def tile_assignment_dimensions(self) -> Sequence[int]: ... def tile_assignment_devices(self) -> Sequence[int]: ... - def subgroup_types(self) -> Sequence[OpSharding_Type]: ... def replicate_on_last_tile_dim(self) -> bool: ... + def subgroup_types(self) -> list[OpSharding_Type]: ... + def __repr__(self) -> str: ... def to_proto(self) -> OpSharding: ... def get_axis_sizes(self) -> list[int]: ... -# === END xla_compiler.cc - class Device: - id: int - host_id: int - process_index: int - platform: str - device_kind: str - client: Client - local_hardware_id: int | None - def __repr__(self) -> str: ... + """A descriptor of an available device. + + Subclasses are used to represent specific types of devices, e.g. CPUs, GPUs. + Subclasses may have additional properties specific to that device type. + """ + + @property + def id(self) -> int: + """Integer ID of this device. + + Unique across all available devices of this type, including remote devices + on multi-host platforms. + """ + + @property + def process_index(self) -> int: + """Integer index of this device's process. + + This is always 0 except on multi-process platforms. + """ + + @property + def host_id(self) -> int: + """Deprecated; please use process_index""" + + @property + def task_id(self) -> int: + """Deprecated; please use process_index""" + + @property + def platform(self) -> str: ... + @property + def device_kind(self) -> str: ... + @property + def client(self) -> Client: ... + @property + def local_hardware_id(self) -> int | None: + """Opaque hardware ID, e.g., the CUDA device number. + + In general, not guaranteed to be dense, and not guaranteed to be defined on + all platforms. + """ + def __str__(self) -> str: ... + def __repr__(self) -> str: ... def memory(self, kind: str) -> Memory: ... - def default_memory(self) -> Memory: ... - def addressable_memories(self) -> list[Memory]: ... - def live_buffers(self) -> list[Any]: ... - def memory_stats(self) -> dict[str, int] | None: ... + def default_memory(self) -> Memory: + """Returns the default memory of a device.""" + + def addressable_memories(self) -> list[Memory]: + """Returns all the memories that a device can address.""" + + def live_buffers(self) -> list: ... + def memory_stats(self) -> dict[str, int] | None: + """Returns memory statistics for this device keyed by name. + + May not be implemented on all platforms, and different platforms may return + different stats, or -1 for unavailable stats. 'bytes_in_use' is usually + available. Intended for diagnostic use. + """ + def get_stream_for_external_ready_events(self) -> int: ... - def __getattr__(self, name: str) -> Any: ... -class Memory: - process_index: int - platform: str - kind: str - def __repr__(self) -> str: ... - def __str__(self) -> str: ... - def addressable_by_devices(self) -> list[Device]: ... + __getattr__: types.MethodDescriptorType = ... -class PjRtLayout: +class Memory: + @property + def process_index(self) -> int: ... + @property + def platform(self) -> str: ... + @property + def kind(self) -> str: ... def __str__(self) -> str: ... - def __eq__(self, other: Any) -> bool: ... - def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, _: Any): ... - def _xla_layout(self) -> Layout: ... + def __repr__(self) -> str: ... + def addressable_by_devices(self) -> list[Device]: + """Returns devices that can address this memory.""" -class GpuAllocatorConfig: - class Kind(enum.IntEnum): - DEFAULT = ... - PLATFORM = ... - BFC = ... - CUDA_ASYNC = ... +class HostBufferSemantics(enum.Enum): + IMMUTABLE_ONLY_DURING_CALL = 0 - def __init__( - self, - kind: Kind = ..., - memory_fraction: float = ..., - preallocate: bool = ..., - collective_memory_size: int = ..., - ) -> None: ... + IMMUTABLE_UNTIL_TRANSFER_COMPLETES = 1 -class HostBufferSemantics(enum.IntEnum): - IMMUTABLE_ONLY_DURING_CALL = ... - IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ... - ZERO_COPY = ... + ZERO_COPY = 2 class Client: - platform: str - _raw_platform: str - platform_version: str - runtime_type: str + @property + def platform(self) -> str: ... + @property + def _raw_platform(self) -> str: ... + @property + def platform_version(self) -> str: ... + @property + def runtime_type(self) -> str: ... def device_count(self) -> int: ... def local_device_count(self) -> int: ... def devices(self) -> list[Device]: ... def local_devices(self) -> list[Device]: ... def _get_all_devices(self) -> list[Device]: ... - def device_from_local_hardware_id(self, int) -> Device: ... - def live_buffers(self) -> list[Any]: ... - def live_arrays(self) -> list[ArrayImpl]: ... + def device_from_local_hardware_id(self, arg: int, /) -> Device: ... def live_executables(self) -> list[LoadedExecutable]: ... - def host_id(self) -> int: ... + def live_arrays(self) -> list[Array]: ... + def live_buffers(self) -> list[Array]: ... def process_index(self) -> int: ... + def host_id(self) -> int: ... + def task_id(self) -> int: ... def buffer_from_pyval( self, - argument: Any, - device: Device | None = ..., - force_copy: bool = ..., - host_buffer_semantics: HostBufferSemantics = ..., - ) -> ArrayImpl: ... + argument: object, + device: Device | None = None, + force_copy: bool = False, + host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, + ) -> object: ... def compile( self, - computation: str | bytes, - executable_devices: DeviceList | Sequence[Device], + computation: object, + executable_devices: DeviceList, compile_options: CompileOptions = ..., ) -> Executable: ... + @overload def compile_and_load( self, - computation: str | bytes, - executable_devices: DeviceList | Sequence[Device], + computation: object, + executable_devices: DeviceList, compile_options: CompileOptions = ..., - host_callbacks: Sequence[Any] = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., ) -> LoadedExecutable: ... - def compile_ifrt_program( + @overload + def compile_and_load( self, - program: ifrt_programs.Program, - program_options: ifrt_programs.CompileOptions, + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Callable[..., Any]] = ..., ) -> LoadedExecutable: ... - def compile_and_load_ifrt_program( + @overload + def compile_and_load( + self, + computation: bytes, + executable_devices: Sequence, + compile_options: CompileOptions = ..., + ) -> LoadedExecutable: ... + @overload + def compile_and_load( self, - program: ifrt_programs.Program, - program_options: ifrt_programs.CompileOptions, + computation: str, + executable_devices: Sequence, + compile_options: CompileOptions = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, arg0: ifrt_programs.Program, arg1: ifrt_programs.CompileOptions, / ) -> LoadedExecutable: ... - def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... + def compile_and_load_ifrt_program( + self, arg0: ifrt_programs.Program, arg1: ifrt_programs.CompileOptions, / + ) -> LoadedExecutable: ... + def serialize_executable(self, arg: LoadedExecutable, /) -> bytes: ... + @overload def deserialize_executable( self, serialized: bytes, - executable_devices: DeviceList | Sequence[Device], - options: CompileOptions | None, - host_callbacks: Sequence[Any] = ..., + executable_devices: DeviceList, + compile_options: CompileOptions | None = None, + host_callbacks: Sequence[typing_extensions.CapsuleType] = [], ) -> LoadedExecutable: ... - def heap_profile(self) -> bytes: ... + @overload + def deserialize_executable( + self, + serialized: bytes, + executable_devices: Sequence, + compile_options: CompileOptions | None = None, + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def defragment(self) -> None: ... def make_python_callback_from_host_send_and_recv( self, callable: Callable, @@ -541,497 +943,654 @@ class Client: result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], recv_channel_ids: Sequence[int], - serializer: Callable | None = ..., - ) -> Any: ... + serializer: Callable | None = None, + ) -> object: ... def get_default_layout( - self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + self, dtype: numpy.dtype, shard_shape: Sequence, device: Device ) -> PjRtLayout: ... - def __getattr__(self, name: str) -> Any: ... + def __getattr__(self, arg: str, /) -> object: ... -class CompileOnlyPyClient(Client): - def compile( - self, - computation: str | bytes, - executable_devices: DeviceList | Sequence[Device], - compile_options: CompileOptions = ..., - ) -> Executable: ... +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY = 0 -class CpuCollectives: ... + REUSE_INPUT = 1 -def make_gloo_tcp_collectives( - distributed_client: DistributedRuntimeClient | None = ..., - hostname: str | None = ..., - interface: str | None = ..., -) -> CpuCollectives: ... + DONATE_INPUT = 2 + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + def _xla_layout(self) -> Layout: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... -class MpiCollectives(CpuCollectives): - def Init(self): ... - def Finalize(self): ... +class CpuCollectives: + def Init(self) -> None: ... + def Finalize(self) -> None: ... -def make_mpi_collectives() -> MpiCollectives: ... +def make_gloo_tcp_collectives( + distributed_client: DistributedRuntimeClient, + hostname: str | None = None, + interface: str | None = None, +) -> CpuCollectives: ... +def make_mpi_collectives() -> CpuCollectives: ... def get_tfrt_cpu_client( - asynchronous: bool = ..., - distributed_client: DistributedRuntimeClient | None = ..., - node_id: int = ..., - num_nodes: int = ..., - collectives: CpuCollectives | None = ..., - num_devices: int | None = ..., - get_local_topology_timeout_minutes: int | None = ..., - get_global_topology_timeout_minutes: int | None = ..., - transfer_server_factory: TransferServerInterfaceFactory | None = ..., -) -> Client: ... -def get_mock_gpu_client( - asynchronous: bool = ..., - allocator_config: GpuAllocatorConfig = ..., - distributed_client: DistributedRuntimeClient | None = ..., - node_id: int = ..., - allowed_devices: Any | None = ..., - platform_name: str | None = ..., + asynchronous: bool = True, + distributed_client: DistributedRuntimeClient | None = None, + node_id: int = 0, + num_nodes: int = 1, + collectives: CpuCollectives | None = None, + num_devices: int | None = None, + get_local_topology_timeout_minutes: int | None = None, + get_global_topology_timeout_minutes: int | None = None, + transfer_server_factory: TransferServerInterfaceFactory | None = None, ) -> Client: ... +def pjrt_plugin_loaded(arg: str, /) -> bool: ... +def load_pjrt_plugin( + platform_name: str, + library_path: str | None = None, + c_api: typing_extensions.CapsuleType | None = None, +) -> typing_extensions.CapsuleType: ... +def pjrt_plugin_initialized(arg: str, /) -> bool: ... +def initialize_pjrt_plugin(arg: str, /) -> None: ... def get_c_api_client( platform_name: str, - options: Mapping[str, str | int | list[int] | float | bool], - distributed_client: DistributedRuntimeClient | None = ..., - transfer_server_factory: TransferServerInterfaceFactory | None = ..., + options: Mapping[str, str | bool | int | Sequence[int] | float] = {}, + distributed_client: DistributedRuntimeClient | None = None, + transfer_server_factory: TransferServerInterfaceFactory | None = None, ) -> Client: ... def get_default_c_api_topology( - platform_name: str, - topology_name: str, - options: dict[str, str | int | list[int] | float], + arg0: str, + arg1: str, + arg2: Mapping[str, str | bool | int | Sequence[int] | float], + /, ) -> DeviceTopology: ... def get_c_api_topology( - c_api: Any, - topology_name: str, - options: dict[str, str | int | list[int] | float], + arg0: typing_extensions.CapsuleType, + arg1: str, + arg2: Mapping[str, str | bool | int | Sequence[int] | float], + /, ) -> DeviceTopology: ... -def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... -def load_pjrt_plugin( - platform_name: str, library_path: str | None, c_api: Any | None -) -> _Status: ... -def pjrt_plugin_loaded(plugin_name: str) -> bool: ... -def pjrt_plugin_initialized(plugin_name: str) -> bool: ... -def initialize_pjrt_plugin(platform_name: str) -> _Status: ... - -Array = Any -ArrayImpl = Any - -# TODO(phawkins): this type is problematic because it is not a subtype of -# jax.Array, and pytype notices. -# class ArrayImpl: -# def __init__(self, -# aval: Any, -# sharding: Any, -# arrays: Sequence[ArrayImpl], -# committed: bool, -# _skip_checks: bool = ...): ... -# def block_until_ready(self) -> ArrayImpl: ... -# def is_deleted(self) -> bool: ... -# def is_ready(self) -> bool: ... -# def delete(self): ... -# def unsafe_buffer_pointer(self) -> Any: ... -# def clone(self) -> ArrayImpl: ... -# def _copy_single_device_array_to_host_async(self): ... -# def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: ... -# def on_device_size_in_bytes(self) -> int: ... -# def _fully_replicated_shard(self) -> ArrayImpl: ... -# __cuda_array_interface__: Dict[str, Any] -# dtype: np.dtype -# shape: Tuple[int, ...] -# _arrays: Any -# _npy_value: Any -# traceback: Traceback -# _HAS_DYNAMIC_ATTRIBUTES: bool = ... +def get_topology_for_devices(arg: Sequence[Device], /) -> DeviceTopology: ... + +Array: Any + +ArrayImpl: Any def batched_copy_array_to_devices_with_sharding( - arrays: Sequence[ArrayImpl], - devices: Sequence[list[Device] | DeviceList], - sharding: Sequence[Any], - array_copy_semantics: Sequence[ArrayCopySemantics], -) -> Sequence[ArrayImpl]: ... -def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... -def batched_device_put( - aval: Any, - sharding: Any, - shards: Sequence[Any], - devices: list[Device], - committed: bool = ..., - force_copy: bool = ..., - host_buffer_semantics: Any = ..., - enable_x64: bool | None = ..., -) -> ArrayImpl: ... -def reorder_shards( - x: ArrayImpl, - dst_sharding: Any, - array_copy_semantics: ArrayCopySemantics, -) -> ArrayImpl: ... -def check_and_canonicalize_memory_kind( - memory_kind: str | None, device_list: DeviceList -) -> str | None: ... + arg0: Sequence[Array], + arg1: Sequence[DeviceList], + arg2: Sequence[object], + arg3: Sequence[ArrayCopySemantics], + /, +) -> list[Array]: ... def array_result_handler( - aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... -) -> Callable: ... + aval: object, sharding: object, committed: bool, _skip_checks: bool = False +) -> ResultHandler: ... -class Token: - def block_until_ready(self): ... +class ResultHandler: + def __call__(self, arg: Array | Sequence[Array], /) -> Array: ... -class ShardedToken: - def block_until_ready(self): ... - def get_token(self, device_id: int): ... +class DeviceList: + def __init__(self, arg: tuple[Device, ...], /) -> None: ... + def __hash__(self) -> int: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... + def __len__(self) -> int: ... + @overload + def __getitem__(self, index: int, /) -> Device: ... + @overload + def __getitem__(self, slice: slice, /) -> Sequence[Device]: ... + def __iter__(self) -> Iterator: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def process_indices(self) -> set[int]: ... + @property + def default_memory_kind(self) -> str | None: ... + @property + def memory_kinds(self) -> tuple[str, ...]: ... + @property + def device_kind(self) -> str: ... + +class Sharding: + def __init__(self) -> None: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: object, + spec: PartitionSpec, + memory_kind: object | None = None, + _logical_device_ids: object | None = None, + ) -> None: ... + @property + def mesh(self) -> object: ... + @property + def spec(self) -> PartitionSpec: ... + @property + def _memory_kind(self) -> object: ... + @property + def _logical_device_ids(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + def __eq__(self, arg: object) -> bool: ... + def __hash__(self) -> int: ... + +class SingleDeviceSharding(Sharding): + def __init__( + self, device: object, memory_kind: object | None = None + ) -> None: ... + @property + def _device(self) -> object: ... + @property + def _memory_kind(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class PmapSharding(Sharding): + def __init__( + self, devices: object, sharding_spec: pmap_lib.ShardingSpec + ) -> None: ... + @property + def devices(self) -> numpy.ndarray: ... + @property + def sharding_spec(self) -> pmap_lib.ShardingSpec: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class GSPMDSharding(Sharding): + @overload + def __init__( + self, + devices: DeviceList, + op_sharding: OpSharding, + memory_kind: object | None = None, + ) -> None: ... + @overload + def __init__( + self, + devices: DeviceList, + op_sharding: HloSharding, + memory_kind: object | None = None, + ) -> None: ... + @overload + def __init__( + self, + devices: Sequence[Device], + op_sharding: OpSharding, + memory_kind: object | None = None, + ) -> None: ... + @overload + def __init__( + self, + devices: Sequence[Device], + op_sharding: HloSharding, + memory_kind: object | None = None, + ) -> None: ... + @property + def _devices(self) -> DeviceList: ... + @property + def _hlo_sharding(self) -> HloSharding: ... + @property + def _memory_kind(self) -> object: ... + @property + def _internal_device_list(self) -> DeviceList: ... + +class CompiledMemoryStats: + @property + def generated_code_size_in_bytes(self) -> int: ... + @generated_code_size_in_bytes.setter + def generated_code_size_in_bytes(self, arg: int, /) -> None: ... + @property + def argument_size_in_bytes(self) -> int: ... + @argument_size_in_bytes.setter + def argument_size_in_bytes(self, arg: int, /) -> None: ... + @property + def output_size_in_bytes(self) -> int: ... + @output_size_in_bytes.setter + def output_size_in_bytes(self, arg: int, /) -> None: ... + @property + def alias_size_in_bytes(self) -> int: ... + @alias_size_in_bytes.setter + def alias_size_in_bytes(self, arg: int, /) -> None: ... + @property + def temp_size_in_bytes(self) -> int: ... + @temp_size_in_bytes.setter + def temp_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_generated_code_size_in_bytes(self) -> int: ... + @host_generated_code_size_in_bytes.setter + def host_generated_code_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_argument_size_in_bytes(self) -> int: ... + @host_argument_size_in_bytes.setter + def host_argument_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_output_size_in_bytes(self) -> int: ... + @host_output_size_in_bytes.setter + def host_output_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_alias_size_in_bytes(self) -> int: ... + @host_alias_size_in_bytes.setter + def host_alias_size_in_bytes(self, arg: int, /) -> None: ... + @property + def host_temp_size_in_bytes(self) -> int: ... + @host_temp_size_in_bytes.setter + def host_temp_size_in_bytes(self, arg: int, /) -> None: ... + @property + def serialized_buffer_assignment_proto(self) -> bytes: ... + @property + def peak_memory_in_bytes(self) -> int: ... + @peak_memory_in_bytes.setter + def peak_memory_in_bytes(self, arg: int, /) -> None: ... + def __str__(self) -> str: ... class ExecuteResults: def __len__(self) -> int: ... - def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]: ... + def disassemble_into_single_device_arrays(self) -> list[list[Array]]: ... def disassemble_prefix_into_single_device_arrays( - self, n: int - ) -> list[list[ArrayImpl]]: ... - def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]: ... + self, arg: int, / + ) -> list[list[Array]]: ... + def consume_with_handlers( + self, arg: Sequence[ResultHandler | object], / + ) -> list[object]: ... def consume_token(self) -> ShardedToken: ... def get_execution_stream_id() -> int: ... - -def set_execution_stream_id(new_id: int): ... +def set_execution_stream_id(arg: int, /) -> None: ... class LoadedExecutable: - client: Client + @property + def client(self) -> Client: ... def local_devices(self) -> list[Device]: ... def get_hlo_text(self) -> str: ... def size_of_generated_code_in_bytes(self) -> int: ... - def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ... - def execute_with_token( - self, arguments: Sequence[ArrayImpl] - ) -> tuple[list[ArrayImpl], Token]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def execute_sharded( - self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ... + self, + arguments: Sequence[Array | Sequence[Array]], + with_tokens: bool = False, ) -> ExecuteResults: ... def hlo_modules(self) -> list[HloModule]: ... def get_output_memory_kinds(self) -> list[list[str]]: ... - def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[PjRtLayout]: ... + def get_output_layouts(self) -> list[PjRtLayout]: ... def get_parameter_shardings(self) -> list[OpSharding] | None: ... - def get_parameter_layouts(self) -> list[Layout]: ... - def get_output_layouts(self) -> list[Layout]: ... - def keep_alive(self) -> None: ... - def cost_analysis(self) -> dict[str, Any]: ... - traceback: Traceback - fingerprint: bytes | None + def keep_alive(self, arg: object, /) -> None: ... + def cost_analysis( + self, + ) -> dict[str, str | bool | int | list[int] | float]: ... + @property + def traceback(self) -> Traceback | None: ... + @property + def fingerprint(self) -> object: ... -class Executable: - def hlo_modules(self) -> list[HloModule]: ... - def get_output_memory_kinds(self) -> list[list[str]]: ... - def get_output_shardings(self) -> list[OpSharding] | None: ... - def get_parameter_shardings(self) -> list[OpSharding] | None: ... - def get_parameter_layouts(self) -> list[Layout]: ... - def get_output_layouts(self) -> list[Layout]: ... - def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... - def serialize(self) -> str: ... - def cost_analysis(self) -> dict[str, Any]: ... +class Token: + def block_until_ready(self) -> None: ... -class DeviceTopology: - platform: str - platform_version: str - def _make_compile_only_devices(self) -> list[Device]: ... - def serialize(self) -> bytes: ... - def __getattr__(self, name: str) -> Any: ... +class ShardedToken: + def block_until_ready(self) -> None: ... + def get_token(self, arg: int, /) -> Token: ... def buffer_to_dlpack_managed_tensor( - buffer: ArrayImpl, stream: int | None = None -) -> Any: ... - + buffer: object, stream: int | None = None +) -> typing_extensions.CapsuleType: ... def dlpack_managed_tensor_to_buffer( - tensor: Any, device: Device, stream: int | None + dlpack: typing_extensions.CapsuleType, device: Device, stream: int | None ) -> ArrayImpl: ... - def cuda_array_interface_to_buffer( - cai: dict[ - str, - ( - str - | int - | None - | tuple[int, ...] - | tuple[int, bool] - | list[tuple[str, str]] - | list[tuple[str, str, tuple[int, ...]]] - ), - ], - gpu_backend: Client | None = ..., - device_id: int | None = None, -) -> ArrayImpl: ... + cai: dict, gpu_backend: Client | None = None, device_id: int | None = None +) -> object: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = 4096) -> None: ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self) -> None: ... + @staticmethod + def clear_all() -> None: ... + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + +class PjitFunction: + def __repr__(self, /): + """Return repr(self).""" + + def __call__(self, /, *args, **kwargs): + """Call self as a function.""" + + def __get__(self, instance, owner=None, /): + """Return an attribute of instance, which is of type owner.""" + __vectorcalloffset__: types.MemberDescriptorType = ... + + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + @property + def __signature__(self) -> inspect.Signature: ... + @property + def _cache_miss(self) -> Callable: ... + def _cache_size(self) -> int: ... + def _clear_cache(self) -> None: ... -# === BEGIN py_traceback.cc +def pjit( + function_name: str, + fun: Callable[..., Any] | None, + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: _PyTreeRegistry, + shard_arg_fallback: Callable[..., Any], + cache: PjitFunctionCache | None = ..., +) -> PjitFunction: ... class Frame: - file_name: str - function_name: str - function_line_start: int - line_num: int - def __init__( - self, - file_name: str, - function_name: str, - function_line_start: int, - line_num: int, - ): ... + def __init__(self, arg0: str, arg1: str, arg2: int, arg3: int, /) -> None: ... + @property + def file_name(self) -> str: ... + @property + def function_name(self) -> str: ... + @property + def function_start_line(self) -> int: ... + @property + def line_num(self) -> int: ... def __repr__(self) -> str: ... class Traceback: - enabled: ClassVar[bool] - @staticmethod - def get_traceback() -> Traceback: ... + def __hash__(self, /): + """Return hash(self).""" + + def __str__(self, /): + """Return str(self).""" + + def __lt__(self, value, /): + """Return selfvalue.""" + + def __ge__(self, value, /): + """Return self>=value.""" + @staticmethod - def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... - frames: Sequence[Frame] - def __str__(self) -> str: ... - def as_python_traceback(self) -> Any: ... + def get_traceback() -> Traceback | None: + """Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. + """ + + @property + def frames(self) -> list[Frame]: ... def raw_frames(self) -> tuple[list[types.CodeType], list[int]]: ... + def as_python_traceback(self) -> traceback.TracebackType: ... @staticmethod - def code_addr2line(code: types.CodeType, lasti: int) -> int: ... + def traceback_from_frames(frames: list[Frame]) -> traceback.TracebackType: + """Creates a traceback from a list of frames.""" + + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: + """Python wrapper around the Python C API function PyCode_Addr2Line""" + @staticmethod def code_addr2location( code: types.CodeType, lasti: int - ) -> tuple[int, int, int, int]: ... + ) -> tuple[int, int, int, int]: + """Python wrapper around the Python C API function PyCode_Addr2Location""" def tracebacks_enabled() -> bool: ... -def set_tracebacks_enabled(enabled: bool) -> None: ... +def set_tracebacks_enabled(arg: bool, /) -> None: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: object, + partition: object, + infer_sharding_from_operands: object, + can_side_effecting_have_replicated_sharding: bool = False, + c_api: typing_extensions.CapsuleType | None = None, +) -> None: + """Registers a partitioner for a custom-call operation. + + Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. Takes result + sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns a + generated HLO and sharding specs. The spmd lowerer first reshards to match + the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension + """ + +def encode_inspect_sharding_callback(arg: object, /) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, c_api: typing_extensions.CapsuleType | None = None +) -> None: + """Registers a custom call as batch partitionable. + + If a custom call is "batch partitionable", it means that it can be trivially + partitioned on some number of (leading) dimensions, with the same call being + executed independently on each shard of data. If the data are sharded on + non-batch dimensions, partitioning will re-shard the data to be replicated on + the non-batch dimensions. + + Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. + """ + +class TransferConnection: + def _testonly_inject_failure(self) -> None: ... + def _pull_flat( + self, arg0: int, arg1: Client, arg2: Sequence[object], / + ) -> list[Array]: ... + def _pull_into_flat( + self, arg0: int, arg1: Sequence[Array], arg2: Sequence[slice], / + ) -> list[Token]: ... + +class TransferServer: + def address(self) -> str: ... + def _await_pull_flat(self, arg0: int, arg1: Sequence[Array], /) -> None: ... + def _reset_rendevous_table(self) -> None: ... + def connect(self, arg: str, /) -> TransferConnection: ... + +def _make_error_array(arg0: Client, arg1: object, arg2: str, /) -> Array: ... +def start_transfer_server( + client: Client, + address: str = '[::]:0', + transport_addresses: Sequence[str] = [], + max_num_parallel_copies: int = 8, + transfer_size: int = 268435456, + supports_pinned_allocator: bool = False, + use_raw_buffers: bool = False, +) -> TransferServer: ... +def make_transfer_server_interface_factory( + transfer_size: int = 268435456, + cross_host_transfer_timeout_seconds: int = 60, + distributed_client: DistributedRuntimeClient | None = None, + socket_address: str = '[::]:0', + transport_addresses: Sequence[str] = [], +) -> TransferServerInterfaceFactory: ... + +class PreemptionSyncManager: + def initialize( + self, distributed_client: DistributedRuntimeClient + ) -> None: ... + def reached_sync_point(self, arg: int, /) -> bool: ... + def shutdown(self) -> None: ... -# === END py_traceback.cc +def create_preemption_sync_manager() -> PreemptionSyncManager: ... class DistributedRuntimeService: def shutdown(self) -> None: ... class DistributedRuntimeClient: - def connect(self) -> _Status: ... - def shutdown(self) -> _Status: ... - def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... + def connect(self) -> None: ... + def shutdown(self) -> None: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> str: ... def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int - ) -> _Status: ... - def key_value_try_get(self, key: str) -> _Status: ... - def key_value_try_get_bytes(self, key: str) -> _Status: ... - def key_value_increment(self, key: str, increment: int) -> _Status: ... - def key_value_dir_get(self, key: str) -> _Status: ... - def key_value_dir_get_bytes(self, key: str) -> _Status: ... - def key_value_set( - self, key: str, value: str, allow_overwrite: bool = False - ) -> _Status: ... - def key_value_set_bytes( - self, key: str, value: bytes, allow_overwrite: bool = False - ) -> _Status: ... - def key_value_delete(self, key: str) -> _Status: ... + ) -> bytes: ... + def key_value_try_get(self, key: str) -> str: ... + def key_value_try_get_bytes(self, key: str) -> bytes: ... + def key_value_increment(self, key: str, increment: int) -> int: ... def wait_at_barrier( self, barrier_id: str, timeout_in_ms: int, - process_ids: list[int] | None = None, - ) -> _Status: ... - def get_live_nodes(self, process_ids: list[int]) -> _Status: ... + process_ids: Sequence[int] | None = None, + ) -> None: ... + def get_live_nodes(self, process_ids: Sequence[int]) -> list[int]: ... + def key_value_set( + self, key: str, value: str, allow_overwrite: bool = False + ) -> None: ... + def key_value_set_bytes( + self, key: str, value: bytes, allow_overwrite: bool = False + ) -> None: ... + def key_value_dir_get(self, key: str) -> list[tuple[str, str]]: ... + def key_value_dir_get_bytes(self, key: str) -> list[tuple[str, bytes]]: ... + def key_value_delete(self, key: str) -> None: ... def get_distributed_runtime_service( address: str, num_nodes: int, - heartbeat_timeout: int | None = ..., - cluster_register_timeout: int | None = ..., - shutdown_timeout: int | None = ..., + heartbeat_timeout: int | None = None, + cluster_register_timeout: int | None = None, + shutdown_timeout: int | None = None, ) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, node_id: int, - rpc_timeout: int | None = ..., - init_timeout: int | None = ..., - shutdown_timeout: int | None = ..., - heartbeat_timeout: int | None = ..., - missed_heartbeat_callback: Any | None = ..., - shutdown_on_destruction: bool | None = ..., - use_compression: bool | None = ..., - recoverable: bool | None = ..., + rpc_timeout: int | None = None, + init_timeout: int | None = None, + shutdown_timeout: int | None = None, + heartbeat_timeout: int | None = None, + missed_heartbeat_callback: Callable | None = None, + shutdown_on_destruction: bool | None = None, + use_compression: bool | None = None, + recoverable: bool | None = None, ) -> DistributedRuntimeClient: ... - -class PreemptionSyncManager: - def initialize(self, client: DistributedRuntimeClient) -> _Status: ... - def reached_sync_point(self, step_counter: int) -> bool: ... - def shutdown(self) -> None: ... - -def create_preemption_sync_manager() -> PreemptionSyncManager: ... def collect_garbage() -> None: ... def is_optimized_build() -> bool: ... -def json_to_pprof_profile(json: str) -> bytes: ... -def pprof_profile_to_json(proto: bytes) -> str: ... - -class PmapFunction: - def __call__(self, *args, **kwargs) -> Any: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, Any): ... - __signature__: inspect.Signature - def _cache_size(self) -> int: ... - def _cache_clear(self) -> None: ... - -class DeviceList: - def __init__(self, device_assignment: tuple[Device, ...]): ... - def __hash__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... - def __len__(self) -> int: ... - def __getitem__(self, index: Any) -> Any: ... - def __iter__(self) -> Iterator[Device]: ... - def __str__(self) -> str: ... - def __repr__(self) -> str: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - @property - def is_fully_addressable(self) -> bool: ... - @property - def addressable_device_list(self) -> DeviceList: ... - @property - def process_indices(self) -> set[int]: ... - @property - def default_memory_kind(self) -> str | None: ... - @property - def memory_kinds(self) -> tuple[str, ...]: ... - @property - def device_kind(self) -> str: ... +def json_to_pprof_profile(arg: str, /) -> bytes: + """Encodes the JSON representation of a pprof Profile into its binary protocol buffer encoding.""" -class Sharding: ... +def pprof_profile_to_json(arg: bytes, /) -> str: + """Decodes an uncompressed pprof Profile protocol buffer into a JSON representation""" -class NamedSharding(Sharding): - def __init__( - self, - mesh: Any, - spec: Any, - *, - memory_kind: str | None = None, - _logical_device_ids: tuple[int, ...] | None = None, - ): ... - mesh: Any - spec: Any - _memory_kind: str | None - _internal_device_list: DeviceList - _logical_device_ids: tuple[int, ...] | None - -class SingleDeviceSharding(Sharding): - def __init__(self, device: Device, *, memory_kind: str | None = None): ... - _device: Device - _memory_kind: str | None - _internal_device_list: DeviceList - -class PmapSharding(Sharding): - def __init__( - self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec - ): ... - devices: list[Any] - sharding_spec: pmap_lib.ShardingSpec - _internal_device_list: DeviceList - -class GSPMDSharding(Sharding): - def __init__( +class CompileOnlyPyClient(Client): + def compile( self, - devices: Sequence[Device], - op_sharding: OpSharding | HloSharding, - *, - memory_kind: str | None = None, - _device_list: DeviceList | None = None, - ): ... - _devices: tuple[Device, ...] - _hlo_sharding: HloSharding - _memory_kind: str | None - _internal_device_list: DeviceList - -class PjitFunction: - def __call__(self, *args, **kwargs) -> Any: ... - -class PjitFunctionCache: - def __init__(self, capacity: int = ...): ... - def __getstate__(self) -> Any: ... - def __setstate__(self, Any): ... - def size(self) -> int: ... - def capacity(self) -> int: ... - def clear(self): ... - @staticmethod - def clear_all(): ... - -def pjit( - function_name: str, - fun: Callable | None, - cache_miss: Callable, - static_argnums: Sequence[int], - static_argnames: Sequence[str], - global_cache_key: Any, - pytree_registry: pytree.PyTreeRegistry, - shard_arg_fallback: Callable, - cache: PjitFunctionCache | None = ..., -) -> PjitFunction: ... + computation: object, + executable_devices: DeviceList, + compile_options: CompileOptions = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., + ) -> Executable: ... -class WeakrefLRUCacheInfo: - @property - def hits(self) -> int: ... - @property - def misses(self) -> int: ... +class DeviceTopology: + def _make_compile_only_devices(self) -> list[Device]: ... @property - def maxsize(self) -> int: ... + def platform(self) -> str: ... @property - def currsize(self) -> int: ... + def platform_version(self) -> str: ... + def serialize(self) -> bytes: ... + def __getattr__(self, arg: str, /) -> object: ... -class WeakrefLRUCache: - def __call__(self, weakref_key: Any, *args, **kwargs) -> Any: ... - def cache_keys(self) -> list[Any]: ... - def cache_info(self) -> WeakrefLRUCacheInfo: ... - def cache_clear(self): ... +class TransferServerInterfaceFactory: + pass + +class Executable: + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[PjRtLayout]: ... + def get_output_layouts(self) -> list[PjRtLayout]: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> bytes: ... + def cost_analysis( + self, + ) -> dict[str, str | bool | int | list[int] | float]: ... def is_asan() -> bool: ... def is_msan() -> bool: ... def is_tsan() -> bool: ... def is_sanitized() -> bool: ... +def batched_device_put( + aval: object, + sharding: object, + xs: Sequence[object], + devices: Sequence[Device], + committed: bool = True, + force_copy: bool = False, + host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, + enable_x64: bool | None = None, +) -> object: ... +def reorder_shards( + x: Array, dst_sharding: object, array_copy_semantics: ArrayCopySemantics +) -> Array: ... +def batched_block_until_ready(arg: Sequence[object], /) -> None: ... +def check_and_canonicalize_memory_kind( + memory_kind: object | None, device_list: DeviceList +) -> object: ... -class TransferConnection: - def address(self) -> str: ... - def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... - -class TransferServer: - def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... - def connect(self, address: str) -> TransferConnection: ... - -def start_transfer_server( - client: Client, - address: str = "", - transport_addresses: list[str] = [], - max_num_parallel_copies: int = 0, - transfer_size: int = 0, -) -> TransferServer: ... - -class TransferServerInterfaceFactory: ... - -def make_transfer_server_interface_factory( - transfer_size: int = 0, - cross_host_transfer_timeout_seconds: int = 0, - distributed_client: DistributedRuntimeClient | None = None, - socket_address: str = "", - transport_addresses: list[str] = [], -) -> TransferServerInterfaceFactory: ... +ifrt_version_number: int = 33 def approx_top_k_reduction_output_size( input_size: int, rank: int, top_k: int, recall_target: float, - aggregate_to_topk: bool | None = ..., - input_size_override: int | None = ..., + aggregate_to_topk: bool = True, + input_size_override: int = -1, ) -> tuple[int, int]: ... def get_internal_device_put_info() -> dict[str, int]: ... class UnconstrainedSingleton: def __repr__(self) -> str: ... - def __reduce__(self) -> Any: ... + def __reduce__(self) -> str: ... -UNCONSTRAINED_PARTITION: UnconstrainedSingleton +UNCONSTRAINED_PARTITION: UnconstrainedSingleton = ... -class PartitionSpec: - def __init__(self, *partitions, unreduced: Set[Any] | None = None): ... - def __hash__(self): ... - def __eq__(self, other): ... - _HAS_DYNAMIC_ATTRIBUTES: bool = ... +def canonicalize_partition(arg: object, /) -> object: ... -def canonicalize_partition(partition: Any) -> Any: ... +class PartitionSpec(Any): + def __init__( + self, *partitions, unreduced: object = ..., reduced: object = ... + ) -> None: ... + @property + def _partitions(self) -> tuple: ... + @property + def unreduced(self) -> frozenset: ... + @property + def reduced(self) -> frozenset: ... + def __eq__(self, arg: object) -> bool: ... + def __hash__(self) -> int: ... -def set_typed_int_type(t: type) -> None: ... -def set_typed_float_type(t: type) -> None: ... -def set_typed_complex_type(t: type) -> None: ... -def set_typed_ndarray_type(t: type) -> None: ... +def set_typed_int_type(arg: object, /) -> None: ... +def set_typed_float_type(arg: object, /) -> None: ... +def set_typed_complex_type(arg: object, /) -> None: ... +def set_typed_ndarray_type(arg: object, /) -> None: ... diff --git a/jaxlib/_jax/config.pyi b/jaxlib/_jax/config.pyi index 72bfd1996278..85d3ba4c6bdb 100644 --- a/jaxlib/_jax/config.pyi +++ b/jaxlib/_jax/config.pyi @@ -1,23 +1,22 @@ -# Copyright 2024 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== from typing import Any, Generic, TypeVar -unset: object +unset: object = ... -_T = TypeVar('_T') +_T = TypeVar("_T") class Config(Generic[_T]): def __init__( @@ -25,18 +24,17 @@ class Config(Generic[_T]): name: str, value: _T, *, - include_in_jit_key: bool = False, - include_in_trace_context: bool = False - ): ... - @property - def name(self) -> str: ... - + include_in_jit_key: bool = ..., + include_in_trace_context: bool = ..., + ) -> None: ... @property def value(self) -> _T: ... + @property + def name(self) -> str: ... def get_local(self) -> Any: ... def get_global(self) -> _T: ... - def set_local(self, value: Any) -> None: ... - def swap_local(self, value: Any) -> Any: ... - def set_global(self, value: _T) -> None: ... + def set_local(self, value: Any | None) -> None: ... + def swap_local(self, value: Any | None) -> Any: ... + def set_global(self, value: Any | None) -> None: ... -def trace_context() -> Any: ... +def trace_context() -> tuple: ... diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi index b92575e77c96..0536647333ce 100644 --- a/jaxlib/_jax/ffi.pyi +++ b/jaxlib/_jax/ffi.pyi @@ -1,47 +1,55 @@ -# Copyright 2025 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== import enum -from typing import Any +import numpy +import typing_extensions class Buffer: @property - def dtype(self) -> Any: ... + def dtype(self) -> numpy.dtype: ... @property def ndim(self) -> int: ... @property - def shape(self) -> tuple[int, ...]: ... + def shape(self) -> tuple: ... @property def writeable(self) -> bool: ... - def __array__(self, dtype: Any = None, copy: bool | None = None) -> Any: ... - def __cuda_array_interface__(self) -> Any: ... + def __array__( + self, dtype: object | None = None, copy: object | None = None + ) -> numpy.ndarray: ... + @property + def __cuda_array_interface__(self) -> dict: ... def __dlpack__( self, - stream: Any = None, - max_version: Any = None, - dl_device: Any = None, - copy: Any = None, - ) -> Any: ... - def __dlpack_device__(self) -> tuple[int, int]: ... + stream: object | None = None, + max_version: object | None = None, + dl_device: object | None = None, + copy: object | None = None, + ) -> typing_extensions.CapsuleType: ... + def __dlpack_device__(self) -> tuple: ... + +class ExecutionStage(enum.Enum): + INSTANTIATE = 0 + + PREPARE = 1 -class ExecutionStage(enum.IntEnum): - INSTANTIATE = ... - PREPARE = ... - INITIALIZE = ... - EXECUTE = ... + INITIALIZE = 2 + + EXECUTE = 3 class ExecutionContext: + @property def stage(self) -> ExecutionStage: ... + @property def stream(self) -> int: ... diff --git a/jaxlib/_jax/guard_lib.pyi b/jaxlib/_jax/guard_lib.pyi index 7f8896a4f75a..59ced73369ca 100644 --- a/jaxlib/_jax/guard_lib.pyi +++ b/jaxlib/_jax/guard_lib.pyi @@ -1,46 +1,64 @@ -# Copyright 2024 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from typing import Any +import enum -class TransferGuardLevel: - ALLOW: Any - LOG: Any - DISALLOW: Any - LOG_EXPLICIT: Any - DISALLOW_EXPLICIT: Any +class TransferGuardLevel(enum.Enum): + ALLOW = 0 -class GarbageCollectionGuardLevel: - ALLOW: Any - LOG: Any - FATAL: Any + LOG = 1 -class GuardState: - host_to_device: TransferGuardLevel | None - device_to_device: TransferGuardLevel | None - device_to_host: TransferGuardLevel | None + DISALLOW = 2 + + LOG_EXPLICIT = 3 + + DISALLOW_EXPLICIT = 4 + +class GarbageCollectionGuardLevel(enum.Enum): + ALLOW = 0 - explicit_device_put: bool - explicit_device_get: bool + LOG = 1 - garbage_collect_array: GarbageCollectionGuardLevel | None + FATAL = 2 + +class GuardState: + @property + def host_to_device(self) -> TransferGuardLevel | None: ... + @host_to_device.setter + def host_to_device(self, arg: TransferGuardLevel | None) -> None: ... + @property + def device_to_device(self) -> TransferGuardLevel | None: ... + @device_to_device.setter + def device_to_device(self, arg: TransferGuardLevel | None) -> None: ... + @property + def device_to_host(self) -> TransferGuardLevel | None: ... + @device_to_host.setter + def device_to_host(self, arg: TransferGuardLevel | None) -> None: ... + @property + def explicit_device_put(self) -> bool: ... + @explicit_device_put.setter + def explicit_device_put(self, arg: bool, /) -> None: ... + @property + def explicit_device_get(self) -> bool: ... + @explicit_device_get.setter + def explicit_device_get(self, arg: bool, /) -> None: ... + @property + def garbage_collect_array(self) -> GarbageCollectionGuardLevel | None: ... + @garbage_collect_array.setter + def garbage_collect_array( + self, arg: GarbageCollectionGuardLevel | None + ) -> None: ... def global_state() -> GuardState: ... def thread_local_state() -> GuardState: ... - -class _TestingScopedLogSink: - def __enter__(self) -> _TestingScopedLogSink: ... - def __exit__(self, *args, **kwargs) -> None: ... - def logs(self) -> list[str]: ... diff --git a/jaxlib/_jax/hlo_sharding_util.pyi b/jaxlib/_jax/hlo_sharding_util.pyi new file mode 100644 index 000000000000..bad23a1a736b --- /dev/null +++ b/jaxlib/_jax/hlo_sharding_util.pyi @@ -0,0 +1,21 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +from .import HloSharding as _HloSharding + +def PartiallyReplicateTiledShardingOnDims( + sharding: _HloSharding, dims: Sequence[int], / +) -> _HloSharding: ... diff --git a/jaxlib/_jax/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi index 5e426b070c21..ee74bc4bf877 100644 --- a/jaxlib/_jax/ifrt_programs.pyi +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -1,45 +1,52 @@ -# Copyright 2024 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from typing import Any from collections.abc import Sequence +from typing import Any, overload -from jaxlib import _jax +from .import ( + CompileOptions as _CompileOptions, + DeviceList as _DeviceList, + Device as _Device, +) +import typing_extensions -class Program: ... +class Program: + pass -class CompileOptions: ... - -def make_hlo_program(mlir_module: str | bytes) -> Program: ... +class CompileOptions: + pass +@overload +def make_hlo_program(mlir_module: str) -> Program: ... +@overload +def make_hlo_program(mlir_module: bytes) -> Program: ... def make_colocated_python_program( - name : str, + name: str, picked_function: bytes, - devices: Sequence[_jax.Device] | _jax.DeviceList, + devices: Sequence[_Device] | _DeviceList, input_avals: Sequence[Any], output_avals: Sequence[Any], ) -> Program: ... - -def make_plugin_program(data: str | bytes) -> Program: ... - -def make_colocated_python_compile_options() -> CompileOptions: ... - +@overload +def make_plugin_program(data: str) -> Program: ... +@overload +def make_plugin_program(data: bytes) -> Program: ... def make_xla_compile_options( - compile_options: _jax.CompileOptions, - executable_devices: _jax.DeviceList, - host_callbacks: Sequence[Any] + options: _CompileOptions, + executable_devices: Sequence[_Device], + host_callbacks: Sequence[typing_extensions.CapsuleType], ) -> CompileOptions: ... - +def make_colocated_python_compile_options() -> CompileOptions: ... def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/_jax/jax_jit.pyi b/jaxlib/_jax/jax_jit.pyi index c542fba7792b..07e02d4731c8 100644 --- a/jaxlib/_jax/jax_jit.pyi +++ b/jaxlib/_jax/jax_jit.pyi @@ -1,62 +1,74 @@ -# Copyright 2021 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== from collections.abc import Callable, Sequence -from typing import Any +from .config import Config as _Config +from .pytree import ( + PyTreeDef as _PyTreeDef, + PyTreeRegistry as _PyTreeRegistry, +) +import numpy -from jaxlib import _jax -import numpy as np - -from . import pytree - -Client = _jax.Client -Device = _jax.Device - -def set_disable_jit_state(config: _jax.config.Config) -> None: ... -def set_enable_x64_state(config: _jax.config.Config) -> None: ... -def set_post_hook_state(config: _jax.config.Config) -> None: ... +def set_disable_jit_state(config: _Config) -> None: ... +def set_enable_x64_state(config: _Config) -> None: ... +def set_post_hook_state(config: _Config) -> None: ... def set_thread_local_state_initialization_callback( - function: Callable[[], None], -): ... + f: Callable[[], None], +) -> None: ... -class ArgSignature: - dtype: np.dtype - shape: tuple[int, ...] - weak_type: bool +class PyArgSignature: + @property + def dtype(self) -> numpy.dtype: ... + @property + def shape(self) -> tuple[int, ...]: ... + @property + def weak_type(self) -> bool: ... -def _ArgSignatureOfValue( - __arg: Any, __jax_enable_x64: bool -) -> ArgSignature: ... +def _ArgSignatureOfValue(arg0: object, arg1: bool, /) -> PyArgSignature: ... class ArgumentSignature: - static_args: Sequence[Any] - static_arg_names: Sequence[str] - dynamic_arg_names: Sequence[str] - dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] - - def __eq__(self, value, /): ... - def __ne__(self, value, /): ... - def __hash__(self, /): ... - def __str__(self): ... - def __repr__(self): ... + @property + def static_args(self) -> list[object]: ... + @property + def static_arg_names(self) -> list[str]: ... + @property + def dynamic_arg_names(self) -> list[str]: ... + @property + def dynamic_arg_treedefs(self) -> Sequence[_PyTreeDef]: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def __hash__(self) -> int: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... def parse_arguments( - positional_args: Sequence[Any], - keyword_args: Sequence[Any], + positional_args: Sequence[object], + keyword_args: Sequence[object], kwnames: tuple[str, ...], static_argnums: Sequence[int], static_argnames: Sequence[str], - pytree_registry: pytree.PyTreeRegistry, -) -> tuple[ArgumentSignature, Sequence[Any]]: ... + pytree_registry: _PyTreeRegistry, +) -> tuple[ArgumentSignature, list[object]]: + """Parses the arguments to a function as jax.jit would. + + Returns a ArgumentSignature and the flattened dynamic arguments. + + Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. + """ diff --git a/jaxlib/_jax/mlir.pyi b/jaxlib/_jax/mlir.pyi index 594ed67587e0..6730a9773160 100644 --- a/jaxlib/_jax/mlir.pyi +++ b/jaxlib/_jax/mlir.pyi @@ -1,34 +1,54 @@ -# Copyright 2021 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from . import XlaComputation +from typing import overload + +from .import XlaComputation as _XlaComputation def hlo_to_stablehlo(computation: bytes) -> bytes: ... -def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def xla_computation_to_mlir_module(computation: _XlaComputation) -> str: ... +@overload +def mlir_module_to_xla_computation( + mlir_module: bytes, use_tuple_args: bool = ..., return_tuple: bool = ... +) -> _XlaComputation: ... +@overload def mlir_module_to_xla_computation( - mlir_module: bytes | str, - use_tuple_args: bool = ..., - return_tuple: bool = ..., -) -> XlaComputation: ... -def mhlo_to_stablehlo(mlir_module: bytes | str) -> bytes: ... -def stablehlo_to_mhlo(mlir_module: bytes | str) -> bytes: ... -def serialize_portable_artifact(mlir_module: bytes | str, target: str, use_mixed_serialization: bool = True) -> bytes: ... + mlir_module: str, use_tuple_args: bool = ..., return_tuple: bool = ... +) -> _XlaComputation: ... +@overload +def mhlo_to_stablehlo(mlir_module: bytes) -> bytes: ... +@overload +def mhlo_to_stablehlo(mlir_module: str) -> bytes: ... +@overload +def serialize_portable_artifact( + mlir_module: bytes, target: str, use_mixed_serialization: bool = False +) -> bytes: ... +@overload +def serialize_portable_artifact( + mlir_module: str, target: str, use_mixed_serialization: bool = False +) -> bytes: ... def deserialize_portable_artifact(mlir_module: bytes) -> str: ... def refine_polymorphic_shapes( - mlir_module: bytes | str, - enable_shape_assertions: bool = ..., - validate_static_shapes: bool = ..., - enable_shardy: bool = ..., -) -> bytes: ... + mlir_module: bytes, + enable_shape_assertions: bool = True, + validate_static_shapes: bool = True, + enable_shardy: bool = False, +) -> bytes: + """Refines the dynamic shapes for a module. + + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + """ diff --git a/jaxlib/_jax/pmap_lib.pyi b/jaxlib/_jax/pmap_lib.pyi index 3e26e7e1da84..679e43f1cddc 100644 --- a/jaxlib/_jax/pmap_lib.pyi +++ b/jaxlib/_jax/pmap_lib.pyi @@ -1,84 +1,102 @@ -# Copyright 2021 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== +from collections.abc import Callable, Iterable, Sequence import inspect +import types from typing import Any -from collections.abc import Callable, Sequence, Iterable -from . import pytree - -_AvalDimSharding = Any -_MeshDimAssignment = Any +from .pytree import PyTreeRegistry as _PyTreeRegistry class NoSharding: def __init__(self) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... + def __eq__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... class Chunked: + def __init__(self, arg: Sequence[int], /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... @property - def chunks(self) -> Sequence[int]: ... - def __init__(self, __chunks: Sequence[int]) -> None: ... + def chunks(self) -> list[int]: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... + def __eq__(self, arg: object, /) -> bool: ... class Unstacked: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... @property def size(self) -> int: ... - def __init__(self, __sz: int) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... + def __eq__(self, arg: object, /) -> bool: ... class ShardedAxis: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... @property def axis(self) -> int: ... - def __init__(self, __axis: int) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... + def __eq__(self, arg: object, /) -> bool: ... class Replicated: + def __init__(self, arg: int, /) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... @property def replicas(self) -> int: ... - def __init__(self, __replicas: int) -> None: ... def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... + def __eq__(self, arg: object, /) -> bool: ... -class ShardingSpec: - def __init__(self, - sharding: Iterable[_AvalDimSharding], - mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... +class ShardingSpec(Any): + def __init__(self, sharding: Iterable, mesh_mapping: Iterable) -> None: ... + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... @property - def sharding(self) -> tuple[_AvalDimSharding, ...]: ... + def sharding(self) -> tuple[NoSharding | Chunked | Unstacked, ...]: ... @property - def mesh_mapping(self) -> tuple[_MeshDimAssignment]: ... - def __eq__(self, __other: Any) -> bool: ... + def mesh_mapping(self) -> tuple[ShardedAxis | Replicated, ...]: ... + def __eq__(self, arg: object, /) -> bool: ... def __hash__(self) -> int: ... - _HAS_DYNAMIC_ATTRIBUTES = True - class PmapFunction: - def __call__(self, *args, **kwargs) -> Any: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, Any): ... - __signature__: inspect.Signature + def __call__(self, /, *args, **kwargs): + """Call self as a function.""" + + def __get__(self, instance, owner=None, /): + """Return an attribute of instance, which is of type owner.""" + __vectorcalloffset__: types.MemberDescriptorType = ... + + @property + def __signature__(self) -> inspect.Signature: ... + @property + def _cache_miss(self) -> Callable: ... + def __getstate__(self) -> dict: ... + def __setstate__(self, arg: dict, /) -> None: ... + @property def _cache_size(self) -> int: ... def _cache_clear(self) -> None: ... def _debug_cache_keys(self) -> str: ... -def pmap(fun: Callable[..., Any], - cache_miss: Callable[..., Any], - static_argnums: Sequence[int], - shard_arg_fallback: Callable[..., Any], - pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ... +def pmap( + fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: _PyTreeRegistry, +) -> PmapFunction: ... diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi index 9517a3ba6dcd..1cd5e2fe5515 100644 --- a/jaxlib/_jax/pytree.pyi +++ b/jaxlib/_jax/pytree.pyi @@ -1,149 +1,184 @@ -# Copyright 2021 The JAX Authors +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -from builtins import tuple as Tuple -from typing import Any, TypeVar from collections.abc import Callable, Hashable, Iterable, Sequence +from typing import Any, TypeVar + +version: int = 3 _T = TypeVar("_T") -version: int +_Children = TypeVar("_Children", bound=Iterable[Any]) + +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=tuple[Any, Any]) + +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[tuple[Any, Any]]) + +_KeyPath = TypeVar("_KeyPath", bound=tuple[Any, ...]) + +_AuxData = TypeVar("_AuxData", bound=Hashable) class PyTreeRegistry: def __init__( self, - *, - enable_none: bool = ..., - enable_tuple: bool = ..., - enable_namedtuple: bool = ..., - enable_list: bool = ..., - enable_dict: bool = ... - ): ... + enable_none: bool = True, + enable_tuple: bool = True, + enable_namedtuple: bool = True, + enable_list: bool = True, + enable_dict: bool = True, + ) -> None: ... def flatten( self, - tree: Any, - leaf_predicate: Callable[[Any], bool] | None = ..., - ) -> Tuple[list[Any], PyTreeDef]: ... + tree: object | None, + leaf_predicate: Callable[[Any], bool] | None = None, + ) -> tuple[list[Any], PyTreeDef]: ... def flatten_one_level( - self, tree: Any - ) -> Tuple[Iterable[Any], Any] | None: ... + self, tree: object | None + ) -> tuple[Iterable[Any], Any] | None: ... def flatten_one_level_with_keys( - self, tree: Any - ) -> Tuple[Iterable[_KeyLeafPair], Any] | None: ... + self, tree: object | None + ) -> tuple[Iterable[_KeyLeafPair], Any] | None: ... def flatten_with_path( self, - tree: Any, - leaf_predicate: Callable[[Any, Any], bool] | None = ..., - ) -> Tuple[list[Tuple[_KeyPath, Any]], PyTreeDef]: ... + tree: object | None, + leaf_predicate: Callable[[Any, Any], bool] | None = None, + ) -> tuple[list[tuple[_KeyPath, Any]], PyTreeDef]: ... def register_node( self, - __type: type[_T], - to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], + type: type[_T], + to_iterable: Callable[[_T], tuple[_Children, _AuxData]], from_iterable: Callable[[_AuxData, _Children], _T], to_iterable_with_keys: ( - Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None - ) = ..., + Callable[[_T], tuple[_KeyLeafPairs, _AuxData]] | None + ) = None, ) -> Any: ... def register_dataclass_node( - self, __type: type[_T], meta_fields: list[str], data_fields: list[str] + self, + type: type, + data_fields: Sequence[str], + meta_fields: Sequence[str], + /, ) -> Any: ... + def __reduce__(self) -> str: ... + +_default_registry: PyTreeRegistry = ... def default_registry() -> PyTreeRegistry: ... -def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ... -def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... +def treedef_tuple( + registry: PyTreeRegistry, arg0: Sequence[PyTreeDef], / +) -> PyTreeDef: ... +def all_leaves(arg0: PyTreeRegistry, arg1: Iterable, /) -> bool: ... + +class PyTreeDef: + def unflatten(self, arg: Iterable[Any], /) -> Any: ... + def flatten_up_to(self, tree: object | None) -> list: ... + def compose(self, arg: PyTreeDef, /) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Callable[[_T], Any] | None, + leaves: Iterable[Any], + /, + ) -> Any: + """Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf at leaves""" + + def from_iterable_tree(self, arg: object, /) -> object: ... + def children(self) -> list[PyTreeDef]: ... + @property + def num_leaves(self) -> int: ... + @property + def num_nodes(self) -> int: ... + def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... + def __ne__(self, arg: object, /) -> bool: ... + def __hash__(self) -> int: ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + def node_data(self) -> tuple[type, Any] | None: + """Returns None if a leaf-pytree, else (type, node_data)""" + + @staticmethod + def from_node_data_and_children( + self, + registry: PyTreeRegistry, + node_data: tuple[type, Any] | None, + children: Iterable[PyTreeDef], + ) -> PyTreeDef: + """Reconstructs a pytree from `node_data()` and `children()`.""" + + def __getstate__(self) -> object: ... + def __setstate__(self, arg: object, /) -> None: ... class SequenceKey(Hashable): - idx: int - __match_args__: Tuple = ... - def __init__(self, idx: int): ... + def __init__(self, idx: int) -> None: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - def __eq__(self, __other: Any) -> bool: ... + @property + def idx(self) -> int: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... class DictKey(Hashable): - key: Hashable - __match_args__: Tuple = ... - def __init__(self, key: Hashable): ... + def __init__(self, key: object) -> None: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - def __eq__(self, __other: Any) -> bool: ... + @property + def key(self) -> object: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... class GetAttrKey(Hashable): - name: str - __match_args__: Tuple = ... - def __init__(self, name: str): ... + def __init__(self, name: str) -> None: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - def __eq__(self, __other: Any) -> bool: ... + @property + def name(self) -> str: ... + + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" + + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... class FlattenedIndexKey(Hashable): - key: int - __match_args__: Tuple = ... - def __init__(self, key: int): ... + def __init__(self, key: int) -> None: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... + def __eq__(self, arg: object, /) -> bool: ... def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - def __eq__(self, __other: Any) -> bool: ... + @property + def key(self) -> int: ... -class PyTreeDef: - def unflatten(self, __leaves: Iterable[Any]) -> Any: ... - def flatten_up_to(self, __xs: Any) -> list[Any]: ... - def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... - def walk( - self, - __f_node: Callable[[Any, Any], Any], - __f_leaf: Callable[[_T], Any] | None, - leaves: Iterable[Any], - ) -> Any: ... - def from_iterable_tree(self, __xs: Any): ... - def node_data(self) -> Tuple[type, Any] | None: ... - def children(self) -> list[PyTreeDef]: ... - @staticmethod - def from_node_data_and_children( - registry: PyTreeRegistry, - node_data: Tuple[type, Any] | None, - children: Iterable[PyTreeDef], - ) -> PyTreeDef: ... + __match_args__: tuple = ... + """(arg: object, /) -> tuple""" - num_leaves: int - num_nodes: int - def __repr__(self) -> str: ... - def __eq__(self, __other: Any) -> bool: ... - def __ne__(self, __other: Any) -> bool: ... - def __hash__(self) -> int: ... - def __getstate__(self) -> Any: ... - def __setstate__(self, state: Any): ... - def serialize_using_proto(self) -> bytes: ... - @staticmethod - def deserialize_using_proto( - registry: PyTreeRegistry, data: bytes - ) -> PyTreeDef: ... - -_Children = TypeVar("_Children", bound=Iterable[Any]) -_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) -_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[Tuple[Any, Any]]) -_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) -_AuxData = TypeVar("_AuxData", bound=Hashable) + def __getstate__(self) -> tuple: ... + def __setstate__(self, arg: tuple, /) -> None: ... diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc index 875b93466357..b0a2d7788cb4 100644 --- a/jaxlib/pytree.cc +++ b/jaxlib/pytree.cc @@ -1779,8 +1779,6 @@ void BuildPytreeSubmodule(nb::module_& m) { ") -> PyTreeDef" // clang-format on )); - // TODO(slebedev): Remove once we migrate JAX to use the new name. - pytree.attr("tuple") = pytree.attr("treedef_tuple"); pytree.def("all_leaves", &PyTreeDef::AllLeaves); nb::class_ treedef(pytree, "PyTreeDef",