From 8f2c98c36cf5ce20795373c6959f582cba12715c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 15 May 2024 12:56:18 -0400 Subject: [PATCH] Make monkeypatching more efficient and change easy API to a single argument (#72) --- axonn/intra_layer/fully_connected.py | 14 +-- axonn/models/transformers/__init__.py | 16 +++ axonn/models/transformers/modify_llama.py | 42 +++++++- axonn/models/transformers/modify_mistral.py | 60 +++++++++-- axonn/models/transformers/modify_mixtral.py | 104 ++++++++++++++++++++ axonn/models/transformers/modify_opt.py | 26 +++-- axonn/tests/test_intra_layer_fc.py | 12 +-- 7 files changed, 242 insertions(+), 32 deletions(-) create mode 100644 axonn/models/transformers/modify_mixtral.py diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 5b75346..1cae5d4 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -169,6 +169,7 @@ def __init__( bias=True, skip_bias_add=False, init_method=None, + use_easy_api=True, **kwargs ): super(Linear, self).__init__() @@ -182,6 +183,7 @@ def __init__( self.in_features = in_features self.out_features = out_features + self.use_easy_api = use_easy_api if init_method is None: init_method = default_init_method @@ -257,8 +259,6 @@ def get_output_feature_size(self): def forward( self, x, - scatter_input=True, - gather_output=True, cache_weights_in_all_gather=False, ): # gather weights from depth parallel group @@ -266,7 +266,7 @@ def forward( weight = self.weight if not self.transpose: - if scatter_input: + if self.use_easy_api: x = Drop.apply(x, self.inner_group) x = AsyncLinear.apply( x, @@ -279,10 +279,10 @@ def forward( axonn.intra_layer.OVERLAP_ALL_REDUCE, False, ) - if gather_output: + if self.use_easy_api: x = Gather.apply(x, self.outer_group) else: - if scatter_input: + if self.use_easy_api: x = Drop.apply(x, self.outer_group) x = AsyncLinear.apply( @@ -296,14 +296,14 @@ def forward( axonn.intra_layer.OVERLAP_ALL_REDUCE, False, ) - if gather_output: + if self.use_easy_api: x = Gather.apply(x, self.inner_group) if self.bias is None: return x else: bias = self.bias - if gather_output: + if self.use_easy_api: bias = Gather.apply( bias, self.outer_group if not self.transpose else self.inner_group, diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py index db33249..7ce150b 100644 --- a/axonn/models/transformers/__init__.py +++ b/axonn/models/transformers/__init__.py @@ -5,6 +5,14 @@ monkey_patch_llama_with_axonn, reverse_monkey_patch_llama_with_axonn, ) +from .modify_mixtral import ( + monkey_patch_mixtral_with_axonn, + reverse_monkey_patch_mixtral_with_axonn, +) +from .modify_mistral import ( + monkey_patch_mistral_with_axonn, + reverse_monkey_patch_mistral_with_axonn, +) modify_dict = { "OPTForCausalLM": ( @@ -15,6 +23,14 @@ monkey_patch_llama_with_axonn, reverse_monkey_patch_llama_with_axonn, ), + "MixtralForCausalLM": ( + monkey_patch_mixtral_with_axonn, + reverse_monkey_patch_mixtral_with_axonn, + ), + "MistralForCausalLM": ( + monkey_patch_mistral_with_axonn, + reverse_monkey_patch_mistral_with_axonn, + ), } diff --git a/axonn/models/transformers/modify_llama.py b/axonn/models/transformers/modify_llama.py index 4a1d2f5..27496e4 100644 --- a/axonn/models/transformers/modify_llama.py +++ b/axonn/models/transformers/modify_llama.py @@ -1,6 +1,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN from axonn.intra_layer import Linear from typing import Optional +from axonn import axonn as ax def modified_attention_init(self, config, layer_idx: Optional[int] = None): @@ -31,19 +32,40 @@ def modified_attention_init(self, config, layer_idx: Optional[int] = None): ) self.q_proj = Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + use_easy_api=False, ) self.k_proj = Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, + use_easy_api=False, ) self.v_proj = Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, + use_easy_api=False, ) - self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + use_easy_api=False, + transpose=True, + ) + + assert self.num_heads % ax.config.G_intra_r == 0 + self.num_heads //= ax.config.G_intra_r + + assert self.num_key_value_heads % ax.config.G_intra_r == 0 + self.num_key_value_heads //= ax.config.G_intra_r + + assert self.hidden_size % ax.config.G_intra_r == 0 + self.hidden_size //= ax.config.G_intra_r + self._init_rope() @@ -52,9 +74,19 @@ def modified_mlp_init(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = Linear( + self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False + ) + self.up_proj = Linear( + self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False + ) + self.down_proj = Linear( + self.intermediate_size, + self.hidden_size, + bias=False, + use_easy_api=False, + transpose=True, + ) self.act_fn = ACT2FN[config.hidden_act] diff --git a/axonn/models/transformers/modify_mistral.py b/axonn/models/transformers/modify_mistral.py index 0597903..7815fd7 100644 --- a/axonn/models/transformers/modify_mistral.py +++ b/axonn/models/transformers/modify_mistral.py @@ -5,11 +5,14 @@ ACT2FN, ) from axonn.intra_layer import Linear +from axonn import axonn as ax -def modified_attention_init(self, config): +def modified_attention_init(self, config, layer_idx): super(MistralAttention, self).__init__() self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -18,6 +21,7 @@ def modified_attention_init(self, config): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True + self.attention_dropout = config.attention_dropout # This gives an attribute error, not sure why # self.attention_dropout = config.attention_dropout @@ -26,14 +30,28 @@ def modified_attention_init(self, config): f"hidden_size must be divisible by num_heads " f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." ) - self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.q_proj = Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False, use_easy_api=False + ) self.k_proj = Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_easy_api=False, ) self.v_proj = Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_easy_api=False, + ) + self.o_proj = Linear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + use_easy_api=False, + transpose=True, ) - self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = MistralRotaryEmbedding( self.head_dim, @@ -41,18 +59,44 @@ def modified_attention_init(self, config): base=self.rope_theta, ) + assert self.num_heads % ax.config.G_intra_r == 0 + self.num_heads //= ax.config.G_intra_r + + assert self.num_key_value_heads % ax.config.G_intra_r == 0 + self.num_key_value_heads //= ax.config.G_intra_r + + assert self.hidden_size % ax.config.G_intra_r == 0 + self.hidden_size //= ax.config.G_intra_r + def modified_mlp_init(self, config): super(MistralMLP, self).__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = Linear( + self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False + ) + self.up_proj = Linear( + self.hidden_size, self.intermediate_size, bias=False, use_easy_api=False + ) + self.down_proj = Linear( + self.intermediate_size, + self.hidden_size, + bias=False, + use_easy_api=False, + transpose=True, + ) self.act_fn = ACT2FN[config.hidden_act] def monkey_patch_mistral_with_axonn(): + original_inits = MistralAttention.__init__, MistralMLP.__init__ MistralAttention.__init__ = modified_attention_init MistralMLP.__init__ = modified_mlp_init + return original_inits + + +def reverse_monkey_patch_mistral_with_axonn(original_attention_init, original_mlp_init): + MistralAttention.__init__ = original_attention_init + MistralMLP.__init__ = original_mlp_init diff --git a/axonn/models/transformers/modify_mixtral.py b/axonn/models/transformers/modify_mixtral.py new file mode 100644 index 0000000..19695d9 --- /dev/null +++ b/axonn/models/transformers/modify_mixtral.py @@ -0,0 +1,104 @@ +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralRotaryEmbedding, + MixtralBlockSparseTop2MLP, + ACT2FN, +) +from axonn.intra_layer import Linear +from axonn import axonn as ax + + +def modified_attention_init(self, config, layer_idx): + super(MixtralAttention, self).__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + import logger + + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a" + f"`layer_idx` is not recommended and will" + f"lead to errors during the forward call if" + f"caching is used. Please make sure to provide a `layer_idx` " + f"when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads" + f"(got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=False, use_easy_api=False + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_easy_api=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_easy_api=False, + ) + self.o_proj = Linear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + use_easy_api=False, + transpose=True, + ) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + assert self.num_heads % ax.config.G_intra_r == 0 + self.num_heads //= ax.config.G_intra_r + + assert self.num_key_value_heads % ax.config.G_intra_r == 0 + self.num_key_value_heads //= ax.config.G_intra_r + + assert self.hidden_size % ax.config.G_intra_r == 0 + self.hidden_size //= ax.config.G_intra_r + + +def modified_mlp_init(self, config): + super(MixtralBlockSparseTop2MLP, self).__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = Linear(self.hidden_dim, self.ffn_dim, bias=False, use_easy_api=False) + self.w2 = Linear( + self.ffn_dim, self.hidden_dim, bias=False, use_easy_api=False, transpose=True + ) + self.w3 = Linear(self.hidden_dim, self.ffn_dim, bias=False, use_easy_api=False) + + self.act_fn = ACT2FN[config.hidden_act] + + +def monkey_patch_mixtral_with_axonn(): + original_inits = MixtralAttention.__init__, MixtralBlockSparseTop2MLP.__init__ + MixtralAttention.__init__ = modified_attention_init + MixtralBlockSparseTop2MLP.__init__ = modified_mlp_init + return original_inits + + +def reverse_monkey_patch_mixtral_with_axonn(original_attention_init, original_mlp_init): + MixtralAttention.__init__ = original_attention_init + MixtralBlockSparseTop2MLP.__init__ = original_mlp_init diff --git a/axonn/models/transformers/modify_opt.py b/axonn/models/transformers/modify_opt.py index 7043748..120855a 100644 --- a/axonn/models/transformers/modify_opt.py +++ b/axonn/models/transformers/modify_opt.py @@ -1,6 +1,7 @@ from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, ACT2FN import torch.nn as nn from axonn.intra_layer import Linear +from axonn import axonn as ax def modified_attention_init( @@ -25,10 +26,15 @@ def modified_attention_init( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = Linear(embed_dim, embed_dim, bias=bias, easy_api=False) + self.v_proj = Linear(embed_dim, embed_dim, bias=bias, easy_api=False) + self.q_proj = Linear(embed_dim, embed_dim, bias=bias, easy_api=False) + self.out_proj = Linear( + embed_dim, embed_dim, bias=bias, easy_api=False, transpose=True + ) + + assert self.num_heads % ax.config.G_intra_r == 0 + self.num_heads //= ax.config.G_intra_r def modified_decoder_init(self, config): @@ -48,8 +54,16 @@ def modified_decoder_init(self, config): self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine ) - self.fc1 = Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) - self.fc2 = Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) + self.fc1 = Linear( + self.embed_dim, config.ffn_dim, bias=config.enable_bias, easy_api=False + ) + self.fc2 = Linear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + easy_api=False, + transpose=True, + ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine ) diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index af09ab7..dbcb2a0 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -44,7 +44,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): ) # divide colunns of X along the inner tensor group # manually divide input - layer = Linear(in_features=H, out_features=H, bias=bias).cuda() + layer = Linear( + in_features=H, out_features=H, bias=bias, use_easy_api=easy_tp + ).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() # test if load state dict works with a sequential checkpoint @@ -54,7 +56,7 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): with torch.no_grad(): # parallel FW pass - Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_local = layer(X_local) Y_parallel = _gather(Y_local.clone(), 0, depth_group) if not easy_tp: # gather output manually Y_parallel = _gather(Y_local.clone(), 1, outer_group) @@ -101,9 +103,7 @@ def test_bw_pass( # parallel backward pass layer = Linear( - in_features=H, - out_features=H, - bias=bias, + in_features=H, out_features=H, bias=bias, use_easy_api=easy_tp ).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() @@ -133,7 +133,7 @@ def test_bw_pass( overlap_all_gather=comm_opt_level == 4, model_object_for_overlapping_allgathers=layer, ): - Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_local = layer(X_local) Y_local.backward(Y_local_grad) sync_gradients(layer)