-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsave_activations.py
86 lines (65 loc) · 2.08 KB
/
save_activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#%%
from nnsight import LanguageModel
from dictionary_learning.utils import hf_dataset_to_generator
from config import lm, activation_dim, layer, hf, n_ctxs
import torch as t
import einops
from tqdm import tqdm
import os
t.set_grad_enabled(False)
# %%
device = f'cuda:1'
model = LanguageModel(lm, dispatch=True, device_map=device)
submodule = model.transformer.h[layer]
data = hf_dataset_to_generator(hf)
# %%
batch_size = 256
num_batches = 128
ctx_len = 128
total_tokens = batch_size * num_batches * ctx_len
total_memory = total_tokens * activation_dim * 4
print(f"Total contexts: {batch_size * num_batches / 1e3:.2f}K")
print(f"Total tokens: {total_tokens / 1e6:.2f}M")
print(f"Total memory: {total_memory / 1e9:.2f}GB")
# %%
# These functions copied from buffer.py
def text_batch():
return [
next(data) for _ in range(batch_size)
]
def tokenized_batch():
texts = text_batch()
return model.tokenizer(
texts,
return_tensors='pt',
max_length=ctx_len,
padding=True,
truncation=True
)
def get_activations(input):
with t.no_grad():
with model.trace(input):
hidden_states = submodule.output.save()
hidden_states = hidden_states.value
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
hidden_states = hidden_states[input['attention_mask'] != 0]
return hidden_states
# %%
all_activations = []
all_tokens = []
for _ in tqdm(range(num_batches)):
batch = tokenized_batch()
all_tokens.append(batch['input_ids'].cpu())
activations = get_activations(batch)
activations = einops.rearrange(activations, "(b c) d -> b c d", b=batch_size)
all_activations.append(activations.cpu())
# %%
concatenated_activations = t.cat(all_activations)
concatenated_tokens = t.cat(all_tokens)
print(concatenated_activations.shape, concatenated_tokens.shape)
# %%
# save activations
os.makedirs('data', exist_ok=True)
t.save(concatenated_activations, f'data/gpt2_activations_layer{layer}.pt')
t.save(concatenated_tokens, f'data/gpt2_tokens.pt')