-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiPerceptron.java
77 lines (68 loc) · 2.5 KB
/
MultiPerceptron.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
public class MultiPerceptron {
private int n; // number of inputs
private int m; // number of perceptrons
private Perceptron[] storage; // storage of perceptrons
// Creates a multi-perceptron object with m classes and n inputs.
public MultiPerceptron(int m, int n) {
this.n = n;
this.m = m;
this.storage = new Perceptron[this.m];
for (int i = 0; i < storage.length; i++) {
this.storage[i] = new Perceptron(this.n);
}
}
// Returns the number of classes m.
public int numberOfClasses() {
return this.m;
}
// Returns the number of inputs n (length of the feature vector).
public int numberOfInputs() {
return this.n;
}
// Returns the predicted label (between 0 and m-1) for the given input.
public int predictMulti(double[] x) {
double max = this.storage[0].weightedSum(x);
int index = 0;
for (int i = 0; i < this.storage.length; i++) {
if (max < this.storage[i].weightedSum(x)) {
max = this.storage[i].weightedSum(x);
index = i;
}
}
return index;
}
// Trains this multi-perceptron on the labeled (between 0 and m-1) input.
public void trainMulti(double[] x, int label) {
for (int i = 0; i < this.storage.length; i++) {
if (i == label) this.storage[i].train(x, 1);
else this.storage[i].train(x, -1);
}
}
// Returns a string representation of this MultiPerceptron.
public String toString() {
String res = "(";
for (int i = 0; i < this.storage.length; i++) {
res += this.storage[i] + ", ";
}
return res.substring(0, res.length() - 2) + ")";
}
// Tests this class by directly calling all instance methods.
public static void main(String[] args) {
int m = 2;
int n = 3;
double[] training1 = { 5.0, -4.0, 3.0 };
double[] training2 = { 2.0, 3.0, -2.0 };
double[] training3 = { 4.0, 3.0, 2.0 };
double[] training4 = { -6.0, -5.0, 7.0 };
MultiPerceptron perceptron = new MultiPerceptron(m, n);
StdOut.println(perceptron);
perceptron.trainMulti(training1, 1);
StdOut.println(perceptron);
perceptron.trainMulti(training2, 0);
StdOut.println(perceptron);
perceptron.trainMulti(training3, 1);
StdOut.println(perceptron);
perceptron.trainMulti(training4, 0);
StdOut.println(perceptron);
}
}