-
Notifications
You must be signed in to change notification settings - Fork 0
/
TransferLearningResnet101.m
81 lines (67 loc) · 2.91 KB
/
TransferLearningResnet101.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
%<a>href="https://www.mathworks.com/help/nnet/examples/transfer-learning-using-googlenet.html"</a>
net = resnet101;
testImages = imageDatastore('ResizedImg\images\images\test','IncludeSubfolders',true,'LabelSource','foldernames');
trainingImages = imageDatastore('ResizedImg\images\images\train','IncludeSubfolders',true,'LabelSource','foldernames');
validationImages = imageDatastore('ResizedImg\images\images\validate','IncludeSubfolders',true,'LabelSource','foldernames');
testImages.ReadFcn = @(loc)imresize(imread(loc),[224,224]);
trainingImages.ReadFcn = @(loc)imresize(imread(loc),[224,224]);
validationImages.ReadFcn = @(loc)imresize(imread(loc),[224,224]);
%Extract the layer graph from the trained network and plot the layer graph
lgraph = layerGraph(net);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
%transfer layers to new network
%remove the last three
lgraph = removeLayers(lgraph, {'fc1000','prob','ClassificationLayer_predictions'});
numClasses = numel(categories(trainingImages.Labels));
newLayers = [
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',20,'BiasLearnRateFactor', 20)
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
%connect new ones back
lgraph = connectLayers(lgraph,'pool5','fc');
figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])
%train network
miniBatchSize = 10;
numIterationsPerEpoch = floor(numel(trainingImages.Labels)/miniBatchSize);
options = trainingOptions('sgdm',...
'MiniBatchSize',miniBatchSize,...
'MaxEpochs',4,...
'InitialLearnRate',1e-4,...
'Verbose',false,...
'Plots','training-progress',...
'ValidationData',validationImages,...
'ValidationFrequency',numIterationsPerEpoch);
netTransfer = trainNetwork(trainingImages,lgraph,options);
%Classify validation images
predictedLabels = classify(netTransfer,testImages);
testLabels = testImages.Labels;
accuracy = mean(predictedLabels == testLabels)
ROC = [];
for threshold = linspace(0,1,101)
tp = sum((score(:,2)>threshold)&(testLabels=='sunset'));
fp = sum((score(:,2)>threshold)&(testLabels=='nonsunset'));
fn = sum((score(:,2)<threshold)&(testLabels=='sunset'));
tn = sum((score(:,2)<threshold)&(testLabels=='nonsunset'));
tpr = tp/(tp+fn);
fpr = fp/(fp+tn);
ROC = [ROC; threshold,tpr,fpr,tp,tn,fp,fn];
end
figure(140);
hold on;
threshold = ROC(:,1);
FPR = ROC(:,3);
TPR = ROC(:,2);
plot(FPR,TPR, 'b-', 'LineWidth', 2);
plot(FPR,TPR, 'b.', 'MarkerSize', 6, 'LineWidth', 2);
grid;
title(sprintf('ROC for Alexnet Transfer Learning Using Softmax Layer'), 'fontSize', 18);
xlabel('False Positive Rate', 'fontWeight', 'bold');
ylabel('True Positive Rate', 'fontWeight', 'bold');
axis([0 1 0 1]);
dist2 = FPR.^2 + (1-TPR).^2;
best_threshold = threshold(find(dist2 == min(dist2)));
save('resnet_results.mat','netTransfer','score','ROC','best_threshold');