Skip to content

Commit

Permalink
Streamline params usage
Browse files Browse the repository at this point in the history
  • Loading branch information
anfals committed Apr 2, 2024
1 parent 33dd61e commit 6498e14
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 28 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ jobs:
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4'
- name: Test generate_param_only_checkpoint with int8 quantization
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8'
- name: Test grain checkpoint determinism
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
Expand Down
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"python.testing.pytestArgs": [],
"python.testing.cwd": "${workspaceFolder}/MaxText",
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
7 changes: 5 additions & 2 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,13 @@ def map_to_pspec(data):
max_logging.log(f"restoring params from {load_parameters_from_path=}")
p = epath.Path(load_parameters_from_path)
ckptr = orbax.checkpoint.PyTreeCheckpointer()
# This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
# memory, we instead specify here that we are just restoring the params field of the checkpoint
# (which itself may be a dictionary containing a key named 'params').
restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params)
restored = ckptr.restore(p, item = {'params': abstract_unboxed_pre_state.params}, transforms={},
restore_args = {'params': restore_args})

restore_args = {'params': restore_args})
return None, restored['params']

elif load_full_state_from_path != "":
Expand Down
2 changes: 1 addition & 1 deletion MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def astype_fn(x):
state_new = train_state.TrainState(
step=0,
apply_fn=None,
params=jax_weights,
params={'params': jax_weights},
tx=None, # type: ignore
opt_state={}
)
Expand Down
6 changes: 3 additions & 3 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ def get_layer_prefix(keystr_pax):

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
# model variable
state_map[f".params{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
# first momentum in optimizer state
state_map[f".opt_state.mu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn)
# second momentum in optimizer state
state_map[f".opt_state.nu{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down
12 changes: 6 additions & 6 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def _possibly_unroll_params(config, training_state, training_state_annotations,
if not config.scan_layers or not config.force_unroll:
return

training_state_layers = training_state.params['decoder']['layers']
training_state_annotations_layers = training_state_annotations.params['decoder']['layers']
training_state_layers = training_state.params['params']['decoder']['layers']
training_state_annotations_layers = training_state_annotations.params['params']['decoder']['layers']

def new_pspec(x):
return jax.sharding.PartitionSpec(*x[0:config.param_scan_axis] + x[config.param_scan_axis+1:])
Expand All @@ -62,11 +62,11 @@ def slice_ith(input_layers):

new_layer = jax.jit(slice_ith, out_shardings = new_per_layer_state_sharding)(training_state_layers)

training_state.params['decoder'][f'layers_{i}'] = new_layer
training_state_annotations.params['decoder'][f'layers_{i}'] = new_per_layer_state_annotation
training_state.params['params']['decoder'][f'layers_{i}'] = new_layer
training_state_annotations.params['params']['decoder'][f'layers_{i}'] = new_per_layer_state_annotation

del training_state.params['decoder']['layers']
del training_state_annotations.params['decoder']['layers']
del training_state.params['params']['decoder']['layers']
del training_state_annotations.params['params']['decoder']['layers']

jax.tree_map(lambda x : x.delete(), training_state_layers)

Expand Down
2 changes: 1 addition & 1 deletion MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def checkpoint_device_put(arr):
state_new = train_state.TrainState(
step=0,
apply_fn=None,
params=jax_weights,
params={'params': jax_weights},
tx=None, # type: ignore
opt_state={}
)
Expand Down
7 changes: 3 additions & 4 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,12 @@ def init_initial_state(model, tx, config, is_training, key):
jnp.ones(input_shape, dtype=jnp.int32),
jnp.ones(input_shape, dtype=jnp.int32))
if is_training:
return init_training_state(model.apply, model_vars['params'], tx)
return init_decode_state(model.apply, model_vars['params'])
return init_training_state(model.apply, model_vars, tx)
return init_decode_state(model.apply, model_vars)

def load_decode_model_vars(model, config, rng, mesh):
state, _ = setup_decode_state(model, config, rng, mesh, None)
model_vars = {'params': state.params}
return model_vars
return state.params

def setup_decode_state(model, config, rng, mesh, checkpoint_manager):
is_training = False
Expand Down
12 changes: 4 additions & 8 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,16 @@ def load_params(self, *args, **kwargs) -> Params:
self.kv_cache_shardings = jax.tree_map(lambda x : jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)

if not self.model.quant:
params = {"params" : state.params}
self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
params)
return params
state.params)
return state.params
else:
self.model.quant.quant_mode = quantizations.get_quant_mode('convert')

@jax.jit
def model_apply(_p, _rng):
return self.model.apply(
{
"params": _p,
"aqt" : {}
},
_p | {"aqt": {}},
jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32),
jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32),
decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
Expand All @@ -115,7 +111,7 @@ def model_apply(_p, _rng):
params = {}
params['aqt'] = new_vars['aqt']
# Remove param values which have corresponding qtensors in aqt to save memory.
params['params'] = quantizations.remove_quantized_params(state.params, new_vars['aqt'])
params['params'] = quantizations.remove_quantized_params(state.params['params'], new_vars['aqt'])

self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
params)
Expand Down
54 changes: 54 additions & 0 deletions MaxText/tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,60 @@ def test_init_training_state(self):
max_utils.calculate_num_params_from_pytree(self.params)
)

class ModelWithMultipleCollections(nn.Module):
"""
A simple model that has variables in multiple collections - "params" and "special_variables"
"""
def setup(self):
self.dense = nn.Dense(4)
self.kernel = self.variable(
"special_variables", "my_first_kernel", lambda: jnp.ones((4, 5))
)

def __call__(self, x, y):
x = self.dense(x)
x = x @ self.kernel.value
return x

class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase):

def setUp(self):
pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.config = pyconfig.config
self.model = ModelWithMultipleCollections()
self.key1, self.key2, self.key3 = random.split(random.key(0), num=3)
self.input = random.normal(self.key1,
(self.config.global_batch_size_to_load, self.config.max_target_length))
self.params = self.model.init(self.key2, self.input, self.input)
self.tx = optax.adam(learning_rate=0.001)

def _test_init_initial_state_driver(self, is_training):
state_under_test = max_utils.init_initial_state(self.model, self.tx, self.config, is_training, self.key3)
self.assertEqual(state_under_test.apply_fn, self.model.apply)
if is_training:
self.assertEqual(state_under_test.tx, self.tx)
self.assertNotEqual(state_under_test.opt_state, {})
else:
self.assertIsNone(state_under_test.tx)
self.assertEqual(state_under_test.opt_state, {})
self.assertEqual(
max_utils.calculate_num_params_from_pytree(state_under_test.params),
max_utils.calculate_num_params_from_pytree(self.params)
)
self.assertEqual(
len(self.params),
len(state_under_test.params)
)
self.assertIn("special_variables", state_under_test.params)
self.assertIn("params", state_under_test.params)

def test_initial_train_state(self):
self._test_init_initial_state_driver(True)

def test_initial_decode_state(self):
self._test_init_initial_state_driver(False)


class MaxUtilsInitTransformerState(unittest.TestCase):
"""Tests initialization of transformer states in max_utils.py"""

Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
for k, v in data.items():
data[k] = v[:config.global_batch_size_to_train_on,:]

logits, intermediate_outputs = model.apply({'params': params},
logits, intermediate_outputs = model.apply(params,
data['inputs'],
data['inputs_position'],
decoder_segment_ids=data['inputs_segmentation'],
Expand Down
10 changes: 8 additions & 2 deletions end_to_end/test_generate_param_only_checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ helpFunction()
echo -e "\t-o output_path: gs://test-maxtext-output"
echo -e "\t-i ici_tensor_parallelism: 8"
echo -e "\t-a attention: flash"
echo -e "\t-q quantization: int8"
exit 1 # Exit script after printing help
}

Expand All @@ -22,8 +23,9 @@ dataset_path=gs://test-maxtext-dataset
base_output_directory=gs://test-maxtext-output
ici_tensor_parallelism=8
attention=flash
quantization=""

while getopts "nr:d:o:t:i:a:" opt
while getopts "nr:d:o:t:i:a:q:" opt
do
case "$opt" in
n ) dry_run=true ;;
Expand All @@ -32,14 +34,15 @@ do
o ) base_output_directory="$OPTARG";;
i ) ici_tensor_parallelism="$OPTARG" ;;
a ) attention="$OPTARG" ;;
q ) quantization="int8" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done

echo
echo "Running: ./$0 dataset_path=${dataset_path} base_output_directory=${base_output_directory}"
echo " dry_run=${dry_run} run_id=${run_id} "
echo " ici_tensor_parallelism=${ici_tensor_parallelism} attention=${attention}"
echo " ici_tensor_parallelism=${ici_tensor_parallelism} attention=${attention} quantization=${quantization}"
echo

if "$dry_run"; then
Expand All @@ -60,6 +63,7 @@ run_name=${training_ckpt_run_id} \
base_output_directory=${base_output_directory} \
dataset_path=${dataset_path} attention=${attention} \
steps=5 checkpoint_period=3 async_checkpointing=false \
quantization=${quantization} \
${model_params} \


Expand All @@ -83,6 +87,7 @@ run_name=${decode_ckpt_run_id} attention=${attention} \
base_output_directory=${base_output_directory} \
dataset_path=${dataset_path} async_checkpointing=false \
load_full_state_path=${base_output_directory}/${training_ckpt_run_id}/checkpoints/3/items \
quantization=${quantization} \
${model_params} \


Expand All @@ -106,6 +111,7 @@ dataset_path=${dataset_path} \
load_parameters_path=${base_output_directory}/${decode_ckpt_run_id}/checkpoints/0/items \
attention=dot_product ici_tensor_parallelism=${ici_tensor_parallelism} steps=50 \
metrics_file=/tmp/${run_id}_metrics.txt async_checkpointing=false max_target_length=128 per_device_batch_size=1 \
quantization=${quantization} \
${model_params} \

if [ $? -eq 0 ]
Expand Down

0 comments on commit 6498e14

Please sign in to comment.