forked from shadowpa0327/Palu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
110 lines (90 loc) · 4.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import importlib
import numpy as np
import random, torch
from functools import reduce
from palu.model import HeadwiseLowRankModule
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
# Set seed for reproducibility
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def get_model_numel(model):
param_cnt = 0
for name, module in model.named_modules():
if hasattr(module, '_nelement'):
param_cnt += module._nelement()
return param_cnt
def get_model_size(model):
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_all_mb = (param_size + buffer_size) / 1024**3
return size_all_mb
def get_module_by_name(module, module_name):
names = module_name.split(sep='.')
return reduce(getattr, names, module)
def dump_to_huggingface_repos(model, tokenizer, save_path, args):
tokenizer.save_pretrained(save_path)
#model.generation_config = Gene
#if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
#NOTE(brian1009): Ad-hoc fixing the bug in Vicuna
# model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)
model.save_pretrained(save_path)
config = model.config.to_dict()
config["head_wise_ranks"] = {}
for name, module in model.named_modules():
if isinstance(module, HeadwiseLowRankModule):
config["head_wise_ranks"][name] = module.ranks
if "llama" in model.config._name_or_path.lower() or model.config.model_type == "llama":
config["model_type"] = "palullama"
config['architectures'] = ['PaluLlamaForCausalLM']
elif "mistral" in model.config._name_or_path.lower():
config["model_type"] = "palumistral"
config['architectures'] = ['PaluMistralForCausalLM']
elif "qwen2" in model.config._name_or_path.lower():
config["model_type"] = "paluqwen2"
config['architectures'] = ['PaluQwenForCausalLM']
else:
raise NotImplementedError
config["original_model_name_or_path"] = model.config._name_or_path
import json
json.dump(config, open(save_path + "/config.json", "w"), indent=2)
def load_model_and_tokenizer(model_name_or_path, use_flash_attn2=False):
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
attn_implementation="flash_attention_2" if use_flash_attn2 else "sdpa",
)
model.eval()
# Fix the bug in generation configs
#TODO: Add reference to the issue that also faced this bug
if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
model.generation_config.do_sample = True
return model, tokenizer
def add_common_args(parser: argparse.ArgumentParser):
parser.add_argument('--model_name_or_path', type=str, help='model to load')
parser.add_argument('--lt_bits', type=int, help='Bits for low_rank latents. \
When lt_bits is less than 16, we quantize the low_rank latents in low_bits', default=16)
parser.add_argument('--lt_group_size', type=int, help='Group size for low_rank latents', default=0)
parser.add_argument('--lt_sym', action='store_true', help='Symmetric quantization for low_rank latents')
parser.add_argument('--lt_clip_ratio', type=float, help='Clip ratio for low_rank latents', default=1.0)
parser.add_argument('--lt_hadamard', action='store_true', help='Apply Hadamard transform to low_rank latents')
parser.add_argument('--flash2', action='store_true', help='whether to use flash-attention2')
return parser