-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpatchcore_model.py
307 lines (236 loc) · 13.8 KB
/
patchcore_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
This file contains the principal classes needed to perform the classification task
PatchCore represent the main class, as it is the main subject of this implementation, and is a subclass of the KNNExtractor class
The PatchCore class overrides the fit and predict methods of the KNNExtractor class to perform a specific type of image classification using a coreset sampling approach.
"""
import torch
import torch.nn.functional as F
from torch import tensor
import numpy as np
import timm
from typing import Tuple
from torch.utils.data import DataLoader
from torchvision import models
from utils import get_coreset_idx, plot_roc
from sklearn.metrics import roc_auc_score
class KNNExtractor(torch.nn.Module):
def __init__(
self,
featExtract_model_name: str = "wideresnet50",
output_indices: Tuple = (2,3),
pool_last: bool = False,
depth: int = 20,
width_multiplier : float = 1.0,
num_classes : int = 15,
dropout: float = 0.2,
featExtract_model = None
):
self.depth = depth
self.width_multiplier = width_multiplier
self.num_classes = num_classes
self.dropout = dropout
self.output_indices = output_indices
self.pool_last = pool_last
self.backbone_name = featExtract_model_name
self.preprocess = None
"""
KNNExtractor is a class used as a superclass to the PatchCore. It is implemented to initialized a base for the PatchCore implementation
It is based on the Pytorch library
7 parameters characterized this class :
- feature_extractor_name : String variable that corresponds to the neural network model name that will be used as the feature extractor
- output_indices : This Tuple represent the output indices to be used as the feature extractor
- pool_last : Boolean that determines if an adaptive average pooling will be apply to the last feature extracted by the feature extractor or not
- depth : Corresponds to the number of layers in each block of the WideResNet architecture
- width_multiplier : Parameters that will be used to controls the width of the layer, this value will be multiplied to the number of channels in each layer
- num_classes : Number of output classes in the classification task
- dropout : Dropout rate. It is used as a regularization technique to prevent overfitting
After the base nn.Module class initialization, the __init__ method creates an instance of a neural network model as the feature extractor (here ResNet50),
sets up the pooling layer if specified and the device to be used
"""
super().__init__()
# Creation of the WideResNet50 neural network
self.featExtract_model = timm.create_model(
"wide_resnet50_2",
out_indices=output_indices,
features_only=True,
pretrained=True,
)
for param in self.featExtract_model.parameters():
param.requires_grad = False
# The feature extractor model is set to evaluation mode
self.featExtract_model.eval()
# If the variable pool_last is equal to True, an adaptative average pooling layer is applied
self.pool = torch.nn.AdaptiveAvgPool2d(1) if pool_last else None
self.featExtract_model_name = featExtract_model_name
self.output_indices = output_indices
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# The computation device is chosen and used
self.featExtract_model = self.featExtract_model.to(self.device)
def __call__(self, tensor: torch.Tensor) -> Tuple:
"""
The __call__ method in python, allows an object of a class to be called like a function
This call method will allow to extract features from the input tensor using the neural network model predifined as the feature extractor
It takes in parameters the tensors we want to extract features from.
The output of this call method is a tuple composed of the extracted features
"""
# The following disable the gradient tracking, they would not be stored in memory, there is no need to compute gradient in our case
with torch.no_grad():
# Extract the features from the input tensor using the feature extractor
extracted_features = self.featExtract_model(tensor.to(self.device))
# If the 'pool_last' attribute is set to True,
if self.pool:
# It return a tuple containing all of the extracted feature maps except the last one,
# and the last feature map after being passed through the 'AdaptiveAvgPool2d' layer
extracted_features = extracted_features[:-1], self.pool(extracted_features[-1]).to("cpu")
# If the 'pool_last' attribute is set to False
else:
# It returns a tuple containing all of the extracted feature
extracted_features = [x.to("cpu") for x in extracted_features]
return extracted_features
def evaluate(self, test_data: DataLoader, cls) -> Tuple[float, float]:
"""
The evaluate method is used in this class to evaluate the model performance
The roc_auc method is called and used in this method
The model performance is evaluate on the test data
The output is the roc_auc scores for the image-level
"""
# Initialization of the lists
image_predictions = []
image_labels = []
# We iterate over the test data to evaluate the performance
for sample, mask, label in test_data:
# Retrieve the prediction for the current sample
z_score = self.predict(sample)
# The z_score is added to the list of the images_predictions
image_predictions.append(z_score.numpy())
# And the label to the list images_labels
image_labels.append(label[0].numpy())
# Then we compute the roc_auc score for the images predictions, with the roc_auc method
# image_roc_auc = self.roc_auc(image_labels, image_predictions)
print("Sklearn roc_auc metric computation loading ...")
image_roc_auc = roc_auc_score(image_labels, image_predictions)
# plot the roc curve
plot_roc(image_labels, image_predictions, cls)
# Finally, the output is the roc_auc value as a tuple
return image_roc_auc
def get_parameters(self, extra_params : dict = None) -> dict:
return {
"backbone_name": self.backbone_name,
"out_indices": self.output_indices,
**extra_params,
}
class PatchCore(KNNExtractor) :
"""
This is the PatchCore class, which is a subclass of the KNNExtractor class
This class is the main one and implement the main functionality of the anomaly detection
The __init__ method of this class override the one from the KNNExtractor class and have 3 more parameters :
- f_coreset : This is a float that correspond to the percentage defined and to use for the coreset sampling method.
It represents the fraction of the number of training samples that we want to keep from the Memory bank
- backbone_name : This string specify the name of the neural network to use as a backbone
- coreset_eps : This float corresponds to the sparse projection paramater and is used for selection a random subset of points as the coreset
The PatchCore override the evaluate method of the KNNExtractor class and have two more functions : the fit() and the predict() that will
respectively fit the defined model to the training data ant predict the value on the test data using the trained model
"""
def __init__(
self,
f_coreset: float = 0.01,
backbone_name: str = "wideresnet50",
coreset_eps: float = 0.90,
out_indices: Tuple = None,
pool_last: bool = False,
):
# Initialize the parent class
super().__init__(backbone_name, output_indices=(2,3), pool_last=pool_last)
# Store the additional parameters
self.f_coreset = f_coreset
self.coreset_eps = coreset_eps
self.image_size = 224
self.average = torch.nn.AvgPool2d(3, stride=1)
self.n_reweight = 3
self.memory_bank = []
self.resize = None
self.featExtract_model_name = backbone_name
def fit(self, train_dl):
"""
This method is used to fit the model using the training data. It will extract the features from the samples of the training data using the backbone network
and then store the obtained patches in a patch memory. Then this memory will be subsampled using the coreset sampling method
"""
print("Beginning of the fit method...")
# First, we initialize an empty list that will be used to store the patches created from the training data
memory_bank = []
# This variable will be used to store the size of the largest feature maps
largest_fmap_size = None
# The function iterate over the training data to train the model
for sample in train_dl:
# The features maps are extracte from the sample ( the self(sample) call the __call__ method from the KNNExtractor class)
feature_maps = self(sample[0])
# If the size of the largest feature map has not been set,
# We have to set this size to the size of the current feature map
if largest_fmap_size is None:
largest_fmap_size = feature_maps[0].shape[-2:]
# Using the Adaptative Average Pooling, the resize layer with the size of the largest feature map is initialized (TO VERIFY)
self.resize = torch.nn.AdaptiveAvgPool2d(largest_fmap_size)
# The features extracted are resized using this computed value
resized_maps = [self.resize(self.average(fmap)) for fmap in feature_maps]
# A patch is created, it is composed of several features maps
patch = torch.cat(resized_maps, 1)
# A reshape is needed, the patch is flattenned and then transposed
patch = patch.reshape(patch.shape[1], -1).T
# Finally, add the patch to the memory bank
memory_bank.append(patch)
# The patches in the memory bank are then concatenated in a single tensor
self.memory_bank = torch.cat(memory_bank, 0)
# If the chosen percentage to use as a coreset is less than 1
# Then select a subset of the patch library as the coreset
if self.f_coreset < 1:
self.coreset_idx = get_coreset_idx(
self.memory_bank,
n=int(self.f_coreset * self.memory_bank.shape[0]),
eps=self.coreset_eps,
)
self.memory_bank = self.memory_bank[self.coreset_idx]
def predict(self, sample):
"""
The predict method is applied to the test data to classify an image as a good one or as an outlier (detected as an anomaly)
First, the patch is extracted from the sample and then a distance is computed between the patch and the patches in the memory bank
The patch from the memory with the smallest distance is identified.
Then a reweighting is applied to this patch based on the distance between this patch and the nearest neighbors in the memory bank.
The resulting weight is then returned.
"""
# These first steps are the same that are applied on the training data when fitting the model
feature_maps = self(sample)
resized_maps = [self.resize(self.average(fmap)) for fmap in feature_maps]
patch = torch.cat(resized_maps, 1)
patch = patch.reshape(patch.shape[1], -1).T
# Compute distancew between every patch of the sample with every feature of the memory bank
dist = torch.cdist(patch,self.memory_bank)
# Find for every path of the sample the closest patch in the memory bank, and the corresponding distance
min_val, min_idx = torch.min(dist,dim=1)
# Among the list of the distances from the nearezs Memory Bank patch, we take the index of the biggest
s_idx = torch.argmax(min_val)
# Among the list of the distances from the neareast Memory Bank patch, we take the value of the biggest
s_star = torch.max(min_val)
# Anomalous patch (the one with the biggest minimum distance)
m_test = patch[s_idx].unsqueeze(0)
# Closest neighbours (in the memory bank)
m_star = self.memory_bank[min_idx[s_idx]].unsqueeze(0)
# Find knn to m_star pt.1 | Computes the distances between the closest neighbour and all the other patches in the Memory Bank
w_dist = torch.cdist(m_star,self.memory_bank)
# Pt.2 | Take the indexes of the top l neighbours in the Memory Bank
_, nn_idx = torch.topk(w_dist, k=self.n_reweight,largest=False)
# Calculates the distance between the "worst" test patch and its knn patch in the Memory Bank
m_star_knn = torch.linalg.norm(m_test - self.memory_bank[nn_idx[0, 1:]],dim=1)
# Softmax normalization trick as in transformers.
# As the patch vectors grow larger, their norm might differ a lot.
# exp(norm) can give infinities.
# Apply the equation 7 from the paper
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))
s = w * s_star
# Return the anomaly sccore
return s
def get_parameters(self):
return super().get_parameters({
"f_coreset": self.f_coreset,
"n_reweight": self.n_reweight,
})