Skip to content

Commit

Permalink
added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahua Singh committed Mar 20, 2024
1 parent 03a01bb commit af409b7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from axonn import axonn as ax
import torch
import torch.distributed as dist
from .automatic_parallelism import auto_parallellize # noqa: F401
from .automatic_parallelism import auto_parallelize # noqa: F401


def drop(
Expand Down
13 changes: 7 additions & 6 deletions axonn/intra_layer/automatic_parallelism.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import torch.nn as nn
import sys
import os
from axonn import axonn as ax
from axonn.intra_layer import Linear

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))


def auto_parallellize(model):
def auto_parallelize(model):
G_row = ax.config.G_intra_r
G_col = ax.config.G_intra_c
G_depth = ax.config.G_intra_r
G_depth = ax.config.G_intra_d
# Iterate through all modules in the model
for name, module in model.named_modules():
if isinstance(module, nn.Module):
# Iterate through all child modules of each module
for attr_name, attr_module in module.named_children():
# Check if the module is a linear layer
if isinstance(attr_module, nn.Linear):
# Check if layer is "parallelizable"
if (
(attr_module.out_features % G_row == 0)
and (attr_module.in_features % G_col == 0)
Expand All @@ -25,6 +25,7 @@ def auto_parallellize(model):
== 0
)
):
# Replace the linear layer with Axonn's linear layer
setattr(
module,
attr_name,
Expand Down

0 comments on commit af409b7

Please sign in to comment.