Skip to content

Commit

Permalink
Use int.bit_length() in next_power_of_2.
Browse files Browse the repository at this point in the history
- Added docstring explaining its behaviour.
- Check for negative inputs.

See https://docs.python.org/3/library/stdtypes.html#int.bit_length.

PiperOrigin-RevId: 615721128
  • Loading branch information
chr1sj0nes authored and jax authors committed Mar 14, 2024
1 parent 4a35c12 commit c5bee4b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Pallas utility functions."""

from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
Expand Down Expand Up @@ -44,7 +45,10 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:


def next_power_of_2(x: int) -> int:
return 2**x.bit_length()
"""Returns the next power of two greater than or equal to `x`."""
if x < 0:
raise ValueError("`next_power_of_2` requires a non-negative integer.")
return 1 if (x == 0) else (2 ** (x - 1).bit_length())


def pattern_match_scan_to_fori_loop(
Expand Down

0 comments on commit c5bee4b

Please sign in to comment.