| Documentation | Poster | Design Doc |
⚠️ WARNING: This is an alpha prototype for PyTorch fault tolerance and may have bugs or breaking changes as this is actively under development. We'd love to collaborate and contributions are welcome. Please reach out if you're interested in torchft or want to discuss fault tolerance in PyTorch
This repository implements techniques for doing a per-step fault tolerance so you can keep training if errors occur without interrupting the entire training job.
This is based on the large scale training techniques presented at PyTorch Conference 2024.
torchft is designed to provide the primitives required to implement fault tolerance in any application/train script as well as the primitives needed to implement custom fault tolerance strategies.
Out of the box, torchft provides the following algorithms:
- Fault Tolerant DDP
- Fault Tolerant HSDP: fault tolerance across the replicated dimension with any mix of FSDP/TP/etc across the other dimensions.
- LocalSGD
- DiLoCo
To implement these, torchft provides some key reusable components:
- Coordination primitives that can determine which workers are healthy via heartbeating on a per-step basis
- Fault tolerant ProcessGroup implementations that report errors sanely and be reinitialized gracefully.
- Checkpoint transports that can be used to do live recovery from a healthy peer when doing scale up operations.
The following component diagram shows the high level components and how they relate to each other:
See torchft's documentation for more details.
torchtitan provides an out of the box fault tolerant HSDP training loop built on top of torchft that can be used to train models such as Llama 3 70B.
It also serves as a good example of how you can integrate torchft into your own training script for use with HSDP.
See torchtitan's documentation for end to end usage.
We have a minimal DDP train loop that highlights all of the key components in torchft.
See train_ddp.py for more info.
LocalSGD and DiLoCo are currently experimental.
See the diloco_train_loop/local_sgd_train_loop tests for an example on how to integrate these algorithms into your training loop.
torchft is designed to allow for fault tolerance when using training with replicated weights such as in DDP or HSDP (FSDP with DDP).
See the design doc for the most detailed explanation.
torchft implements a lighthouse server that coordinates across the different replica groups and then a per replica group manager and fault tolerance library that can be used in a standard PyTorch training loop.
This allows for membership changes at the training step granularity which can greatly improve efficiency by avoiding stopping the world training on errors.
torchft provides an implementation of a fault tolerant HSDP/DDP algorithm. The following diagram shows the high level operations that need to happen in the train loop to ensure everything stays consistent during a healing operation.
See the design doc linked above for more details.
Before proceeding, ensure you have the following installed:
- Rust (with necessary dependencies)
protobuf-compiler
and the corresponding development package for Protobuf.
Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command:
curl --proto '=https' --tlsv1.2 https://sh.rustup.rs -sSf | sh
To install the required packages on a Debian-based system (such as Ubuntu) using apt, run:
sudo apt install protobuf-compiler libprotobuf-dev
or for a Red Hat-based system, run:
sudo dnf install protobuf-compiler protobuf-devel
pip install .
This uses pyo3+maturin to build the package, you'll need maturin installed.
If the installation command fails to invoke cargo update
due to an inability to fetch the manifest, it may be caused by the proxy
, proxySSLCert
, and proxySSLKey
settings in your .gitconfig
file affecting the cargo
command. To resolve this issue, try temporarily removing these fields from your .gitconfig
before running the installation command.
To install in editable mode w/ the Rust extensions and development dependencies, you can use the normal pip install command:
pip install -e '.[dev]'
The lighthouse is used for fault tolerance across replicated workers (DDP/FSDP) when using synchronous training.
You can start a lighthouse server by running:
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
See train_ddp.py for the full example.
Invoke with:
TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train_ddp.py
train.py:
from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo
manager = Manager(
pg=ProcessGroupGloo(),
load_state_dict=...,
state_dict=...,
)
m = nn.Linear(2, 3)
m = DistributedDataParallel(manager, m)
optimizer = Optimizer(manager, optim.AdamW(m.parameters()))
for i in range(1000):
batch = torch.rand(2, 2, device=device)
optimizer.zero_grad()
out = m(batch)
loss = out.sum()
loss.backward()
optimizer.step()
torchft has a fault tolerant parameter server implementation built on it's reconfigurable ProcessGroups. This does not require/use a Lighthouse server.
See parameter_server_test.py for an example.
We welcome PRs! See the CONTRIBUTING file.
torchft is BSD 3-Clause licensed. See LICENSE for more details.