Skip to content

Commit

Permalink
XAOT (#215)
Browse files Browse the repository at this point in the history
Add Ahead of Time Compilation functionality to maxtext with train_compile.py
  • Loading branch information
gobbleturk authored Oct 30, 2023
1 parent 1802b4b commit 07dc6ce
Show file tree
Hide file tree
Showing 14 changed files with 528 additions and 75 deletions.
85 changes: 85 additions & 0 deletions MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Static map of TPU names such as v4-8 to properties such as chip layout."""

""" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO
UserFacingNameToSystemCharacteristics in xpk/xpk.py !!!!! """

from dataclasses import dataclass

@dataclass
class SystemCharacteristics:
platform: str
topology_name: str
chip_config_name: str # 'megacore' or 'default'
chips_per_host_bounds: tuple
devices_per_slice: int

UserFacingNameToSystemCharacteristics = {
'v5e-16': SystemCharacteristics(
'tpu', 'v5e:4x4', 'default', (2, 2, 1), 16
),
'v5e-32': SystemCharacteristics(
'tpu', 'v5e:4x8', 'default', (2, 2, 1), 32
),
'v5e-64': SystemCharacteristics(
'tpu', 'v5e:8x8', 'default', (2, 2, 1), 64
),
'v5e-128': SystemCharacteristics(
'tpu', 'v5e:8x16', 'default', (2, 2, 1), 128
),
'v5e-256': SystemCharacteristics(
'tpu', 'v5e:16x16', 'default', (2, 2, 1), 256
),
'v4-8': SystemCharacteristics(
'tpu', 'v4:2x2x1', 'megacore', (2, 2, 1), 4
),
'v4-16': SystemCharacteristics(
'tpu', 'v4:2x2x2', 'megacore', (2, 2, 1), 8
),
'v4-32': SystemCharacteristics(
'tpu', 'v4:2x2x4', 'megacore', (2, 2, 1), 16
),
'v4-64': SystemCharacteristics(
'tpu', 'v4:2x4x4', 'megacore', (2, 2, 1), 32
),
'v4-128': SystemCharacteristics(
'tpu', 'v4:4x4x4', 'megacore', (2, 2, 1), 64
),
'v4-256': SystemCharacteristics(
'tpu', 'v4:4x4x8', 'megacore', (2, 2, 1), 128
),
'v4-512': SystemCharacteristics(
'tpu', 'v4:4x8x8', 'megacore', (2, 2, 1), 256
),
'v4-1024': SystemCharacteristics(
'tpu', 'v4:8x8x8', 'megacore', (2, 2, 1), 512
),
'v4-1536': SystemCharacteristics(
'tpu', 'v4:8x8x12','megacore', (2, 2, 1), 768
),
'v4-2048': SystemCharacteristics(
'tpu', 'v4:8x8x16','megacore', (2, 2, 1), 1024
),
'v4-4096': SystemCharacteristics(
'tpu', 'v4:8x16x16', 'megacore', (2, 2, 1), 2048
),
}

def get_system_characteristics(user_facing_name):
return UserFacingNameToSystemCharacteristics.get(user_facing_name)
6 changes: 6 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,9 @@ stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds

# Use iota operator in Embed
use_iota_embed: False

# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
14 changes: 14 additions & 0 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

import tokenizer
Expand Down Expand Up @@ -321,3 +322,16 @@ def create_data_iterator_with_tokenizer(config, mesh):
return make_c4_train_iterator_and_tokenizer(config, mesh)
else:
assert False, "dataset type not implemented"

def get_shaped_batch(config):
""" Return the shape of the batch - this is what eval_shape would return for the
output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078."""
batch_shape = (config.global_batch_size_to_load, config.max_target_length)
shaped_batch = {}
shaped_batch['inputs'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch['inputs_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch['inputs_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch['targets'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch['targets_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch['targets_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
return shaped_batch
20 changes: 8 additions & 12 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools

import max_logging
import maxtext_utils

import numpy as np
import jax
Expand All @@ -29,7 +30,6 @@
import json
import flax
from flax.training import train_state
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning

import optax
Expand Down Expand Up @@ -112,9 +112,8 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ

return parallelism_vals

def create_device_mesh(config, logging=True):
def create_device_mesh(config, devices=jax.devices(), logging=True):
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas """
devices = jax.devices()
num_devices = len(devices)
try:
num_slices = 1 + max([d.slice_index for d in devices])
Expand All @@ -124,7 +123,7 @@ def create_device_mesh(config, logging=True):
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")
assert len(devices) > 1, "You must have at least two devices"

multi_slice_env = hasattr(jax.devices()[0], 'slice_index')
multi_slice_env = num_slices > 1

dcn_parallelism = [config.dcn_data_parallelism, config.dcn_fsdp_parallelism, config.dcn_tensor_parallelism]
ici_parallelism = [config.ici_data_parallelism, config.ici_fsdp_parallelism, config.ici_tensor_parallelism]
Expand All @@ -134,9 +133,9 @@ def create_device_mesh(config, logging=True):

if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, 'DCN')
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism)
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
else:
mesh = mesh_utils.create_device_mesh(ici_parallelism)
mesh = mesh_utils.create_device_mesh(ici_parallelism, devices)

if logging:
max_logging.log(f"Decided on mesh: {mesh}")
Expand Down Expand Up @@ -196,15 +195,11 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
state: the initialized train state
state_mesh_annotations: the mesh annotations for the train state
"""
init_train_state_partial = functools.partial(init_train_state, model, tx,
config)
abstract_state = jax.eval_shape(init_train_state_partial, rng)
state_logical_annotations = nn.get_partition_spec(abstract_state)
unboxed_abstract_state = unbox_logicallypartioned_trainstate(abstract_state)

unboxed_abstract_state, state_mesh_annotations = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh)

# Initialization
with nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager,
config.load_parameters_path,
config.load_from_other_directory,
Expand All @@ -216,6 +211,7 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
if not state:
init_train_state_partial = functools.partial(init_train_state, model, tx, config)
state = jax.jit(
init_train_state_partial,
in_shardings=None,
Expand Down
100 changes: 100 additions & 0 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=bare-except, consider-using-generator
"""Utils that are only interesting to MaxText. """

import jax
from jax.sharding import PartitionSpec as P
from jax.experimental.serialize_executable import deserialize_and_load


import max_utils
import pickle
import functools
import input_pipeline
import optax



from flax import linen as nn
from flax.linen import partitioning as nn_partitioning

def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config):
""" Get the shardings (both state and data) for train_step """
functional_train = get_functional_train_step(train_step, model, config)
data_pspec = P(*config.data_sharding)
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = (state_mesh_shardings, None, None) # State, metrics, rng
static_argnums = () # We partial out the static argnums of model and config
donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory.
return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums

def get_functional_train_step(train_step, model, config):
return functools.partial(train_step, model, config)

def get_optimizer(config, learning_rate_schedule):
""" Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 """
return optax.adamw(
learning_rate_schedule,
b1=config.adam_b1,
b2=config.adam_b2,
eps=config.adam_eps,
eps_root=config.adam_eps_root,
weight_decay=config.adam_weight_decay,
)

def load_compiled(config, partial_train, state):
""" # Loading a serialized compiled train step function."""
# Currently partial_train and state are needed to reconstruct
# input/output shapes to construct the in_trees and out_trees for load API
# Parker is working on a serializing these
def load_serialized_compiled(save_name):
with open(save_name, "rb") as f:
serialized_compiled = pickle.load(f)
return serialized_compiled

def get_train_input_output_trees(func, input_args, input_kwargs):
_, in_tree_recreated = jax.tree_util.tree_flatten((input_args, input_kwargs))
out_shaped = jax.eval_shape(func, *input_args, **input_kwargs)
_, out_tree_recreated = jax.tree_util.tree_flatten(out_shaped)
return in_tree_recreated, out_tree_recreated

serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file)
shaped_batch = input_pipeline.get_shaped_batch(config)
example_rng = jax.random.PRNGKey(0)
shaped_input_args = (state, shaped_batch, example_rng)
shaped_input_kwargs = {}
in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs)
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree)
return p_train_step

def get_abstract_state(model, tx, config, rng, mesh):
""" Get a shaped abstraction of the state (including optimizer)"""
init_train_state_partial = functools.partial(max_utils.init_train_state, model, tx,
config)
abstract_state = jax.eval_shape(init_train_state_partial, rng)
state_logical_annotations = nn.get_partition_spec(abstract_state)
unboxed_abstract_state = max_utils.unbox_logicallypartioned_trainstate(abstract_state)

# Initialization
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return unboxed_abstract_state, state_mesh_annotations
38 changes: 18 additions & 20 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=missing-module-docstring
from collections import OrderedDict

import accelerator_to_spec_map
import math
import os
import sys
Expand Down Expand Up @@ -88,28 +89,19 @@ def user_init(raw_keys):
if raw_keys["run_name"] == "":
raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default
run_name = raw_keys["run_name"]
assert run_name, "Erroring out, need a real run_name"
base_output_directory = raw_keys["base_output_directory"]
validate_gcs_bucket_name(base_output_directory, "base_output_directory")
dataset_path = raw_keys["dataset_path"]
validate_gcs_bucket_name(dataset_path, "dataset_path")
assert ((raw_keys["load_parameters_path"]=="" and raw_keys["load_from_other_directory"]=="") or
raw_keys["enable_checkpointing"]), "You must set enable_checkpointing to load a checkpoint"
assert raw_keys["load_parameters_path"]=="" or raw_keys["load_from_other_directory"]=="" \
"At most one of load_parameters_path or load_from_other_directory should be set"
assert raw_keys["load_from_other_directory_step"]==-1 or raw_keys["load_from_other_directory"]!="", \
"You must specify the loading directory if you specify the loading step"
raw_keys["tensorboard_dir"] = os.path.join(base_output_directory, run_name, "tensorboard", "")
raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "")
raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "")
if run_name:
raw_keys["tensorboard_dir"] = os.path.join(base_output_directory, run_name, "tensorboard", "")
raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "")
raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "")

raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])

if raw_keys["learning_rate_schedule_steps"]==-1:
raw_keys["learning_rate_schedule_steps"] = raw_keys["steps"]
if raw_keys["steps"]==-1:
raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"]
assert raw_keys["steps"] > 0, "You must set steps or learning_rate_schedule_steps to a positive interger."

emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys['global_parameter_scale'])
raw_keys['emb_dim'] = 2**emb_scale * raw_keys['base_emb_dim']
Expand All @@ -118,11 +110,8 @@ def user_init(raw_keys):
raw_keys['num_decoder_layers'] = 2**layer_scale * raw_keys['base_num_decoder_layers']

raw_keys['global_batch_size_to_load'], raw_keys['global_batch_size_to_train_on'] = \
calculate_global_batch_sizes(raw_keys['per_device_batch_size'])
calculate_global_batch_sizes(raw_keys)

def validate_gcs_bucket_name(bucket_name, config_var):
assert bucket_name, f"Please set {config_var}."
assert len(bucket_name) > 5 and bucket_name[0:5]=='gs://', f"Erroring out, {config_var} should start with 'gs://' "

def get_individual_scales(scale):
'''Choose appropriate scales for individual dimensions based on global scale
Expand All @@ -145,8 +134,10 @@ def get_individual_scales(scale):
layer_scale = base_scale
return emb_scale, num_head_scale, mlp_dim_scale, layer_scale

def calculate_global_batch_sizes(per_device_batch_size):
num_devices = len(jax.devices())
def calculate_global_batch_sizes(raw_keys):
""" Calculates target global batch size from target devices and per_device_batch"""
per_device_batch_size = raw_keys['per_device_batch_size']
num_devices = get_num_target_devices(raw_keys)
if per_device_batch_size < 1:
# For per_device_batch_size<1, we load the data as if per_device_batch_size=1
global_batch_size_to_load = num_devices
Expand All @@ -156,6 +147,13 @@ def calculate_global_batch_sizes(per_device_batch_size):
global_batch_size_to_train_on = int(num_devices * per_device_batch_size)
return global_batch_size_to_load, global_batch_size_to_train_on

def get_num_target_devices(raw_keys):
if raw_keys['compile_topology'] != "":
devices_per_slice = accelerator_to_spec_map.get_system_characteristics(raw_keys['compile_topology']).devices_per_slice
return int(devices_per_slice * raw_keys['compile_topology_num_slices'])
else:
return len(jax.devices())

class HyperParameters(): # pylint: disable=missing-class-docstring
def __init__(self):
pass
Expand Down
Loading

0 comments on commit 07dc6ce

Please sign in to comment.