-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
107 lines (83 loc) · 3.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
import random
from continuum.datasets import MNIST, CIFAR10, CIFAR100, Core50, CUB200, TinyImageNet200, OxfordFlower102
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
def set_seeds(seed):
# Reproducibility seeds
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def update_memory(features, pseudo_labels, memory):
for y in pseudo_labels.unique():
indices = torch.where(pseudo_labels == y)[0]
y_prototype = features[indices].mean(dim=0).unsqueeze(0)
if memory.numel() == 0:
memory = y_prototype
else:
memory = torch.cat((memory, y_prototype), dim=0)
return memory
def get_dset(data_path, dset_name, train, download=True):
""" Returns a tuple with the dataset object, the total number of
classes and the shape of the original data
"""
if dset_name == 'CIFAR100':
return (CIFAR100(data_path=data_path, download=download, train=train), 100, (3, 32, 32))
elif dset_name == 'CIFAR10':
return (CIFAR10(data_path=data_path, download=download, train=train), 10, (3, 32, 32))
elif dset_name == 'MNIST':
return (MNIST(data_path=data_path, download=download, train=train), 10, (1, 28, 28))
elif dset_name == 'CUB200':
return (CUB200(data_path=data_path, download=download, train=train), 200, (3, 224, 224))
elif dset_name == 'OxfordFlower102':
return (OxfordFlower102(data_path=data_path, download=download, train=train), 102, (3, '?', '?'))
elif dset_name == 'TinyImageNet200':
return (TinyImageNet200(data_path=data_path, download=download, train=train), 200, (3, 64, 64))
elif dset_name == 'Core50':
return (Core50(data_path=data_path, download=download, train=train), 50, (3, 224, 224))
else:
raise NotImplementedError
def get_transform(dset_name, resize=(224,224)):
""" Returns the proper transformations """
if dset_name == 'CIFAR100':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'CIFAR10':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'CUB200':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'Core50':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'TinyImageNet200':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'OxfordFlower102':
return [transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
elif dset_name == 'MNIST':
return [transforms.ToTensor()]
else:
raise NotImplementedError
def forward_taskset(opt, model, taskset):
with torch.no_grad():
for n, (x, y, t) in enumerate(DataLoader(taskset, batch_size=opt.batch_size)):
if opt.cuda:
x = x.to(opt.gpu)
out = model(x).cpu()
if n==0:
features = out
labels = y
else:
features = torch.cat((features, out), dim=0)
labels = torch.cat((labels, y), dim=-1)
return features, labels