Skip to content

Commit

Permalink
Use int.bit_length() in next_power_of_2.
Browse files Browse the repository at this point in the history
- Added docstring explaining its behaviour.
- Check for negative inputs.

See https://docs.python.org/3/library/stdtypes.html#int.bit_length.

PiperOrigin-RevId: 615731376
  • Loading branch information
chr1sj0nes authored and jax authors committed Mar 14, 2024
1 parent 694df0d commit 64bd95d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Pallas utility functions."""
import math

from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
Expand Down Expand Up @@ -45,9 +45,10 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:


def next_power_of_2(x: int) -> int:
if x == 0:
return 1
return int(2 ** math.ceil(math.log2(x)))
"""Returns the next power of two greater than or equal to `x`."""
if x < 0:
raise ValueError("`next_power_of_2` requires a non-negative integer.")
return 1 if x == 0 else 2 ** (x - 1).bit_length()


def pattern_match_scan_to_fori_loop(
Expand Down

0 comments on commit 64bd95d

Please sign in to comment.