-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change parallelize context to use AutoConfig
- Loading branch information
1 parent
f975e58
commit 1bec8ca
Showing
5 changed files
with
54 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
# For parallelize context manager use | ||
from . import transformers # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,36 @@ | ||
from contextlib import contextmanager | ||
from modify_opt import monkey_patch_opt_with_axonn | ||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer | ||
from modify_llama import monkey_patch_llama_with_axonn | ||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP | ||
from transformers import AutoConfig | ||
from .modify_opt import monkey_patch_opt_with_axonn, reverse_monkey_patch_opt_with_axonn | ||
from .modify_llama import ( | ||
monkey_patch_llama_with_axonn, | ||
reverse_monkey_patch_llama_with_axonn, | ||
) | ||
|
||
modify_dict = { | ||
"facebook/opt-125m": monkey_patch_opt_with_axonn, | ||
"facebook/opt-350m": monkey_patch_opt_with_axonn, | ||
"facebook/opt-1.3b": monkey_patch_opt_with_axonn, | ||
"codellama/CodeLlama-70b-hf": monkey_patch_llama_with_axonn, | ||
"codellama/CodeLlama-34b-hf": monkey_patch_llama_with_axonn, | ||
"codellama/CodeLlama-13b-hf": monkey_patch_llama_with_axonn, | ||
"codellama/CodeLlama-7b-hf": monkey_patch_llama_with_axonn, | ||
"deepseek-ai/deepseek-coder-6.7b-base": monkey_patch_llama_with_axonn, | ||
"meta-llama/Llama-2-7b-hf": monkey_patch_llama_with_axonn, | ||
"OPTForCausalLM": ( | ||
monkey_patch_opt_with_axonn, | ||
reverse_monkey_patch_opt_with_axonn, | ||
), | ||
"LlamaForCausalLM": ( | ||
monkey_patch_llama_with_axonn, | ||
reverse_monkey_patch_llama_with_axonn, | ||
), | ||
} | ||
|
||
|
||
@contextmanager | ||
def parallelize(model_id): | ||
original_inits = modify_dict[model_id]() # call to monkey patch | ||
config = AutoConfig.from_pretrained(model_id) | ||
architecture = config.architectures[0] | ||
# config.architectures is a list, not sure what to do | ||
# if it has multiple elements | ||
assert ( | ||
architecture in modify_dict | ||
), f"{architecture} has not been parallelized within AxoNN" | ||
|
||
monkey_patch_fn, reverse_monkey_patch_fn = modify_dict[architecture] | ||
original_attention_init, original_mlp_init = monkey_patch_fn() | ||
try: | ||
yield None | ||
finally: | ||
OPTAttention.__init__ = original_inits["OPTAttention"] | ||
OPTDecoderLayer.__init__ = original_inits["OPTDecoderLayer"] | ||
LlamaAttention.__init__ = original_inits["LlamaAttention"] | ||
LlamaMLP.__init__ = original_inits["LlamaMLP"] | ||
reverse_monkey_patch_fn(original_attention_init, original_mlp_init) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters