-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathXOR_network_with_one_hidden_unit.m
123 lines (99 loc) · 3.69 KB
/
XOR_network_with_one_hidden_unit.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Set up
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% creating the data for a XOR function - already with the bias unit in the first
% column
X = [
1, 0, 0;
1, 0, 1;
1, 1, 0;
1, 1, 1;
];
% target vector of correct values
t = [0; 1; 1; 0];
% initial weights for input neurons and the middle neuron feeding into
% the output unit
theta_out = rand(4, 1);
% initial weights for connections from input neurons to the neuron in the middle
theta_mid = rand(3, 1);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% performing the actual calculations
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
fprintf('Count of incorrectly classified examples before training: %d\n', ...
get_error_count(theta_out, theta_mid, t, X))
costs = [];
while isempty(costs) || costs(end) > 0.1
costs(end + 1) = get_cost(theta_out, theta_mid, t, X);
[grad_out, grad_mid] = get_grad(theta_out, theta_mid, t, X);
theta_out = theta_out - grad_out;
theta_mid = theta_mid - grad_mid;
end
% Plot costs during training and provide summary information
plot(costs);
title('Cost vs iteration');
ylabel('cost');
xlabel('iteration #');
fprintf('Training completed after %d iterations\n', size(costs, 2));
fprintf('Count of incorrectly classified examples after training: %d\n', ...
get_error_count(theta_out, theta_mid, t, X))
fprintf('\n <press any key to continue>\n')
pause;
close;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% definitions of functions used in the calculations above
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [z_mid, z_out, a_mid, a_out] = forward_prop(theta_out, theta_mid, X)
z_mid = X * theta_mid;
a_mid = [X, sigmoid(z_mid)];
z_out = a_mid * theta_out;
a_out = sigmoid(z_out);
end
function h = sigmoid(z)
h = 1 ./ (1 + exp(-z));
end
function cost = get_cost(theta_out, theta_mid, t, X)
[~, ~, ~, a_out] = forward_prop(theta_out, theta_mid, X);
cost = 0.5 * (t - a_out)' * (t - a_out);
end
function count = get_error_count(theta_out, theta_mid, t, X)
[~, ~, ~, a_out] = forward_prop(theta_out, theta_mid, X);
count = sum((a_out > 0.5) ~= t);
end
function [grad_out, grad_mid] = get_grad(theta_out, theta_mid, t, X)
grad_out = zeros(size(theta_out));
grad_mid = zeros(size(theta_mid));
[~, ~, a_mid, a_out] = forward_prop(theta_out, theta_mid, X);
delta_out = (t - a_out) .* a_out .* (1 - a_out);
delta_mid = delta_out * theta_out(4) .* a_mid(:, 4) .* (1 - a_mid(:, 4));
for i=1:4
grad_mid = grad_mid - X(i, :)' * delta_mid(i);
grad_out = grad_out - a_mid(i, :)' * delta_out(i);
end
end
function [grad_out, grad_mid] = get_numerical_grad(theta_out, theta_mid, t, X)
delta = 1e-4;
grad_out = zeros(size(theta_out));
for i = 1:size(theta_out, 1)
for j = 1:size(theta_out, 2)
old_val = theta_out(i, j);
theta_out(i, j) = old_val - delta;
cost_a = get_cost(theta_out, theta_mid, t, X);
theta_out(i, j) = old_val + delta;
cost_b = get_cost(theta_out, theta_mid, t, X);
grad_out(i, j) = (cost_b - cost_a) / (2 * delta);
theta_out(i, j) = old_val;
end
end
grad_mid = zeros(size(theta_mid));
for i = 1:size(theta_mid, 1)
for j = 1:size(theta_mid, 2)
old_val = theta_mid(i, j);
theta_mid(i, j) = old_val - delta;
cost_a = get_cost(theta_out, theta_mid, t, X);
theta_mid(i, j) = old_val + delta;
cost_b = get_cost(theta_out, theta_mid, t, X);
grad_mid(i, j) = (cost_b - cost_a) / (2 * delta);
theta_mid(i, j) = old_val;
end
end
end