-
Notifications
You must be signed in to change notification settings - Fork 0
/
Utils.py
98 lines (73 loc) · 3.48 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import shutil
import torch
from torch import Tensor, eq, cat
from torch.nn.init import normal_, constant_
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
class ExperienceDataset(Dataset):
def __init__(self, image: Tensor, target: Tensor, device: str = "cpu"):
self.image, self.target = image, target
self.device = device
def __len__(self):
return self.image.size(0)
def __getitem__(self, idx):
return self.image[idx].to(self.device), self.target[idx].to(self.device)
def weights_init_normal(m):
class_name = m.__class__.__name__
if class_name.find("Conv") != -1:
normal_(m.weight.data, 0.0, 0.02)
elif class_name.find("BatchNorm2d") != -1:
normal_(m.weight.data, 1.0, 0.02)
constant_(m.bias.data, 0.0)
def compute_acc(predicted: Tensor, labels: Tensor):
"""
Compute the accuracy of model both for real and fake images
:param predicted: label predicted by discriminator
:param labels: true label
:return:
"""
correct = eq(predicted.argmax(dim=1), labels).sum().item()
return float(correct) / float(labels.size(0))
def generate_mnist_dataset(dataset_path: str, temp_path: str = "temp_mnist"):
if not os.path.exists(dataset_path):
print("Dataset not found...")
os.makedirs(dataset_path)
# Downloading the MNIST digits
transformations = Compose([Resize((32, 32)), ToTensor(), Normalize([0.5], [0.5])])
mnist_data = MNIST(temp_path, download=True, train=True, transform=transformations)
# a straightforward trick to apply all the transformation
print("Preprocessing numbers...")
dataloader = DataLoader(mnist_data, shuffle=False, batch_size=mnist_data.data.size(0))
x, y = next(iter(dataloader))
# we save all preprocessed digits into separate files
for n in range(10):
print("Saving number:", n)
idx = torch.where(y == n)[0]
torch.save([x[idx], y[idx]], os.path.join(dataset_path, f"num_{n}.pt"))
shutil.rmtree(temp_path)
else:
print("Dataset found...")
def custom_mnist(experiences: list[list[int]], max_sampling: int = -1,
dataset_path: str = "../single_digit") -> tuple[list, Tensor, Tensor]:
"""
Generates a custom version of mnist based on the experience's list.
We have introduced a max_sampling because in the "join retrain", we concatenate the current
experience (the digits inside the lists) with the buffer replay. Naive approach leads to an imbalanced
problem because the digits in the new experiences are roughly 6k (per classes) and the number digits
in the buffer is a hyperparameter. To balance the classes, we adopt "max_sampling" where we limit the
number of examples per classes in the new experiences.
:return:
"""
# check if the dataset is ready
generate_mnist_dataset(dataset_path)
for t_ in experiences: # iterate each experience
img_x, img_y = None, None
for n in t_: # for each experience concatenate the numbers
num_x, num_y = torch.load(f"{dataset_path}/num_{n}.pt")
if max_sampling != -1:
num_x, num_y = num_x[:max_sampling], num_y[:max_sampling]
img_x = num_x if img_x is None else cat([img_x, num_x])
img_y = num_y if img_y is None else cat([img_y, num_y])
yield t_, img_x, img_y