From fad34976cd8683560603f8dff9ec27129318a549 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Fri, 16 Aug 2024 10:06:41 +0000 Subject: [PATCH] sparse upcycling converter --- examples/xglm/README.md | 31 +++-- examples/xglm/convert_dense2moe.py | 179 +++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 examples/xglm/convert_dense2moe.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index 22765f52..8f62fc57 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -1,18 +1,27 @@ # How to use XGLM? 1. First, make sure to convert the weights from huggingface, for instance: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M - ``` +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M +``` -1. Now you are ready to use XGLM. +2. Now you are ready to use XGLM. Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: - ``` - torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml - ``` +```bash +torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml +``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. -1. If you want to convert your finetuned checkpoint back to huggingface use: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M - ``` +3. If you want to convert your finetuned checkpoint back to huggingface use: +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checkpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M +``` + +## Sparse Upcycling + +To create a sparse model from a dense model, you can use the `convert_dense2moe.py` script that goes from a GPT3 Nanotron model to a GPT3 MoE Nanotron model. For instance: +```bash +cd examples/xglm +torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8 +``` +Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps. diff --git a/examples/xglm/convert_dense2moe.py b/examples/xglm/convert_dense2moe.py new file mode 100644 index 00000000..fa4d9af7 --- /dev/null +++ b/examples/xglm/convert_dense2moe.py @@ -0,0 +1,179 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=nanotron_weights --save-path=nanotron_moe_weights +""" + +import dataclasses +import json +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +from torch import nn +import torch +import nanotron +from nanotron.config.models_config import GPT3Config, GPT3MoEConfig +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.trainer import mark_tied_parameters + +from convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config, num_experts=8) -> GPT3MoEConfig: + return GPT3MoEConfig( + **config.__dict__, + is_moe=True, + moe_num_experts=num_experts, + num_experts_per_tok=min(2, num_experts), # arbitrarily chosen + moe_loss_weight=0.01, # arbitrarily chosen + moe_z_loss_weight=0.001, # arbitrarily chosen + moe_glu=False, + ) + + +def convert_dense_to_moe(ff_moe: nn.Module, dense_ff: nn.Module, num_experts: int): + with torch.no_grad(): + # only copy the weight matrix and repeat it n_expert times + weight_1 = dense_ff.c_fc.weight.clone() + if num_experts == 1: + ff_moe.experts.mlp.w1.module.weight.data = weight_1.contiguous() + else: + # [intermediate_size, hidden_size] -> [hidden_size, intermediate_size * n_experts] + weight_1 = weight_1.T + ff_moe.experts.mlp.w1.module.weight.data = weight_1.repeat(1, num_experts) + + weight_2 = dense_ff.c_proj.weight.clone() + if num_experts == 1: # just a specific case for 1 expert + ff_moe.experts.mlp.w2.module.weight.data = weight_2.contiguous() + else: + # [hidden_size, intermediate_size] -> [intermediate_size * n_experts, hidden_size] + weight_2 = weight_2.T + ff_moe.experts.mlp.w2.module.weight.data = weight_2.repeat(num_experts, 1) + + # # -- could add bias only for 2nd layer, because that works with the MegaBlocks MoE implementation + # # -- but won't make a big difference? + # ff_moe.experts.bias.copy_(dense_ff.c_proj.bias) + + # init gating randomly + nn.init.normal_(ff_moe.gate.layer.weight, mean=0.0, std=0.02) + + +def convert_decoder(block_moe: GPT3MoEBlock, block_nt: GPTBlock, num_experts: int): + convert_generic(block_moe.ln_1, block_nt.ln_1) + convert_generic(block_moe.attn, block_nt.attn) + convert_generic(block_moe.ln_2, block_nt.ln_2) + convert_dense_to_moe(block_moe.ff, block_nt.ff, num_experts) + + +def convert( + model_moe: GPT3MoEForTraining, model_dense: GPT3ForTraining, num_experts: int +): + convert_generic( + model_moe.model.token_embeddings.pp_block.token_embedding, + model_dense.model.token_embeddings.pp_block.token_embedding, + ) + for layer_moe, layer_nt in zip(model_moe.model.decoder, model_dense.model.decoder): + convert_decoder(layer_moe.pp_block, layer_nt.pp_block, num_experts) + convert_generic( + model_moe.model.final_layer_norm.pp_block, + model_dense.model.final_layer_norm.pp_block, + ) + convert_generic( + model_moe.model.lm_head.pp_block, model_dense.model.lm_head.pp_block + ) + + +def create_nt_moe_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3MoEConfig(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, + parallel_context=parallel_context, + root_folder=checkpoint_path, + ) + + return model_nt + + +def main( + checkpoint_path: Path, + save_path: Path, + num_experts: int, +): + # Load nanotron model. + model_dense = create_nt_model(checkpoint_path=checkpoint_path) + + # Init moe model. + model_config_moe = convert_config(model_dense.config, num_experts) + model_moe = create_nt_moe_model(model_config=model_config_moe) + + convert(model_moe, model_dense, num_experts) + nanotron.serialize.save_weights( + model=model_moe, + parallel_context=model_moe.parallel_context, + root_folder=save_path, + ) + with open(save_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(model_config_moe), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + # fix all random seeds + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + parser = ArgumentParser(description="Convert dense weights to moe format") + parser.add_argument( + "--checkpoint-path", + type=Path, + default="checkpoints/xglm-7.5B", + help="Path to the nanotron dense checkpoint", + ) + parser.add_argument( + "--save-path", + type=Path, + default="checkpoints/xglm-moe-7.5B", + help="Path to save the nanotron moe model", + ) + parser.add_argument( + "--num-experts", + type=int, + default=8, + help="Number of experts in the MoE model (duplicates of MLP layer)", + ) + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.num_experts)