-
Notifications
You must be signed in to change notification settings - Fork 1
/
sgd.py
32 lines (23 loc) · 955 Bytes
/
sgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Author : Youness Landa
# The SGD w/ momentum algorithm for network optimization
class SGD:
def __init__(self, lr, beta=0.9):
self.beta = beta
self.lr = lr
def optim(self, weights, gradients, velocities=None):
if velocities is None: velocities = [0 for weight in weights]
velocities = self.update_velocities(gradients, velocities)
new_weights = []
for weight, velocity in zip(weights, velocities):
weight += velocity
new_weights.append(weight)
return new_weights, velocities
def update_velocities(self, gradients, velocities):
"""
Updates the velocities of the derivatives of the params.
"""
new_velocities = []
for gradient, velocity in zip(gradients, velocities):
new_velocity = self.beta * velocity - self.lr * gradient
new_velocities.append(new_velocity)
return new_velocities