From 5087268420ed58b94653ae3222123b30cf0ef051 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 27 Feb 2024 23:11:19 -0500 Subject: [PATCH] change parallelize context to use AutoConfig (#67) --- axonn/__init__.py | 1 + axonn/models/__init__.py | 1 + axonn/models/transformers/__init__.py | 43 +++++++++++++---------- axonn/models/transformers/modify_llama.py | 31 ++++++++++------ axonn/models/transformers/modify_opt.py | 10 +++--- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/axonn/__init__.py b/axonn/__init__.py index 8abdbbe..ab84c20 100644 --- a/axonn/__init__.py +++ b/axonn/__init__.py @@ -2,3 +2,4 @@ # See the top-level LICENSE file for details. # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from . import models # noqa: F401 diff --git a/axonn/models/__init__.py b/axonn/models/__init__.py index ea17b1e..da56eb3 100644 --- a/axonn/models/__init__.py +++ b/axonn/models/__init__.py @@ -1 +1,2 @@ # For parallelize context manager use +from . import transformers # noqa: F401 diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py index 27b6013..db33249 100644 --- a/axonn/models/transformers/__init__.py +++ b/axonn/models/transformers/__init__.py @@ -1,29 +1,36 @@ from contextlib import contextmanager -from modify_opt import monkey_patch_opt_with_axonn -from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer -from modify_llama import monkey_patch_llama_with_axonn -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP +from transformers import AutoConfig +from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn +from .modify_llama import ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, +) modify_dict = { - "facebook/opt-125m": monkey_patch_opt_with_axonn, - "facebook/opt-350m": monkey_patch_opt_with_axonn, - "facebook/opt-1.3b": monkey_patch_opt_with_axonn, - "codellama/CodeLlama-70b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-34b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-13b-hf": monkey_patch_llama_with_axonn, - "codellama/CodeLlama-7b-hf": monkey_patch_llama_with_axonn, - "deepseek-ai/deepseek-coder-6.7b-base": monkey_patch_llama_with_axonn, - "meta-llama/Llama-2-7b-hf": monkey_patch_llama_with_axonn, + "OPTForCausalLM": ( + monkey_patch_opt_with_axonn, + reverse_monkey_patch_opt_with_axonn, + ), + "LlamaForCausalLM": ( + monkey_patch_llama_with_axonn, + reverse_monkey_patch_llama_with_axonn, + ), } @contextmanager def parallelize(model_id): - original_inits = modify_dict[model_id]() # call to monkey patch + config = AutoConfig.from_pretrained(model_id) + architecture = config.architectures[0] + # config.architectures is a list, not sure what to do + # if it has multiple elements + assert ( + architecture in modify_dict + ), f"{architecture} has not been parallelized within AxoNN" + + monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture] + original_attention_init, original_mlp_init = monkey_patch_fn() try: yield None finally: - OPTAttention.__init__ = original_inits["OPTAttention"] - OPTDecoderLayer.__init__ = original_inits["OPTDecoderLayer"] - LlamaAttention.__init__ = original_inits["LlamaAttention"] - LlamaMLP.__init__ = original_inits["LlamaMLP"] + reverse_monkey_patch_fn(original_attention_init, original_mlp_init) diff --git a/axonn/models/transformers/modify_llama.py b/axonn/models/transformers/modify_llama.py index 1403f5a..4a1d2f5 100644 --- a/axonn/models/transformers/modify_llama.py +++ b/axonn/models/transformers/modify_llama.py @@ -1,10 +1,20 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN from axonn.intra_layer import Linear +from typing import Optional -def modified_attention_init(self, config): +def modified_attention_init(self, config, layer_idx: Optional[int] = None): super(LlamaAttention, self).__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( # noqa: F821 + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " # noqa: E501 + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " # noqa: E501 + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -16,9 +26,10 @@ def modified_attention_init(self, config): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})." + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" # noqa: E501 + f" and `num_heads`: {self.num_heads})." ) + self.q_proj = Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) @@ -32,9 +43,7 @@ def modified_attention_init(self, config): self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) - self.o_proj = Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias - ) + self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self._init_rope() @@ -50,10 +59,12 @@ def modified_mlp_init(self, config): def monkey_patch_llama_with_axonn(): - original_inits = { - "LlamaAttention": LlamaAttention.__init__, - "LlamaMLP": LlamaMLP.__init__, - } + original_inits = LlamaAttention.__init__, LlamaMLP.__init__ LlamaAttention.__init__ = modified_attention_init LlamaMLP.__init__ = modified_mlp_init return original_inits + + +def reverse_monkey_patch_llama_with_axonn(original_attention_init, original_mlp_init): + LlamaAttention.__init__ = original_attention_init + LlamaMLP.__init__ = original_mlp_init diff --git a/axonn/models/transformers/modify_opt.py b/axonn/models/transformers/modify_opt.py index 1fd6969..7043748 100644 --- a/axonn/models/transformers/modify_opt.py +++ b/axonn/models/transformers/modify_opt.py @@ -56,10 +56,12 @@ def modified_decoder_init(self, config): def monkey_patch_opt_with_axonn(): - original_inits = { - "OPTAttention": OPTAttention.__init__, - "OPTDecoderLayer": OPTDecoderLayer.__init__, - } + original_inits = OPTAttention.__init__, OPTDecoderLayer.__init__ OPTAttention.__init__ = modified_attention_init OPTDecoderLayer.__init__ = modified_decoder_init return original_inits + + +def reverse_monkey_patch_opt_with_axonn(original_attention_init, original_mlp_init): + OPTAttention.__init__ = original_attention_init + OPTDecoderLayer.__init__ = original_mlp_init