Skip to content

Commit

Permalink
Merge pull request #293 from google:lizhiyu/mlperf_gpt3_migration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602534669
  • Loading branch information
maxtex authors committed Jan 30, 2024
2 parents a54a940 + 0fee320 commit ddcd1c4
Show file tree
Hide file tree
Showing 23 changed files with 1,051 additions and 48 deletions.
9 changes: 7 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
run_name: ""

model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
rms_norm_epsilon: 1.e-06
normalization_layer_epsilon: 1.e-06

################################## CHECKPOINTING ##################################
# Checkpointing makes the following choices in the following order, starting with (1):
Expand Down Expand Up @@ -71,6 +71,9 @@ head_dim: 256
mlp_activations: ["relu"]
dropout_rate: 0
logits_via_embedding: True # NOTE: this is True just for testing.
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability

# proj, minimal, full, or none
remat_policy: 'full'
scan_layers: True
Expand Down Expand Up @@ -185,6 +188,7 @@ gradient_clipping_threshold: 1.0

# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
opt_type: "adamw" # one of "adam_pax" or "adamw"
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
Expand All @@ -199,7 +203,8 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds
# Use iota operator in Embed
use_iota_embed: False
# use positional embedding
use_positional_embedding: False
use_untrainable_positional_embedding: False
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size

# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
Expand Down
34 changes: 34 additions & 0 deletions MaxText/configs/models/gpt3-175b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gpt3-175b

base_emb_dim: 12288
base_num_query_heads: 96
base_num_kv_heads: 96
base_mlp_dim: 49152
base_num_decoder_layers: 96
head_dim: 128
trainable_position_size: 16384
mlp_activations: ["gelu"]
vocab_size: 50304
enable_dropout: False
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
fused_qkv: True
opt_type: "adam_pax"
decoder_block: "gpt3"
34 changes: 34 additions & 0 deletions MaxText/configs/models/gpt3-52k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gpt3-52k, i.e. a fake and small model for testing purpose only

base_emb_dim: 16
base_num_query_heads: 2
base_num_kv_heads: 2
base_mlp_dim: 64
base_num_decoder_layers: 1
head_dim: 8
trainable_position_size: 2048
mlp_activations: ["gelu"]
vocab_size: 1024
enable_dropout: False
logits_via_embedding: True
normalize_embedding_logits: False
logits_dot_in_fp32: False
normalization_layer_epsilon: 1.e-05
use_iota_embed: True
fused_qkv: True
opt_type: "adam_pax"
decoder_block: "gpt3"
2 changes: 1 addition & 1 deletion MaxText/configs/models/llama2-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ vocab_size: 32000
enable_dropout: False
vocab_relative_path: "tokenizer.llama2"
logits_via_embedding: False
rms_norm_epsilon: 1.0e-5
normalization_layer_epsilon: 1.0e-5
decoder_block: "llama2"
2 changes: 1 addition & 1 deletion MaxText/configs/models/mistral-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ vocab_size: 32000
enable_dropout: False
vocab_relative_path: "tokenizer.model"
logits_via_embedding: False
rms_norm_epsilon: 1.0e-5
normalization_layer_epsilon: 1.0e-5
decoder_block: "mistral"
230 changes: 230 additions & 0 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=line-too-long
"""Convert weights from a paxml gpt3 model to a MaxText one.
Test cmd for gpt3-52k:
python MaxText/convert_gpt3_ckpt_from_paxml.py \
--paxml-ckpt-path=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state \
--maxtext-model-name=gpt3-52k \
--run-name=$RUN_NAME \
--base-output-directory=$BASE_OUTPUT_DIR
True cmd for gpt3-175b:
The script is memory demanding, requires at least 250 GiB in cpu and cumulative TPU memory of all devices should be
above ~4.2 TiB (175 billion param * 4 byte/param * 3 (model var and 2 opt momentums) * 2 copies in converting)
python MaxText/convert_gpt3_ckpt_from_paxml.py \
--paxml-ckpt-path=gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000 \
--maxtext-model-name=gpt3-175b \
--run-name=$RUN_NAME \
--base-output-directory=$BASE_OUTPUT_DIR
"""
import max_utils
import optimizers
import pyconfig
import os
from jax import random
from jax.sharding import Mesh
from layers.models import Transformer
import checkpointing

import numpy as np
import tensorstore as ts

import sys
import jax
import gc
import max_logging
from psutil import Process
import argparse

def fmt_size(num_bytes: int) -> str:
assert num_bytes > 0
for unit in ["B", "KiB", "MiB", "GiB"]:
if num_bytes < 1024.0:
break
num_bytes /= 1024.0
return f"{num_bytes:.2f} {unit}"

def check_memory():
"""print out cpu/tpu memory."""
cpu_bytes = Process().memory_info().rss
max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}")
for d in jax.local_devices():
stats = d.memory_stats()
used = stats['bytes_in_use']
limit = stats['bytes_limit']
max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")


def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name):
"""convert ckpt."""

base_args = [
'', 'MaxText/configs/base.yml', # base arg
'per_device_batch_size=1',
'ici_fsdp_parallelism=-1', 'ici_tensor_parallelism=1',
f'model_name={maxtext_model_name}',
f'run_name={run_name}', f'base_output_directory={base_output_directory}',
'checkpoint_period=1',
'async_checkpointing=false',
]
pyconfig.initialize(base_args)
cfg = pyconfig.config
init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2)
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

model = Transformer(config=cfg, mesh=mesh)
learning_rate_schedule = max_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
cfg.enable_checkpointing,
cfg.async_checkpointing,
cfg.checkpoint_period,
)

state, _ = max_utils.setup_training_state(model, tx, cfg, init_rng, mesh, checkpoint_manager)
max_logging.log("start")
check_memory()

# maxtext keystr: (paxml keystr, transform_fn)
keystr_map = {
"['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T),
"['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None),
"['decoder']['layers']['pre_self_attention_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['pre_self_attention_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['query']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['query']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['key']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['key']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['value']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['value']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['qkv_proj']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis])),
"['decoder']['layers']['self_attention']['qkv_proj']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['self_attention']['out']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1])),
"['decoder']['layers']['self_attention']['out']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['wi']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['wi']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['wo']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['layers']['mlp']['wo']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)),
"['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T),
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
if "x_layers_0" in keystr_pax:
# string format for all variables in scanned decoder layer
prefix_pax_opt_state = f"p#{cfg.base_num_decoder_layers}#i-1_2"
else:
prefix_pax_opt_state = "no_prefix_2"
return prefix_pax_opt_state

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn)
# second momentum in optimizer state
state_map[f".opt_state.nu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
assert keystr in state_map, f"{keystr} not found"

jax.tree_util.tree_map_with_path(verify_fn, state)

memory_metrics = {'max_cpu_bytes': 0}

bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://"):].split('/', 1)

def map_fn(key_path, value):
key_path_str = jax.tree_util.keystr(key_path)
file_path, transform_fn = state_map[key_path_str]
full_path = os.path.join(paxml_ckpt_prefix, file_path)
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'bucket': bucket_name,
'driver': 'gcs',
'path': full_path,
}

arr = ts.open(ts.Spec(spec), open=True).result().read().result()
if transform_fn is not None:
arr = transform_fn(arr)

assert value.shape == arr.shape, f"{key_path}, {value.shape}, {arr.shape}"
shape = value.shape
sharding = value.sharding
result = jax.make_array_from_single_device_arrays(
shape,
sharding,
[jax.device_put(np.array(arr[index]), d)
for d, index in sharding.addressable_devices_indices_map(shape).items()],
)

# log peak cpu memory
cpu_bytes = Process().memory_info().rss
memory_metrics["max_cpu_bytes"] = max(cpu_bytes, memory_metrics["max_cpu_bytes"])

# collect cpu memory back asap
arr = None
del arr
gc.collect()
max_logging.log(f"{key_path_str} finished")
check_memory()
return result

converted_state = jax.tree_util.tree_map_with_path(map_fn, state)
max_logging.log("converted state finished")
check_memory()

if checkpoint_manager.save(converted_state.step, converted_state):
max_logging.log(f"saved a checkpoint at step {converted_state.step}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(converted_state.step):
checkpoint_manager.wait_until_finished()
sys.exit()

max_logging.log(f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}")
max_logging.log("checkpoint converted and saved successfully.")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--paxml-ckpt-path',
type=str,
default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000",
required=True)
parser.add_argument('--maxtext-model-name', choices=['gpt3-175b', 'gpt3-52k'], type=str, required=True)
parser.add_argument('--base-output-directory', type=str, required=True)
parser.add_argument('--run-name', type=str, required=True)

args = parser.parse_args()
if not args.paxml_ckpt_path.startswith("gs://"):
raise ValueError("--paxml-ckpt-path should be a gcs path starting with gs://")

convert(args.paxml_ckpt_path, args.maxtext_model_name, args.base_output_directory, args.run_name)
4 changes: 2 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import jax
import max_logging
import max_utils
import maxtext_utils
import optimizers
import pyconfig

from absl import app
Expand Down Expand Up @@ -77,7 +77,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
model = Transformer(config, mesh)
rng = random.PRNGKey(0)
learning_rate_schedule = max_utils.create_learning_rate_schedule(config)
tx = maxtext_utils.get_optimizer(config, learning_rate_schedule)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
state, state_mesh_notations = max_utils.setup_training_state(
model, tx, config, rng, mesh, checkpoint_manager
)
Expand Down
Loading

0 comments on commit ddcd1c4

Please sign in to comment.