From 13c310cb05e84084ca0d44e62d4049ae62aed4dd Mon Sep 17 00:00:00 2001 From: Mahua Singh <35654846+S-Mahua@users.noreply.github.com> Date: Sat, 20 Apr 2024 11:20:43 +0530 Subject: [PATCH] added automatic_parallelism (#70) * added automatic_parallelism * corrected formatting * added comments * swap linear layers at source * formatting changes --------- Co-authored-by: Mahua Singh Co-authored-by: Siddharth Singh --- axonn/__init__.py | 2 +- axonn/intra_layer/__init__.py | 1 + axonn/intra_layer/automatic_parallelism.py | 41 ++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 axonn/intra_layer/automatic_parallelism.py diff --git a/axonn/__init__.py b/axonn/__init__.py index ab84c20..e7733ee 100644 --- a/axonn/__init__.py +++ b/axonn/__init__.py @@ -2,4 +2,4 @@ # See the top-level LICENSE file for details. # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from . import models # noqa: F401 +# from . import models # noqa: F401 diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 5d08cbc..ae0aecf 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -8,6 +8,7 @@ from axonn import axonn as ax import torch import torch.distributed as dist +from .automatic_parallelism import auto_parallelize # noqa: F401 def drop( diff --git a/axonn/intra_layer/automatic_parallelism.py b/axonn/intra_layer/automatic_parallelism.py new file mode 100644 index 0000000..c0b95d2 --- /dev/null +++ b/axonn/intra_layer/automatic_parallelism.py @@ -0,0 +1,41 @@ +import torch.nn as nn +from axonn import axonn as ax +from axonn.intra_layer import Linear +from contextlib import contextmanager + + +def is_parallelizable(in_features, out_features): + G_row = ax.config.G_intra_r + G_col = ax.config.G_intra_c + G_depth = ax.config.G_intra_d + row_col_condition = out_features % G_row == 0 and in_features % G_col == 0 + depth_condition = (out_features * in_features // (G_row * G_col)) % G_depth == 0 + return row_col_condition and depth_condition + + +class patched_linear: + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None): + if is_parallelizable(in_features, out_features): + parallel_layer = Linear(in_features, out_features, bias=bias) + if device is not None: + parallel_layer = parallel_layer.to(device) + if dtype is not None: + parallel_layer = parallel_layer.to(dtype) + return parallel_layer + else: + sequential_layer = nn.Linear(in_features, out_features, bias=bias) + if device is not None: + sequential_layer = sequential_layer.to(device) + if dtype is not None: + sequential_layer = sequential_layer.to(dtype) + return sequential_layer + + +@contextmanager +def auto_parallelize(): + old_linear = nn.Linear + nn.Linear = patched_linear + try: + yield None + finally: + nn.Linear = old_linear