-
Notifications
You must be signed in to change notification settings - Fork 1
/
adn_pmnist.py
129 lines (116 loc) · 5.03 KB
/
adn_pmnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
'''
Script to train active dendrite neuron networks on 10 tasks of PermutedMNIST.
'''
import os
import dendritic_mlp as D
from datasets.permutedMNIST import make_loader
from sparse_weights import rezero_weights
import numpy
import torch
from torch import nn
from tqdm import tqdm
num_epochs = 3
train_bs = 256
test_bs = 512
num_tasks = 10
conf = dict(
input_size=784,
output_size=10,
hidden_sizes=[2048, 2048],
dim_context=num_tasks, # 784,
kw=True,
kw_percent_on=0.05,
weight_sparsity=0.5,
context_percent_on=0.05, # used for weight init, but paper reported using dense context
num_segments=10
)
if __name__ == "__main__":
seeds = [33, 34, 35, 36, 37]
# used for creating avg over all seed runs
all_single_acc = []
all_avg_acc = []
for seed in seeds:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = D.DendriticMLP(**conf)
model = model.to(device)
train_loader = make_loader(num_tasks, seed, train_bs, train=True)
test_loader = make_loader(num_tasks, seed, test_bs, train=False)
# Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
criterion = nn.CrossEntropyLoss()
# @TODO use Euclidian distance to infer which task's input at test time
# calculate all the context vectors, avg's of each tasks' inputs
"""
contexts = []
for curr_task in range(num_tasks):
train_loader.sampler.set_active_tasks(curr_task)
_sum = 0
for batch_idx, (imgs, _) in enumerate(train_loader):
imgs = imgs.to(device)
imgs = imgs.flatten(start_dim=1)
_sum += imgs.sum(0)
# hardcoded for mnist train
avg_task_input = _sum / 6000
contexts.append(avg_task_input)
"""
# records latest task's test accuracy
single_acc = []
avg_acc = []
for curr_task in range(num_tasks):
train_loader.sampler.set_active_tasks(curr_task)
for e in range(num_epochs):
model.train()
for batch_idx, (imgs, targets) in enumerate(train_loader):
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
#context = contexts[curr_task]
context = torch.zeros([num_tasks])
context[curr_task] = 1
context = context.to(device)
context = context.unsqueeze(0)
context = context.repeat(imgs.shape[0], 1)
imgs = imgs.flatten(start_dim=1)
output = model(imgs, context)
pred = output.data.max(1, keepdim=True)[1]
train_loss = criterion(output, targets)
train_loss.backward()
# print(f"train_loss: {train_loss.item()}")
optimizer.step()
model.apply(rezero_weights)
model.eval()
total_correct = 0
with torch.no_grad():
for t in range(curr_task+1):
latest_correct = 0
test_loader.sampler.set_active_tasks(t)
for imgs, targets in test_loader:
imgs, targets = imgs.to(device), targets.to(device)
#context = contexts[t]
context = torch.zeros([num_tasks])
context[t] = 1
context = context.to(device)
context = context.unsqueeze(0)
context = context.repeat(imgs.shape[0], 1)
imgs = imgs.flatten(start_dim=1)
output = model(imgs, context)
pred = output.data.max(1, keepdim=True)[1]
latest_correct += pred.eq(targets.data.view_as(pred)).sum().item()
total_correct += latest_correct
# record latest trained task's test acc
if t == curr_task:
# hardcoded number of test examples per mnist digit/class
single_acc.append(100 * latest_correct / 10000)
# print(f"correct: {total_correct}")
acc = 100. * total_correct * num_tasks / (curr_task+1) / len(test_loader.dataset)
avg_acc.append(acc)
# print(f"[t:{t} e:{e}] test acc: {acc}%")
print("single accuracies: ", single_acc)
print("running avg accuracies: ", avg_acc)
all_single_acc.append(single_acc)
all_avg_acc.append(avg_acc)
# figure out average wrt all seeds
avg_seed_acc = list(map(lambda x: sum(x)/len(x), zip(*all_avg_acc)))
avg_single_acc = list(map(lambda x: sum(x)/len(x), zip(*all_single_acc)))
print("seed avg running avg accuracies: ", avg_seed_acc)
print("seed avg single accuracies: ", avg_single_acc)
print("SCRIPT FINISHED!")