diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 57efc32..cce4cb5 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,4 +1,4 @@ -from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401 +from .fully_connected import Linear # noqa: F401 from .communication import Drop, Gather from axonn import axonn as ax diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index d924824..61ac6ef 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -108,7 +108,9 @@ def forward(self, x, scatter_input=True, gather_output=True): bias = self.bias if gather_output: - bias = Gather.apply(self.bias, self.outer_group if not self.transpose else self.inner_group) + bias = Gather.apply( + self.bias, self.outer_group if not self.transpose else self.inner_group + ) if self.skip_bias_add: return x, bias