From 57a895d8e859ee9a33a2ada551b186b91a8ca360 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 1 Jul 2024 03:57:24 -0400 Subject: [PATCH] Supporting init_module, load/save checkpoint (#83) --- axonn/intra_layer/__init__.py | 1 + axonn/lightning/axonn_strategy.py | 71 ++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index a2f42a2..98a4546 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -184,6 +184,7 @@ def optimize_communication( finally: clear_handles() accumulate() + clear_weights_cache() OVERLAP_ALL_REDUCE = False OVERLAP_REDUCE_SCATTER = False ALL_GATHER_ITERATOR = None diff --git a/axonn/lightning/axonn_strategy.py b/axonn/lightning/axonn_strategy.py index 8e65584..5ae56b1 100644 --- a/axonn/lightning/axonn_strategy.py +++ b/axonn/lightning/axonn_strategy.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from datetime import timedelta -from typing import Any, Dict, List, Optional, Union, ContextManager +from typing import Any, Dict, List, Optional, Union, ContextManager, Callable from contextlib import nullcontext import torch @@ -35,6 +35,7 @@ ) from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.utilities.types import _PATH from axonn import axonn as ax from axonn.intra_layer import ( @@ -42,7 +43,11 @@ sync_gradients_depth_parallel, clip_grad_norm_, no_grad_sync, + auto_parallelize, + optimize_communication, ) +from axonn.checkpoint import get_prefix_for_checkpoint +import os class AxonnStrategy(ParallelStrategy): @@ -212,26 +217,44 @@ def backward( if self.G_data > 1: sync_gradients_data_parallel(module, mean=True) - def save_checkpoint( + @override + def load_checkpoint( self, - *args, - **kwargs, - ) -> None: - assert False, ( - "Current fabric.save(..) is not supported with the " - "AxoNN strategy. Use axonn.save instead." - ) + path: _PATH, + state: Optional[ + Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]] + ] = None, + strict: bool = True, + ) -> Dict[str, Any]: + # different prefix for different tensor parallel ranks + checkpoint_prefix = get_prefix_for_checkpoint() + directory, filename = os.path.split(path) + directory = os.path.join(directory, checkpoint_prefix) + path = os.path.join(directory, filename) + return super().load_checkpoint(path, state, strict) - def load_checkpoint( + @override + def save_checkpoint( self, - *args, - **kwargs, + path: _PATH, + state: Dict[str, Union[Module, Optimizer, Any]], + storage_options: Optional[Any] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: - assert False, ( - "Current fabric.load(..) is not supported with the" - " AxoNN strategy. Use axonn.load instead." - ) + if torch.distributed.get_rank(ax.comm_handle.data_parallel_group) == 0: + # different prefix for different tensor parallel ranks + checkpoint_prefix = get_prefix_for_checkpoint() + directory, filename = os.path.split(path) + directory = os.path.join(directory, checkpoint_prefix) + state = self._convert_stateful_objects_in_state( + state, filter=(filter or {}) + ) + path = os.path.join(directory, filename) + self.checkpoint_io.save_checkpoint( + checkpoint=state, path=path, storage_options=storage_options + ) + @override def clip_gradients_norm( self, module: Module, @@ -250,6 +273,22 @@ def clip_gradients_norm( ) return grad_norm + @override + def module_init_context(self, empty_init: Optional[bool] = None): + return self.module_sharded_context() + + @override + def module_sharded_context(self) -> ContextManager: + return auto_parallelize() + + def optimize_communication( + self, module: Module, enabled: bool = True + ) -> ContextManager: + if not enabled: + return nullcontext() + else: + return optimize_communication(True, True, True, module) + class _AxoNNBackwardSyncControl(_BackwardSyncControl): @override