You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def plot_gradient_descent(theta, eta):
m = len(X_b)
n_epochs = 1000
n_shown = 20
theta_path = []
for epoch in range(n_epochs):
if epoch < n_shown:
y_predict = X_new_b @ theta
color = mpl.colors.rgb2hex(plt.cm.OrRd(epoch / n_shown + 0.15))
plt.plot(X_new, y_predict, linestyle="solid", color=color)
gradients = 2 / m * X_b.T @ (X_b @ theta - y)
theta = theta - eta * gradients
theta_path.append(theta)
#PUT THIS LINE AFTER THE LOOP TO GET BETTER Figure 4-8. THE DATASET IS NOW MORE VISIBLE
plt.plot(X, y, "b.")
plt.xlabel("$x_1$")
plt.axis([0, 2, 0, 15])
plt.grid()
plt.title(fr"$\eta = {eta}$")
return theta_path
The text was updated successfully, but these errors were encountered:
def plot_gradient_descent(theta, eta):
m = len(X_b)
The text was updated successfully, but these errors were encountered: