From 809077214ff736758579156121a18bd55f4111bd Mon Sep 17 00:00:00 2001 From: jwendlan <82612519+jwendlan@users.noreply.github.com> Date: Mon, 8 Jan 2024 04:49:26 -0500 Subject: [PATCH] Parallel transformers (#59) * storing parallel hf implementations in axonn * adding parallel hf implementations --------- Co-authored-by: John Wendlandt --- models/transformers/modify_llama.py | 54 ++++++++++++++++++++++++ models/transformers/modify_mistral.py | 58 ++++++++++++++++++++++++++ models/transformers/modify_opt.py | 60 +++++++++++++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 models/transformers/modify_llama.py create mode 100644 models/transformers/modify_mistral.py create mode 100644 models/transformers/modify_opt.py diff --git a/models/transformers/modify_llama.py b/models/transformers/modify_llama.py new file mode 100644 index 0000000..af0a580 --- /dev/null +++ b/models/transformers/modify_llama.py @@ -0,0 +1,54 @@ +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN +from axonn.intra_layer import Linear + + +def modified_attention_init(self, config): + super(LlamaAttention, self).__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." + ) + self.q_proj = Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + self._init_rope() + + +def modified_mlp_init(self, config): + super(LlamaMLP, self).__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + +def monkey_patch_llama_with_axonn(): + LlamaAttention.__init__ = modified_attention_init + LlamaMLP.__init__ = modified_mlp_init diff --git a/models/transformers/modify_mistral.py b/models/transformers/modify_mistral.py new file mode 100644 index 0000000..0597903 --- /dev/null +++ b/models/transformers/modify_mistral.py @@ -0,0 +1,58 @@ +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralRotaryEmbedding, + MistralMLP, + ACT2FN, +) +from axonn.intra_layer import Linear + + +def modified_attention_init(self, config): + super(MistralAttention, self).__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + # This gives an attribute error, not sure why + # self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." + ) + self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + +def modified_mlp_init(self, config): + super(MistralMLP, self).__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + +def monkey_patch_mistral_with_axonn(): + MistralAttention.__init__ = modified_attention_init + MistralMLP.__init__ = modified_mlp_init diff --git a/models/transformers/modify_opt.py b/models/transformers/modify_opt.py new file mode 100644 index 0000000..ee1130e --- /dev/null +++ b/models/transformers/modify_opt.py @@ -0,0 +1,60 @@ +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, ACT2FN +import torch.nn as nn +from axonn.intra_layer import Linear + + +def modified_attention_init( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, +): + super(OPTAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} & `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + +def modified_decoder_init(self, config): + super(OPTDecoderLayer, self).__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=config.enable_bias, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) + self.fc2 = Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + + +def monkey_patch_opt_with_axonn(): + OPTAttention.__init__ = modified_attention_init + OPTDecoderLayer.__init__ = modified_decoder_init