Skip to content

Commit

Permalink
added automatic_parallelism (#70)
Browse files Browse the repository at this point in the history
* added automatic_parallelism

* corrected formatting

* added comments

* swap linear layers at source

* formatting changes

---------

Co-authored-by: Mahua Singh <mahua04@pssg-mordor.umiacs.umd.edu>
Co-authored-by: Siddharth Singh <siddharth9820@gmail.com>
  • Loading branch information
3 people authored Apr 20, 2024
1 parent 3a3c538 commit 13c310c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion axonn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from . import models # noqa: F401
# from . import models # noqa: F401
1 change: 1 addition & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from axonn import axonn as ax
import torch
import torch.distributed as dist
from .automatic_parallelism import auto_parallelize # noqa: F401


def drop(
Expand Down
41 changes: 41 additions & 0 deletions axonn/intra_layer/automatic_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch.nn as nn
from axonn import axonn as ax
from axonn.intra_layer import Linear
from contextlib import contextmanager


def is_parallelizable(in_features, out_features):
G_row = ax.config.G_intra_r
G_col = ax.config.G_intra_c
G_depth = ax.config.G_intra_d
row_col_condition = out_features % G_row == 0 and in_features % G_col == 0
depth_condition = (out_features * in_features // (G_row * G_col)) % G_depth == 0
return row_col_condition and depth_condition


class patched_linear:
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None):
if is_parallelizable(in_features, out_features):
parallel_layer = Linear(in_features, out_features, bias=bias)
if device is not None:
parallel_layer = parallel_layer.to(device)
if dtype is not None:
parallel_layer = parallel_layer.to(dtype)
return parallel_layer
else:
sequential_layer = nn.Linear(in_features, out_features, bias=bias)
if device is not None:
sequential_layer = sequential_layer.to(device)
if dtype is not None:
sequential_layer = sequential_layer.to(dtype)
return sequential_layer


@contextmanager
def auto_parallelize():
old_linear = nn.Linear
nn.Linear = patched_linear
try:
yield None
finally:
nn.Linear = old_linear

0 comments on commit 13c310c

Please sign in to comment.