Skip to content

Commit

Permalink
Add minimal remat policy for flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
michelle-yooh committed Apr 3, 2024
1 parent 5575702 commit 012e0a4
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def __call__(self,
offload_src="device", offload_dst="pinned_host")
elif cfg.remat_policy == 'minimal_offloaded':
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host")
elif cfg.remat_policy == 'minimal_flash':
policy = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
jax.checkpoint_policies.save_only_these_names('context',),
)
else:
assert (
cfg.remat_policy == 'full'
Expand Down

0 comments on commit 012e0a4

Please sign in to comment.