Skip to content

Commit

Permalink
ILP Conv Layer: Rebased to develop; Added skip_bias_add and init_meth…
Browse files Browse the repository at this point in the history
…od params
  • Loading branch information
prajwal1210 committed Oct 10, 2023
1 parent 89107ab commit 8cd52c2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
74 changes: 58 additions & 16 deletions axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,77 @@
import torch.distributed as dist
import torch
from .communication import ForwardAllReduce, BackwardAllReduce, Drop
from .utils import divide


@torch.no_grad()
def initialize_params(out_channels, in_channels, kernel_size, outer_group, inner_group, init_method):
params = torch.empty((out_channels, in_channels, kernel_size, kernel_size))
init_method(params)
params = Drop.apply(params, outer_group, 0)
params = Drop.apply(params, inner_group, 1)
return params

class Conv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, *args, transpose=False, **kwargs):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
*args,
transpose=False,
skip_bias_add=False,
init_method=None,
**kwargs
):
super(Conv2d, self).__init__()
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group

if transpose:
ordered_groups = [self.outer_group, self.inner_group]
if not transpose:
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
else:
ordered_groups = [self.inner_group, self.outer_group]
self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group
self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group

self.group_sizes = [dist.get_world_size(group=group) for group in ordered_groups]
self.ordered_groups = ordered_groups
self.in_channels, self.out_channels = in_channels, out_channels


assert in_channels % self.group_sizes[0] == 0
assert out_channels % self.group_sizes[1] == 0
self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)

self.in_channels = divide(in_channels, self.inner_group_size)
self.out_channels = divide(out_channels, self.outer_group_size)


self.conv = torch.nn.Conv2d(
in_channels=in_channels // self.group_sizes[0],
out_channels=out_channels // self.group_sizes[1],
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
bias=False,
**kwargs)

if init_method:
initial_params = initialize_params(
out_channels,
in_channels,
kernel_size,
self.outer_group,
self.inner_group,
init_method,
)
self.conv.weight.data.copy_(initial_params)

self.skip_bias_add = skip_bias_add

if not self.skip_bias_add:
self.bias = torch.nn.Parameter(
torch.zeros(self.out_channels)
)

def forward(self, x):
x = BackwardAllReduce.apply(x, self.ordered_groups[1])
x = BackwardAllReduce.apply(x, self.outer_group)
h = self.conv(x)
h = ForwardAllReduce.apply(h, self.ordered_groups[0])
h = ForwardAllReduce.apply(h, self.inner_group)
if self.skip_bias_add:
return h
else:
return h + self.bias.view(1, -1, 1, 1)
return h
5 changes: 5 additions & 0 deletions axonn/intra_layer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def divide(a, b):
assert a % b == 0
return a // b


0 comments on commit 8cd52c2

Please sign in to comment.