From 012e0a442231d265cd986e88b76356da94968549 Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Wed, 3 Apr 2024 05:58:35 +0000 Subject: [PATCH] Add minimal remat policy for flash attention --- MaxText/layers/models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index d15f76444..86af3b0ca 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -234,6 +234,11 @@ def __call__(self, offload_src="device", offload_dst="pinned_host") elif cfg.remat_policy == 'minimal_offloaded': policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") + elif cfg.remat_policy == 'minimal_flash': + policy = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names('context',), + ) else: assert ( cfg.remat_policy == 'full'