diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index df77ca3317ef..72b61af668b7 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -123,6 +123,14 @@ def save_and_offload_only_these_names( offload_src, offload_dst): names_which_can_be_saved = set(names_which_can_be_saved) names_which_can_be_offloaded = set(names_which_can_be_offloaded) + intersection = names_which_can_be_saved.intersection(names_which_can_be_offloaded) + if intersection: + raise ValueError( + "The names should be exclusive and should not intersect in" + " `names_which_can_be_saved` and `names_which_can_be_offloaded`. Got" + f" names_which_can_be_saved={names_which_can_be_saved}," + f" names_which_can_be_offloaded={names_which_can_be_offloaded} and the" + f" intersection={intersection}") def policy(prim, *_, **params): if prim is name_p and params['name'] in names_which_can_be_saved: return pe.Saveable diff --git a/tests/memories_test.py b/tests/memories_test.py index d3784ab59209..5c0cff5039d2 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1124,6 +1124,12 @@ def test_remat_scan_jaxpr_offloadable(self): s = NamedSharding(mesh, P("x")) inp = jax.device_put(np_inp, s) + with self.assertRaisesRegex( + ValueError, "The names should be exclusive and should not intersect"): + jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=["y"], names_which_can_be_offloaded=["y", "w"], + offload_src="device", offload_dst="pinned_host") + policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"], offload_src='device', offload_dst='pinned_host')