-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtree_predict.m
More file actions
52 lines (47 loc) · 1.92 KB
/
tree_predict.m
File metadata and controls
52 lines (47 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
function y_pred = tree_predict(tree, X_test, X_type)
% Predict the classes of the test set.
%
% tree is a classifier generated by tree_fit on a training set which
% has the same structure as X_test.
%
% X_test is a matrix that contains data that shall be classified. It
% is structured like the training set.
%
% X_type is a vector that specifies the type of features of the test
% set, which should be exactly the same as the vector used in tree_fit.
%
% Returns y_pred, a vector that contains the class predictions for
% X_test.
y_pred = [];
for i = 1:size(X_test, 1) % Iterate over instances
t = tree;
while isnumeric(t) == 0
if X_type(t{1}) == 1
found = 0;
for j = 2:size(t, 2) % Iterate over splits (subsets)
if t{j}{1} == X_test(i,t{1}) % The subset to which the instance belongs is found
t = t{j}{2}; % Traverse the tree
found = 1;
break
end
end
% If every category has been checked but the current
% instance does not belong to any of them, it means that
% the tree does not have enough information to assign a
% label to the instance. In this case, we assign a default
% label and skip to the following instance.
if found == 0
t = -1;
end
elseif X_type(t{1}) == 2
% NOTE: t{2}{1} and t{3}{1} store the same threshold value.
if X_test(i,t{1}) <= t{2}{1} % Value less than or equal to threshold
t = t{2}{2};
elseif X_test(i,t{1}) > t{3}{1} % Value higher than threshold
t = t{3}{2};
end
end
end
y_pred = [y_pred; t];
end
end