-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_router_weights.py
165 lines (141 loc) · 6.37 KB
/
get_router_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
import numpy as np
from tqdm import tqdm
from data.utils import get_dataset, prepare_dataset
from contextlib import nullcontext
import torch
import models
import json
import os
import random
def iceildiv(x, y):
return (x + y - 1) // y
def get_as_batch(data, seq_length, batch_size, device='cpu', sample_size=None):
all_ix = list(range(0, len(data), seq_length))
assert all_ix[-1] + seq_length + 1 > len(data)
all_ix.pop()
if sample_size is not None:
all_ix = np.random.choice(all_ix, size=sample_size // seq_length, replace=False).tolist()
idx = 0
for idx in range(0, len(all_ix), batch_size):
ix = all_ix[idx:idx+batch_size]
assert all([idx + seq_length + 1 <= len(data) for idx in ix])
x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
if device != 'cpu':
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
yield x, y
def forward(self, idx, targets):
device = idx.device
b, t = idx.size()
assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
# forward the GPT model itself
index_shift = 0
cache_context = None
if getattr(self.transformer.wpe, "needs_iter", False):
idx, pos_emb_closure = self.transformer.wpe(idx, iter=iter) # position embeddings of shape (1, t, n_embd)
else:
idx, pos_emb_closure = self.transformer.wpe(idx) # position embeddings of shape (1, t, n_embd)
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
x = self.transformer.drop(x)
x = pos_emb_closure.adapt_model_input(x, start_index=index_shift)
for block in self.transformer.h_begin:
x = block(x, pos_emb_closure, cache_context, start_index=index_shift)
B, T, D = x.shape
active_indices = (index_shift + torch.arange(T, device=x.device)).unsqueeze(0).repeat(B, 1).view(B, T)
router_weights = None
final_mask = []
all_outputs = []
all_indices = []
total_expected_length = T * self.n_repeat
for block in self.transformer.h_mid:
block.attn.init_cache(total_expected_length)
sum_active = 0
all_router_weights = []
for rep_idx in range(1, self.n_repeat+1):
x_in = x
if self.depth_emb is not None:
x = self.depth_emb(x, indices=torch.full_like(active_indices, self.n_repeat - rep_idx))
for block in self.transformer.h_mid:
x = block(x, pos_emb_closure, cache_context, start_index=None, indices=active_indices)
x = self.transformer.ln_mid(x)
if router_weights is not None:
x = x_in * (1 - router_weights) + x * router_weights
all_router_weights.append(router_weights)
sum_active += x.shape[1]
if rep_idx < self.n_repeat:
is_final, selected_indices, router_weights = self.transformer.mod[rep_idx - 1](x, capacity_factor=1.0)
else:
is_final = x.new_ones((B, x.shape[1])) == 1.
selected_indices = x.new_ones((B, 0, 1)).long()
router_weights = None # Not gonna be used anymore
final_mask.append(is_final)
all_outputs.append(x)
all_indices.append(active_indices) # no need to clone, as we don't do inplace operations
x = x.take_along_dim(selected_indices, dim=1)
active_indices = active_indices.take_along_dim(selected_indices.squeeze(2), dim=1)
for block in self.transformer.h_mid:
block.attn.drop_cache()
return all_router_weights
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
if os.path.isdir(args.checkpoint):
args.checkpoint = args.checkpoint + '/'
checkpoint_dir, checkpoint_filename = os.path.split(args.checkpoint)
if not checkpoint_filename:
checkpoint_filename = 'ckpt.pt'
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
class Args:
pass
with open(os.path.join(checkpoint_dir, "summary.json")) as f:
config = Args()
config.__dict__ = json.load(f)['args']
model = models.make_model_from_args(config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(
{x: y
for x, y in checkpoint['model'].items()
if "attn.bias" not in x and "wpe" not in x}, strict=False)
data = get_dataset(config)
config.device = "cuda:0"
device_type = 'cuda' if 'cuda' in str(config.device) else 'cpu'
type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
device_type=device_type, dtype=config.dtype) # extra_args.dtype)
model.cuda()
model.eval()
all_router_weights = [[] for _ in range(model.n_repeat)]
for idx, (x, y) in tqdm(enumerate(get_as_batch(
data['train'],
config.sequence_length,
config.batch_size,
device=config.device,
sample_size=len(data['val']),
)),
total=iceildiv(
iceildiv(len(data['val']), config.sequence_length),
config.batch_size
)
):
with torch.no_grad():
with type_ctx:
router_weights = forward(model, x, y)
for i in range(2, len(router_weights)):
for j in range(len(router_weights[i])):
router_weights[i][j] = torch.minimum(router_weights[i][j], router_weights[i - 1][j])
for x, y in zip(all_router_weights, router_weights):
if y is not None:
x += y.detach().view(-1).tolist()
all_router_weights = [np.array(x) if x is not None else x for x in all_router_weights]
output_path = os.path.join(checkpoint_dir, "router_weights.npy")
with open(output_path, "wb") as f:
np.save(f, all_router_weights)
print(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument('--checkpoint', type=str, required=True)
args, rem_args = parser.parse_known_args()
main(args)