Skip to content

Commit

Permalink
Merge pull request #48 from abhinavgoel95:moe_checkpoint_policies
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608838552
  • Loading branch information
pax authors committed Feb 21, 2024
2 parents 8f72ea7 + 9ff9c95 commit c096897
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion praxis/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,12 @@ pytype_strict_library(
deps = [
":activations",
":attentions",
":checkpoint_policy",
":embedding_softmax",
":normalizations",
":transformer_models",
":transformers",
# Implicit fiddle dependency.
"//praxis:base_layer",
"//praxis:pax_fiddle",
],
)
Expand Down
5 changes: 4 additions & 1 deletion praxis/layers/glam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"""Helper function to config GLaM models."""

import fiddle as fdl
from praxis import base_layer
from praxis import pax_fiddle
from praxis.layers import activations
from praxis.layers import attentions
from praxis.layers import checkpoint_policy
from praxis.layers import embedding_softmax
from praxis.layers import normalizations
from praxis.layers import transformer_models
Expand Down Expand Up @@ -211,6 +211,7 @@ def GlamUniTransformerLmHParams(
num_pipeline_stages=1,
num_pipeline_microbatches=1,
model_type=LanguageModelType.CAUSAL,
checkpoint_policy=checkpoint_policy.AutodiffCheckpointType.SAVE_NOTHING,
) -> pax_fiddle.Config[transformer_models.TransformerLm]:
"""Common setup for GLaM Decoder-only Transformer Model.
Expand Down Expand Up @@ -263,6 +264,7 @@ def GlamUniTransformerLmHParams(
num_pipeline_microbatches: Number of pipeline microbatches.
model_type: Type of the Language Model. Either `CAUSAL`, `PREFIX`, or
`BIDIRECTIONAL`.
checkpoint_policy: Activation remat policy when pipelining is disabled.
Returns:
A Params object to set up a StackedTransformer.
Expand Down Expand Up @@ -326,6 +328,7 @@ def GlamUniTransformerLmHParams(
unroll_in_decode=True,
block=glam_p,
x_times=num_blocks,
checkpoint_policy=checkpoint_policy,
)
else:
assert num_blocks % num_pipeline_stages == 0
Expand Down

0 comments on commit c096897

Please sign in to comment.