-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathperceptron
29 lines (26 loc) · 927 Bytes
/
perceptron
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Perceptron:
def __init__(self, eta=.1, n_iter=10, model_w=[.0, .0], model_b=.0):
self.eta = eta
self.n_iter = n_iter
self.model_w = model_w
self.model_b = model_b
def predict(self, x):
if np.dot(self.model_w, x) + self.model_b >= 0:
return 1
else:
return -1
def update_weights(self, idx, model_w, model_b):
w = model_w
b = model_b
w += self.eta * y_train[idx] * x_train[idx]
b += self.eta * y_train[idx]
return w, b
def fit(self, x, y):
if len(x) != len(y):
print('error')
return False
for i in range(self.n_iter):
for idx in range(len(x)):
if y[idx] != self.predict(x[idx]):
self.model_w, self.model_b = self.update_weights(idx,
self.model_w, self.model_b)