diff --git a/examples/README.md b/examples/README.md index 53e11ffe21..1221fbf49f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,6 +4,7 @@ This folder contains actively maintained examples of use of Flair, organized alo ## Table of Tasks -| Task | Documentation -| ----------------------------- | ------------- +| Task | Documentation +| ------------------------------ | ------------- | Named Entity Recognition (NER) | [Here](ner/) +| Multi GPU | [Here](multi_gpu/) diff --git a/examples/multi_gpu/README.md b/examples/multi_gpu/README.md new file mode 100644 index 0000000000..4c12b9d0bd --- /dev/null +++ b/examples/multi_gpu/README.md @@ -0,0 +1,32 @@ +# Multi GPU + +Training can be distributed across multiple GPUs on a local machine when using +[`ModelTrainer`](#flair.trainers.trainer.ModelTrainer). + +## Example + +See the script `run_multi_gpu.py` and its comments. + +## Tutorial + +There are 2 changes that are always required, as well as a few things to consider + +Always Required: +1) Pass the argument `multi_gpu=True` to your [`.train()`](#flair.trainers.trainer.ModelTrainer.train) or `.fine_tune()` +2) Wrap your code in [`launch_distributed`](#flair.distributed_utils.launch_distributed), e.g. + `launch_distributed(main, *args)`. This spawns multiple processes, each driving a GPU + +Other considerations: +- The corpus and other preprocessing must be the same on all processes. For example, if corpus initialization involves + anything random, you should either + - Set the random seed before initializing the corpus (e.g. [`flair.set_seed(42)`) OR + - Initialize the corpus before calling `launch_distributed` and pass the corpus as an argument so it's serialized to + all processes +- The effective batch size will be larger by a factor of num_gpus + - Each GPU will now process `mini_batch_size` examples before the optimizer steps, resulting in fewer total steps + taken relative to training with a single device. To obtain comparable results between single/multi gpu, + both mathematically, and in terms of wall time, consider the method in the example script. +- Large batch sizes may be necessary to see faster runs, otherwise the communication overhead may dominate + +Only the parameter updates in the training process will be distributed across multiple GPUs. Evaluation and prediction +are still done on a single device. diff --git a/examples/multi_gpu/__init__.py b/examples/multi_gpu/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/multi_gpu/run_multi_gpu.py b/examples/multi_gpu/run_multi_gpu.py new file mode 100644 index 0000000000..f7d059111b --- /dev/null +++ b/examples/multi_gpu/run_multi_gpu.py @@ -0,0 +1,54 @@ +import torch + +import flair +from flair.datasets import IMDB +from flair.distributed_utils import launch_distributed +from flair.embeddings import TransformerDocumentEmbeddings +from flair.models import TextClassifier +from flair.trainers import ModelTrainer + + +def main(multi_gpu): + # Note: Multi-GPU can affect corpus loading + # This code will run multiple times -- each GPU gets its own process and each process runs this code. We need to + # ensure that the corpus has the same elements and order on all processes, despite sampling. We do that by using + # the same seed on all processes. + flair.set_seed(42) + + corpus = IMDB() + corpus.downsample(0.1) + label_type = "sentiment" + label_dictionary = corpus.make_label_dictionary(label_type) + + embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") + model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary) + + # Note: Multi-GPU can affect choice of batch size. + # In order to compare batch updates fairly between single and multi-GPU training, we should: + # 1) Step the optimizer after the same number of examples to achieve com + # 2) Process the same number of examples in each forward pass + mini_batch_chunk_size = 32 # Make this as large as possible without running out of GPU-memory to pack device + num_devices_when_distributing = max(torch.cuda.device_count(), 1) + mini_batch_size = mini_batch_chunk_size if multi_gpu else mini_batch_chunk_size * num_devices_when_distributing + # e.g. Suppose your machine has 2 GPUs. If multi_gpu=False, the first gpu will process 32 examples, then the + # first gpu will process another 32 examples, then the optimizer will step. If multi_gpu=True, each gpu will + # process 32 examples at the same time, then the optimizer will step. + + trainer = ModelTrainer(model, corpus) + trainer.fine_tune( + "resources/taggers/multi-gpu", + multi_gpu=multi_gpu, # Required for multi-gpu + max_epochs=2, + mini_batch_chunk_size=mini_batch_chunk_size, + mini_batch_size=mini_batch_size, + ) + + +if __name__ == "__main__": + """Minimal example demonstrating how to train a model on multiple GPUs.""" + multi_gpu = True + + if multi_gpu: + launch_distributed(main, multi_gpu) # Required for multi-gpu + else: + main(multi_gpu) diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py new file mode 100644 index 0000000000..e774084009 --- /dev/null +++ b/flair/distributed_utils.py @@ -0,0 +1,88 @@ +import logging +import os +import random +from multiprocessing.connection import Connection +from typing import Callable + +import numpy as np +import torch +import torch.multiprocessing as mp +from torch.distributed import destroy_process_group, init_process_group +from torch.utils.data import Dataset + +import flair +from flair.data import Corpus, _len_dataset + +log = logging.getLogger("flair") + + +def launch_distributed(fn, *args, **kwargs): + """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). + + If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune. + + Returns: the return value of the function fp(*args, **kwargs) from the rank 0 process + """ + world_size = torch.cuda.device_count() + log.info(f"Launching {world_size} processes") + parent_conn, child_conn = mp.Pipe() + mp.spawn(_process_entrypoint, args=(world_size, child_conn, fn, args, kwargs), nprocs=world_size) + return_value = parent_conn.recv() + return return_value + + +def _process_entrypoint( + rank: int, world_size: int, child_conn: Connection, fn: Callable, args: tuple, kwargs: dict +) -> None: + """Lifecycle of a distributed process -- setup, run, cleanup.""" + log.info(f"Started process on rank={rank}") + try: + _ddp_setup(rank, world_size) + return_value = fn(*args, **kwargs) + if is_main_process(): + child_conn.send(return_value) + finally: + destroy_process_group() + + +def _ddp_setup(rank: int, world_size: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + flair.device = torch.device(rank) + torch.cuda.set_device(flair.device) + init_process_group(backend="nccl", rank=rank, world_size=world_size) + + +def is_main_process() -> bool: + """True for exactly 1 process, regardless of whether being run on CPU/single-GPU/multi-gpu.""" + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 + else: + return True + + +def aggregate(value, aggregation_fn=np.mean): + """Gather `value` from all processes and send to `aggregation_fn` to get a single return value.""" + if torch.distributed.is_initialized(): + gathered_values = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_values, value) + else: + gathered_values = [value] + return aggregation_fn(gathered_values) + + +def validate_corpus_same_each_process(corpus: Corpus) -> None: + """Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two + reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable""" + for dataset in [corpus.train, corpus.dev, corpus.test]: + if dataset is not None: + _validate_dataset_same_each_process(dataset) + + +def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None: + random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset))) + for i in random_indices: + example = str(dataset[i]) + examples = aggregate(example, list) + if not all(example == examples[0] for example in examples): + raise ValueError("Dataset must be the same on each process") diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 245d528e5d..c43258a193 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1356,7 +1356,7 @@ def to_params(self): # do not switch the attention implementation upon reload. config_dict["attn_implementation"] = self.model.config._attn_implementation - del config_dict["_attn_implementation_autoset"] + config_dict.pop("_attn_implementation_autoset", None) super_params = super().to_params() diff --git a/flair/nn/model.py b/flair/nn/model.py index bf13baf2f1..f670c969a0 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -17,6 +17,7 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import is_main_process from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -291,7 +292,7 @@ def evaluate( loader = DataLoader(data_points, batch_size=mini_batch_size) sentence_id = 0 - for batch in Tqdm.tqdm(loader): + for batch in Tqdm.tqdm(loader, disable=not is_main_process()): # remove any previously predicted labels for datapoint in batch: datapoint.remove_labels("predicted") diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py index 663a78d6d3..dcf7240a83 100644 --- a/flair/trainers/plugins/base.py +++ b/flair/trainers/plugins/base.py @@ -13,6 +13,8 @@ cast, ) +from flair.distributed_utils import is_main_process + log = logging.getLogger("flair") @@ -184,6 +186,8 @@ def attach_to(self, pluggable: Pluggable): assert self._pluggable is None assert len(self._hook_handles) == 0 + if not is_main_process() and not self.attach_to_all_processes: + return self._pluggable = pluggable pluggable.append_plugin(self) @@ -252,6 +256,11 @@ def decorator_func(func: Callable): def pluggable(self) -> Optional[Pluggable]: return self._pluggable + @property + def attach_to_all_processes(self) -> bool: + """If set, the plugin will be attached to all processes when distributed, not just the main process.""" + return True + def __str__(self) -> str: return self.__class__.__name__ diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index 1936177835..a8179edbc5 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -1,6 +1,8 @@ import logging from typing import Any +import torch + from flair.trainers.plugins.base import TrainerPlugin log = logging.getLogger("flair") @@ -28,6 +30,12 @@ def after_training_epoch(self, epoch, **kw): ) model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + + @property + def attach_to_all_processes(self) -> bool: + return False def get_state(self) -> dict[str, Any]: return { diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index 2258844129..fdf4752bde 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -1,6 +1,8 @@ import logging from typing import Any +import torch.distributed + from flair.optim import LinearSchedulerWithWarmup from flair.trainers.plugins.base import TrainerPlugin @@ -34,7 +36,8 @@ def after_setup( ): """Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.""" # calculate warmup steps - steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size + num_processes = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size / num_processes num_train_steps = int(steps_per_epoch * max_epochs) num_warmup_steps = int(num_train_steps * self.warmup_fraction) diff --git a/flair/trainers/plugins/functional/reduce_transformer_vocab.py b/flair/trainers/plugins/functional/reduce_transformer_vocab.py index 86759c2fe2..eed6d7f1b2 100644 --- a/flair/trainers/plugins/functional/reduce_transformer_vocab.py +++ b/flair/trainers/plugins/functional/reduce_transformer_vocab.py @@ -55,6 +55,10 @@ def save_model_at_the_end(self, **kw): elif (self.base_path / "final-model.pt").exists(): self.model.save(self.base_path / "final-model.pt", checkpoint=self.save_optimizer_state) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_transformer_embeddings(model: Model) -> list[TransformerEmbeddings]: embeddings = model.tars_embeddings if isinstance(model, FewshotClassifier) else getattr(model, "embeddings", None) diff --git a/flair/trainers/plugins/functional/weight_extractor.py b/flair/trainers/plugins/functional/weight_extractor.py index 4ba7c07621..5c9bd4c4ac 100644 --- a/flair/trainers/plugins/functional/weight_extractor.py +++ b/flair/trainers/plugins/functional/weight_extractor.py @@ -21,6 +21,10 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw): if (iteration + 1) % modulo == 0: self.weight_extractor.extract_weights(self.model.state_dict(), iteration) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/clearml_logger.py b/flair/trainers/plugins/loggers/clearml_logger.py index 891b9f9244..18228d2db6 100644 --- a/flair/trainers/plugins/loggers/clearml_logger.py +++ b/flair/trainers/plugins/loggers/clearml_logger.py @@ -40,3 +40,7 @@ def metric_recorded(self, record: MetricRecord) -> None: self.logger.report_text(record.value, print_console=False) elif record.is_histogram: self.logger.report_histogram(record_name, record_name, record.value, record.global_step) + + @property + def attach_to_all_processes(self) -> bool: + return False diff --git a/flair/trainers/plugins/loggers/log_file.py b/flair/trainers/plugins/loggers/log_file.py index 21a8c54632..9bf22a284d 100644 --- a/flair/trainers/plugins/loggers/log_file.py +++ b/flair/trainers/plugins/loggers/log_file.py @@ -21,5 +21,9 @@ def close_file_handler(self, **kw): self.log_handler.close() log.removeHandler(self.log_handler) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return {**super().get_state(), "base_path": str(self.base_path)} diff --git a/flair/trainers/plugins/loggers/loss_file.py b/flair/trainers/plugins/loggers/loss_file.py index b53a23a956..bfe938b72d 100644 --- a/flair/trainers/plugins/loggers/loss_file.py +++ b/flair/trainers/plugins/loggers/loss_file.py @@ -113,3 +113,7 @@ def after_evaluation(self, epoch, **kw): f.write("\t".join([str(self.current_row[col]) for col in self.headers]) + "\n") self.current_row = {} + + @property + def attach_to_all_processes(self) -> bool: + return False diff --git a/flair/trainers/plugins/loggers/metric_history.py b/flair/trainers/plugins/loggers/metric_history.py index a22cf1b0e7..426e055186 100644 --- a/flair/trainers/plugins/loggers/metric_history.py +++ b/flair/trainers/plugins/loggers/metric_history.py @@ -34,6 +34,10 @@ def after_training(self, **kw): """Returns metric history.""" self.trainer.return_values.update(self.metric_history) + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/tensorboard.py b/flair/trainers/plugins/loggers/tensorboard.py index a7af50a521..bf2dfcc29d 100644 --- a/flair/trainers/plugins/loggers/tensorboard.py +++ b/flair/trainers/plugins/loggers/tensorboard.py @@ -59,6 +59,10 @@ def _training_finally(self, **kw): assert self.writer is not None self.writer.close() + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/plugins/loggers/wandb.py b/flair/trainers/plugins/loggers/wandb.py index 0f8dc89f73..410d0377ba 100644 --- a/flair/trainers/plugins/loggers/wandb.py +++ b/flair/trainers/plugins/loggers/wandb.py @@ -72,6 +72,10 @@ def metric_recorded(self, record): def _training_finally(self, **kw): self.writer.close() + @property + def attach_to_all_processes(self) -> bool: + return False + def get_state(self) -> dict[str, Any]: return { **super().get_state(), diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 2f32b54c01..cf3bb80eb6 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -9,14 +9,18 @@ from pathlib import Path from typing import Optional, Union +import numpy as np import torch +from torch.nn.parallel import DistributedDataParallel from torch.optim.sgd import SGD +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import ConcatDataset import flair import flair.nn from flair.data import Corpus, Dictionary, _len_dataset from flair.datasets import DataLoader +from flair.distributed_utils import aggregate, is_main_process, validate_corpus_same_each_process from flair.samplers import FlairSampler from flair.trainers.plugins import ( AnnealingPlugin, @@ -163,6 +167,8 @@ def train( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, + # acceleration + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -237,8 +243,9 @@ def fine_tune( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # acceleration use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, attach_default_scheduler: bool = True, @@ -287,8 +294,9 @@ def fine_tune( create_file_logs=create_file_logs, create_loss_file=create_loss_file, write_weights=write_weights, - # amp + # acceleration use_amp=use_amp, + multi_gpu=multi_gpu, # plugins plugins=plugins, **kwargs, @@ -331,8 +339,9 @@ def train_custom( create_file_logs: bool = True, create_loss_file: bool = True, write_weights: bool = False, - # amp + # acceleration use_amp: bool = False, + multi_gpu: bool = False, # plugins plugins: Optional[list[TrainerPlugin]] = None, **kwargs, @@ -375,6 +384,7 @@ def train_custom( create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created use_amp: If True, uses the torch automatic mixed precision + multi_gpu: If True, distributes training across local GPUs write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer @@ -481,6 +491,18 @@ def train_custom( sampler.set_dataset(train_data) shuffle = False + # configure special behavior to use multiple GPUs + if multi_gpu: + if not torch.distributed.is_initialized(): + raise RuntimeError("multi_gpu=True can only used inside flair.distributed_utils.launch_distributed()") + # Guard against each process initializing corpus differently due to e.g. different random seeds + validate_corpus_same_each_process(self.corpus) + self.ddp_model = DistributedDataParallel( + self.model, device_ids=[flair.device.index], find_unused_parameters=True + ) + log.disabled = not is_main_process() # Only print logs once + original_forward = self.model.forward + # this field stores the names of all dynamic embeddings in the model (determined after first forward pass) dynamic_embeddings = None @@ -508,6 +530,9 @@ def train_custom( if use_final_model_for_eval else "model from best epoch (best-model.pt)" ) + computation_device_info = aggregate( + flair.device, lambda devices: ", ".join([str(device) for device in devices]) + ) log_line(log) log.info(f'Model: "{self.model}"') @@ -534,7 +559,7 @@ def train_custom( log.info(f' - metric: "{main_evaluation_metric}"') log_line(log) log.info("Computation:") - log.info(f" - compute on device: {flair.device}") + log.info(f" - compute on device: {computation_device_info}") log.info(f" - embedding storage: {embeddings_storage_mode}") log_line(log) log.info(f'Model training base path: "{base_path}"') @@ -560,12 +585,24 @@ def train_custom( if not shuffle_first_epoch and epoch == 1: shuffle_data_this_epoch = False - batch_loader = DataLoader( - train_data, - batch_size=mini_batch_size, - shuffle=shuffle_data_this_epoch, - sampler=sampler, - ) + if multi_gpu: + distributed_sampler: DistributedSampler = DistributedSampler( + train_data, shuffle=shuffle_data_this_epoch + ) + distributed_sampler.set_epoch(epoch - 1) + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=False, + sampler=distributed_sampler, + ) + else: + batch_loader = DataLoader( + train_data, + batch_size=mini_batch_size, + shuffle=shuffle_data_this_epoch, + sampler=sampler, + ) self.model.train() @@ -603,7 +640,18 @@ def train_custom( for batch_step in batch_steps: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): - loss, datapoint_count = self.model.forward_loss(batch_step) + if multi_gpu: + # We need to __call__ ddp_model() because this triggers hooks that sync gradients. + # But that calls forward rather than forward_loss. So we patch forward to redirect + # to forward_loss. Then undo the patch in case forward_loss itself calls forward. + def wrapped_forward_loss(*args, **kwargs2): + self.model.forward = original_forward + return self.model.forward_loss(*args, **kwargs2) + + self.model.forward = wrapped_forward_loss + loss, datapoint_count = self.ddp_model(batch_step) + else: + loss, datapoint_count = self.model.forward_loss(batch_step) batch_train_samples += datapoint_count batch_train_loss += loss.item() @@ -649,8 +697,11 @@ def train_custom( if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) + intermittent_loss = aggregate(intermittent_loss) current_time = time.time() + samples_per_second = epoch_train_samples / (current_time - epoch_start_time) + samples_per_second = aggregate(samples_per_second, np.sum) lr_info, momentum_info = self._get_current_lr_and_momentum(batch_count) log.info( @@ -658,7 +709,7 @@ def train_custom( f" - iter {batch_no + 1}/{len(batch_loader)}" f" - loss {intermittent_loss:.8f}" f" - time (sec): {(current_time - epoch_start_time):.2f}" - f" - samples/sec: {epoch_train_samples / (current_time - epoch_start_time):.2f}" + f" - samples/sec: {samples_per_second:.2f}" f"{lr_info}{momentum_info}" ) @@ -667,6 +718,7 @@ def train_custom( self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples + train_loss = aggregate(train_loss) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -682,7 +734,7 @@ def train_custom( # Determine if this is the best model or if we need to anneal current_epoch_has_best_model_so_far = False - validation_scores: tuple + validation_scores: tuple = () for evaluation_split, evaluation_split_data in evaluation_splits.items(): eval_result = self.model.evaluate( @@ -722,7 +774,7 @@ def train_custom( if not determine_best_epoch_using_dev_score: validation_scores = (train_loss,) - if epoch_train_loss < best_epoch_score: + if train_loss < best_epoch_score: current_epoch_has_best_model_so_far = True best_epoch_score = train_loss @@ -737,14 +789,14 @@ def train_custom( if save_best_model and current_epoch_has_best_model_so_far: log.info("saving best model") - self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "best-model.pt", checkpoint=save_optimizer_state) # - SWAPlugin -> restores SGD weights from SWA self.dispatch("after_training_loop") # if we do not use dev data for model selection, save final model if save_final_model: - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) except KeyboardInterrupt: log_line(log) @@ -754,7 +806,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except TrainingInterrupt as exc: @@ -765,7 +817,7 @@ def train_custom( if save_final_model: log.info("Saving model ...") - self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state) + self._save_model(base_path / "final-model.pt", checkpoint=save_optimizer_state) log.info("Done.") except Exception: @@ -783,7 +835,7 @@ def train_custom( if (base_path / "best-model.pt").exists(): log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self._load_model(base_path / "best-model.pt") else: log.info("Testing using last state of model ...") @@ -808,7 +860,7 @@ def train_custom( else: if (base_path / "best-model.pt").exists(): log.info("Loading model from best epoch ...") - self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict()) + self._load_model(base_path / "best-model.pt") self.return_values["test_score"] = 0 log.info("Test data not provided setting final score to 0") @@ -905,3 +957,12 @@ def _initialize_model_card(self, **training_parameters): def _record(self, metric): self.dispatch("metric_recorded", metric) + + def _load_model(self, model_file: Union[str, Path]) -> None: + self.model.load_state_dict(self.model.load(model_file).state_dict()) + + def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: + if is_main_process(): + self.model.save(model_file, checkpoint) + if torch.distributed.is_initialized(): + torch.distributed.barrier() # Prevent any process from loading a model until writing is complete