Skip to content

Commit

Permalink
Merge branch 'develop' into neo-x
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 18, 2023
2 parents 2af4e18 + e447012 commit 2cbe4e9
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 47 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: ci

on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]

concurrency:
group: ci-${{github.ref}}-${{github.event.pull_request.number || github.run_number}}
cancel-in-progress: true

jobs:
formatting:
uses: ./.github/workflows/formatting.yaml
nvidia-gpu:
needs: [ formatting ]
uses: ./.github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
name: formatting tests

on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
workflow_dispatch:
workflow_call:

concurrency:
group: unit_tests-${{github.ref}}-${{github.event.pull_request.number || github.run_number}}
cancel-in-progress: true

jobs:
formatting:
runs-on: ${{ matrix.os }}

strategy:
matrix:
os: [ubuntu-latest, macos-latest]
runs-on: [ubuntu-latest]

steps:
- uses: actions/checkout@v2
Expand All @@ -34,3 +32,4 @@ jobs:
run: |
black --diff --check .
flake8
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
name: nvidia-rtx-3090 tests

on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
workflow_dispatch:
workflow_call:

concurrency:
group: unit_tests-${{github.ref}}-${{github.event.pull_request.number || github.run_number}}
cancel-in-progress: true

jobs:
mnist-trainer:
inter-layer:
runs-on: [ nvidia ]

strategy:
matrix:
ginter: [ 1, 2 ]
memopt: [ '0', '1' ]
memopt: [ '0', '1' ]

steps:
- uses: actions/checkout@v3
- name: Install AxoNN
Expand All @@ -28,7 +31,23 @@ jobs:
export G_data=$(( 2 / G_inter ))
export memopt=${{ matrix.memopt }}
echo "training with G_inter = ${G_inter}, G_data = $(( 2 / G_inter )) ${{ matrix.memopt }}"
mpirun -n 2 pytest --with-mpi
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
intra-layer:
runs-on: [ nvidia ]

steps:
- uses: actions/checkout@v3
- name: Install AxoNN
run: |
pip install -r requirements.txt
- name: Run unit intra-layer unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# AxoNN
# <img src="https://github.com/axonn-ai/axonn/blob/c356b821c2020c7dcd2181dfacc226619bfd5240/logo.png" width="64" valign="middle" alt="hatchet"/> AxoNN

[![rtx-3090 tests](https://github.com/hpcgroup/axonn/actions/workflows/nvidia-tests.yaml/badge.svg)](https://github.com/hpcgroup/axonn/actions/workflows/nvidia-tests.yaml)
[![docs](https://readthedocs.org/projects/axonn/badge/?version=latest)](https://axonn.readthedocs.io/en/latest/?badge=latest)
Expand Down
9 changes: 0 additions & 9 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,3 @@ def allreduce(self, tensor, async_op: bool = True):
def broadcast_inter_layer(self, tensor, root):
mpi4py_compatible_array = self._torch_to_mpi(tensor)
self.p2p_mpi_comm.Bcast(mpi4py_compatible_array, root=root)

def get_tensor_model_parallel_rank(self):
return self.intra_layer_parallel_rank

def get_tensor_model_parallel_world_size(self):
return self.G_intra

def get_tensor_model_parallel_group(self):
return self.intra_layer_group
8 changes: 7 additions & 1 deletion axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@


def _all_reduce(input_, process_group=None):
dist.all_reduce(input_.contiguous(), group=process_group)
if dist.get_world_size(process_group) > 1:
dist.all_reduce(input_.contiguous(), group=process_group)
return input_


def _drop(input_, dim, process_group=None):
"""Divide a tensor among the tensor parallel ranks"""
if dist.get_world_size(process_group) == 1:
return input_

total_chunks = dist.get_world_size(process_group)
this_chunk = dist.get_rank(process_group)
assert input_.shape[dim] % total_chunks == 0
Expand All @@ -19,6 +23,8 @@ def _drop(input_, dim, process_group=None):

def _gather(input_, dim, process_group=None):
"""Gather tensors and concatenate them along a dimension"""
if dist.get_world_size(process_group) == 1:
return input_

input_ = input_.contiguous()
# Size and dimension.
Expand Down
88 changes: 72 additions & 16 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,34 @@
from .communication import ForwardAllReduce, BackwardAllReduce, Drop


def divide(a, b):
assert a % b == 0
return a // b


@torch.no_grad()
def initialize_params(
out_features, in_features, out_features_group, in_features_group, init_method
):
params = torch.empty((out_features, in_features))
init_method(params)
params = Drop.apply(torch.t(params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
return params


class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, *args, transpose=False, **kwargs):
def __init__(
self,
in_features,
out_features,
*args,
transpose=False,
skip_bias_add=False,
init_method=None,
**kwargs
):
super(Linear, self).__init__()
self.inner_group = ax.comm_handle.inner_intra_layer_parallel_group
self.outer_group = ax.comm_handle.outer_intra_layer_parallel_group
Expand All @@ -16,25 +42,51 @@ def __init__(self, in_features, out_features, *args, transpose=False, **kwargs):
if not transpose:
assert in_features % self.inner_group_size == 0
assert out_features % self.outer_group_size == 0
self.local_in_features = in_features // self.inner_group_size
self.linear = torch.nn.Linear(
in_features=in_features // self.inner_group_size,
out_features=out_features // self.outer_group_size,
*args,
**kwargs,
)
self.local_in_features = divide(in_features, self.inner_group_size)
self.local_out_features = divide(out_features, self.outer_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.outer_group,
self.inner_group,
init_method,
)
else:
assert out_features % self.inner_group_size == 0
assert in_features % self.outer_group_size == 0
self.local_in_features = in_features // self.outer_group_size
self.linear = torch.nn.Linear(
in_features=in_features // self.outer_group_size,
out_features=out_features // self.inner_group_size,
*args,
**kwargs,
)
self.local_in_features = divide(in_features, self.outer_group_size)
self.local_out_features = divide(out_features, self.inner_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.inner_group,
self.outer_group,
init_method,
)

self.linear = torch.nn.Linear(
in_features=self.local_in_features,
out_features=self.local_out_features,
*args,
**kwargs,
bias=False
)

if init_method:
self.linear.weight.data.copy_(initial_params)

self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
)
self.transpose = transpose
self.skip_bias_add = skip_bias_add

def get_output_feature_size(self):
return self.local_out_features

def forward(self, x):
if not self.transpose:
Expand All @@ -49,4 +101,8 @@ def forward(self, x):
x = BackwardAllReduce.apply(x, self.inner_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.outer_group)
return x

if self.skip_bias_add:
return x, self.bias
else:
return x + self.bias
12 changes: 8 additions & 4 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
X_local = _drop(
X, 1, inner_group
) # divide colunns of X along the inner tensor group
layer = Tensor_Parallel_Linear(in_features=H, out_features=H, bias=False).cuda()
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True
).cuda()

with torch.no_grad():
# parallel FW pass
Y_local = layer(X_local)
Y_local, _ = layer(X_local)
Y_parallel = _gather(Y_local.clone(), 1, outer_group)

# sequential FW pass
Expand Down Expand Up @@ -65,12 +67,14 @@ def test_bw_pass(G_intra_r, G_intra_c, B, H):
outer_group = ax.comm_handle.outer_intra_layer_parallel_group

# parallel backward pass
layer = Tensor_Parallel_Linear(in_features=H, out_features=H, bias=False).cuda()
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True
).cuda()
X_local = (
_drop(X, 1, inner_group).detach().clone()
) # divide colunns of X along the inner tensor group
X_local.requires_grad = True
Y_local = layer(X_local)
Y_local, _ = layer(X_local)
Y_local_grad = _drop(Y_grad, 1, outer_group)
Y_local.backward(Y_local_grad)

Expand Down
Binary file added logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2cbe4e9

Please sign in to comment.