Skip to content

Commit

Permalink
Raise an error if the names intersect in `save_and_offload_only_these…
Browse files Browse the repository at this point in the history
…_names` policy

PiperOrigin-RevId: 611666221
  • Loading branch information
yashk2810 authored and jax authors committed Mar 1, 2024
1 parent 32bb3b0 commit 48e6e0d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def save_and_offload_only_these_names(
offload_src, offload_dst):
names_which_can_be_saved = set(names_which_can_be_saved)
names_which_can_be_offloaded = set(names_which_can_be_offloaded)
intersection = names_which_can_be_saved.intersection(names_which_can_be_offloaded)
if intersection:
raise ValueError(
"The names should be exclusive and should not intersect in"
" `names_which_can_be_saved` and `names_which_can_be_offloaded`. Got"
f" names_which_can_be_saved={names_which_can_be_saved},"
f" names_which_can_be_offloaded={names_which_can_be_offloaded} and the"
f" intersection={intersection}")
def policy(prim, *_, **params):
if prim is name_p and params['name'] in names_which_can_be_saved:
return pe.Saveable
Expand Down
6 changes: 6 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,12 @@ def test_remat_scan_jaxpr_offloadable(self):
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)

with self.assertRaisesRegex(
ValueError, "The names should be exclusive and should not intersect"):
jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["y", "w"],
offload_src="device", offload_dst="pinned_host")

policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
offload_src='device', offload_dst='pinned_host')
Expand Down

0 comments on commit 48e6e0d

Please sign in to comment.