Skip to content

Commit

Permalink
Cache the _check_sharding check in device_put. If aval and sharding…
Browse files Browse the repository at this point in the history
… are the same, no need to check multiple times

PiperOrigin-RevId: 626244240
  • Loading branch information
yashk2810 authored and jax authors committed Apr 19, 2024
1 parent 8fec8a6 commit 837f0bb
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import collections
from collections.abc import Generator, Hashable, Iterable, Sequence
from functools import partial
from functools import partial, lru_cache
import inspect
import math
import typing
Expand Down Expand Up @@ -2451,9 +2451,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:

# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(x, s):
@lru_cache(maxsize=2048)
def _check_sharding(aval, s):
if isinstance(s, Sharding):
aval = shaped_abstractify(x)
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
Expand Down Expand Up @@ -2494,7 +2494,7 @@ def device_put(
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
for leaf in tree_leaves(x):
_check_sharding(leaf, s=device)
_check_sharding(shaped_abstractify(leaf), s=device)
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
Expand All @@ -2503,7 +2503,7 @@ def device_put(
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
for x_leaf, device_leaf in zip(x_flat, device_flat):
_check_sharding(x_leaf, device_leaf)
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
out_flat = [
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
for xf, d, s in zip(x_flat, device_flat, src_flat)
Expand Down

0 comments on commit 837f0bb

Please sign in to comment.