From 9ff9c952b802776ade3d65a8ff7937894b88bbe0 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Tue, 30 Jan 2024 11:35:44 -0800 Subject: [PATCH] Add support for checkpoint policies in MoE models --- praxis/layers/glam.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/praxis/layers/glam.py b/praxis/layers/glam.py index 220a330a..023d8442 100644 --- a/praxis/layers/glam.py +++ b/praxis/layers/glam.py @@ -24,6 +24,7 @@ from praxis.layers import normalizations from praxis.layers import transformer_models from praxis.layers import transformers +from praxis.layers import AutodiffCheckpointType LanguageModelType = transformer_models.LanguageModelType @@ -211,6 +212,7 @@ def GlamUniTransformerLmHParams( num_pipeline_stages=1, num_pipeline_microbatches=1, model_type=LanguageModelType.CAUSAL, + checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, ) -> pax_fiddle.Config[transformer_models.TransformerLm]: """Common setup for GLaM Decoder-only Transformer Model. @@ -263,6 +265,7 @@ def GlamUniTransformerLmHParams( num_pipeline_microbatches: Number of pipeline microbatches. model_type: Type of the Language Model. Either `CAUSAL`, `PREFIX`, or `BIDIRECTIONAL`. + checkpoint_policy: Select activation rematerialization policy Returns: A Params object to set up a StackedTransformer. @@ -326,6 +329,7 @@ def GlamUniTransformerLmHParams( unroll_in_decode=True, block=glam_p, x_times=num_blocks, + checkpoint_policy=checkpoint_policy, ) else: assert num_blocks % num_pipeline_stages == 0