From 61911b4772b6338919b0f5aefdc8942fd42cdb8a Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 31 Oct 2023 12:16:17 -0400 Subject: [PATCH] [WIP] A tensor parallel API for beginners (#40) * Easy TP that works with hf models --- axonn/intra_layer/__init__.py | 6 +- axonn/intra_layer/fully_connected.py | 152 +++++++++++++++++--- axonn/intra_layer/gradient_normalization.py | 90 ++++++++++++ axonn/tests/test_intra_layer_fc.py | 121 ++++++++++++---- 4 files changed, 322 insertions(+), 47 deletions(-) create mode 100644 axonn/intra_layer/gradient_normalization.py diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 7180e9e..cbe63a7 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,6 +1,9 @@ -from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401 +from .fully_connected import Linear # noqa: F401 from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401 + from .communication import Drop, Gather +from .gradient_normalization import clip_grad_norm_ # noqa: F401 + from axonn import axonn as ax @@ -18,4 +21,5 @@ def gather(x, transpose=False, dim=-1): group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group + return Gather.apply(x, group, dim) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index dd758fb..faa7715 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -1,8 +1,9 @@ from axonn import axonn as ax import torch.distributed as dist import torch -from .communication import Drop +from .communication import Drop, Gather from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd import math @@ -11,20 +12,34 @@ def divide(a, b): return a // b +def extract_local_params_from_full_params( + full_params, out_features_group, in_features_group +): + params = Drop.apply(torch.t(full_params).contiguous(), out_features_group) + params = torch.t(params).contiguous() + params = Drop.apply(params, in_features_group) + return params + + @torch.no_grad() def initialize_params( out_features, in_features, out_features_group, in_features_group, init_method ): params = torch.empty((out_features, in_features)) init_method(params) - params = Drop.apply(torch.t(params).contiguous(), out_features_group) - params = torch.t(params).contiguous() - params = Drop.apply(params, in_features_group) + params = extract_local_params_from_full_params( + params, out_features_group, in_features_group + ) return params +def default_init_method(weight): + return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + class AsyncLinear(Function): @staticmethod + @custom_fwd def forward( ctx, input_, @@ -41,6 +56,7 @@ def forward( return output @staticmethod + @custom_bwd def backward(ctx, grad_output): input_, weight = ctx.saved_tensors handle = None @@ -53,7 +69,7 @@ def backward(ctx, grad_output): ) if ctx.needs_input_grad[1]: grad_weight = ( - grad_output.view(-1, grad_output.shape[-1]) + grad_output.reshape(-1, grad_output.shape[-1]) .t() .mm(input_.view(-1, input_.shape[-1])) ) @@ -62,10 +78,6 @@ def backward(ctx, grad_output): return grad_input, grad_weight, None, None, None -def default_init_method(weight): - return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - - class Linear(torch.nn.Module): def __init__( self, @@ -73,6 +85,7 @@ def __init__( out_features, *args, transpose=False, + bias=True, skip_bias_add=False, init_method=None, async_comm_in_backward_pass=True, @@ -84,6 +97,10 @@ def __init__( self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) + + self.in_features = in_features + self.out_features = out_features + self.async_comm_in_backward_pass = async_comm_in_backward_pass if init_method is None: @@ -116,19 +133,47 @@ def __init__( self.weight = torch.nn.Parameter(initial_params, requires_grad=True) - self.bias = torch.nn.Parameter( - torch.zeros( - self.local_out_features, - ) + setattr(self.weight, "is_tensor_parallel", True) + setattr( + self.weight, + "process_group_for_norm_reduction", + ax.comm_handle.intra_layer_group, ) + + if bias: + self.bias = torch.nn.Parameter( + torch.zeros( + self.local_out_features, + ) + ) + setattr(self.bias, "is_tensor_parallel", True) + if not transpose: + setattr( + self.bias, + "process_group_for_norm_reduction", + ax.comm_handle.outer_intra_layer_parallel_group, + ) + else: + setattr( + self.bias, + "process_group_for_norm_reduction", + ax.comm_handle.inner_intra_layer_parallel_group, + ) + else: + self.bias = None + self.transpose = transpose self.skip_bias_add = skip_bias_add + self._old_load_from_state_dict = self._load_from_state_dict + self._load_from_state_dict = self._modified_load_from_state_dict def get_output_feature_size(self): return self.local_out_features - def forward(self, x): + def forward(self, x, scatter_input=True, gather_output=True): if not self.transpose: + if scatter_input: + x = Drop.apply(x, self.inner_group) x = AsyncLinear.apply( x, self.weight, @@ -136,7 +181,11 @@ def forward(self, x): self.outer_group, self.async_comm_in_backward_pass, ) + if gather_output: + x = Gather.apply(x, self.outer_group) else: + if scatter_input: + x = Drop.apply(x, self.outer_group) x = AsyncLinear.apply( x, self.weight, @@ -144,7 +193,76 @@ def forward(self, x): self.inner_group, self.async_comm_in_backward_pass, ) - if self.skip_bias_add: - return x, self.bias + if gather_output: + x = Gather.apply(x, self.inner_group) + + if self.bias is None: + return x else: - return x + self.bias + bias = self.bias + if gather_output: + bias = Gather.apply( + self.bias, + self.outer_group if not self.transpose else self.inner_group, + ) + if self.skip_bias_add: + return x, bias + else: + return x + bias + + def _is_full_weight_matrix(self, weight): + return (weight.size(0) == self.out_features) and ( + weight.size(1) == self.in_features + ) + + def _is_sharded_weight_matrix(self, weight): + return (weight.size(0) == self.local_out_features) and ( + weight.size(1) == self.local_in_features + ) + + @torch.no_grad() + def _modified_load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + weight = ( + state_dict[prefix + "weight"] if prefix + "weight" in state_dict else None + ) + + if weight is not None: + is_full_weight_matrix = self._is_full_weight_matrix(weight) + is_sharded_weight_matrix = self._is_sharded_weight_matrix(weight) + + assert ( + is_full_weight_matrix or is_sharded_weight_matrix + ), "This is neither a full checkpoint nor a sharded checkpoint" + + if is_full_weight_matrix: + out_features_group, in_features_group = ( + self.outer_group, + self.inner_group, + ) + if self.transpose: + out_features_group, in_features_group = ( + self.inner_group, + self.outer_group, + ) + weight = extract_local_params_from_full_params( + weight, out_features_group, in_features_group + ) + state_dict[prefix + "weight"] = weight + + if self.bias is not None: + bias = ( + state_dict[prefix + "bias"] if prefix + "bias" in state_dict else None + ) + if bias is not None: + if bias.size(0) == self.out_features: + bias = Drop.apply( + bias, + self.outer_group if not self.transpose else self.inner_group, + ) + state_dict[prefix + "bias"] = bias + else: + assert ( + bias.size(0) == self.local_out_features + ), "This is neither a full checkpoint nor a sharded checkpoint" + + self._old_load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/axonn/intra_layer/gradient_normalization.py b/axonn/intra_layer/gradient_normalization.py new file mode 100644 index 0000000..71d880f --- /dev/null +++ b/axonn/intra_layer/gradient_normalization.py @@ -0,0 +1,90 @@ +import torch + +# for backwards compatibility with pytorch 1.13 +try: + from torch._six import inf +except ImportError: + from torch import inf + +import torch.distributed as dist +from collections import defaultdict + + +def get_total_norm(tensors, norm_type, error_if_nonfinite): + if len(tensors) == 0: + return torch.tensor(0.0) + device = tensors[0].device + total_norm = torch.norm( + torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in tensors]), + norm_type, + ) + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + + return total_norm + + +def clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False): + if norm_type == inf: + raise NotImplementedError + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + tensor_parallel_params = defaultdict(list) + non_tensor_parallel_params = [] + for p in parameters: + if hasattr(p, "is_tensor_parallel") and p.is_tensor_parallel: + assert hasattr( + p, "process_group_for_norm_reduction" + ), "each tensor parallel tensor should" + "have a process group for all-reducing norms" + tensor_parallel_params[p.process_group_for_norm_reduction].append(p) + else: + non_tensor_parallel_params.append(p) + + tensor_parallel_grads = {} + for process_group, group_params in tensor_parallel_params.items(): + tensor_parallel_grads[process_group] = [ + p.grad for p in group_params if p.grad is not None + ] + + non_tensor_parallel_grads = [ + p.grad for p in non_tensor_parallel_params if p.grad is not None + ] + + max_norm = float(max_norm) + norm_type = float(norm_type) + + non_tensor_parallel_norm = get_total_norm( + non_tensor_parallel_grads, norm_type, error_if_nonfinite + ) + + tensor_parallel_norms = [] + for process_group, grads in tensor_parallel_grads.items(): + local_tensor_parallel_norm = get_total_norm( + grads, norm_type, error_if_nonfinite + ) + tensor_parallel_norm = local_tensor_parallel_norm**norm_type + dist.all_reduce(tensor_parallel_norm, group=process_group) + tensor_parallel_norm = tensor_parallel_norm ** (1.0 / norm_type) + tensor_parallel_norms.append(tensor_parallel_norm) + + all_norms = tensor_parallel_norms + [non_tensor_parallel_norm] + total_norm = get_total_norm(all_norms, norm_type, error_if_nonfinite) + + clip_coef = max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in non_tensor_parallel_grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + + for group_grads in tensor_parallel_grads.values(): + for g in group_grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + + return total_norm diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index f3f9d34..5fed505 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -2,13 +2,15 @@ import pytest from axonn import axonn as ax from axonn.intra_layer.communication import _drop, _gather -from axonn.intra_layer import Tensor_Parallel_Linear +from axonn.intra_layer import Linear, clip_grad_norm_ @pytest.mark.mpi @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) -def test_fw_pass(G_intra_r, G_intra_c, B, H): +@pytest.mark.parametrize("easy_tp", [False, True]) +@pytest.mark.parametrize("bias", [False, True]) +def test_fw_pass(G_intra_r, G_intra_c, B, H, easy_tp, bias): # These tests are in fp-32 torch.manual_seed(42) ax.init( @@ -23,22 +25,30 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): inner_group = ax.comm_handle.inner_intra_layer_parallel_group outer_group = ax.comm_handle.outer_intra_layer_parallel_group - X_local = _drop( - X, 1, inner_group - ) # divide colunns of X along the inner tensor group - layer = Tensor_Parallel_Linear( - in_features=H, out_features=H, skip_bias_add=True - ).cuda() + if not easy_tp: + # manually divide input + X_local = _drop( + X, 1, inner_group + ) # divide colunns of X along the inner tensor group + else: + X_local = X + + layer = Linear(in_features=H, out_features=H, bias=bias).cuda() + layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() + + # test if load state dict works with a sequential checkpoint + layer.load_state_dict(layer_sequential.state_dict()) + # test if load state dict works with a sharded checkpoint + layer.load_state_dict(layer.state_dict()) with torch.no_grad(): # parallel FW pass - Y_local, _ = layer(X_local) - Y_parallel = _gather(Y_local.clone(), 1, outer_group) - + Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + if not easy_tp: # gather output manually + Y_parallel = _gather(Y_local.clone(), 1, outer_group) + else: + Y_parallel = Y_local # sequential FW pass - layer_sequential = torch.nn.Linear( - in_features=H, out_features=H, bias=False - ).cuda() weight_sequential = _gather( _gather(layer.weight, 1, inner_group), 0, outer_group ) @@ -52,7 +62,19 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): @pytest.mark.parametrize("B, H", [(32, 64), (16, 128), (2, 256)]) @pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)]) @pytest.mark.parametrize("async_comm_in_backward_pass", [True, False]) -def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass): +@pytest.mark.parametrize("easy_tp", [False, True]) +@pytest.mark.parametrize("clip_grad_norm", [-1, 1e-3]) +@pytest.mark.parametrize("bias", [False]) +def test_bw_pass( + G_intra_r, + G_intra_c, + B, + H, + async_comm_in_backward_pass, + easy_tp, + clip_grad_norm, + bias, +): # These tests are in fp-32 torch.manual_seed(42) ax.init( @@ -68,32 +90,52 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass): outer_group = ax.comm_handle.outer_intra_layer_parallel_group # parallel backward pass - layer = Tensor_Parallel_Linear( + layer = Linear( in_features=H, out_features=H, - skip_bias_add=True, + bias=bias, async_comm_in_backward_pass=async_comm_in_backward_pass, ).cuda() - X_local = ( - _drop(X, 1, inner_group).detach().clone() - ) # divide colunns of X along the inner tensor group + layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() + + # test if load state dict works with a sequential checkpoint + layer.load_state_dict(layer_sequential.state_dict()) + # test if load state dict works with a sharded checkpoint + layer.load_state_dict(layer.state_dict()) + + if not easy_tp: + X_local = ( + _drop(X, 1, inner_group).detach().clone() + ) # divide colunns of X along the inner tensor group + else: + X_local = X + X_local.requires_grad = True - Y_local, _ = layer(X_local) - Y_local_grad = _drop(Y_grad, 1, outer_group) + Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + + if not easy_tp: + Y_local_grad = _drop(Y_grad, 1, outer_group) + else: + Y_local_grad = Y_grad + Y_local.backward(Y_local_grad) # sequential backward pass - layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=False).cuda() - with torch.no_grad(): - weight_sequential = _gather( - _gather(layer.weight, 1, inner_group), 0, outer_group - ) - layer_sequential.weight.copy_(weight_sequential) X.requires_grad = True Y_sequential = layer_sequential(X) Y_sequential.backward(Y_grad) - X_grad_parallel = _gather(X_local.grad, 1, inner_group) + if clip_grad_norm > 0: + clip_grad_norm_(layer.parameters(), max_norm=clip_grad_norm) + torch.nn.utils.clip_grad_norm_( + layer_sequential.parameters(), max_norm=clip_grad_norm + ) + + if not easy_tp: + X_grad_parallel = _gather(X_local.grad, 1, inner_group) + else: + X_grad_parallel = X_local.grad + assert torch.allclose( X_grad_parallel, X.grad ), "BW Pass - gradients of input do not match" @@ -104,3 +146,24 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H, async_comm_in_backward_pass): assert torch.allclose( weight_grad_parallel, layer_sequential.weight.grad ), "BW Pass - gradients of weight do not match" + + if bias: + bias_grad_parallel = _gather(layer.bias.grad, 0, outer_group) + assert torch.allclose( + bias_grad_parallel, layer_sequential.bias.grad + ), "BW Pass - gradients of bias do not match" + + +if __name__ == "__main__": + test_fw_pass(G_intra_r=2, G_intra_c=1, B=4, H=256, easy_tp=True, bias=True) + test_bw_pass( + G_intra_r=2, + G_intra_c=1, + B=4, + H=256, + async_comm_in_backward_pass=True, + easy_tp=True, + clip_grad_norm=0.01, + bias=True, + ) + print("finished")