Skip to content

Commit

Permalink
Supporting init_module, load/save checkpoint (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Jul 1, 2024
1 parent 0ee9562 commit 57a895d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
1 change: 1 addition & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def optimize_communication(
finally:
clear_handles()
accumulate()
clear_weights_cache()
OVERLAP_ALL_REDUCE = False
OVERLAP_REDUCE_SCATTER = False
ALL_GATHER_ITERATOR = None
Expand Down
71 changes: 55 additions & 16 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from datetime import timedelta
from typing import Any, Dict, List, Optional, Union, ContextManager
from typing import Any, Dict, List, Optional, Union, ContextManager, Callable
from contextlib import nullcontext

import torch
Expand Down Expand Up @@ -35,14 +35,19 @@
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH

from axonn import axonn as ax
from axonn.intra_layer import (
sync_gradients_data_parallel,
sync_gradients_depth_parallel,
clip_grad_norm_,
no_grad_sync,
auto_parallelize,
optimize_communication,
)
from axonn.checkpoint import get_prefix_for_checkpoint
import os


class AxonnStrategy(ParallelStrategy):
Expand Down Expand Up @@ -212,26 +217,44 @@ def backward(
if self.G_data > 1:
sync_gradients_data_parallel(module, mean=True)

def save_checkpoint(
@override
def load_checkpoint(
self,
*args,
**kwargs,
) -> None:
assert False, (
"Current fabric.save(..) is not supported with the "
"AxoNN strategy. Use axonn.save instead."
)
path: _PATH,
state: Optional[
Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]
] = None,
strict: bool = True,
) -> Dict[str, Any]:
# different prefix for different tensor parallel ranks
checkpoint_prefix = get_prefix_for_checkpoint()
directory, filename = os.path.split(path)
directory = os.path.join(directory, checkpoint_prefix)
path = os.path.join(directory, filename)
return super().load_checkpoint(path, state, strict)

def load_checkpoint(
@override
def save_checkpoint(
self,
*args,
**kwargs,
path: _PATH,
state: Dict[str, Union[Module, Optimizer, Any]],
storage_options: Optional[Any] = None,
filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
) -> None:
assert False, (
"Current fabric.load(..) is not supported with the"
" AxoNN strategy. Use axonn.load instead."
)
if torch.distributed.get_rank(ax.comm_handle.data_parallel_group) == 0:
# different prefix for different tensor parallel ranks
checkpoint_prefix = get_prefix_for_checkpoint()
directory, filename = os.path.split(path)
directory = os.path.join(directory, checkpoint_prefix)
state = self._convert_stateful_objects_in_state(
state, filter=(filter or {})
)
path = os.path.join(directory, filename)
self.checkpoint_io.save_checkpoint(
checkpoint=state, path=path, storage_options=storage_options
)

@override
def clip_gradients_norm(
self,
module: Module,
Expand All @@ -250,6 +273,22 @@ def clip_gradients_norm(
)
return grad_norm

@override
def module_init_context(self, empty_init: Optional[bool] = None):
return self.module_sharded_context()

@override
def module_sharded_context(self) -> ContextManager:
return auto_parallelize()

def optimize_communication(
self, module: Module, enabled: bool = True
) -> ContextManager:
if not enabled:
return nullcontext()
else:
return optimize_communication(True, True, True, module)


class _AxoNNBackwardSyncControl(_BackwardSyncControl):
@override
Expand Down

0 comments on commit 57a895d

Please sign in to comment.