Skip to content

Commit

Permalink
change parallelize context to use AutoConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Feb 28, 2024
1 parent f975e58 commit 1bec8ca
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 32 deletions.
1 change: 1 addition & 0 deletions axonn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from . import models # noqa: F401
1 change: 1 addition & 0 deletions axonn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# For parallelize context manager use
from . import transformers # noqa: F401
43 changes: 25 additions & 18 deletions axonn/models/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
from contextlib import contextmanager
from modify_opt import monkey_patch_opt_with_axonn
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer
from modify_llama import monkey_patch_llama_with_axonn
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
from transformers import AutoConfig
from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn
from .modify_llama import (
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
)

modify_dict = {
"facebook/opt-125m": monkey_patch_opt_with_axonn,
"facebook/opt-350m": monkey_patch_opt_with_axonn,
"facebook/opt-1.3b": monkey_patch_opt_with_axonn,
"codellama/CodeLlama-70b-hf": monkey_patch_llama_with_axonn,
"codellama/CodeLlama-34b-hf": monkey_patch_llama_with_axonn,
"codellama/CodeLlama-13b-hf": monkey_patch_llama_with_axonn,
"codellama/CodeLlama-7b-hf": monkey_patch_llama_with_axonn,
"deepseek-ai/deepseek-coder-6.7b-base": monkey_patch_llama_with_axonn,
"meta-llama/Llama-2-7b-hf": monkey_patch_llama_with_axonn,
"OPTForCausalLM": (
monkey_patch_opt_with_axonn,
reverse_monkey_patch_opt_with_axonn,
),
"LlamaForCausalLM": (
monkey_patch_llama_with_axonn,
reverse_monkey_patch_llama_with_axonn,
),
}


@contextmanager
def parallelize(model_id):
original_inits = modify_dict[model_id]() # call to monkey patch
config = AutoConfig.from_pretrained(model_id)
architecture = config.architectures[0]
# config.architectures is a list, not sure what to do
# if it has multiple elements
assert (
architecture in modify_dict
), f"{architecture} has not been parallelized within AxoNN"

monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture]
original_attention_init, original_mlp_init = monkey_patch_fn()
try:
yield None
finally:
OPTAttention.__init__ = original_inits["OPTAttention"]
OPTDecoderLayer.__init__ = original_inits["OPTDecoderLayer"]
LlamaAttention.__init__ = original_inits["LlamaAttention"]
LlamaMLP.__init__ = original_inits["LlamaMLP"]
reverse_monkey_patch_fn(original_attention_init, original_mlp_init)
31 changes: 21 additions & 10 deletions axonn/models/transformers/modify_llama.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN
from axonn.intra_layer import Linear
from typing import Optional


def modified_attention_init(self, config):
def modified_attention_init(self, config, layer_idx: Optional[int] = None):
super(LlamaAttention, self).__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once( # noqa: F821
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " # noqa: E501
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " # noqa: E501
"when creating this class."
)

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand All @@ -16,9 +26,10 @@ def modified_attention_init(self, config):

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})."
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" # noqa: E501
f" and `num_heads`: {self.num_heads})."
)

self.q_proj = Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
Expand All @@ -32,9 +43,7 @@ def modified_attention_init(self, config):
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
self.o_proj = Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope()


Expand All @@ -50,10 +59,12 @@ def modified_mlp_init(self, config):


def monkey_patch_llama_with_axonn():
original_inits = {
"LlamaAttention": LlamaAttention.__init__,
"LlamaMLP": LlamaMLP.__init__,
}
original_inits = LlamaAttention.__init__, LlamaMLP.__init__
LlamaAttention.__init__ = modified_attention_init
LlamaMLP.__init__ = modified_mlp_init
return original_inits


def reverse_monkey_patch_llama_with_axonn(original_attention_init, original_mlp_init):
LlamaAttention.__init__ = original_attention_init
LlamaMLP.__init__ = original_mlp_init
10 changes: 6 additions & 4 deletions axonn/models/transformers/modify_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def modified_decoder_init(self, config):


def monkey_patch_opt_with_axonn():
original_inits = {
"OPTAttention": OPTAttention.__init__,
"OPTDecoderLayer": OPTDecoderLayer.__init__,
}
original_inits = OPTAttention.__init__, OPTDecoderLayer.__init__
OPTAttention.__init__ = modified_attention_init
OPTDecoderLayer.__init__ = modified_decoder_init
return original_inits


def reverse_monkey_patch_opt_with_axonn(original_attention_init, original_mlp_init):
OPTAttention.__init__ = original_attention_init
OPTDecoderLayer.__init__ = original_mlp_init

0 comments on commit 1bec8ca

Please sign in to comment.