diff --git a/3.test_cases/XX.transformer-engine/.gitignore b/3.test_cases/XX.transformer-engine/.gitignore new file mode 100644 index 00000000..86692abd --- /dev/null +++ b/3.test_cases/XX.transformer-engine/.gitignore @@ -0,0 +1,2 @@ +checkpoints +slurm-*.out diff --git a/3.test_cases/XX.transformer-engine/0.transformer-engine.dockerfile b/3.test_cases/XX.transformer-engine/0.transformer-engine.dockerfile new file mode 100644 index 00000000..6b8149d0 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/0.transformer-engine.dockerfile @@ -0,0 +1,227 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +FROM nvcr.io/nvidia/pytorch:24.04-py3 +ENV DEBIAN_FRONTEND=noninteractive + +# The three must-be-built packages. +# Efa-installer>=1.29.1 required for nccl>=2.19.0 to avoid libfabric NCCL error. +ENV EFA_INSTALLER_VERSION=1.30.0 +ENV AWS_OFI_NCCL_VERSION=1.8.1-aws +ENV NCCL_TESTS_VERSION=master + +## Uncomment below when this Dockerfile builds a container image with efa-installer<1.29.1 and +# nccl>=2.19.0. See https://github.com/aws-samples/awsome-distributed-training/tree/main/1.architectures/efa-cheatsheet.md +#ENV FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +RUN apt-get update -y +RUN apt-get remove -y --allow-change-held-packages \ + libmlx5-1 ibverbs-utils libibverbs-dev libibverbs1 + +# We noticed that since 23.09, we can't just delete the whole /opt/hpcx/, otherwise `import torch` +# complains about missing libuc?.so. +RUN rm -rf /opt/hpcx/ompi \ + && rm -rf /usr/local/mpi \ + && rm -rf /opt/hpcx/nccl_rdma_sharp_plugin \ + && ldconfig +ENV OPAL_PREFIX= +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + git \ + gcc \ + vim \ + kmod \ + openssh-client \ + openssh-server \ + build-essential \ + curl \ + autoconf \ + libtool \ + gdb \ + automake \ + cmake \ + apt-utils \ + libhwloc-dev \ + aptitude && \ + DEBIAN_FRONTEND=noninteractive apt autoremove -y + +# EFA +RUN apt-get update && \ + cd /tmp && \ + curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \ + tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \ + cd aws-efa-installer && \ + # ONLY add `--skip-kmod`, `--no-verify` and `--skip-limit-conf` flags to container image. + # Those three flags must NOT be used on the host. + # + # Explanations: + # - to build EFA in the Dockerfile, we added --skip-kmod and --no-verify. Without these flags, + # the Dockerfile will fail to build. If installing EFA on the host and not in a container, + # please remove these flags. + # - The --skip-limit-conf can be retained in Dockerfile, but it's redundant as the host already + # has these limits set by efa_installer. + ./efa_installer.sh -y -g -d --skip-kmod --no-verify --skip-limit-conf && \ + ldconfig && \ + rm -rf /tmp/aws-efa-installer /var/lib/apt/lists/* +ENV LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +ENV PATH=/opt/amazon/efa/bin:/opt/amazon/openmpi/bin:$PATH + + +#################################################################################################### +# [CUSTOM_NCCL_OPTION_1] Uncomment below stanza to install another NCCL version using the official +# binaries. +# +# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the +# aws-ofi-ccnl. +#################################################################################################### +#ENV NCCL_VERSION=2.19.3-1 +#RUN cd /opt && \ +# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \ +# dpkg -i cuda-keyring_1.0-1_all.deb && \ +# apt update && \ +# apt install -y libnccl2==${NCCL_VERSION} libnccl-dev==${NCCL_VERSION} && \ +# echo NCCL_SOCKET_IFNAME=^docker0,lo >> /etc/nccl.conf + + +#################################################################################################### +# [CUSTOM_NCCL_OPTION_2] Install NCCL from source to the same location as the built-in ones. The +# benefits of installing to the same location as the built-in version are: +# +# 1. There's only ever a single libnccl version offered by this image, preventing application from +# mistakenly chooses a wrong version. +# 2. No longer needing extra settings for LD_LIBRARY_PATH or LD_PRELOAD. +# +# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the +# aws-ofi-ccnl. +#################################################################################################### +ENV NCCL_VERSION=2.19.3-1 +RUN apt-get remove -y libnccl2 libnccl-dev \ + && cd /tmp \ + && git clone https://github.com/NVIDIA/nccl.git -b v${NCCL_VERSION} \ + && cd nccl \ + && make -j src.build BUILDDIR=/usr \ + # Build for p4 & p5. + NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90, -gencode=arch=compute_80,code=sm_80" \ + && rm -rf /tmp/nccl \ + && echo NCCL_SOCKET_IFNAME=^docker0,lo >> /etc/nccl.conf + + +#################################################################################################### +# Rebuild OpenMPI with custom PMIX version. E.g., to match what host's Slurm is built with (see +# /opt/pmix/ on host, or run pmix_info on host). +# +# May be needed on rare occassions when `srun --mpi=pmix --container-image=... ` +# mysteriously crashes. +# +# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the +# aws-ofi-ccnl. +#################################################################################################### +ENV OPEN_MPI_PATH=/opt/amazon/openmpi + +# OpenMPI build script claims PMIX_VERSION, and complains if we use it. +ENV CUSTOM_PMIX_VERSION=4.2.6 +RUN apt-get update && apt-get install -y libevent-dev \ + && cd /tmp \ + && wget https://github.com/openpmix/openpmix/releases/download/v${CUSTOM_PMIX_VERSION}/pmix-${CUSTOM_PMIX_VERSION}.tar.gz \ + && tar -xzf pmix-${CUSTOM_PMIX_VERSION}.tar.gz \ + && rm pmix-${CUSTOM_PMIX_VERSION}.tar.gz \ + && cd pmix-${CUSTOM_PMIX_VERSION}/ \ + && ./autogen.pl \ + && ./configure --prefix=/opt/pmix \ + && make -j \ + && make install \ + && echo /opt/pmix/lib > /etc/ld.so.conf.d/pmix.conf \ + && ldconfig \ + && cd / \ + && rm -fr /tmp/pmix-${CUSTOM_PMIX_VERSION}/ +# To silence this runtime error message: +# [p4de-st-p4de-2:110912] PMIX ERROR: ERROR in file gds_ds12_lock_pthread.c at line 168 +ENV PMIX_GDS_MODULE=^ds12 \ + PMIX_MCA_gds=^ds12 + +# Rebuild openmpi with DLC style (which it remarks as "without libfabric"), with the above pmix. +ENV OMPI_VERSION=4.1.6 +RUN rm -fr ${OPEN_MPI_PATH} \ + && mkdir /tmp/openmpi \ + && cd /tmp/openmpi \ + && wget --quiet https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OMPI_VERSION}.tar.gz \ + && tar zxf openmpi-${OMPI_VERSION}.tar.gz \ + && rm openmpi-${OMPI_VERSION}.tar.gz \ + && cd openmpi-${OMPI_VERSION} \ + && ./configure --enable-orterun-prefix-by-default --prefix=$OPEN_MPI_PATH --with-cuda=${CUDA_HOME} --with-slurm --with-pmix=/opt/pmix \ + && make -j $(nproc) all \ + && make install \ + && ldconfig \ + && cd / \ + && rm -rf /tmp/openmpi \ + && ompi_info --parsable --all | grep mpi_built_with_cuda_support:value \ + # Verify pmix from /opt/pmix/ + && ldd /opt/amazon/openmpi/lib/openmpi/mca_pmix_ext3x.so | grep '/opt/pmix/lib/libpmix.so.* ' > /opt/amazon/openmpi-pmix.txt +#################################################################################################### + + +# NCCL EFA Plugin +RUN mkdir -p /tmp && \ + cd /tmp && \ + curl -LO https://github.com/aws/aws-ofi-nccl/archive/refs/tags/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ + tar -xzf /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ + rm /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ + mv aws-ofi-nccl-${AWS_OFI_NCCL_VERSION} aws-ofi-nccl && \ + cd /tmp/aws-ofi-nccl && \ + ./autogen.sh && \ + ./configure --prefix=/opt/amazon/efa \ + --with-libfabric=/opt/amazon/efa \ + --with-cuda=/usr/local/cuda \ + --enable-platform-aws \ + --with-mpi=/opt/amazon/openmpi && \ + make -j$(nproc) install && \ + rm -rf /tmp/aws-ofi/nccl + +# Do this to minimize the ld path env vars that users need to define when running this image. +RUN echo "/usr/local/lib" >> /etc/ld.so.conf.d/local.conf && \ + echo "/opt/amazon/openmpi/lib" >> /etc/ld.so.conf.d/efa.conf && \ + ldconfig + +ENV OMPI_MCA_pml=^cm,ucx \ + OMPI_MCA_btl=tcp,self \ + OMPI_MCA_btl_tcp_if_exclude=lo,docker0 \ + OPAL_PREFIX=/opt/amazon/openmpi \ + # https://discuss.pytorch.org/t/nccl-network-is-unreachable-connection-refused-when-initializing-ddp/137352 + # https://github.com/pytorch/pytorch/issues/68893 + NCCL_SOCKET_IFNAME=^docker,lo + +ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" + +# NCCL-tests: always good to include this as a diagnostic tool. +RUN git clone https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \ + && cd /opt/nccl-tests \ + && git checkout ${NCCL_TESTS_VERSION} \ + && make MPI=1 \ + MPI_HOME=/opt/amazon/openmpi \ + CUDA_HOME=/usr/local/cuda \ + NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_80,code=sm_80" + + +#################################################################################################### +# Custom packages. Disable as you like. NOTE: always check `pip list` what's been installed. For +# example, the base container comes pre-installed with Transformer Engine, flash attention, triton +# (https://github.com/openai/triton/), etc. +#################################################################################################### +# Install the xformers dependency from source, because pip install either breaks or try to pull +# its own pt + cuda. +# +# Pre-requisite: build node has enough memory to compile xformers. More info on the stanza. +RUN export TORCH_CUDA_ARCH_LIST="8.0;9.0+PTX" && \ + # On p4de.24xlarge: + # - MAX_JOBS=16 => 145GB memory + # - MAX_JOBS=32 => 241GB memory + # - MAX_JOBS=48 => 243GB memory, 542.5s + # + # NOTE: must export MAX_JOBS. For some reason, `MAX_JOBS=16 pip install ...` doesn't seem to + # work to prevent OOM. + export MAX_JOBS=32 && \ + export NVCC_PREPEND_FLAGS="-t 32" && \ + pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers + +RUN pip install transformers datasets + +WORKDIR "/fsx" diff --git a/3.test_cases/XX.transformer-engine/1.train_llama.sbatch b/3.test_cases/XX.transformer-engine/1.train_llama.sbatch new file mode 100755 index 00000000..a0cf35b2 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/1.train_llama.sbatch @@ -0,0 +1,95 @@ +#!/bin/bash + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +#SBATCH --nodes=2 # number of nodes to use +#SBATCH --job-name=LlamaFP8 # name of your job +#SBATCH --exclusive # job has exclusive use of the resource, no sharing + +set -ex; + +########################### +###### User Variables ##### +########################### + +GPUS_PER_NODE=8 # 4 for G5.12x, 8 for P4/P5 + +########################### +## Environment Variables ## +########################### + +## Plenty of EFA level variables +## Comment out for non-efa instances (G4d, P3) +## For G5.12x, Comment out RDMA and Fork safe +## For G4dn and other G5, comment out all +export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_PROVIDER=efa +export NCCL_DEBUG=INFO +## Switching SYNC_MEMOPS to zero can boost throughput with FSDP +## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS +## Reduces memory synchronizations +## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +# default variables for Enroot +: "${IMAGE:=$(pwd)/transformer-engine.sqsh}" +: "${DATA_PATH:=/fsx}" +: "${FSX_MOUNT:=$(pwd):$DATA_PATH}" + +declare -a ARGS=( + --container-image $IMAGE + --container-mounts $FSX_MOUNT +) + +########################### +####### Torch Dist ####### +########################### + +declare -a TORCHRUN_ARGS=( + --nproc_per_node=$GPUS_PER_NODE + --nnodes=$SLURM_JOB_NUM_NODES + --rdzv_id=$SLURM_JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=$(hostname) +) + +export TORCHRUN=torchrun +export TRAIN_SCRIPT=./train.py + +############################ +# Llama 2 Training Params ## +############################ + +declare -a TRAINING_ARGS=( + --max_context_width=4096 + --num_key_value_heads=32 # 7b: 32 13b: 40 70b: 8 + --intermediate_size=11008 # 7b: 11008 13b: 13824 70b: 28672 + --hidden_width=4096 # 7b: 4096 13b: 5120 70b: 8192 + --num_layers=32 # 7b: 32 13b: 40 70b: 80 + --num_heads=32 # 7b: 32 13b: 40 70b: 64 + --model_type=llama_v2 + --tokenizer="hf-internal-testing/llama-tokenizer" + --checkpoint_freq=5000 + --validation_freq=100 + --max_steps=5000 + --checkpoint_dir=./checkpoints + --dataset='c4' + --dataset_config_name='en' + --resume_from_checkpoint=./checkpoints + --train_batch_size=1 + --val_batch_size=1 + --sharding_strategy="full" # https://pytorch.org/docs/stable/fsdp.html + --offload_activations=1 + --fp8=1 +) + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +srun ${AUTO_RESUME} -l "${ARGS[@]}" torchrun "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}" diff --git a/3.test_cases/XX.transformer-engine/README.md b/3.test_cases/XX.transformer-engine/README.md new file mode 100644 index 00000000..525dd2cc --- /dev/null +++ b/3.test_cases/XX.transformer-engine/README.md @@ -0,0 +1,78 @@ +# Llama FSDP training with FP8(Nvidia Transformer Engine) + +[Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. + +This example trains [Hugging Face implementation](https://huggingface.co/docs/transformers/en/model_doc/llama2) of [Meta's Llama model](https://llama.meta.com/) with Fully Sharded Data Parallel using FP8. It combines all [Llama 2 optimizations described in Transformer Engine documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/te_llama/tutorial_accelerate_hf_llama_with_te.html) with popular distributed training approach [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html) + +This example derived from FSDP example with the following changes: +* Run the training code in nvidia/pytorch docker container(required by Transformer Engine) +* Bound training code to Llama model only +* Use [the optimized version of Llama model](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py) +* Apply Transformer Engine fp8_autocast + +This guide assumes that you have the following: + +- A functional Slurm cluster on AWS. +- Docker, [Pyxis](https://github.com/NVIDIA/pyxis) and [Enroot](https://github.com/NVIDIA/enroot) installed. +- An FSx for Lustre filesystem mounted on `/fsx`. + +It is recommended that you use the templates in the architectures [directory](../../1.architectures) + +You will also setup the following variables in your terminal environment. + +```bash +export DATA_PATH=/fsx # FSx for Lustre shared file-system +``` + +Make sure that your current directory is under a shared filesystem such as `/fsx/` or the home directory when using [Parallel Cluster](../../1.architectures/aws-parallelcluster). + +1. Copy the file `0.transformer-engine.dockerfile` or its content to your head-node. +2. Build the container image with the command below + +```bash +docker build -t transformer-engine -f 0.transformer-engine.dockerfile . +``` + +3. Once the image is built, you can check if it is present with `docker images`. You should see an output similar to this one: + +``` +[ec2-user@ip-10-0-10-78 ~]$ docker images +REPOSITORY TAG IMAGE ID CREATED SIZE +transformer-engine latest 91dbebf98269 9 seconds ago 22.6GB +``` + +4. Create the squash file with the command below. + +```bash +enroot import -o transformer-engine.sqsh dockerd://transformer-engine:latest +``` + +5. Now you copy the file `1.train_llama.sbatch` to your cluster then submit a training jobs with the command below: + +```bash +sbatch 1.train_llama.sbatch +``` + +6. You will see a new file in your current working directory called `slurm-XY.out` where `XY` is a number. This is your output file and will capture the `STDOUT` and `STDERR` from your job. You can check how it progresses via the command `tail -f slurm-XY.out` but with the relevant filename. The file content will be similar to the below: + +``` +0: 2024-05-15 02:37:59 I [train.py:110] Batch 20 Loss: 8.42184, Speed: 47.06 samples/sec, lr: 0.000100 +0: 2024-05-15 02:37:59 I [train.py:110] Batch 21 Loss: 8.26941, Speed: 47.30 samples/sec, lr: 0.000100 +0: 2024-05-15 02:37:59 I [train.py:110] Batch 22 Loss: 8.19849, Speed: 47.12 samples/sec, lr: 0.000100 +0: 2024-05-15 02:38:00 I [train.py:110] Batch 23 Loss: 7.74492, Speed: 46.86 samples/sec, lr: 0.000100 +0: 2024-05-15 02:38:00 I [train.py:110] Batch 24 Loss: 8.46525, Speed: 47.35 samples/sec, lr: 0.000100 +0: 2024-05-15 02:38:00 I [train.py:110] Batch 25 Loss: 7.60201, Speed: 47.38 samples/sec, lr: 0.000100 +``` + +7. Change `--fp8=1` to `--fp8=0` in `1.train_llama.sbatch` to turn off Transformer Engine FP8 precision and rerun `sbatch 1.train_llama.sbatch`. + +``` +0: 2024-05-15 02:44:55 I [train.py:110] Batch 20 Loss: 8.82996, Speed: 30.46 samples/sec, lr: 0.000100 +0: 2024-05-15 02:44:55 I [train.py:110] Batch 21 Loss: 8.17265, Speed: 30.71 samples/sec, lr: 0.000100 +0: 2024-05-15 02:44:56 I [train.py:110] Batch 22 Loss: 7.92729, Speed: 30.63 samples/sec, lr: 0.000100 +0: 2024-05-15 02:44:56 I [train.py:110] Batch 23 Loss: 7.75582, Speed: 30.64 samples/sec, lr: 0.000100 +0: 2024-05-15 02:44:57 I [train.py:110] Batch 24 Loss: 8.72653, Speed: 30.56 samples/sec, lr: 0.000100 +0: 2024-05-15 02:44:57 I [train.py:110] Batch 25 Loss: 7.79590, Speed: 30.78 samples/sec, lr: 0.000100 +``` + +It's noticeable that Transformer Engine fp8 precision gives more than 50% speedup. diff --git a/3.test_cases/XX.transformer-engine/model_utils/__init__.py b/3.test_cases/XX.transformer-engine/model_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/3.test_cases/XX.transformer-engine/model_utils/arguments.py b/3.test_cases/XX.transformer-engine/model_utils/arguments.py new file mode 100644 index 00000000..edd89495 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/model_utils/arguments.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +import argparse +import os + + +def parse_args(): # pylint: disable=too-many-statements + """Parse args.""" + parser = argparse.ArgumentParser() + + # hyperparameters sent by the client are passed as command-line arguments to the script. + + opt_grp = parser.add_argument_group( + title="optimization", description="arguments for optimization" + ) + opt_grp.add_argument( + "--train_batch_size", + type=int, + default=2, + help="batch size per dp rank", # pylint: disable=line-too-long + ) + opt_grp.add_argument("--val_batch_size", type=int, default=4) + opt_grp.add_argument("--max_steps", "--max_training_steps", type=int, default=5000) + opt_grp.add_argument("--seed", type=int, default=12345) + opt_grp.add_argument("--same_seed", type=int, default=0) + opt_grp.add_argument("--bf16", default=1, type=int, help="automatic mixed precision training with bf16") + opt_grp.add_argument("--fp8", default=0, type=int, help="automatic mixed precision training with fp8") + opt_grp.add_argument("--grad_clip", default=1.0, type=float, help="gradient clipping") + opt_grp.add_argument("--weight_decay", default=0.2, type=float, help="weight decay") + opt_grp.add_argument( + "--beta1", default=0.9, type=float, help="beta1 parameter for Adam optimizer" + ) + opt_grp.add_argument( + "--beta2", default=0.95, type=float, help="beta2 parameter for Adam optimizer" + ) + opt_grp.add_argument( + "--activation_checkpointing", + type=int, + default=1, + help="enable gradient checkpointing to reduce memory consumption", + ) + opt_grp.add_argument( + "--intermediate_size", + type=int, + default=11008, + help="intermediate_size, a dimension associated with MLP", + ) + opt_grp.add_argument( + "--num_key_value_heads", + type=int, + default=None, + help="num_key_value_heads for GQA", + ) + parser.add_argument( + "--logging_freq", type=int, default=1, help="number of iterations between logging" + ) + parser.add_argument("--tensorboard_dir", type=str, nargs="+", default=None) + + model_grp = parser.add_argument_group( + title="model", description="arguments to describe model configuration" + ) + model_grp.add_argument("--max_context_width", type=int, default=2048) + model_grp.add_argument("--vocab_size", type=int, default=32000) + model_grp.add_argument("--hidden_width", type=int, default=4096) + model_grp.add_argument("--num_layers", type=int, default=32) + model_grp.add_argument("--num_heads", type=int, default=32) + model_grp.add_argument("--resid_pdrop", type=float, default=0.1) + model_grp.add_argument("--embd_pdrop", type=float, default=0.1) + model_grp.add_argument("--attn_pdrop", type=float, default=0.1) + model_grp.add_argument("--summary_first_pdrop", type=float, default=0.1) + model_grp.add_argument("--initializer_range", type=float, default=0.02) + model_grp.add_argument("--model_type", type=str, default="gpt_neox") + model_grp.add_argument("--rotary_pct", type=float, default=0.25) + model_grp.add_argument("--rotary_emb_base", type=int, default=10000) + + fsdp_grp = parser.add_argument_group( + title="fsdp", description="arguments for fully sharded data parallel" + ) + fsdp_grp.add_argument("--offload_activations", type=int, default=0) + fsdp_grp.add_argument("--activation_loading_horizon", type=int, default=2) + fsdp_grp.add_argument("--limit_all_gathers", default=1, type=int) + fsdp_grp.add_argument( + "--sharding_strategy", + type=str, + default="full", + choices=["full", "hybrid"], + help="FSDP sharding strategy https://pytorch.org/docs/stable/fsdp.html", + ) + + # learning rate + lr_grp = parser.add_argument_group( + title="lr", description="arguments for learning rate schedule" + ) + lr_grp.add_argument("--lr", type=float, default=0.0001, help="Initial learning rate.") + lr_grp.add_argument( + "--lr_decay_style", + type=str, + default="cosine", + choices=["constant", "linear", "cosine", "exponential", "plateau"], + help="Learning rate decay function.", + ) + lr_grp.add_argument( + "--lr_decay_iters", + type=int, + default=None, + help="number of iterations to decay learning rate over," " If None defaults to train iters", + ) + lr_grp.add_argument( + "--min_lr", + type=float, + default=1e-05, + help="Minumum value for learning rate. The scheduler" "clip values below this threshold.", + ) + lr_grp.add_argument( + "--warmup", + type=float, + default=0.0032, + help="Percentage of total iterations to warmup on " + "(.01 = 1 percent of all training iters).", + ) + lr_grp.add_argument( + "--plateau", + type=float, + default=0.0, + help="Percentage of total iterations to keep at max if using plateau lr", + ) + io_grp = parser.add_argument_group(title="io", description="location for input and output") + io_grp.add_argument("--dataset", type=str, default="c4") + io_grp.add_argument("--dataset_config_name", type=str, default=None) + io_grp.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") + io_grp.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="Checkpoint folder name to load from", + ) + io_grp.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Saves partial checkpoints (model, optimizer) to this dir.", # pylint: disable=line-too-long + ) + io_grp.add_argument( + "--epochs", type=int, default=3, help="times of iterating over the training dataset" + ) + + parser.add_argument( + "--checkpoint_freq", + type=int, + default=1000, + help="number of iterations between checkpointing", + ) + parser.add_argument( + "--validation_freq", + type=int, + default=None, + help="number of iterations to print validation loss", + ) + parser.add_argument( + "--validation_batches", + type=int, + default=10, + help="number of batches to estimate validation loss", + ) + + return parser.parse_known_args() diff --git a/3.test_cases/XX.transformer-engine/model_utils/checkpoint.py b/3.test_cases/XX.transformer-engine/model_utils/checkpoint.py new file mode 100644 index 00000000..e7576132 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/model_utils/checkpoint.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +import os +import re +import pickle +import statistics +import time +import warnings +from pathlib import Path + +import torch +import torch.distributed as dist + +# pylint: disable=import-error,no-name-in-module +import torch.distributed.checkpoint as dist_cp +from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from model_utils.train_utils import get_logger + + +logger = get_logger() + +def save_checkpoint(model, optimizer, scheduler, user_content, root_dir, sub_dir): + torch.cuda.empty_cache() + + save_dir = os.path.join(root_dir, sub_dir) + if dist.get_rank() == 0: + logger.info("Writing checkpoint to {0}.".format(save_dir)) + + with FSDP.state_dict_type( + model, + StateDictType.SHARDED_STATE_DICT): + state_dict = { + "model": model.state_dict(), + "optim": FSDP.optim_state_dict(model, optimizer), + "scheduler": scheduler.state_dict(), + "total_steps": user_content["total_steps"], + "start_batch_index": user_content["start_batch_index"], + } + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(save_dir) + ) + dist.barrier() + if dist.get_rank() == 0: + logger.info("Completed checkpoint.") + +def get_last_checkpoint(checkpoint_paths): + steps = [int(re.findall(r'\d+steps', checkpoint.stem)[0].replace('steps','')) \ + for checkpoint in checkpoint_paths] + checkpoints = sorted([(step, path) for step,path in zip(steps, checkpoint_paths)]) + + # find last checkpoint, skipping incomplete ones + for step, path in reversed(checkpoints): + metadata_path = path.joinpath(".metadata") + if not metadata_path.exists(): + logger.warn(f"{metadata_path} not found. Skipping this incomplete checkpoint") + continue + return path.as_posix() + else: + return None + +def load_checkpoint(model, optimizer, scheduler, checkpoint_dir, model_type, device): + checkpoint_paths = list(Path(checkpoint_dir).glob(f"{model_type}-*steps")) + last_checkpoint = get_last_checkpoint(checkpoint_paths) + if last_checkpoint is None: + if dist.get_rank() == 0: + logger.info("No Checkpoints Found") + return( + model, + optimizer, + scheduler, + 0, + 0, + ) + if dist.get_rank() == 0: + logger.info("Loading checkpoint from %s ...", last_checkpoint) + with FSDP.state_dict_type( + model, + StateDictType.SHARDED_STATE_DICT, + ): + state_dict = { + "model": model.state_dict(), + "scheduler": scheduler.state_dict(), + "total_steps": 0, + "start_batch_index": 0, + # cannot load the optimizer state_dict together with the model state_dict + } + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(last_checkpoint), + ) + model.load_state_dict(state_dict["model"]) + scheduler.load_state_dict(state_dict["scheduler"]) + if dist.get_rank() == 0: + logger.info("Loaded model state from disk") + logger.info("Loading optimizer state from disk") + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=state_dict["model"], + optimizer_key="optim", + storage_reader=dist_cp.FileSystemReader(last_checkpoint), + ) + if dist.get_rank() == 0: + logger.info("Loaded and sharded optimizer state from disk") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + # UserWarning to replace all_gather_base with all_gather_into_tensor floods the logs + flattened_osd = FSDP.optim_state_dict_to_load( + model, optimizer, optim_state["optim"] + ) + + if dist.get_rank() == 0: + logger.info("Converted optimizer state dict for FSDP") + optimizer.load_state_dict(flattened_osd) + dist.barrier() + if dist.get_rank() == 0: + logger.info("Checkpoint loaded from %s.", last_checkpoint) + return ( + model, + optimizer, + scheduler, + state_dict["total_steps"], + state_dict["start_batch_index"], + ) diff --git a/3.test_cases/XX.transformer-engine/model_utils/concat_dataset.py b/3.test_cases/XX.transformer-engine/model_utils/concat_dataset.py new file mode 100644 index 00000000..687f683a --- /dev/null +++ b/3.test_cases/XX.transformer-engine/model_utils/concat_dataset.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +import os +import numpy as np +import datasets as hf_datasets +from torch.utils.data import IterableDataset +from typing import Dict, Iterable, Union +from transformers import PreTrainedTokenizerBase + +class ConcatTokensDataset(IterableDataset): + def __init__( + self, + hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + wrap: bool, + ): + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + self.hf_dataset = hf_dataset + self.tokenizer = tokenizer + self.max_length = max_length + self.should_wrap = wrap + + def __iter__(self) -> Iterable[Dict[str, bytes]]: + + buffer = [] + mask_buffer = [] + for sample in self.hf_dataset: + encoded = self.tokenizer(sample['text'], + truncation=False, + padding=False) + iids = encoded['input_ids'] + mask = encoded['attention_mask'] + buffer = buffer + iids + [self.tokenizer.eos_token_id] + mask_buffer = mask_buffer + mask + [1] + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self.max_length:] if self.should_wrap else [] + concat_sample_mask = mask_buffer[:self.max_length] + mask_buffer = mask_buffer[self.max_length:] if self.should_wrap else [] + yield np.array(concat_sample) diff --git a/3.test_cases/XX.transformer-engine/model_utils/train_utils.py b/3.test_cases/XX.transformer-engine/model_utils/train_utils.py new file mode 100644 index 00000000..d4d58e30 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/model_utils/train_utils.py @@ -0,0 +1,395 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +import os +import math +import functools +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from datetime import datetime +import tqdm +import logging +from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy +from transformers import AutoTokenizer +from datasets import load_dataset + +from model_utils.concat_dataset import ConcatTokensDataset + +from transformer_engine.pytorch import checkpoint as te_checkpoint + +g_gigabyte = 1024**3 + +def setup(): + # initialize the process group + dist.init_process_group("nccl") + + +def cleanup(): + dist.destroy_process_group() + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + + +def format_metrics_to_gb(item): + """quick function to format numbers to gigabyte and round to 4 digit precision""" + metric_num = item / g_gigabyte + metric_num = round(metric_num, ndigits=4) + return metric_num + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(2).to(local_rank) + + if sampler: + sampler.set_epoch(epoch) + if rank==0: + inner_pbar = tqdm.tqdm( + range(len(train_loader)), colour="blue", desc="r0 Training Epoch" + ) + for batch in train_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + optimizer.zero_grad() + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] ) + loss = output["loss"] + loss.backward() + optimizer.step() + fsdp_loss[0] += loss.item() + fsdp_loss[1] += len(batch) + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + train_accuracy = fsdp_loss[0] / fsdp_loss[1] + + + if rank == 0: + inner_pbar.close() + print( + f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}" + ) + return train_accuracy + + +def validation(model, rank, world_size, val_loader): + model.eval() + correct = 0 + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(2).to(local_rank) + if rank == 0: + inner_pbar = tqdm.tqdm( + range(len(val_loader)), colour="green", desc="Validation Epoch" + ) + with torch.no_grad(): + for batch in val_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) + fsdp_loss[0] += output["loss"].item() # sum up batch loss + fsdp_loss[1] += len(batch) + + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + val_loss = fsdp_loss[0] / fsdp_loss[1] + if rank == 0: + inner_pbar.close() + print(f"Validation Loss: {val_loss:.4f}") + return val_loss + +def get_model_config(args): + from transformers import LlamaConfig + + model_config = LlamaConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_width, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_layers, + num_attention_heads=args.num_heads, + num_key_value_heads=args.num_key_value_heads, + hidden_act="silu", + max_position_embeddings=args.max_context_width, + initializer_range=args.initializer_range, + rms_norm_eps=1e-5, + use_cache=False, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + ) + return model_config + +def compute_num_params(model): + """Get num params.""" + num_params = 0 + seen = set() + for p in model.parameters(): # pylint: disable=invalid-name + if p not in seen: + seen.add(p) + if hasattr(p, "ds_shape"): + num_params += np.prod(p.ds_shape) + else: + num_params += np.prod(p.size()) + + return num_params + +_logger = None +def get_logger(): + global _logger + if _logger is None: + logging.getLogger("torch.distributed.checkpoint._dedup_tensors").setLevel(logging.ERROR) + logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.ERROR) + _logger = logging.getLogger(__name__) + _logger.setLevel(logging.INFO) + _logger.handlers = [] + ch = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(levelname).1s " "[%(filename)s:%(lineno)d] %(message)s", + "%Y-%m-%d %H:%M:%S", + ) + ch.setFormatter(formatter) + _logger.addHandler(ch) + _logger.propagate = False + return _logger + + +def get_sharding_strategy(strategy: str): + """Get sharding strategy.""" + sharding_strategy = getattr(ShardingStrategy, strategy.upper()) + _logger.debug("Translating %s to %s.", strategy, sharding_strategy) + return sharding_strategy + + +def get_backward_fetch_policy(policy: str): + """Get backward fetch policy.""" + backward_fetch_policy = getattr(BackwardPrefetch, policy.upper()) + _logger.debug("Translating %s to %s.", policy, backward_fetch_policy) + return backward_fetch_policy + +def apply_activation_checkpoint(args, model): + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, + ) + + transformer_layer = type(model.model.layers[0]) + check_fn_gpt = lambda submodule: isinstance( + submodule, transformer_layer + ) + entrant_wrapper = functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_fn=te_checkpoint if args.fp8 else None, + use_reentrant=False + ) + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=entrant_wrapper, check_fn=check_fn_gpt + ) + +def get_param_groups_by_weight_decay(module): + """Get param groups.""" + weight_decay_params = {"params": []} + no_weight_decay_params = {"params": [], "weight_decay": 0.0} + param_ids = set() + + from torch.nn import LayerNorm + + for module_ in module.modules(): + # if isinstance(module_, FusedLayerNorm) or + if isinstance(module_, LayerNorm): + for p in list( + module_._parameters.values() + ): # pylint: disable=invalid-name,protected-access + if p is not None and id(p) not in param_ids: + no_weight_decay_params["params"].append(p) + param_ids.add(id(p)) + else: + for n, p in list( + module_._parameters.items() + ): # pylint: disable=invalid-name,protected-access + if p is not None and n != "bias" and id(p) not in param_ids: + weight_decay_params["params"].append(p) + param_ids.add(id(p)) + for n, p in list( + module_._parameters.items() + ): # pylint: disable=invalid-name,protected-access + if p is not None and n == "bias" and id(p) not in param_ids: + no_weight_decay_params["params"].append(p) + param_ids.add(id(p)) + return weight_decay_params, no_weight_decay_params + +class AnnealingLR: # pylint: disable=too-many-instance-attributes + """Anneals the learning rate.""" + + def __init__( # pylint: disable=too-many-arguments + self, + optimizer, + start_lr, + warmup_iter, + plateau_iter, + total_iters, + decay_style, + last_iter, + min_lr=0.0, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False, + ): + + # Class values. + self.optimizer = optimizer + self.start_lr = start_lr + self.min_lr = min_lr + self.warmup_iter = warmup_iter + self.plateau_iter = plateau_iter + self.num_iters = last_iter + self.end_iter = total_iters + assert self.end_iter > 0 + self.decay_style = decay_style + self.override_lr_scheduler = override_lr_scheduler + self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler + if self.override_lr_scheduler: + assert not self.use_checkpoint_lr_scheduler, ( + "both override and " "use-checkpoint are set." + ) + # Set the learning rate + self.step(self.num_iters) + self.rank = dist.get_rank() + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) + # Warmup. + if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: + return float(self.start_lr) * num_iters_ / self.warmup_iter + + num_iters_ = num_iters_ - self.warmup_iter + if self.decay_style == "linear": + lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter + elif self.decay_style == "plateau": + if self.num_iters <= self.plateau_iter: + lr = self.start_lr + else: + lr = ( + self.start_lr + * (self.end_iter - self.num_iters) + / (self.end_iter - self.plateau_iter) + ) + elif self.decay_style == "cosine": + lr = self.start_lr / 2.0 * (math.cos(math.pi * num_iters_ / self.end_iter) + 1) + elif self.decay_style == "exponential": + # exp(-0.693) = 1/2 + lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) + else: + lr = self.start_lr + return max(lr, self.min_lr) + + def step(self, step_num=None): + """Set lr for all parameters groups.""" + if step_num is None: + step_num = self.num_iters + 1 + self.num_iters = step_num + new_lr = self.get_lr() + for group in self.optimizer.param_groups: + group["lr"] = new_lr + + def state_dict(self): + """State dict.""" + state_dict = { + "start_lr": self.start_lr, + "warmup_iter": self.warmup_iter, + "num_iters": self.num_iters, + "decay_style": self.decay_style, + "end_iter": self.end_iter, + "min_lr": self.min_lr, + } + return state_dict + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_lr_scheduler: + if self.rank == 0: + _logger.info(f"Overriding {name} value to {cls_value}") + return cls_value + + if not self.use_checkpoint_lr_scheduler: + assert ( + cls_value == sd_value + ), f"AnnealingLR: class input value and checkpoint values for {name} do not match" + if self.rank == 0: + _logger.info(f" > using checkpoint value {sd_value} for {name}") + return sd_value + + def load_state_dict(self, sd): + """Load state dict.""" + self.start_lr = self._check_and_set(self.start_lr, sd["start_lr"], "learning rate") + self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate") + self.warmup_iter = self._check_and_set( + self.warmup_iter, sd["warmup_iter"], "warmup iterations" + ) + self.end_iter = self._check_and_set( + self.end_iter, sd["end_iter"], "total number of iterations" + ) + self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style") + + self.num_iters = sd["num_iters"] + self.step(self.num_iters) + +def get_learning_rate_scheduler(optimizer, args): + """Get learning rate scheduler.""" + use_checkpoint_lr_scheduler = args.resume_from_checkpoint is not None + + # Add linear learning rate scheduler. + if args.lr_decay_iters is not None: + num_iters = args.lr_decay_iters + else: + num_iters = args.max_steps + num_iters = max(1, num_iters) + init_step = 0 + warmup_iter = args.warmup * num_iters + plateau_iter = warmup_iter + args.plateau * num_iters + lr_scheduler = AnnealingLR( + optimizer, + start_lr=args.lr, + warmup_iter=warmup_iter, + plateau_iter=plateau_iter, + total_iters=num_iters, + decay_style=args.lr_decay_style, + last_iter=init_step, + min_lr=args.min_lr, + use_checkpoint_lr_scheduler=use_checkpoint_lr_scheduler, + override_lr_scheduler=False, + ) + + return lr_scheduler + +def create_streaming_dataloader(dataset, + tokenizer, + name, + global_rank, + batch_size, + max_context_width, + workers, + split=None): + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + data = load_dataset(dataset, name=name, streaming=True, split=split, trust_remote_code=True).shuffle(42+global_rank) + train_concat_dataset = ConcatTokensDataset(data, tokenizer, max_context_width, True) + train_dataloader = DataLoader(train_concat_dataset, + batch_size=batch_size, + num_workers=workers, + pin_memory=True, + prefetch_factor=4) + return train_dataloader diff --git a/3.test_cases/XX.transformer-engine/te_llama.py b/3.test_cases/XX.transformer-engine/te_llama.py new file mode 100644 index 00000000..a52f45d9 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/te_llama.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +from contextlib import contextmanager + +import torch +from torch import nn + +import transformer_engine as te +from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch.fp8 import fp8_model_init + +import transformers +from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig +from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.utils import WEIGHTS_INDEX_NAME +from transformers.utils.hub import get_checkpoint_shard_files + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. + """ + original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer + transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls + + +class TELlamaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `LlamaDecoderLayer` and easier to replace it in the code. + + Args: + config: LlamaConfig + args: positional args (for compatibility with `LlamaDecoderLayer`) + kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) + """ + def __init__(self, config, *args, **kwargs): + default_device = torch.tensor([]).device + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=False, + normalization="RMSNorm", + activation="swiglu", + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + device=default_device, + ) + te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) + self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() + + def forward(self, + hidden_states, + *args, + attention_mask, + **kwargs): + """ + Custom forward to make sure we only pass relevant arguments to the + forward pass of the `TransformerLayer`. Also, make sure the output + format matches the output of the HF's `LlamaDecoderLayer`. + """ + return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),) + + +class TELlamaForCausalLM: + """ + Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer` + class is monkey-patched with `TELlamaDecoderLayer` class before + initializing the causal LM with `LlamaForCausalLM`. + + Args: + config: LlamaConfig + """ + + def __new__(cls, config: LlamaConfig): + with replace_decoder(te_decoder_cls=TELlamaDecoderLayer): + llama_for_causal_lm = LlamaForCausalLM(config) + return llama_for_causal_lm + + @classmethod + def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + vanilla_model = cls(config).to(kwargs['torch_dtype']) + is_local = os.path.isdir(pretrained_model_name_or_path) + subfolder = "" + variant = None + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + else: + raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") + + + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + archive_file, + ) + + # If the checkpoint is not sharded, it's a trivial sharding case + if not is_sharded: + assert not isinstance(resolved_archive_file, list) + resolved_archive_file = [resolved_archive_file] + + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + # replace_params copies parameters relevant only to TransformerEngine + replace_params(state_dict, vanilla_model.state_dict(), config) + # _load_state_dict_into_model copies parameters other than those in TransformerEngine + _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + + # Force mem release. Taken from huggingface code + del state_dict + gc.collect() + + return vanilla_model + +def replace_params(hf_state_dict, te_state_dict, config): + # collect all layer prefixes to update + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = 'model.layers.\d+.' + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + + + + for layer_prefix in all_layer_prefixes: + # When loading weights into models with less number of layers, skip the + # copy if the corresponding layer doesn't exist in HF model + if layer_prefix + 'input_layernorm.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] + + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] + + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] + + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] + + if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] + + if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + + # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to + # load them separately. + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ + hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data + + if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ + hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data + + if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] + return all_layer_prefixes diff --git a/3.test_cases/XX.transformer-engine/train.py b/3.test_cases/XX.transformer-engine/train.py new file mode 100644 index 00000000..39de0218 --- /dev/null +++ b/3.test_cases/XX.transformer-engine/train.py @@ -0,0 +1,272 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +import functools +import math +import time + +import torch +import torch.distributed as dist +import torch.utils.data +import transformer_engine.pytorch as te +from torch import optim +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformer_engine.common.recipe import Format, DelayedScaling +from transformers import AutoModelForCausalLM + +from model_utils.arguments import parse_args +from model_utils.checkpoint import save_checkpoint, load_checkpoint +from model_utils.train_utils import (get_model_config, + compute_num_params, + get_param_groups_by_weight_decay, + get_logger, + get_learning_rate_scheduler, + create_streaming_dataloader, apply_activation_checkpoint) +from te_llama import TELlamaForCausalLM + +logger = get_logger() + + +def eval_model(model, dataloader, num_batches): + """Eval step.""" + model = model.eval() + n_batches = 0 + loss = 0.0 + + with torch.no_grad(): + for batch_idx, input_data in enumerate(dataloader): + if batch_idx >= num_batches: + break + + loss += model(input_ids=input_data, attention_mask=None, labels=input_data)["loss"] + n_batches += 1 + + if n_batches > 0: + detached_loss = loss.detach() + torch.distributed.all_reduce(detached_loss) + loss = detached_loss.item() / dist.get_world_size() + loss /= n_batches + ppl = math.exp(loss) + else: + loss = -1.0 + ppl = -1.0 + + return loss, ppl + + +def train( + model, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + model_config, + num_params, + args, + global_rank, + world_size, + total_steps=0, + start_batch_index=0 +): + model.train() + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + all_gpus = dist.new_group(backend='nccl') + for index in range(args.epochs): + for batch_idx, input_data in enumerate(train_dataloader): + if batch_idx < start_batch_index: + continue + optimizer.zero_grad(set_to_none=True) + step_start = time.time() + with te.fp8_autocast(enabled=args.fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + loss = model(input_ids=input_data, attention_mask=None, labels=input_data)["loss"] + loss.backward() + model.clip_grad_norm_(args.grad_clip) + optimizer.step() + lr_scheduler.step() + total_steps += 1 + loss_metric = loss.item() + step_time = time.time() - step_start + sample_processed = input_data.shape[0] * world_size + throughput = sample_processed / step_time + loss_scalar = loss.item() + current_lr = lr_scheduler.get_lr() + if global_rank == 0 and batch_idx % args.logging_freq == 0: + logger.info( + "Batch %d Loss: %.5f, Speed: %.2f samples/sec, lr: %.6f", # pylint: disable=line-too-long + batch_idx, + loss_scalar, + throughput, + current_lr, + ) + if args.validation_freq and not total_steps % args.validation_freq: + val_loss, val_ppl = eval_model( + model, val_dataloader, args.validation_batches + ) + model = model.train() + if global_rank == 0: + logger.info( + "Batch %d Validation loss: %s", + batch_idx, + val_loss, + ) + if args.checkpoint_dir and not total_steps % args.checkpoint_freq: + user_content = { + "cli_args": args.__dict__, + "num_params": num_params, + "total_steps": total_steps, + "model_config": model_config, + "start_batch_index": batch_idx + 1, + } + sub_dir = f"{args.model_type}-{total_steps}steps" + + save_checkpoint( + model, + optimizer, + lr_scheduler, + user_content, + args.checkpoint_dir, + sub_dir, + ) + if total_steps >= args.max_steps: + break + + +def main(args): + dist.init_process_group() + global_rank = dist.get_rank() + device = global_rank % torch.cuda.device_count() + world_size = dist.get_world_size() + torch.cuda.set_device(device) + + if args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.get_default_dtype() + + model_config = get_model_config(args) + if global_rank == 0: + logger.info( + f"Creating Model with {model_config=}" + ) + if args.fp8: + model = TELlamaForCausalLM(model_config).to(device) + else: + model = AutoModelForCausalLM.from_config(model_config).to(device) + + num_params = compute_num_params(model) + if global_rank == 0: + logger.info( + "Created model with total parameters: %d (%.2f B)", num_params, num_params * 1e-9 + ) + transformer_layer = type(model.model.layers[0]) + + gpt_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + transformer_layer, + }, + ) + + mixed_precision_policy = MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype + ) + + if args.sharding_strategy == "full": + sharding_strategy = ShardingStrategy.FULL_SHARD + elif args.sharding_strategy == "hybrid": + sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise NotImplementedError("Available sharding strategies are full and hybrid") + + model = FSDP( + model, + auto_wrap_policy=gpt_auto_wrap_policy, + mixed_precision=mixed_precision_policy, + limit_all_gathers=args.limit_all_gathers, + device_id=torch.cuda.current_device(), + use_orig_params=False, + sharding_strategy=sharding_strategy, + sync_module_states=True, + param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)) + if global_rank != 0 else None, + ) + + if global_rank == 0: + logger.info("Wrapped model with FSDP") + + # TODO(belevich) + # if args.activation_checkpointing > 0: + # apply_activation_checkpoint(args, model=model) + # + # if args.offload_activations > 0: + # from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper + # + # model = offload_wrapper(model) + + param_groups = get_param_groups_by_weight_decay(model) + + optimizer = optim.AdamW( + param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay + ) + + if global_rank == 0: + logger.info("Created optimizer") + + lr_scheduler = get_learning_rate_scheduler(optimizer, args) + + if args.resume_from_checkpoint: + ( + model, + optimizer, + lr_scheduler, + total_steps, + start_batch_index, + ) = load_checkpoint(model, + optimizer, + lr_scheduler, + args.resume_from_checkpoint, + args.model_type, + device) + else: + total_steps = 0 + start_batch_index = 0 + + train_dataloader = create_streaming_dataloader(args.dataset, + args.tokenizer, + args.dataset_config_name, + global_rank, + args.train_batch_size, + args.max_context_width, + workers=4, + split='train') + + val_dataloader = create_streaming_dataloader(args.dataset, + args.tokenizer, + args.dataset_config_name, + global_rank, + args.train_batch_size, + args.max_context_width, + workers=2, + split='validation') + + train(model, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + model_config, + num_params, + args, + global_rank, + world_size, + total_steps, + start_batch_index) + + +if __name__ == "__main__": + args, _ = parse_args() + main(args)