-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into conv-depth-tp
- Loading branch information
Showing
3 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN | ||
from axonn.intra_layer import Linear | ||
|
||
|
||
def modified_attention_init(self, config): | ||
super(LlamaAttention, self).__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.num_heads = config.num_attention_heads | ||
self.head_dim = self.hidden_size // self.num_heads | ||
self.num_key_value_heads = config.num_key_value_heads | ||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.rope_theta = config.rope_theta | ||
self.is_causal = True | ||
|
||
if (self.head_dim * self.num_heads) != self.hidden_size: | ||
raise ValueError( | ||
f"hidden_size must be divisible by num_heads " | ||
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." | ||
) | ||
self.q_proj = Linear( | ||
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias | ||
) | ||
self.k_proj = Linear( | ||
self.hidden_size, | ||
self.num_key_value_heads * self.head_dim, | ||
bias=config.attention_bias, | ||
) | ||
self.v_proj = Linear( | ||
self.hidden_size, | ||
self.num_key_value_heads * self.head_dim, | ||
bias=config.attention_bias, | ||
) | ||
self.o_proj = Linear( | ||
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias | ||
) | ||
self._init_rope() | ||
|
||
|
||
def modified_mlp_init(self, config): | ||
super(LlamaMLP, self).__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.intermediate_size = config.intermediate_size | ||
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) | ||
self.act_fn = ACT2FN[config.hidden_act] | ||
|
||
|
||
def monkey_patch_llama_with_axonn(): | ||
LlamaAttention.__init__ = modified_attention_init | ||
LlamaMLP.__init__ = modified_mlp_init |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from transformers.models.mistral.modeling_mistral import ( | ||
MistralAttention, | ||
MistralRotaryEmbedding, | ||
MistralMLP, | ||
ACT2FN, | ||
) | ||
from axonn.intra_layer import Linear | ||
|
||
|
||
def modified_attention_init(self, config): | ||
super(MistralAttention, self).__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.num_heads = config.num_attention_heads | ||
self.head_dim = self.hidden_size // self.num_heads | ||
self.num_key_value_heads = config.num_key_value_heads | ||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.rope_theta = config.rope_theta | ||
self.is_causal = True | ||
# This gives an attribute error, not sure why | ||
# self.attention_dropout = config.attention_dropout | ||
|
||
if (self.head_dim * self.num_heads) != self.hidden_size: | ||
raise ValueError( | ||
f"hidden_size must be divisible by num_heads " | ||
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." | ||
) | ||
self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | ||
self.k_proj = Linear( | ||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False | ||
) | ||
self.v_proj = Linear( | ||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False | ||
) | ||
self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | ||
|
||
self.rotary_emb = MistralRotaryEmbedding( | ||
self.head_dim, | ||
max_position_embeddings=self.max_position_embeddings, | ||
base=self.rope_theta, | ||
) | ||
|
||
|
||
def modified_mlp_init(self, config): | ||
super(MistralMLP, self).__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
self.intermediate_size = config.intermediate_size | ||
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) | ||
self.act_fn = ACT2FN[config.hidden_act] | ||
|
||
|
||
def monkey_patch_mistral_with_axonn(): | ||
MistralAttention.__init__ = modified_attention_init | ||
MistralMLP.__init__ = modified_mlp_init |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, ACT2FN | ||
import torch.nn as nn | ||
from axonn.intra_layer import Linear | ||
|
||
|
||
def modified_attention_init( | ||
self, | ||
embed_dim: int, | ||
num_heads: int, | ||
dropout: float = 0.0, | ||
is_decoder: bool = False, | ||
bias: bool = True, | ||
): | ||
super(OPTAttention, self).__init__() | ||
self.embed_dim = embed_dim | ||
self.num_heads = num_heads | ||
self.dropout = dropout | ||
self.head_dim = embed_dim // num_heads | ||
|
||
if (self.head_dim * num_heads) != self.embed_dim: | ||
raise ValueError( | ||
f"embed_dim must be divisible by num_heads " | ||
f"(got `embed_dim`: {self.embed_dim} & `num_heads`: {num_heads})." | ||
) | ||
self.scaling = self.head_dim**-0.5 | ||
self.is_decoder = is_decoder | ||
|
||
self.k_proj = Linear(embed_dim, embed_dim, bias=bias) | ||
self.v_proj = Linear(embed_dim, embed_dim, bias=bias) | ||
self.q_proj = Linear(embed_dim, embed_dim, bias=bias) | ||
self.out_proj = Linear(embed_dim, embed_dim, bias=bias) | ||
|
||
|
||
def modified_decoder_init(self, config): | ||
super(OPTDecoderLayer, self).__init__() | ||
self.embed_dim = config.hidden_size | ||
self.self_attn = OPTAttention( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.num_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
bias=config.enable_bias, | ||
) | ||
self.do_layer_norm_before = config.do_layer_norm_before | ||
self.dropout = config.dropout | ||
self.activation_fn = ACT2FN[config.activation_function] | ||
|
||
self.self_attn_layer_norm = nn.LayerNorm( | ||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine | ||
) | ||
self.fc1 = Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) | ||
self.fc2 = Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) | ||
self.final_layer_norm = nn.LayerNorm( | ||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine | ||
) | ||
|
||
|
||
def monkey_patch_opt_with_axonn(): | ||
OPTAttention.__init__ = modified_attention_init | ||
OPTDecoderLayer.__init__ = modified_decoder_init |