From 0fee320451738166c8e596dc63a57a4673671576 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Sat, 13 Jan 2024 00:59:17 -0800 Subject: [PATCH] [MLPerf GPT3] Model Migration --- MaxText/configs/base.yml | 9 +- MaxText/configs/models/gpt3-175b.yml | 34 +++ MaxText/configs/models/gpt3-52k.yml | 34 +++ MaxText/configs/models/llama2-7b.yml | 2 +- MaxText/configs/models/mistral-7b.yml | 2 +- MaxText/convert_gpt3_ckpt_from_paxml.py | 230 +++++++++++++++ MaxText/generate_param_only_checkpoint.py | 4 +- MaxText/layers/attentions.py | 46 ++- MaxText/layers/gamma.py | 1 + MaxText/layers/gpt3.py | 337 ++++++++++++++++++++++ MaxText/layers/initializers.py | 2 + MaxText/layers/linears.py | 48 ++- MaxText/layers/llama2.py | 4 +- MaxText/layers/models.py | 43 ++- MaxText/llama_or_mistral_ckpt.py | 2 +- MaxText/max_utils.py | 4 +- MaxText/maxtext_utils.py | 12 - MaxText/optimizers.py | 144 +++++++++ MaxText/pyconfig.py | 7 +- MaxText/tests/gpt3_test.py | 110 +++++++ MaxText/train.py | 5 +- MaxText/train_compile.py | 3 +- end_to_end/test_gpt3.sh | 16 + 23 files changed, 1051 insertions(+), 48 deletions(-) create mode 100644 MaxText/configs/models/gpt3-175b.yml create mode 100644 MaxText/configs/models/gpt3-52k.yml create mode 100644 MaxText/convert_gpt3_ckpt_from_paxml.py create mode 100644 MaxText/layers/gpt3.py create mode 100644 MaxText/optimizers.py create mode 100644 MaxText/tests/gpt3_test.py create mode 100644 end_to_end/test_gpt3.sh diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 5e91281e7..57f893226 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -17,7 +17,7 @@ run_name: "" model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! -rms_norm_epsilon: 1.e-06 +normalization_layer_epsilon: 1.e-06 ################################## CHECKPOINTING ################################## # Checkpointing makes the following choices in the following order, starting with (1): @@ -71,6 +71,9 @@ head_dim: 256 mlp_activations: ["relu"] dropout_rate: 0 logits_via_embedding: True # NOTE: this is True just for testing. +normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true +logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability + # proj, minimal, full, or none remat_policy: 'full' scan_layers: True @@ -185,6 +188,7 @@ gradient_clipping_threshold: 1.0 # AdamW optimizer parameters # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 +opt_type: "adamw" # one of "adam_pax" or "adamw" adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients. adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. @@ -199,7 +203,8 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds # Use iota operator in Embed use_iota_embed: False # use positional embedding -use_positional_embedding: False +use_untrainable_positional_embedding: False +trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size # Ahead of time Compilation (aka AOT) # Only set these arguments if you are running train_compile or loading a compiled train step. diff --git a/MaxText/configs/models/gpt3-175b.yml b/MaxText/configs/models/gpt3-175b.yml new file mode 100644 index 000000000..4b55b6533 --- /dev/null +++ b/MaxText/configs/models/gpt3-175b.yml @@ -0,0 +1,34 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for gpt3-175b + +base_emb_dim: 12288 +base_num_query_heads: 96 +base_num_kv_heads: 96 +base_mlp_dim: 49152 +base_num_decoder_layers: 96 +head_dim: 128 +trainable_position_size: 16384 +mlp_activations: ["gelu"] +vocab_size: 50304 +enable_dropout: False +logits_via_embedding: True +normalize_embedding_logits: False +logits_dot_in_fp32: False +normalization_layer_epsilon: 1.e-05 +use_iota_embed: True +fused_qkv: True +opt_type: "adam_pax" +decoder_block: "gpt3" diff --git a/MaxText/configs/models/gpt3-52k.yml b/MaxText/configs/models/gpt3-52k.yml new file mode 100644 index 000000000..d507ffc31 --- /dev/null +++ b/MaxText/configs/models/gpt3-52k.yml @@ -0,0 +1,34 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for gpt3-52k, i.e. a fake and small model for testing purpose only + +base_emb_dim: 16 +base_num_query_heads: 2 +base_num_kv_heads: 2 +base_mlp_dim: 64 +base_num_decoder_layers: 1 +head_dim: 8 +trainable_position_size: 2048 +mlp_activations: ["gelu"] +vocab_size: 1024 +enable_dropout: False +logits_via_embedding: True +normalize_embedding_logits: False +logits_dot_in_fp32: False +normalization_layer_epsilon: 1.e-05 +use_iota_embed: True +fused_qkv: True +opt_type: "adam_pax" +decoder_block: "gpt3" diff --git a/MaxText/configs/models/llama2-7b.yml b/MaxText/configs/models/llama2-7b.yml index 8f985a542..751cffa07 100644 --- a/MaxText/configs/models/llama2-7b.yml +++ b/MaxText/configs/models/llama2-7b.yml @@ -25,5 +25,5 @@ vocab_size: 32000 enable_dropout: False vocab_relative_path: "tokenizer.llama2" logits_via_embedding: False -rms_norm_epsilon: 1.0e-5 +normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" \ No newline at end of file diff --git a/MaxText/configs/models/mistral-7b.yml b/MaxText/configs/models/mistral-7b.yml index a60d28f9a..b1c9a8ed3 100644 --- a/MaxText/configs/models/mistral-7b.yml +++ b/MaxText/configs/models/mistral-7b.yml @@ -25,5 +25,5 @@ vocab_size: 32000 enable_dropout: False vocab_relative_path: "tokenizer.model" logits_via_embedding: False -rms_norm_epsilon: 1.0e-5 +normalization_layer_epsilon: 1.0e-5 decoder_block: "mistral" diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py new file mode 100644 index 000000000..9c3c63d35 --- /dev/null +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -0,0 +1,230 @@ +""" + Copyright 2023 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +# pylint: disable=line-too-long +"""Convert weights from a paxml gpt3 model to a MaxText one. + +Test cmd for gpt3-52k: +python MaxText/convert_gpt3_ckpt_from_paxml.py \ + --paxml-ckpt-path=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state \ + --maxtext-model-name=gpt3-52k \ + --run-name=$RUN_NAME \ + --base-output-directory=$BASE_OUTPUT_DIR + +True cmd for gpt3-175b: + +The script is memory demanding, requires at least 250 GiB in cpu and cumulative TPU memory of all devices should be + above ~4.2 TiB (175 billion param * 4 byte/param * 3 (model var and 2 opt momentums) * 2 copies in converting) + +python MaxText/convert_gpt3_ckpt_from_paxml.py \ + --paxml-ckpt-path=gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 \ + --maxtext-model-name=gpt3-175b \ + --run-name=$RUN_NAME \ + --base-output-directory=$BASE_OUTPUT_DIR +""" +import max_utils +import optimizers +import pyconfig +import os +from jax import random +from jax.sharding import Mesh +from layers.models import Transformer +import checkpointing + +import numpy as np +import tensorstore as ts + +import sys +import jax +import gc +import max_logging +from psutil import Process +import argparse + +def fmt_size(num_bytes: int) -> str: + assert num_bytes > 0 + for unit in ["B", "KiB", "MiB", "GiB"]: + if num_bytes < 1024.0: + break + num_bytes /= 1024.0 + return f"{num_bytes:.2f} {unit}" + +def check_memory(): + """print out cpu/tpu memory.""" + cpu_bytes = Process().memory_info().rss + max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}") + for d in jax.local_devices(): + stats = d.memory_stats() + used = stats['bytes_in_use'] + limit = stats['bytes_limit'] + max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}") + + +def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name): + """convert ckpt.""" + + base_args = [ + '', 'MaxText/configs/base.yml', # base arg + 'per_device_batch_size=1', + 'ici_fsdp_parallelism=-1', 'ici_tensor_parallelism=1', + f'model_name={maxtext_model_name}', + f'run_name={run_name}', f'base_output_directory={base_output_directory}', + 'checkpoint_period=1', + 'async_checkpointing=false', + ] + pyconfig.initialize(base_args) + cfg = pyconfig.config + init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2) + devices_array = max_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + + model = Transformer(config=cfg, mesh=mesh) + learning_rate_schedule = max_utils.create_learning_rate_schedule(cfg) + tx = optimizers.get_optimizer(cfg, learning_rate_schedule) + + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + cfg.checkpoint_dir, + cfg.enable_checkpointing, + cfg.async_checkpointing, + cfg.checkpoint_period, + ) + + state, _ = max_utils.setup_training_state(model, tx, cfg, init_rng, mesh, checkpoint_manager) + max_logging.log("start") + check_memory() + + # maxtext keystr: (paxml keystr, transform_fn) + keystr_map = { + "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), + "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), + "['decoder']['layers']['pre_self_attention_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['pre_self_attention_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['query']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['query']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['key']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['key']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['value']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['value']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis])), + "['decoder']['layers']['self_attention']['qkv_proj']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['self_attention']['out']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1])), + "['decoder']['layers']['self_attention']['out']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['wi']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['wi']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['wo']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['layers']['mlp']['wo']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), + "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), + "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), + } + + state_map = { + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + } + + def get_layer_prefix(keystr_pax): + # different path format between decoder_layer variable + if "x_layers_0" in keystr_pax: + # string format for all variables in scanned decoder layer + prefix_pax_opt_state = f"p#{cfg.base_num_decoder_layers}#i-1_2" + else: + prefix_pax_opt_state = "no_prefix_2" + return prefix_pax_opt_state + + for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items(): + # model variable + state_map[f".params{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) + prefix_pax_opt_state = get_layer_prefix(keystr_pax) + # first momentum in optimizer state + state_map[f".opt_state.mu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn) + # second momentum in optimizer state + state_map[f".opt_state.nu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn) + + def verify_fn(key_path, _): + keystr = jax.tree_util.keystr(key_path) + assert keystr in state_map, f"{keystr} not found" + + jax.tree_util.tree_map_with_path(verify_fn, state) + + memory_metrics = {'max_cpu_bytes': 0} + + bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://"):].split('/', 1) + + def map_fn(key_path, value): + key_path_str = jax.tree_util.keystr(key_path) + file_path, transform_fn = state_map[key_path_str] + full_path = os.path.join(paxml_ckpt_prefix, file_path) + spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} + spec['kvstore'] = { + 'bucket': bucket_name, + 'driver': 'gcs', + 'path': full_path, + } + + arr = ts.open(ts.Spec(spec), open=True).result().read().result() + if transform_fn is not None: + arr = transform_fn(arr) + + assert value.shape == arr.shape, f"{key_path}, {value.shape}, {arr.shape}" + shape = value.shape + sharding = value.sharding + result = jax.make_array_from_single_device_arrays( + shape, + sharding, + [jax.device_put(np.array(arr[index]), d) + for d, index in sharding.addressable_devices_indices_map(shape).items()], + ) + + # log peak cpu memory + cpu_bytes = Process().memory_info().rss + memory_metrics["max_cpu_bytes"] = max(cpu_bytes, memory_metrics["max_cpu_bytes"]) + + # collect cpu memory back asap + arr = None + del arr + gc.collect() + max_logging.log(f"{key_path_str} finished") + check_memory() + return result + + converted_state = jax.tree_util.tree_map_with_path(map_fn, state) + max_logging.log("converted state finished") + check_memory() + + if checkpoint_manager.save(converted_state.step, converted_state): + max_logging.log(f"saved a checkpoint at step {converted_state.step}") + # Upon preemption, exit when and only when all ongoing saves are complete. + if checkpoint_manager.reached_preemption(converted_state.step): + checkpoint_manager.wait_until_finished() + sys.exit() + + max_logging.log(f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}") + max_logging.log("checkpoint converted and saved successfully.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--paxml-ckpt-path', + type=str, + default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", + required=True) + parser.add_argument('--maxtext-model-name', choices=['gpt3-175b', 'gpt3-52k'], type=str, required=True) + parser.add_argument('--base-output-directory', type=str, required=True) + parser.add_argument('--run-name', type=str, required=True) + + args = parser.parse_args() + if not args.paxml_ckpt_path.startswith("gs://"): + raise ValueError("--paxml-ckpt-path should be a gcs path starting with gs://") + + convert(args.paxml_ckpt_path, args.maxtext_model_name, args.base_output_directory, args.run_name) diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index 8d1f08768..0e7798c28 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -29,7 +29,7 @@ import jax import max_logging import max_utils -import maxtext_utils +import optimizers import pyconfig from absl import app @@ -77,7 +77,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): model = Transformer(config, mesh) rng = random.PRNGKey(0) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = maxtext_utils.get_optimizer(config, learning_rate_schedule) + tx = optimizers.get_optimizer(config, learning_rate_schedule) state, state_mesh_notations = max_utils.setup_training_state( model, tx, config, rng, mesh, checkpoint_manager ) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 5b381c7f0..f3ab5f40d 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -58,6 +58,30 @@ dynamic_vector_slice_in_dim = jax.vmap( lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) +def apply_mask_to_logits(logits: Array, mask: Array): + """Applies a floating-point mask to a set of logits. + + The mask is represented as a tensor with some dtype where 0 represents true and values + below a large negative number (here set to + get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask + leaves the logits alone in the true case and replaces them by + get_large_negative_number(logits.dtype) in the false case. Previously, this was + done by adding the logits to the mask; however, this leads to a bad fusion + decision in the compiler that saves the values in memory rather than + just the predicate. This implementation avoids that problem. + + from https://github.com/google/praxis/blob/4712a6b9ee13e224b86e235ff55f7c6bab9fbab3/praxis/py_utils.py#L706 + + Args: + logits: A JTensor of logit values. + mask: A JTensor of mask values with the encoding described in the + function documentation. + + Returns: + Masked logits. + """ + return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) + def _maybe_aqt_einsum(int8_training, aqt_rng): """Maybe overwrite dot general with aqt_dot_general.""" if not int8_training: @@ -73,6 +97,7 @@ class AttentionOp(nn.Module): use_int8: bool num_query_heads: int num_kv_heads: int + float32_qk_product: bool = False float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) dtype: DType = jnp.float32 @@ -128,7 +153,7 @@ def generate_attention_mask( output_mask = None return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None - + def apply_attention(self, query: Array, key: Array, @@ -269,17 +294,20 @@ def apply_attention_dot( """Apply Attention.""" aqt_rng = self.make_rng('aqt') - # Casting logits and softmax computation for float32 for model stability. - if self.float32_logits: + # Casting qk_product and softmaxt computation for float32 for model stability. + if self.float32_qk_product: query = query.astype(jnp.float32) key = key.astype(jnp.float32) # QK Product, a.k.a `attn_weights`: [batch, num_kv_heads, num_query_heads_per_kv_head, q_length, kv_length] attn_weights = self.qk_product(query, key) + # Casting softmaxt computation for float32 for model stability. + if self.float32_logits: + attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) if attn_mask is not None: - attn_weights += attn_mask + attn_weights = apply_mask_to_logits(attn_weights, attn_mask) # Normalize the attention weights across `kv_length` dimension. attn_weights = jax.nn.softmax(attn_weights).astype(self.dtype) @@ -538,7 +566,9 @@ class Attention(nn.Module): dtype: the dtype of the computation. dropout_rate: dropout rate kernel_init: initializer for the kernel of the Dense layers. - float32_logits: bool, if True then compute logits in float32 to avoid + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid numerical issues with bfloat16. use_int8: bool, if true accelerate in int8 """ @@ -553,7 +583,8 @@ class Attention(nn.Module): dtype: DType = jnp.float32 dropout_rate: float = 0. kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') - float32_logits: bool = False # computes logits in float32 for stability. + float32_qk_product: bool = False # computes logits in float32 for stability. + float32_logits: bool = False # cast logits in float32 for stability. use_int8: bool = False @@ -617,7 +648,7 @@ def qkv_projection(self, inputs: Array, proj_name: str): features=(3, self.num_query_heads, self.head_dim), axis = -1, kernel_init=self.kernel_init, - kernel_axes=('embed','qkv', 'heads', 'kv'), + kernel_axes=('embed', 'qkv', 'heads', 'kv'), dtype=self.dtype, name=proj_name, use_int8=self.use_int8)(inputs) @@ -698,6 +729,7 @@ def __call__(self, attention_op = AttentionOp(mesh=self.mesh, attention_kernel=self.attention_kernel, max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, float32_logits=self.float32_logits, use_int8=self.use_int8, num_query_heads=self.num_query_heads, diff --git a/MaxText/layers/gamma.py b/MaxText/layers/gamma.py index 6010fd337..2ac556db3 100644 --- a/MaxText/layers/gamma.py +++ b/MaxText/layers/gamma.py @@ -84,6 +84,7 @@ def __call__(self, dtype=cfg.dtype, dropout_rate=cfg.dropout_rate, name='self_attention', + float32_qk_product = True, float32_logits = True, use_int8=cfg.int8_training) diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py new file mode 100644 index 000000000..a0061ddfc --- /dev/null +++ b/MaxText/layers/gpt3.py @@ -0,0 +1,337 @@ +""" + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +"""Transformer model definition.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + + +from typing import Any, Tuple + +from jax.sharding import Mesh +from jax import lax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name + +from flax import linen as nn + +from layers import attentions +from layers import initializers +from layers import linears +from layers import models + +AttentionOp = attentions.AttentionOp + +import common_types + +Array = common_types.Array +Config = common_types.Config +DType = common_types.DType +Mesh = common_types.Mesh +AxisNames = common_types.AxisNames +BATCH = common_types.BATCH +LENGTH = common_types.LENGTH +HEAD = common_types.HEAD +D_KV = common_types.D_KV + +DenseGeneral = linears.DenseGeneral +NdInitializer = initializers.NdInitializer +Initializer = initializers.Initializer +nd_dense_init = initializers.nd_dense_init + + +#----------------------------------------- +# The Normalization Layer specific for GPT3 +#----------------------------------------- + +class Gpt3LayerNorm(nn.Module): + """GPT3 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + kernel_axes: Tuple[str, ...] = () + scale_init: Initializer = nn.initializers.zeros + use_bias: bool = True + reductions_in_fp32: bool = False + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + if self.reductions_in_fp32: + x = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x, axis=[-1], keepdims=True) + var = jnp.mean(jnp.square(x - mean), axis=[-1], keepdims=True) + normed_inputs = (x - mean) * lax.rsqrt(var + self.epsilon) + if self.reductions_in_fp32: + normed_inputs = normed_inputs.astype(self.dtype) + + features = x.shape[-1] + scale = self.param( + 'scale', + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (features,), + jnp.float32, + ) + + scale = jnp.asarray(scale, self.dtype) + output = normed_inputs * (scale + 1) + + if self.use_bias: + bias = self.param( + 'bias', + nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), + (features,), + jnp.float32, + ) + bias = jnp.asarray(bias, self.dtype) + output += bias + return output + + +#----------------------------------------- +# The Attention Layer specific for GPT3 +#----------------------------------------- + +class Gpt3MultiHeadAttention(nn.Module): + """Multi-head attention in gpt3. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + fused_qkv: whether to fuse query, key and value into one projection. + use_int8: bool, if true accelerate in int8. + use_bias: whether to add bias in linear transformation. + """ + + config: Config + num_heads: int + head_dim: int + max_target_length: int + mesh: Mesh + attention_kernel: str + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + float32_qk_product: bool = False # computes logits in float32 for stability. + float32_logits: bool = True # cast logits in float32 for stability. + fused_qkv: bool = True + use_int8: bool = False + use_bias: bool = True + + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + + def qkv_projection(self, inputs: Array, proj_name: str): + """ Fused QKV projection""" + + qkv_proj = DenseGeneral( + features=(3, self.num_heads, self.head_dim), + axis = -1, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'qkv', 'heads', 'kv'), + dtype=self.dtype, + name=proj_name, + use_int8=self.use_int8, + use_bias=self.use_bias, + )(inputs) + query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + return query, key, value + + def projection(self, inputs: Array, proj_name: str) -> Array: + """individual projection for one of q, k and v.""" + proj = DenseGeneral( + features=(self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'heads', 'kv'), + dtype=self.dtype, + name=proj_name, + use_int8=self.use_int8, + use_bias=self.use_bias, + )(inputs) + return proj + + def out_projection(self, output_dim: int, out: Array) -> Array: + """output projection""" + out_proj = DenseGeneral( + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=('heads', 'kv', 'embed'), + dtype=self.dtype, + name='out', + use_int8=self.use_int8, + use_bias=self.use_bias, + )(out) + return out_proj + + @nn.compact + def __call__(self, + inputs_q: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False): + if self.fused_qkv: + query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + else: + query = self.projection(inputs_q, proj_name='query') + key = self.projection(inputs_q, proj_name='key') + value = self.projection(inputs_q, proj_name='value') + + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query /= depth_scaling + + # annotate with sharding constraint. + query = nn.with_logical_constraint(query, self.query_axis_names) + query = checkpoint_name(query, 'query_proj') + key = nn.with_logical_constraint(key, self.key_axis_names) + key = checkpoint_name(key, 'key_proj') + value = nn.with_logical_constraint(value, self.value_axis_names) + value = checkpoint_name(value, 'value_proj') + + attention_op = AttentionOp(mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + use_int8=self.use_int8, + num_query_heads=self.num_heads, + num_kv_heads=self.num_heads, + dtype=self.dtype) + + out = attention_op(query, key, value, decoder_segment_ids, model_mode) + + out = nn.with_logical_constraint(out, self.out_axis_names) + + # apply output projection, output dim is set to the input dim. + out = self.out_projection(inputs_q.shape[-1], out) + return out + + +#----------------------------------------- +# The Decoder Layer specific for GPT3 +#----------------------------------------- + +class Gpt3DecoderLayer(nn.Module): + """Transformer decoder layer that attends to the encoder.""" + config: models.Config + mesh: Mesh + + @nn.compact + def __call__(self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): + cfg = self.config + mesh = self.mesh + + inputs = nn.with_logical_constraint( + inputs, ('activation_batch', 'activation_length', 'activation_embed')) + + + lnx_layer_norm = Gpt3LayerNorm( + dtype=cfg.dtype, + name='pre_self_attention_norm', + kernel_axes=('embed',), + epsilon=cfg.normalization_layer_epsilon, + reductions_in_fp32=False, + use_bias=True, + ) + lnx = lnx_layer_norm(inputs) + + lnx = nn.with_logical_constraint( + lnx, ('activation_batch', 'activation_length', 'activation_embed')) + + # Self-attention block + assert cfg.num_query_heads == cfg.num_kv_heads, \ + f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" + attention_layer = Gpt3MultiHeadAttention( + config=cfg, + num_heads=cfg.num_query_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + attention_kernel=cfg.attention, + mesh=mesh, + dropout_rate=cfg.dropout_rate, + name='self_attention', + fused_qkv=cfg.fused_qkv, + use_bias=True, + use_int8=cfg.int8_training) + + attention_lnx = attention_layer( + lnx, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + deterministic=deterministic) + + attention_lnx = nn.with_logical_constraint( + attention_lnx, + ('activation_batch', 'activation_length', 'activation_embed')) + attention_lnx += inputs + + # MLP block. + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + use_bias=True, + use_pre_norm=True, + config=cfg, + )(attention_lnx, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint( + mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + ) + + layer_output = attention_lnx + mlp_lnx + + layer_output = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + layer_output, deterministic=deterministic) + + layer_output = nn.with_logical_constraint( + layer_output, + ('activation_batch', 'activation_length', 'activation_embed'), + ) + + if cfg.record_internal_nn_metrics: + self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) + self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow( + 'intermediates', + 'activation_fraction_zero', + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 727bcb26a..6f0bb9c23 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -35,6 +35,8 @@ 1.0, 'fan_in', 'normal', out_axis=0 ) +default_bias_init = jax.nn.initializers.constant(0.0) + def nd_dense_init(scale, mode, distribution): """Initializer with in_axis, out_axis set at call time.""" diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 4812d86b2..da5df4311 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -16,13 +16,14 @@ import functools import operator -from typing import Any, Callable, Iterable, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, Sequence, Tuple, Union, Optional import flax.linen as nn from jax import lax import jax.numpy as jnp import common_types from layers import initializers +from layers import normalizations from layers import quantizations import numpy as np @@ -32,6 +33,9 @@ NdInitializer = initializers.NdInitializer nd_dense_init = initializers.nd_dense_init +bias_init = initializers.default_bias_init + +RMSNorm = normalizations.RMSNorm def _convert_to_activation_function( @@ -61,13 +65,14 @@ def _canonicalize_tuple(x): class DenseGeneral(nn.Module): - """A linear transformation (without bias) with flexible axes. + """A linear transformation with flexible axes. Attributes: features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation """ features: Union[Iterable[int], int] @@ -76,6 +81,7 @@ class DenseGeneral(nn.Module): kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') kernel_axes: Tuple[str, ...] = () use_int8: bool = False + use_bias: bool = False @nn.compact def __call__(self, inputs: Array) -> Array: @@ -119,7 +125,19 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): kernel = jnp.asarray(kernel, self.dtype) contract_ind = tuple(range(0, len(axis))) - return compute_dot_general(inputs, kernel, axis, contract_ind) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = self.kernel_axes[-len(features):], kernel_shape[-len(features):] + bias = self.param( + 'bias', + nn.with_logical_partitioning(bias_init, bias_axes), + bias_shape, + jnp.float32, + ) + bias = jnp.asarray(bias, self.dtype) + output += bias + return output class MlpBlock(nn.Module): @@ -133,6 +151,8 @@ class MlpBlock(nn.Module): deterministic: Whether the dropout layers should be deterministic. intermediate_dropout_rate: Dropout rate used after the intermediate layers. dtype: Type for the dense layer. + use_bias: whether to add bias in all feedforward layers. + use_pre_norm: whether to add pre layer norm in mlp layers. """ config: Config @@ -141,12 +161,31 @@ class MlpBlock(nn.Module): kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') intermediate_dropout_rate: float = 0.1 dtype: Any = jnp.float32 + use_bias: bool = False + use_pre_norm: bool = False + + def get_norm_layer(self): + if self.config.decoder_block in ("default", "llama2", "mistral", "gamma"): + return RMSNorm + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @nn.compact def __call__(self, inputs, decode: bool = False, deterministic: bool = False): """Applies Transformer MlpBlock module.""" cfg = self.config + if self.use_pre_norm: + inputs = self.get_norm_layer()( + name='mlp_layer_norm', + dtype=cfg.dtype, + kernel_axes=('embed',), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) + # Iterate over specified MLP input activation functions. # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. activations = [] @@ -158,6 +197,7 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): kernel_axes=('embed', 'num_activations', 'mlp'), name='wi', use_int8=cfg.int8_training, + use_bias=self.use_bias, )(inputs) for idx, act_fn in enumerate(self.activations): y = _convert_to_activation_function(act_fn)(x[:,:,idx,...]) @@ -172,6 +212,7 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): kernel_axes=('embed', 'mlp'), name=dense_name, use_int8=cfg.int8_training, + use_bias=self.use_bias, )(inputs) x = _convert_to_activation_function(act_fn)(x) activations.append(x) @@ -192,5 +233,6 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): kernel_axes=('mlp', 'embed'), name='wo', use_int8=cfg.int8_training, + use_bias=self.use_bias, )(x) return output diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index bbc5dae0f..b0486e9d4 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -76,7 +76,7 @@ def __call__(self, dtype=cfg.dtype, name='pre_self_attention_layer_norm', kernel_axes=('embed',), - epsilon=cfg.rms_norm_epsilon + epsilon=cfg.normalization_layer_epsilon, ) lnx = lnx_rms(inputs) @@ -113,7 +113,7 @@ def __call__(self, # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, name='post_self_attention_layer_norm', kernel_axes=('embed',), - epsilon=cfg.rms_norm_epsilon, + epsilon=cfg.normalization_layer_epsilon, )(intermediate_inputs) hidden_states = nn.with_logical_constraint(hidden_states, ('activation_batch', 'activation_length', 'activation_embed')) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 09c4b89c2..b8636b83f 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -20,6 +20,7 @@ from flax import linen as nn +import functools import jax import jax.numpy as jnp import common_types @@ -67,7 +68,7 @@ def __call__(self, lnx = RMSNorm( dtype=cfg.dtype, name='pre_self_attention_norm', - epsilon=cfg.rms_norm_epsilon, + epsilon=cfg.normalization_layer_epsilon, kernel_axes=('embed',))(inputs) lnx = nn.with_logical_constraint( lnx, ('activation_batch', 'activation_length', 'activation_embed')) @@ -160,9 +161,20 @@ def get_decoder_layer(self): elif self.config.decoder_block == "gamma": from layers import gamma return gamma.GammaDecoderLayer + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + return gpt3.Gpt3DecoderLayer else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + def get_norm_layer(self): + if self.config.decoder_block in ("default", "llama2", "mistral", "gamma"): + return RMSNorm + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @nn.compact def __call__(self, @@ -183,8 +195,17 @@ def __call__(self, y, deterministic=deterministic) y = y.astype(cfg.dtype) - if cfg.use_positional_embedding: - y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + if cfg.use_untrainable_positional_embedding: + y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + + if cfg.trainable_position_size > 0: + y += Embed( + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name='position_embedder', + config=cfg)(decoder_positions) BlockLayer = self.get_decoder_layer() @@ -249,7 +270,12 @@ def __call__(self, model_mode, ) - y = RMSNorm(dtype=cfg.dtype, name='decoder_norm', epsilon=cfg.rms_norm_epsilon,kernel_axes=('embed',))(y) + y = self.get_norm_layer()( + dtype=cfg.dtype, + name='decoder_norm', + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=('embed',), + )(y) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic ) @@ -258,12 +284,13 @@ def __call__(self, if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = self.shared_embedding.attend(y) - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) + if self.config.normalize_embedding_logits: + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) else: logits = linears.DenseGeneral( cfg.vocab_size, - dtype=jnp.float32, # Use float32 for stabiliity. + dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability kernel_axes=('embed', 'vocab'), name='logits_dense', use_int8=cfg.int8_training)(y) @@ -287,7 +314,7 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.emb_dim, dtype=cfg.dtype, - attend_dtype=jnp.float32, # for logit training stability + attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), name='token_embedder', config=cfg, diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index e4ecb5b0c..faefab509 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -63,7 +63,7 @@ def permute_to_match_maxtext_rope(arr): 'dims_per_head': 128, 'vocab': 32000, 'num_gpus': 1, - 'combined_qkv': True, + 'fused_qkv': True, }, 'llama2-7b': { 'num_layers': 32, diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 5beb3ab27..bd1f8afa1 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -248,8 +248,8 @@ def init_initial_state(model, tx, config, is_training, key): config.max_target_length ) model_vars = model.init({'params': key, 'dropout': key, 'aqt': key}, - jnp.ones(input_shape), - jnp.ones(input_shape)) + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32)) if is_training: return init_training_state(model.apply, model_vars['params'], tx) return init_decode_state(model.apply, model_vars['params']) diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index d0694b26a..06dd566eb 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -25,7 +25,6 @@ import pickle import functools from input_pipeline import input_pipeline_interface -import optax @@ -47,17 +46,6 @@ def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations def get_functional_train_step(train_step, model, config): return functools.partial(train_step, model, config) -def get_optimizer(config, learning_rate_schedule): - """ Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 """ - return optax.adamw( - learning_rate_schedule, - b1=config.adam_b1, - b2=config.adam_b2, - eps=config.adam_eps, - eps_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, - ) - def load_compiled(config, partial_train, state): """ # Loading a serialized compiled train step function.""" # Currently partial_train and state are needed to reconstruct diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py new file mode 100644 index 000000000..0a8150293 --- /dev/null +++ b/MaxText/optimizers.py @@ -0,0 +1,144 @@ +""" + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +# pylint: disable=bare-except, consider-using-generator, ungrouped-imports +"""Utils that are only interesting to MaxText. """ + +import jax + + +import optax +import jax.numpy as jnp + + +def get_optimizer(config, learning_rate_schedule): + """create optimizer""" + if config.opt_type == "adamw": + # Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 + return optax.adamw( + learning_rate_schedule, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + eps_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, + ) + elif config.opt_type == "adam_pax": + return adam_pax( + learning_rate_schedule, + beta1=config.adam_b1, + beta2=config.adam_b2, + epsilon=config.adam_eps, + epsilon_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, + ) + else: + raise ValueError(f"{config.opt_type=} is not a supported.") + +def adam_pax( + learning_rate_fn: optax.Schedule, + beta1: float, + beta2: float, + epsilon: float, + epsilon_root: float, + weight_decay: float, + ) -> optax.GradientTransformation: + """Standard Adam optimizer that supports weight decay. + + Follows the implemenation in pax/praxis sharded_adam + https://github.com/google/praxis/blob/545e00ab126b823265d70c715950d39333484f38/praxis/optimizers.py#L621 + + Args: + learning_rate_fn: a callable that given the current training step, returns + the learning rate to apply. + beta1: decay rate to track the first moment. + beta2: decay rate to track the second moment. + epsilon: Small constant applied to the denominator outside of the square + root to avoid dividing by zero when rescaling. + epsilon_root: Small constant applied to the denominator inside of the square + root to avoid dividing by zero when rescaling. + weight_decay: If > 0, weight decay to apply. + + Returns: + A `optax.GradientTransformation`. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + jnp.zeros_like, params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def bias_corrected_decay(step: jnp.int32, decay: float): + """Incorporates bias correction into decay. + + Please see section 7.1 in https://arxiv.org/pdf/1804.04235.pdf for the + derivation of the formulas below. With bias-corrected decay, we can simply + do + + m_{t} = decay1 * m_{t-1} + (1 - decay1) * g + v_{t} = decay2 * v_{t-1} + (1 - decay2) * g ^ 2 + + without further bias correction. + + Args: + step: current step, 0-based. + decay: the raw decay. As t -> infinity, bias corrected decay converges to + this value. + + Returns: + Bias corrected decay. + """ + t = step.astype(jnp.float32) + 1. + return decay * (1. - jnp.power(decay, t - 1.)) / (1. - jnp.power(decay, t)) + + def update_fn(updates, state, params=None): + # Sanitize updates just in case. + if weight_decay > 0: + assert params is not None + count = state.count + + class _slot_opt_state: + def __init__(self, mu, nu): + self.mu = mu + self.nu = nu + + def _update_momentum(update, mu, nu): + beta1_decay = bias_corrected_decay(count, beta1) + beta2_decay = bias_corrected_decay(count, beta2) + mu = (1.0 - beta1_decay) * update + beta1_decay * mu + nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu + return _slot_opt_state(mu=mu, nu=nu) + + updated_moments = jax.tree_map(_update_momentum, updates, state.mu, state.nu) + + mu = jax.tree_map(lambda x: x.mu, updated_moments) + nu = jax.tree_map(lambda x: x.nu, updated_moments) + + updates = jax.tree_map( + lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) + + if weight_decay > 0: + updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params) + + step_size = -1.0 * learning_rate_fn(count) + # Finally, fold in step size. + updates = jax.tree_map(lambda x: step_size * x, updates) + + updated_states = optax.ScaleByAdamState(count=count + 1, mu=mu, nu=nu) + return updates, updated_states + + return optax.GradientTransformation(init_fn, update_fn) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 4b21b8007..1d822adf5 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -61,7 +61,7 @@ def validate_keys(keys): def validate_model_name(s: str) -> bool: # currently supported models - valid_model_names= ('default', 'llama2-7b', 'mistral-7b', 'gamma-7b','gamma-2b') + valid_model_names = ('default', 'llama2-7b', 'mistral-7b', 'gamma-7b','gamma-2b', 'gpt3-175b', 'gpt3-52k') if s not in valid_model_names: raise ValueError( "Invalid model name was passed. Valid options ", valid_model_names @@ -202,8 +202,6 @@ def user_init(raw_keys): validate_keys(raw_keys) - validate_attention_type(raw_keys['attention']) - @staticmethod def update_model_vars(raw_keys) -> list[str]: ''' Update model config variables @@ -213,7 +211,8 @@ def update_model_vars(raw_keys) -> list[str]: updated_keys = [] if raw_keys['model_name'] != 'default': - file_path = f"MaxText/configs/models/{raw_keys['model_name']}.yml" + dir_path = os.path.dirname(os.path.realpath(__file__)) + file_path = os.path.join(dir_path, f"configs/models/{raw_keys['model_name']}.yml") with open(file_path, 'r', encoding="utf-8") as file: model_vars = yaml.safe_load(file) updated_keys = list(model_vars.keys()) diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py new file mode 100644 index 000000000..693d81d60 --- /dev/null +++ b/MaxText/tests/gpt3_test.py @@ -0,0 +1,110 @@ + +""" + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +""" Tests for GPT3 """ +import sys +import jax +import unittest +import max_utils +from jax.sharding import Mesh +from layers import models +from layers import embeddings + +import jax.numpy as jnp + +import pyconfig +import pytest + + +Mesh = jax.sharding.Mesh +Embed = embeddings.Embed + + +def init_random_model_vars(model, rng, example_batch): + """initialze random model vars.""" + model_vars = model.init( + {'params': rng, 'aqt': rng}, + example_batch['inputs'], + example_batch['inputs_position'], + enable_dropout=False, + ) + def _replace_initialization(key, value): + keystr = jax.tree_util.keystr(key) + # replace zero initializer to ensure strong test cases + # including Gpt3LayerNorm scale, Gpt3LayerNorm bias, and DenseGeneral bias + if "scale" in keystr or "bias" in keystr: + value = jax.nn.initializers.normal(1.0)(rng, value.shape, dtype=value.dtype) + return value + + model_vars = jax.tree_util.tree_map_with_path(_replace_initialization, model_vars) + return model_vars + + +class GPT3(unittest.TestCase): + """numerical tests for GPT3.""" + def setUp(self): + super().setUp() + pyconfig.initialize( + [sys.argv[0], 'configs/base.yml'], + attention="dot_product", + run_name='test', + enable_checkpointing=False, + model_name='gpt3-52k', + dtype='float32', + ) + + self.cfg = pyconfig.config + self.rng = jax.random.PRNGKey(1234) + + devices_array = max_utils.create_device_mesh(self.cfg) + mesh = Mesh(devices_array, self.cfg.mesh_axes) + self.model = models.Transformer(config = self.cfg, mesh = mesh) + self.example_batch = { + 'inputs': jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), + 'inputs_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + 'inputs_segmentation': jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), + 'targets': jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), + 'targets_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + 'targets_segmentation': jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), + } + self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) + + @pytest.mark.tpu + def test_logits_numerically(self): + # ground truth values are calculated from paxml after loading above model_vars + # note we expect all xents are the same except the padding one since: + # paxml applies padding in mlp layer + # while maxtext implementaiton applies padding in attention mask instead + # the two implementation are equivalent in valid non-padding tokens + per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.]], dtype=jnp.float32) + logits, _ = self.model.apply(self.model_vars, + self.example_batch['inputs'], + self.example_batch['inputs_position'], + decoder_segment_ids=self.example_batch['inputs_segmentation'], + enable_dropout=self.cfg.enable_dropout, + rngs={'dropout': self.rng, 'aqt': self.rng}, mutable='intermediates') + + one_hot_targets = jax.nn.one_hot(self.example_batch['targets'], self.cfg.vocab_size) + per_example_xent = -jnp.sum(jax.nn.log_softmax(logits) * one_hot_targets, axis=-1, dtype=jnp.float32) + # Mask out paddings at the end of each example. + per_example_xent = per_example_xent * (self.example_batch['targets_segmentation'] != 0) + + self.assertTrue( + jax.numpy.allclose( + per_example_xent, per_example_xent_truth, rtol=1e-06, atol=1e-06 + ) + ) diff --git a/MaxText/train.py b/MaxText/train.py index 8718c6978..a8c7a1aec 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -36,6 +36,7 @@ import max_utils import maxtext_utils import max_logging +import optimizers import pyconfig from input_pipeline.input_pipeline_interface import create_data_iterator_with_tokenizer @@ -184,7 +185,7 @@ def loss_fn(params): xent = nn.with_logical_constraint(xent, ('activation_batch', 'activation_length')) # Mask out paddings at the end of each example. xent = xent * (data['targets_segmentation'] != 0) - return jnp.sum(xent)/jnp.sum((data['targets_segmentation'] != 0)), intermediate_outputs + return jnp.sum(xent)/jnp.sum(data['targets_segmentation'] != 0), intermediate_outputs grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, intermediate_outputs), raw_grads = grad_fn(state.params) @@ -233,7 +234,7 @@ def setup_mesh_and_model(config): # Model and Optimizer definition model = Transformer(config, mesh) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = maxtext_utils.get_optimizer(config, learning_rate_schedule) + tx = optimizers.get_optimizer(config, learning_rate_schedule) return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx def setup_train_loop(config): diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index eec0101f2..84512f1c4 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -29,6 +29,7 @@ from jax.experimental.serialize_executable import serialize from flax.linen import partitioning as nn_partitioning import maxtext_utils +import optimizers import max_utils import pyconfig from layers import models @@ -69,7 +70,7 @@ def get_shaped_inputs(topology_mesh, config): model = Transformer(config, topology_mesh) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = maxtext_utils.get_optimizer(config, learning_rate_schedule) + tx = optimizers.get_optimizer(config, learning_rate_schedule) # Shaped RNG keys _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) diff --git a/end_to_end/test_gpt3.sh b/end_to_end/test_gpt3.sh new file mode 100644 index 000000000..7dd0b315a --- /dev/null +++ b/end_to_end/test_gpt3.sh @@ -0,0 +1,16 @@ +set -euox pipefail + +TIMESTAMP=$(date +%Y%m%d-%H%M) +export PAXML_CKPT_PATH=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state +export OUTPUT_PATH=gs://maxtext-gpt3/tests +export RUN_NAME=test_${TIMESTAMP} + +# convert gpt3-52k model +python3 MaxText/convert_gpt3_ckpt_from_paxml.py --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH} + +# Run gpt3-52k with the converted ckpt +python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ + steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\ + enable_profiler=false remat_policy=full\ + max_target_length=2048 base_output_directory=${OUTPUT_PATH}\ + dataset_type=synthetic