-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_multiple_classifiers.m
144 lines (124 loc) · 5.54 KB
/
test_multiple_classifiers.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Copyright (c) 2021, Christopher E. Arcadia (CC BY-NC 4.0) %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Script to Test Multiple Classifiers
% configure script options
multitest = struct();
multitest.option.iterations = 100;
multitest.option.train_epochs = 700;
multitest.option.class_count = 5;
multitest.option.random_seed = 0;
multitest.option.evaluate_diversity = true;
% load settings and data
configure_options
set_parameters
load_data
% override some defaults
option.epochs = multitest.option.train_epochs;
option.verbose = false;
option.verbose_simulation = false;
option.verbose_training = false;
option.show_all_images = false;
% initialize output fields
num_class = length(database.classnames);
num_iter = multitest.option.iterations;
multitest.training = cell(num_iter,1);
multitest.simulation = cell(num_iter,1);
multitest.class.indices = 1:num_class;
multitest.class.names = database.classnames;
multitest.class.mask = zeros(num_iter,num_class);
multitest.summary = table();
multitest.summary.iteration = zeros(num_iter,1);
multitest.summary.accuracy = zeros(num_iter,1);
multitest.summary.accuracyTrain = zeros(num_iter,1);
multitest.summary.accuracyTest = zeros(num_iter,1);
multitest.summary.classes = cell(num_iter,1);
multitest.summary.diversity = zeros(num_iter,1);
% generate the distance matrix between class averages
if multitest.option.evaluate_diversity
analyze_dataset
end
% test multiple classifiers
rng(multitest.option.random_seed) % seed the random number generator for repeatability
for iter = 1:num_iter
disp(['Starting test iteration ' num2str(iter) ' of ' num2str(num_iter) '.'])
rng(multitest.option.random_seed + iter) % reseed the random number generator
% randomly select classes for this iteration
class_indices = randperm(num_class);
class_indices = class_indices(1:multitest.option.class_count);
class_mask = false(1,num_class);
class_mask(class_indices) = true;
option.classes = database.classnames(class_indices);
class_names = strjoin(option.classes,', ');
disp([' - classes selected: ' class_names])
% compare class averages as a crude measure of input diversity
if multitest.option.evaluate_diversity
interdist = zeros(multitest.option.class_count,multitest.option.class_count);
for i=1:multitest.option.class_count
for j=1:i
interdist(i,j) = avg.distance(class_indices(i),class_indices(j));
end
end
class_diversity = min(interdist(interdist~=0));
clear i k interdist
disp([' - class diversity (minimum distance between class averages): ' num2str(class_diversity)])
end
% run test using the selected classes
select_data
train_network
simulate_experiment
disp([' - classification accuracy: ' num2str(simulation.accuracy.all*100) '%'])
% store output test results
multitest.training{iter} = training;
multitest.simulation{iter} = simulation;
multitest.class.mask(iter,:) = class_mask;
multitest.summary.iteration(iter) = iter;
multitest.summary.accuracy(iter) = simulation.accuracy.all;
multitest.summary.accuracyTrain(iter) = simulation.accuracy.train;
multitest.summary.accuracyTest(iter) = simulation.accuracy.test;
multitest.summary.classes{iter} = class_names;
if multitest.option.evaluate_diversity
multitest.summary.diversity(iter) = class_diversity;
end
% close all autogenerated figures
close all
end
clear iter num_iter num_class class_indices class_names class_mask simulation training
disp('Completed all test iterations.')
% check uniqueness of class selections
[~,ind_unique] = unique(multitest.class.mask,'rows');
num_unique = length(ind_unique);
if num_unique ~= multitest.option.iterations
disp(['Note that some iterations are redundant (' num2str(multitest.option.iterations-num_unique) ' out of ' num2str(multitest.option.iterations) ').'])
end
multitest.class.uniques = false(multitest.option.iterations,1);
multitest.class.uniques(ind_unique) = true;
multitest.summary.unique = multitest.class.uniques;
clear ind_unique num_unique
% review classification results
disp(multitest.summary)
disp(['Mean Classification Accuracy: ' sprintf('%0.2f',100*mean(multitest.summary.accuracy(multitest.class.uniques))) '% overall, ' ...
'' sprintf('%0.2f',100*mean(multitest.summary.accuracyTrain(multitest.class.uniques))) '% on training data, ' ...
'' sprintf('%0.2f',100*mean(multitest.summary.accuracyTest(multitest.class.uniques))) '% on test data.' ...
])
% visualize classification results
figure('color','w','name','Multiple Main Test Results');
subplot(2,1,1)
histogram(multitest.summary.accuracy(multitest.class.uniques)*100)
xlabel('Accuracy [%]')
ylabel('Count')
hold on
plot([1,1]*(1/multitest.option.class_count)*100,ylim,'--','linewidth',2)
hold off
subplot(2,1,2)
if multitest.option.evaluate_diversity
plot(multitest.summary.diversity(multitest.class.uniques),multitest.summary.accuracy(multitest.class.uniques)*100,'.','markersize',15)
xlabel('Class Diversity')
else
plot(multitest.summary.iteration(multitest.class.uniques),multitest.summary.accuracy(multitest.class.uniques)*100,'.','markersize',15)
xlabel('Network Index')
end
ylabel('Accuracy [%]')
axis tight
xlim(xlim+[-1,1]*diff(xlim)*1/50)
ylim(ylim+[-1,1]*diff(ylim)*1/10)