diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 1443bef80d09..3229288d3ad1 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Pallas utility functions.""" +import math from jax import lax from jax._src import core as jax_core from jax._src.util import split_list @@ -44,7 +45,9 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: def next_power_of_2(x: int) -> int: - return 2**x.bit_length() + if x == 0: + return 1 + return int(2 ** math.ceil(math.log2(x))) def pattern_match_scan_to_fori_loop(