Skip to content

Commit

Permalink
reformat and change Tensor_Parallel_Linear to Linear
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 17, 2023
1 parent 33caaf3 commit 34abe16
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401
from .fully_connected import Linear # noqa: F401
from .communication import Drop, Gather
from axonn import axonn as ax

Expand Down
4 changes: 3 additions & 1 deletion axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def forward(self, x, scatter_input=True, gather_output=True):

bias = self.bias
if gather_output:
bias = Gather.apply(self.bias, self.outer_group if not self.transpose else self.inner_group)
bias = Gather.apply(
self.bias, self.outer_group if not self.transpose else self.inner_group
)

if self.skip_bias_add:
return x, bias
Expand Down

0 comments on commit 34abe16

Please sign in to comment.