Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ILP Conv Layer support #38

Merged
merged 6 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ jobs:
- name: Install AxoNN
run: |
pip install -r requirements.txt
- name: Run unit intra-layer unit tests
- name: Run intra-layer FC unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
- name: Run intra-layer Conv unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
9 changes: 5 additions & 4 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401
from .communication import Drop, Gather
from axonn import axonn as ax


def drop(x, transpose=False):
def drop(x, transpose=False, dim=-1):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Drop.apply(x, group)
return Drop.apply(x, group, dim)


def gather(x, transpose=False):
def gather(x, transpose=False, dim=-1):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group
return Gather.apply(x, group)
return Gather.apply(x, group, dim)
30 changes: 20 additions & 10 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,39 @@ def backward(ctx, grad_output):

class Drop(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, process_group=None):
return _drop(input_, dim=-1, process_group=process_group)
def symbolic(graph, input_, process_group=None, dim=-1):
return _drop(input_, dim=dim, process_group=process_group)

@staticmethod
def forward(ctx, input_, process_group=None):
def forward(ctx, input_, process_group=None, dim=-1):
ctx.process_group = process_group
return _drop(input_, dim=-1, process_group=process_group)
ctx.dim = dim
return _drop(input_, dim=dim, process_group=process_group)

@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, dim=-1, process_group=ctx.process_group), None
return (
_gather(grad_output, dim=ctx.dim, process_group=ctx.process_group),
None,
None,
)


class Gather(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, process_group=None):
return _gather(input_, dim=-1, process_group=process_group)
def symbolic(graph, input_, process_group=None, dim=-1):
return _gather(input_, dim=dim, process_group=process_group)

@staticmethod
def forward(ctx, input_, process_group=None):
def forward(ctx, input_, process_group=None, dim=-1):
ctx.process_group = process_group
return _gather(input_, dim=-1, process_group=process_group)
ctx.dim = dim
return _gather(input_, dim=dim, process_group=process_group)

@staticmethod
def backward(ctx, grad_output):
return _drop(grad_output, dim=-1, process_group=ctx.process_group), None
return (
_drop(grad_output, dim=ctx.dim, process_group=ctx.process_group),
None,
None,
)
77 changes: 77 additions & 0 deletions axonn/intra_layer/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import ForwardAllReduce, BackwardAllReduce, Drop
from .utils import divide


@torch.no_grad()
def initialize_params(
out_channels, in_channels, kernel_size, outer_group, inner_group, init_method
):
params = torch.empty((out_channels, in_channels, kernel_size, kernel_size))
init_method(params)
params = Drop.apply(params, outer_group, 0)
params = Drop.apply(params, inner_group, 1)
return params


class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
*args,
transpose=False,
skip_bias_add=False,
init_method=None,
**kwargs
):
super(Conv2d, self).__init__()

if not transpose:
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
else:
self.outer_group = ax.comm_handle.inner_intra_layer_parallel_group
self.inner_group = ax.comm_handle.outer_intra_layer_parallel_group

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)

self.in_channels = divide(in_channels, self.inner_group_size)
self.out_channels = divide(out_channels, self.outer_group_size)

self.conv = torch.nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
bias=False,
**kwargs
)

if init_method:
initial_params = initialize_params(
out_channels,
in_channels,
kernel_size,
self.outer_group,
self.inner_group,
init_method,
)
self.conv.weight.data.copy_(initial_params)

self.bias = torch.nn.Parameter(torch.zeros(self.out_channels))

self.skip_bias_add = skip_bias_add

def forward(self, x):
x = BackwardAllReduce.apply(x, self.outer_group)
h = self.conv(x)
h = ForwardAllReduce.apply(h, self.inner_group)
if self.skip_bias_add:
return h, self.bias
else:
return h + self.bias.view(1, -1, 1, 1)
return h
3 changes: 3 additions & 0 deletions axonn/intra_layer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def divide(a, b):
assert a % b == 0
return a // b
125 changes: 125 additions & 0 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
import pytest
from axonn import axonn as ax
from axonn.intra_layer.communication import _drop, _gather
from axonn.intra_layer import Tensor_Parallel_Conv2d


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_fw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
torch.manual_seed(42)
# Need to remove all non-determinism from convolutions
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# This is required because TF32 cores only look at the first 10 bits of mantissa
torch.backends.cudnn.allow_tf32 = False

ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
)

X = torch.randn(1, C, H, W).cuda() * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group

X_local = _drop(
X, 1, inner_group
) # divide channels of X along the inner tensor group
layer = Tensor_Parallel_Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True
).cuda()

with torch.no_grad():
# parallel FW pass
Y_local, _ = layer(X_local)
Y_parallel = _gather(Y_local.clone(), 1, outer_group)

# sequential FW pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
).cuda()
weight_sequential = _gather(
_gather(layer.conv.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
Y_sequential = layer_sequential(X)

assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match"


@pytest.mark.mpi
@pytest.mark.parametrize("H, W, C", [(64, 64, 4), (64, 64, 8), (64, 32, 8)])
@pytest.mark.parametrize("G_intra_r, G_intra_c", [(1, 2), (2, 1)])
def test_bw_pass(G_intra_r, G_intra_c, H, W, C):
# These tests are in fp-32
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# This is required because TF32 cores only look at the first 10 bits of mantissa
torch.backends.cudnn.allow_tf32 = False

ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
)
X = torch.randn(1, C, H, W).cuda() * 0.01
Y_grad = torch.randn(1, 2 * C, H - 4, W - 4).cuda() * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group

# parallel backward pass
layer = Tensor_Parallel_Conv2d(
in_channels=C, out_channels=2 * C, kernel_size=5, skip_bias_add=True
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
) # divide input channels of X along the inner tensor group
X_local.requires_grad = True
Y_local, _ = layer(X_local)
Y_local_grad = _drop(Y_grad, 1, outer_group)
Y_local.backward(Y_local_grad)

# sequential backward pass
layer_sequential = torch.nn.Conv2d(
in_channels=C,
out_channels=C * 2,
kernel_size=5,
bias=False,
).cuda()
with torch.no_grad():
weight_sequential = _gather(
_gather(layer.conv.weight, 1, inner_group), 0, outer_group
)
layer_sequential.weight.copy_(weight_sequential)
X.requires_grad = True
Y_sequential = layer_sequential(X)
Y_sequential.backward(Y_grad)

X_grad_parallel = _gather(X_local.grad, 1, inner_group)

assert torch.allclose(
X_grad_parallel, X.grad
), "BW Pass - gradients of input do not match"

weight_grad_parallel = _gather(
_gather(layer.conv.weight.grad, 1, inner_group), 0, outer_group
)
assert torch.allclose(
weight_grad_parallel, layer_sequential.weight.grad
), "BW Pass - gradients of weight do not match"
Loading