Skip to content

Commit

Permalink
Merge branch 'develop' into neo-x
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 31, 2023
2 parents 8d9c08e + 08d46d4 commit 744cbc0
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 47 deletions.
6 changes: 5 additions & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401

from axonn import axonn as ax


Expand All @@ -18,4 +21,5 @@ def gather(x, transpose=False, dim=-1):
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Gather.apply(x, group, dim)
152 changes: 135 additions & 17 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import Drop
from .communication import Drop, Gather
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
import math


Expand All @@ -11,20 +12,34 @@ def divide(a, b):
return a // b


def extract_local_params_from_full_params(
full_params, out_features_group, in_features_group
):
params = Drop.apply(torch.t(full_params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
return params


@torch.no_grad()
def initialize_params(
out_features, in_features, out_features_group, in_features_group, init_method
):
params = torch.empty((out_features, in_features))
init_method(params)
params = Drop.apply(torch.t(params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group
)
return params


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class AsyncLinear(Function):
@staticmethod
@custom_fwd
def forward(
ctx,
input_,
Expand All @@ -41,6 +56,7 @@ def forward(
return output

@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
handle = None
Expand All @@ -53,7 +69,7 @@ def backward(ctx, grad_output):
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.view(-1, grad_output.shape[-1])
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
Expand All @@ -62,17 +78,14 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, None, None, None


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Linear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
*args,
transpose=False,
bias=True,
skip_bias_add=False,
init_method=None,
async_comm_in_backward_pass=True,
Expand All @@ -84,6 +97,10 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)

self.in_features = in_features
self.out_features = out_features

self.async_comm_in_backward_pass = async_comm_in_backward_pass

if init_method is None:
Expand Down Expand Up @@ -116,35 +133,136 @@ def __init__(

self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
setattr(self.weight, "is_tensor_parallel", True)
setattr(
self.weight,
"process_group_for_norm_reduction",
ax.comm_handle.intra_layer_group,
)

if bias:
self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
)
setattr(self.bias, "is_tensor_parallel", True)
if not transpose:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.outer_intra_layer_parallel_group,
)
else:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.inner_intra_layer_parallel_group,
)
else:
self.bias = None

self.transpose = transpose
self.skip_bias_add = skip_bias_add
self._old_load_from_state_dict = self._load_from_state_dict
self._load_from_state_dict = self._modified_load_from_state_dict

def get_output_feature_size(self):
return self.local_out_features

def forward(self, x):
def forward(self, x, scatter_input=True, gather_output=True):
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
self.weight,
self.inner_group,
self.outer_group,
self.async_comm_in_backward_pass,
)
if gather_output:
x = Gather.apply(x, self.outer_group)
else:
if scatter_input:
x = Drop.apply(x, self.outer_group)
x = AsyncLinear.apply(
x,
self.weight,
self.outer_group,
self.inner_group,
self.async_comm_in_backward_pass,
)
if self.skip_bias_add:
return x, self.bias
if gather_output:
x = Gather.apply(x, self.inner_group)

if self.bias is None:
return x
else:
return x + self.bias
bias = self.bias
if gather_output:
bias = Gather.apply(
self.bias,
self.outer_group if not self.transpose else self.inner_group,
)
if self.skip_bias_add:
return x, bias
else:
return x + bias

def _is_full_weight_matrix(self, weight):
return (weight.size(0) == self.out_features) and (
weight.size(1) == self.in_features
)

def _is_sharded_weight_matrix(self, weight):
return (weight.size(0) == self.local_out_features) and (
weight.size(1) == self.local_in_features
)

@torch.no_grad()
def _modified_load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight = (
state_dict[prefix + "weight"] if prefix + "weight" in state_dict else None
)

if weight is not None:
is_full_weight_matrix = self._is_full_weight_matrix(weight)
is_sharded_weight_matrix = self._is_sharded_weight_matrix(weight)

assert (
is_full_weight_matrix or is_sharded_weight_matrix
), "This is neither a full checkpoint nor a sharded checkpoint"

if is_full_weight_matrix:
out_features_group, in_features_group = (
self.outer_group,
self.inner_group,
)
if self.transpose:
out_features_group, in_features_group = (
self.inner_group,
self.outer_group,
)
weight = extract_local_params_from_full_params(
weight, out_features_group, in_features_group
)
state_dict[prefix + "weight"] = weight

if self.bias is not None:
bias = (
state_dict[prefix + "bias"] if prefix + "bias" in state_dict else None
)
if bias is not None:
if bias.size(0) == self.out_features:
bias = Drop.apply(
bias,
self.outer_group if not self.transpose else self.inner_group,
)
state_dict[prefix + "bias"] = bias
else:
assert (
bias.size(0) == self.local_out_features
), "This is neither a full checkpoint nor a sharded checkpoint"

self._old_load_from_state_dict(state_dict, prefix, *args, **kwargs)
90 changes: 90 additions & 0 deletions axonn/intra_layer/gradient_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch

# for backwards compatibility with pytorch 1.13
try:
from torch._six import inf
except ImportError:
from torch import inf

import torch.distributed as dist
from collections import defaultdict


def get_total_norm(tensors, norm_type, error_if_nonfinite):
if len(tensors) == 0:
return torch.tensor(0.0)
device = tensors[0].device
total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in tensors]),
norm_type,
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)

return total_norm


def clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False):
if norm_type == inf:
raise NotImplementedError

if isinstance(parameters, torch.Tensor):
parameters = [parameters]

tensor_parallel_params = defaultdict(list)
non_tensor_parallel_params = []
for p in parameters:
if hasattr(p, "is_tensor_parallel") and p.is_tensor_parallel:
assert hasattr(
p, "process_group_for_norm_reduction"
), "each tensor parallel tensor should"
"have a process group for all-reducing norms"
tensor_parallel_params[p.process_group_for_norm_reduction].append(p)
else:
non_tensor_parallel_params.append(p)

tensor_parallel_grads = {}
for process_group, group_params in tensor_parallel_params.items():
tensor_parallel_grads[process_group] = [
p.grad for p in group_params if p.grad is not None
]

non_tensor_parallel_grads = [
p.grad for p in non_tensor_parallel_params if p.grad is not None
]

max_norm = float(max_norm)
norm_type = float(norm_type)

non_tensor_parallel_norm = get_total_norm(
non_tensor_parallel_grads, norm_type, error_if_nonfinite
)

tensor_parallel_norms = []
for process_group, grads in tensor_parallel_grads.items():
local_tensor_parallel_norm = get_total_norm(
grads, norm_type, error_if_nonfinite
)
tensor_parallel_norm = local_tensor_parallel_norm**norm_type
dist.all_reduce(tensor_parallel_norm, group=process_group)
tensor_parallel_norm = tensor_parallel_norm ** (1.0 / norm_type)
tensor_parallel_norms.append(tensor_parallel_norm)

all_norms = tensor_parallel_norms + [non_tensor_parallel_norm]
total_norm = get_total_norm(all_norms, norm_type, error_if_nonfinite)

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in non_tensor_parallel_grads:
g.detach().mul_(clip_coef_clamped.to(g.device))

for group_grads in tensor_parallel_grads.values():
for g in group_grads:
g.detach().mul_(clip_coef_clamped.to(g.device))

return total_norm
Loading

0 comments on commit 744cbc0

Please sign in to comment.