Skip to content

Commit

Permalink
Make monkeypatching more efficient and change easy API to a single ar…
Browse files Browse the repository at this point in the history
…gument (#72)
  • Loading branch information
siddharth9820 authored May 15, 2024
1 parent b96bce0 commit 8f2c98c
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 32 deletions.
14 changes: 7 additions & 7 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
bias=True,
skip_bias_add=False,
init_method=None,
use_easy_api=True,
**kwargs
):
super(Linear, self).__init__()
Expand All @@ -182,6 +183,7 @@ def __init__(

self.in_features = in_features
self.out_features = out_features
self.use_easy_api = use_easy_api

if init_method is None:
init_method = default_init_method
Expand Down Expand Up @@ -257,16 +259,14 @@ def get_output_feature_size(self):
def forward(
self,
x,
scatter_input=True,
gather_output=True,
cache_weights_in_all_gather=False,
):
# gather weights from depth parallel group
# reduce scatter in the backward pass

weight = self.weight
if not self.transpose:
if scatter_input:
if self.use_easy_api:
x = Drop.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
Expand All @@ -279,10 +279,10 @@ def forward(
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if gather_output:
if self.use_easy_api:
x = Gather.apply(x, self.outer_group)
else:
if scatter_input:
if self.use_easy_api:
x = Drop.apply(x, self.outer_group)

x = AsyncLinear.apply(
Expand All @@ -296,14 +296,14 @@ def forward(
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if gather_output:
if self.use_easy_api:
x = Gather.apply(x, self.inner_group)

if self.bias is None:
return x
else:
bias = self.bias
if gather_output:
if self.use_easy_api:
bias = Gather.apply(
bias,
self.outer_group if not self.transpose else self.inner_group,
Expand Down
16 changes: 16 additions & 0 deletions axonn/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
)
from .modify_mixtral import (
monkey_patch_mixtral_with_axonn,
reverse_monkey_patch_mixtral_with_axonn,
)
from .modify_mistral import (
monkey_patch_mistral_with_axonn,
reverse_monkey_patch_mistral_with_axonn,
)

modify_dict = {
"OPTForCausalLM": (
Expand All @@ -15,6 +23,14 @@
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
),
"MixtralForCausalLM": (
monkey_patch_mixtral_with_axonn,
reverse_monkey_patch_mixtral_with_axonn,
),
"MistralForCausalLM": (
monkey_patch_mistral_with_axonn,
reverse_monkey_patch_mistral_with_axonn,
),
}


Expand Down
42 changes: 37 additions & 5 deletions axonn/models/transformers/modify_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN
from axonn.intra_layer import Linear
from typing import Optional
from axonn import axonn as ax


def modified_attention_init(self, config, layer_idx: Optional[int] = None):
Expand Down Expand Up @@ -31,19 +32,40 @@ def modified_attention_init(self, config, layer_idx: Optional[int] = None):
)

self.q_proj = Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
self.hidden_size,
self.num_heads * self.head_dim,
bias=config.attention_bias,
use_easy_api=False,
)
self.k_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
use_easy_api=False,
)
self.v_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
use_easy_api=False,
)
self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
use_easy_api=False,
transpose=True,
)

assert self.num_heads % ax.config.G_intra_r == 0
self.num_heads //= ax.config.G_intra_r

assert self.num_key_value_heads % ax.config.G_intra_r == 0
self.num_key_value_heads //= ax.config.G_intra_r

assert self.hidden_size % ax.config.G_intra_r == 0
self.hidden_size //= ax.config.G_intra_r

self._init_rope()


Expand All @@ -52,9 +74,19 @@ def modified_mlp_init(self, config):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False)
self.gate_proj = Linear(
self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False
)
self.up_proj = Linear(
self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False
)
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=False,
use_easy_api=False,
transpose=True,
)
self.act_fn = ACT2FN[config.hidden_act]


Expand Down
60 changes: 52 additions & 8 deletions axonn/models/transformers/modify_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
ACT2FN,
)
from axonn.intra_layer import Linear
from axonn import axonn as ax


def modified_attention_init(self, config):
def modified_attention_init(self, config, layer_idx):
super(MistralAttention, self).__init__()
self.config = config
self.layer_idx = layer_idx

self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand All @@ -18,6 +21,7 @@ def modified_attention_init(self, config):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
# This gives an attribute error, not sure why
# self.attention_dropout = config.attention_dropout

Expand All @@ -26,33 +30,73 @@ def modified_attention_init(self, config):
f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})."
)
self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.q_proj = Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False, use_easy_api=False
)
self.k_proj = Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
use_easy_api=False,
)
self.v_proj = Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
use_easy_api=False,
)
self.o_proj = Linear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
use_easy_api=False,
transpose=True,
)
self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.rotary_emb = MistralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)

assert self.num_heads % ax.config.G_intra_r == 0
self.num_heads //= ax.config.G_intra_r

assert self.num_key_value_heads % ax.config.G_intra_r == 0
self.num_key_value_heads //= ax.config.G_intra_r

assert self.hidden_size % ax.config.G_intra_r == 0
self.hidden_size //= ax.config.G_intra_r


def modified_mlp_init(self, config):
super(MistralMLP, self).__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False)
self.gate_proj = Linear(
self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False
)
self.up_proj = Linear(
self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False
)
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=False,
use_easy_api=False,
transpose=True,
)
self.act_fn = ACT2FN[config.hidden_act]


def monkey_patch_mistral_with_axonn():
original_inits = MistralAttention.__init__, MistralMLP.__init__
MistralAttention.__init__ = modified_attention_init
MistralMLP.__init__ = modified_mlp_init
return original_inits


def reverse_monkey_patch_mistral_with_axonn(original_attention_init, original_mlp_init):
MistralAttention.__init__ = original_attention_init
MistralMLP.__init__ = original_mlp_init
104 changes: 104 additions & 0 deletions axonn/models/transformers/modify_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralRotaryEmbedding,
MixtralBlockSparseTop2MLP,
ACT2FN,
)
from axonn.intra_layer import Linear
from axonn import axonn as ax


def modified_attention_init(self, config, layer_idx):
super(MixtralAttention, self).__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
import logger

logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a"
f"`layer_idx` is not recommended and will"
f"lead to errors during the forward call if"
f"caching is used. Please make sure to provide a `layer_idx` "
f"when creating this class."
)

self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads"
f"(got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False, use_easy_api=False
)
self.k_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
use_easy_api=False,
)
self.v_proj = Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
use_easy_api=False,
)
self.o_proj = Linear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
use_easy_api=False,
transpose=True,
)

self.rotary_emb = MixtralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)

assert self.num_heads % ax.config.G_intra_r == 0
self.num_heads //= ax.config.G_intra_r

assert self.num_key_value_heads % ax.config.G_intra_r == 0
self.num_key_value_heads //= ax.config.G_intra_r

assert self.hidden_size % ax.config.G_intra_r == 0
self.hidden_size //= ax.config.G_intra_r


def modified_mlp_init(self, config):
super(MixtralBlockSparseTop2MLP, self).__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size

self.w1 = Linear(self.hidden_dim, self.ffn_dim, bias=False, use_easy_api=False)
self.w2 = Linear(
self.ffn_dim, self.hidden_dim, bias=False, use_easy_api=False, transpose=True
)
self.w3 = Linear(self.hidden_dim, self.ffn_dim, bias=False, use_easy_api=False)

self.act_fn = ACT2FN[config.hidden_act]


def monkey_patch_mixtral_with_axonn():
original_inits = MixtralAttention.__init__, MixtralBlockSparseTop2MLP.__init__
MixtralAttention.__init__ = modified_attention_init
MixtralBlockSparseTop2MLP.__init__ = modified_mlp_init
return original_inits


def reverse_monkey_patch_mixtral_with_axonn(original_attention_init, original_mlp_init):
MixtralAttention.__init__ = original_attention_init
MixtralBlockSparseTop2MLP.__init__ = original_mlp_init
Loading

0 comments on commit 8f2c98c

Please sign in to comment.