Skip to content

Latest commit

 

History

History
117 lines (79 loc) · 5.78 KB

README.md

File metadata and controls

117 lines (79 loc) · 5.78 KB

4 GPU Integration Test 8 GPU Integration Test

torchtitan

torchtitan is currently in a pre-release state and under extensive development.

torchtitan is a proof-of-concept for Large-scale LLM training using native PyTorch. It is (and will continue to be) a repo to showcase PyTorch's latest distributed training features in a clean, minimal codebase. torchtitan is complementary to and not a replacement for any of the great large-scale LLM training codebases such as Megatron, Megablocks, LLM Foundry, Deepspeed, etc. Instead, we hope that the features showcased in torchtitan will be adopted by these codebases quickly. torchtitan is unlikely to ever grow a large community around it.

Our guiding principles when building torchtitan:

  • Designed to be easy to understand, use and extend for different training purposes.
  • Minimal changes to the model code when applying 1D, 2D, or (soon) 3D Parallel.
  • Modular components instead of a monolithic codebase.
  • Get started in minutes, not hours!

Intro video - learn more about torchtitan in under 4 mins:

Welcome to torchtitan!

Dive into the code

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:

Pre-Release Updates:

(4/25/2024): torchtitan is now public but in a pre-release state and under development.

Currently we showcase pre-training Llama 3 and Llama 2 LLMs of various sizes from scratch. torchtitan is tested and verified with the PyTorch nightly version torch-2.4.0.dev20240412. (We recommend latest PyTorch nightly).

Key features available

  1. FSDP2 with per param sharding
  2. Tensor Parallel
  3. Selective layer and operator activation checkpointing
  4. Distributed checkpointing
  5. 2 datasets pre-configured (45K - 144M)
  6. GPU usage, MFU, tokens per second and more displayed via Aim
  7. Learning rate scheduler, meta init, Optional Fused RMSNorm
  8. All options easily configured via toml files
  9. Interoperable checkpoints which can be loaded directly into torchtune for fine tuning
  10. Float8 support

We report our Performance verified on 64 A100 GPUs

Coming soon

  1. Async checkpointing
  2. Context Parallel
  3. 3D Pipeline Parallel
  4. torch.compile support
  5. Scalable data loading solution

Installation

git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118

Downloading a tokenizer

torchtitan currently supports training Llama 3 (8B, 70B), and Llama 2 (7B, 13B, 70B) out of the box. To get started training these models, we need to download a tokenizer.model. Follow the instructions on the official meta-llama repository to ensure you have access to the Llama model weights.

Once you have confirmed access, you can run the following command to download the Llama 3 / Llama 2 tokenizer to your local machine.

# Get your HF token from https://huggingface.co/settings/tokens

# chemlactica-125m
python torchtitan/tokenizers/download_tokenizer.py --repo_id yerevann/chemlactica-125m

Start a training run

Llama 3 8B model locally on 8 GPUs

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh

Multi-Node Training

For training on ParallelCluster/Slurm type configurations, you can use the multinode_trainer.slurm file to submit your sbatch job.

To get started adjust the number of nodes and GPUs

#SBATCH --ntasks=2
#SBATCH --nodes=2

Then start a run where nnodes is your total node count, matching the sbatch node count above.

srun torchrun --nnodes 2

If your gpu count per node is not 8, adjust:

--nproc_per_node

in the torchrun command and

#SBATCH --gpus-per-task

in the SBATCH command section.

License

This code is made available under BSD 3 license. However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models, data, etc.