-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3548 from jeffpicard/multi-gpu
Add multi gpu support
- Loading branch information
Showing
19 changed files
with
314 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Multi GPU | ||
|
||
Training can be distributed across multiple GPUs on a local machine when using | ||
[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer). | ||
|
||
## Example | ||
|
||
See the script `run_multi_gpu.py` and its comments. | ||
|
||
## Tutorial | ||
|
||
There are 2 changes that are always required, as well as a few things to consider | ||
|
||
Always Required: | ||
1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()` | ||
2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g. | ||
`launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU | ||
|
||
Other considerations: | ||
- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves | ||
anything random, you should either | ||
- Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR | ||
- Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to | ||
all processes | ||
- The effective batch size will be larger by a factor of num_gpus | ||
- Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps | ||
taken relative to training with a single device. To obtain comparable results between single/multi gpu, | ||
both mathematically, and in terms of wall time, consider the method in the example script. | ||
- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate | ||
|
||
Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction | ||
are still done on a single device. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
|
||
import flair | ||
from flair.datasets import IMDB | ||
from flair.distributed_utils import launch_distributed | ||
from flair.embeddings import TransformerDocumentEmbeddings | ||
from flair.models import TextClassifier | ||
from flair.trainers import ModelTrainer | ||
|
||
|
||
def main(multi_gpu): | ||
# Note: Multi-GPU can affect corpus loading | ||
# This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to | ||
# ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using | ||
# the same seed on all processes. | ||
flair.set_seed(42) | ||
|
||
corpus = IMDB() | ||
corpus.downsample(0.1) | ||
label_type = "sentiment" | ||
label_dictionary = corpus.make_label_dictionary(label_type) | ||
|
||
embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") | ||
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary) | ||
|
||
# Note: Multi-GPU can affect choice of batch size. | ||
# In order to compare batch updates fairly between single and multi-GPU training, we should: | ||
# 1) Step the optimizer after the same number of examples to achieve com | ||
# 2) Process the same number of examples in each forward pass | ||
mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device | ||
num_devices_when_distributing = max(torch.cuda.device_count(), 1) | ||
mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing | ||
# e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the | ||
# first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will | ||
# process 32 examples at the same time, then the optimizer will step. | ||
|
||
trainer = ModelTrainer(model, corpus) | ||
trainer.fine_tune( | ||
"resources/taggers/multi-gpu", | ||
multi_gpu=multi_gpu, # Required for multi-gpu | ||
max_epochs=2, | ||
mini_batch_chunk_size=mini_batch_chunk_size, | ||
mini_batch_size=mini_batch_size, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
"""Minimal example demonstrating how to train a model on multiple GPUs.""" | ||
multi_gpu = True | ||
|
||
if multi_gpu: | ||
launch_distributed(main, multi_gpu) # Required for multi-gpu | ||
else: | ||
main(multi_gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import logging | ||
import os | ||
import random | ||
from multiprocessing.connection import Connection | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import torch | ||
import torch.multiprocessing as mp | ||
from torch.distributed import destroy_process_group, init_process_group | ||
from torch.utils.data import Dataset | ||
|
||
import flair | ||
from flair.data import Corpus, _len_dataset | ||
|
||
log = logging.getLogger("flair") | ||
|
||
|
||
def launch_distributed(fn, *args, **kwargs): | ||
"""Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). | ||
If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune. | ||
Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process | ||
""" | ||
world_size = torch.cuda.device_count() | ||
log.info(f"Launching {world_size} processes") | ||
parent_conn, child_conn = mp.Pipe() | ||
mp.spawn(_process_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) | ||
return_value = parent_conn.recv() | ||
return return_value | ||
|
||
|
||
def _process_entrypoint( | ||
rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict | ||
) -> None: | ||
"""Lifecycle of a distributed process -- setup, run, cleanup.""" | ||
log.info(f"Started process on rank={rank}") | ||
try: | ||
_ddp_setup(rank, world_size) | ||
return_value = fn(*args, **kwargs) | ||
if is_main_process(): | ||
child_conn.send(return_value) | ||
finally: | ||
destroy_process_group() | ||
|
||
|
||
def _ddp_setup(rank: int, world_size: int) -> None: | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
flair.device = torch.device(rank) | ||
torch.cuda.set_device(flair.device) | ||
init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
|
||
|
||
def is_main_process() -> bool: | ||
"""True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu.""" | ||
if torch.distributed.is_initialized(): | ||
return torch.distributed.get_rank() == 0 | ||
else: | ||
return True | ||
|
||
|
||
def aggregate(value, aggregation_fn=np.mean): | ||
"""Gather `value` from all processes and send to `aggregation_fn` to get a single return value.""" | ||
if torch.distributed.is_initialized(): | ||
gathered_values = [None for _ in range(torch.distributed.get_world_size())] | ||
torch.distributed.all_gather_object(gathered_values, value) | ||
else: | ||
gathered_values = [value] | ||
return aggregation_fn(gathered_values) | ||
|
||
|
||
def validate_corpus_same_each_process(corpus: Corpus) -> None: | ||
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two | ||
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable""" | ||
for dataset in [corpus.train, corpus.dev, corpus.test]: | ||
if dataset is not None: | ||
_validate_dataset_same_each_process(dataset) | ||
|
||
|
||
def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None: | ||
random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset))) | ||
for i in random_indices: | ||
example = str(dataset[i]) | ||
examples = aggregate(example, list) | ||
if not all(example == examples[0] for example in examples): | ||
raise ValueError("Dataset must be the same on each process") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.