Skip to content

Commit

Permalink
Merge pull request #3548 from jeffpicard/multi-gpu
Browse files Browse the repository at this point in the history
Add multi gpu support
  • Loading branch information
HallerPatrick authored Nov 22, 2024
2 parents fb27c7e + fb4f07a commit 9a962cb
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 25 deletions.
5 changes: 3 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder contains actively maintained examples of use of Flair, organized alo

## Table of Tasks

| Task | Documentation
| ----------------------------- | -------------
| Task | Documentation
| ------------------------------ | -------------
| Named Entity Recognition (NER) | [Here](ner/)
| Multi GPU | [Here](multi_gpu/)
32 changes: 32 additions & 0 deletions examples/multi_gpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Multi GPU

Training can be distributed across multiple GPUs on a local machine when using
[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer).

## Example

See the script `run_multi_gpu.py` and its comments.

## Tutorial

There are 2 changes that are always required, as well as a few things to consider

Always Required:
1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()`
2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g.
`launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU

Other considerations:
- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves
anything random, you should either
- Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR
- Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to
all processes
- The effective batch size will be larger by a factor of num_gpus
- Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps
taken relative to training with a single device. To obtain comparable results between single/multi gpu,
both mathematically, and in terms of wall time, consider the method in the example script.
- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate

Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction
are still done on a single device.
Empty file added examples/multi_gpu/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions examples/multi_gpu/run_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

import flair
from flair.datasets import IMDB
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer


def main(multi_gpu):
# Note: Multi-GPU can affect corpus loading
# This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to
# ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using
# the same seed on all processes.
flair.set_seed(42)

corpus = IMDB()
corpus.downsample(0.1)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)

embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)

# Note: Multi-GPU can affect choice of batch size.
# In order to compare batch updates fairly between single and multi-GPU training, we should:
# 1) Step the optimizer after the same number of examples to achieve com
# 2) Process the same number of examples in each forward pass
mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device
num_devices_when_distributing = max(torch.cuda.device_count(), 1)
mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing
# e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the
# first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will
# process 32 examples at the same time, then the optimizer will step.

trainer = ModelTrainer(model, corpus)
trainer.fine_tune(
"resources/taggers/multi-gpu",
multi_gpu=multi_gpu, # Required for multi-gpu
max_epochs=2,
mini_batch_chunk_size=mini_batch_chunk_size,
mini_batch_size=mini_batch_size,
)


if __name__ == "__main__":
"""Minimal example demonstrating how to train a model on multiple GPUs."""
multi_gpu = True

if multi_gpu:
launch_distributed(main, multi_gpu) # Required for multi-gpu
else:
main(multi_gpu)
88 changes: 88 additions & 0 deletions flair/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
import os
import random
from multiprocessing.connection import Connection
from typing import Callable

import numpy as np
import torch
import torch.multiprocessing as mp
from torch.distributed import destroy_process_group, init_process_group
from torch.utils.data import Dataset

import flair
from flair.data import Corpus, _len_dataset

log = logging.getLogger("flair")


def launch_distributed(fn, *args, **kwargs):
"""Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU).
If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune.
Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process
"""
world_size = torch.cuda.device_count()
log.info(f"Launching {world_size} processes")
parent_conn, child_conn = mp.Pipe()
mp.spawn(_process_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size)
return_value = parent_conn.recv()
return return_value


def _process_entrypoint(
rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict
) -> None:
"""Lifecycle of a distributed process -- setup, run, cleanup."""
log.info(f"Started process on rank={rank}")
try:
_ddp_setup(rank, world_size)
return_value = fn(*args, **kwargs)
if is_main_process():
child_conn.send(return_value)
finally:
destroy_process_group()


def _ddp_setup(rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
flair.device = torch.device(rank)
torch.cuda.set_device(flair.device)
init_process_group(backend="nccl", rank=rank, world_size=world_size)


def is_main_process() -> bool:
"""True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu."""
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
else:
return True


def aggregate(value, aggregation_fn=np.mean):
"""Gather `value` from all processes and send to `aggregation_fn` to get a single return value."""
if torch.distributed.is_initialized():
gathered_values = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_values, value)
else:
gathered_values = [value]
return aggregation_fn(gathered_values)


def validate_corpus_same_each_process(corpus: Corpus) -> None:
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable"""
for dataset in [corpus.train, corpus.dev, corpus.test]:
if dataset is not None:
_validate_dataset_same_each_process(dataset)


def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None:
random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset)))
for i in random_indices:
example = str(dataset[i])
examples = aggregate(example, list)
if not all(example == examples[0] for example in examples):
raise ValueError("Dataset must be the same on each process")
2 changes: 1 addition & 1 deletion flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,7 @@ def to_params(self):

# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
del config_dict["_attn_implementation_autoset"]
config_dict.pop("_attn_implementation_autoset", None)

super_params = super().to_params()

Expand Down
3 changes: 2 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flair.class_utils import get_non_abstract_subclasses
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.distributed_utils import is_main_process
from flair.embeddings import Embeddings
from flair.embeddings.base import load_embeddings
from flair.file_utils import Tqdm, load_torch_state
Expand Down Expand Up @@ -291,7 +292,7 @@ def evaluate(
loader = DataLoader(data_points, batch_size=mini_batch_size)

sentence_id = 0
for batch in Tqdm.tqdm(loader):
for batch in Tqdm.tqdm(loader, disable=not is_main_process()):
# remove any previously predicted labels
for datapoint in batch:
datapoint.remove_labels("predicted")
Expand Down
9 changes: 9 additions & 0 deletions flair/trainers/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
cast,
)

from flair.distributed_utils import is_main_process

log = logging.getLogger("flair")


Expand Down Expand Up @@ -184,6 +186,8 @@ def attach_to(self, pluggable: Pluggable):
assert self._pluggable is None
assert len(self._hook_handles) == 0

if not is_main_process() and not self.attach_to_all_processes:
return
self._pluggable = pluggable

pluggable.append_plugin(self)
Expand Down Expand Up @@ -252,6 +256,11 @@ def decorator_func(func: Callable):
def pluggable(self) -> Optional[Pluggable]:
return self._pluggable

@property
def attach_to_all_processes(self) -> bool:
"""If set, the plugin will be attached to all processes when distributed, not just the main process."""
return True

def __str__(self) -> str:
return self.__class__.__name__

Expand Down
8 changes: 8 additions & 0 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Any

import torch

from flair.trainers.plugins.base import TrainerPlugin

log = logging.getLogger("flair")
Expand Down Expand Up @@ -28,6 +30,12 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {
Expand Down
5 changes: 4 additions & 1 deletion flair/trainers/plugins/functional/linear_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Any

import torch.distributed

from flair.optim import LinearSchedulerWithWarmup
from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -34,7 +36,8 @@ def after_setup(
):
"""Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers."""
# calculate warmup steps
steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size
num_processes = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size / num_processes
num_train_steps = int(steps_per_epoch * max_epochs)
num_warmup_steps = int(num_train_steps * self.warmup_fraction)

Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/functional/reduce_transformer_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def save_model_at_the_end(self, **kw):
elif (self.base_path / "final-model.pt").exists():
self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state)

@property
def attach_to_all_processes(self) -> bool:
return False


def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]:
embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None)
Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/functional/weight_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw):
if (iteration + 1) % modulo == 0:
self.weight_extractor.extract_weights(self.model.state_dict(), iteration)

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/clearml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ def metric_recorded(self, record: MetricRecord) -> None:
self.logger.report_text(record.value, print_console=False)
elif record.is_histogram:
self.logger.report_histogram(record_name, record_name, record.value, record.global_step)

@property
def attach_to_all_processes(self) -> bool:
return False
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/log_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,9 @@ def close_file_handler(self, **kw):
self.log_handler.close()
log.removeHandler(self.log_handler)

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {**super().get_state(), "base_path": str(self.base_path)}
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/loss_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@ def after_evaluation(self, epoch, **kw):
f.write("\t".join([str(self.current_row[col]) for col in self.headers]) + "\n")

self.current_row = {}

@property
def attach_to_all_processes(self) -> bool:
return False
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/metric_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def after_training(self, **kw):
"""Returns metric history."""
self.trainer.return_values.update(self.metric_history)

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def _training_finally(self, **kw):
assert self.writer is not None
self.writer.close()

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def metric_recorded(self, record):
def _training_finally(self, **kw):
self.writer.close()

@property
def attach_to_all_processes(self) -> bool:
return False

def get_state(self) -> dict[str, Any]:
return {
**super().get_state(),
Expand Down
Loading

0 comments on commit 9a962cb

Please sign in to comment.