Skip to content

Commit

Permalink
Reverts 168f30a
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615715466
  • Loading branch information
chr1sj0nes authored and jax authors committed Mar 14, 2024
1 parent 4a35c12 commit 2fd80e3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Pallas utility functions."""
import math
from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
Expand Down Expand Up @@ -44,7 +45,9 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:


def next_power_of_2(x: int) -> int:
return 2**x.bit_length()
if x == 0:
return 1
return int(2 ** math.ceil(math.log2(x)))


def pattern_match_scan_to_fori_loop(
Expand Down

0 comments on commit 2fd80e3

Please sign in to comment.