-
Notifications
You must be signed in to change notification settings - Fork 0
/
Perceptron.java
76 lines (68 loc) · 2.27 KB
/
Perceptron.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
public class Perceptron {
private int n; // number of inputs
private double[] m; // storage of weights
// Creates a perceptron with n inputs.
public Perceptron(int n) {
this.n = n;
this.m = new double[this.n];
}
// Returns the number of inputs n.
public int numberOfInputs() {
return this.n;
}
// Returns the weighted sum of the weight vector and x.
public double weightedSum(double[] x) {
double sum = 0.0;
for (int i = 0; i < m.length; i++) {
sum += this.m[i] * x[i];
}
return sum;
}
// Predict the label (+1 or -1) of input x.
public int predict(double[] x) {
if (weightedSum(x) > 0) return 1;
else return -1;
}
// Trains this perceptron on the labeled (+1 or -1) input x.
public void train(double[] x, int label) {
int predicted = predict(x);
if (predicted != label) {
if (predicted == 1) {
for (int i = 0; i < this.m.length; i++) {
this.m[i] -= x[i];
}
}
else if (predicted == -1) {
for (int i = 0; i < this.m.length; i++) {
this.m[i] += x[i];
}
}
}
}
// Returns a string representation of this perceptron.
public String toString() {
String res = "(";
for (int i = 0; i < this.m.length; i++) {
res += this.m[i] + ", ";
}
return res.substring(0, res.length() - 2) + ")";
}
// Tests this class by directly calling all instance methods.
public static void main(String[] args) {
double[] training1 = { 5.0, -4.0, 3.0 }; // yes
double[] training2 = { 2.0, 3.0, -2.0 }; // no
double[] training3 = { 4.0, 3.0, 2.0 }; // yes
double[] training4 = { -6.0, -5.0, 7.0 }; // no
int n = 3;
Perceptron perceptron = new Perceptron(n);
StdOut.println(perceptron);
perceptron.train(training1, +1);
StdOut.println(perceptron);
perceptron.train(training2, -1);
StdOut.println(perceptron);
perceptron.train(training3, +1);
StdOut.println(perceptron);
perceptron.train(training4, -1);
StdOut.println(perceptron);
}
}