Skip to content

Commit f293c86

Browse files
committed
more interesting ecamples
1 parent 97a2b03 commit f293c86

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

lab-02-2-linear_regression_feed.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import tensorflow as tf
33
tf.set_random_seed(777) # for reproducibility
44

5-
x_train = [1, 2, 3]
6-
y_train = [1, 2, 3]
7-
85
# Try to find values for W and b to compute y_data = W * x_data + b
96
# We know that W should be 1 and b should be 0
107
# But let's use TensorFlow to figure it out
@@ -14,8 +11,8 @@
1411
# Now we can use X and Y in place of x_data and y_data
1512
# # placeholders for a tensor that will be always fed using feed_dict
1613
# See http://stackoverflow.com/questions/36693740/
17-
X = tf.placeholder(tf.float32)
18-
Y = tf.placeholder(tf.float32)
14+
X = tf.placeholder(tf.float32, shape=[None])
15+
Y = tf.placeholder(tf.float32, shape=[None])
1916

2017
# Our hypothesis XW+b
2118
hypothesis = X * W + b
@@ -34,10 +31,11 @@
3431

3532
# Fit the line
3633
for step in range(2001):
37-
sess.run(train, feed_dict={X: x_train, Y: y_train})
34+
cost_val, W_val, b_val, _ = \
35+
sess.run([cost, W, b, train],
36+
feed_dict={X: [1, 2, 3], Y: [1, 2, 3]})
3837
if step % 20 == 0:
39-
print(step, sess.run(cost, feed_dict={
40-
X: x_train, Y: y_train}), sess.run(W), sess.run(b))
38+
print(step, cost_val, W_val, b_val)
4139

4240
# Learns best fit W:[ 1.], b:[ 0]
4341
'''
@@ -47,10 +45,36 @@
4745
'''
4846

4947
# Testing our model
50-
print(sess.run(hypothesis, feed_dict={X: 5}))
51-
print(sess.run(hypothesis, feed_dict={X: 2.5}))
48+
print(sess.run(hypothesis, feed_dict={X: [5]}))
49+
print(sess.run(hypothesis, feed_dict={X: [2.5]}))
50+
print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]}))
5251

5352
'''
5453
[ 5.0110054]
5554
[ 2.50091505]
55+
[ 1.49687922 3.50495124]
56+
'''
57+
58+
59+
# Fit the line with new training data
60+
for step in range(2001):
61+
cost_val, W_val, b_val, _ = \
62+
sess.run([cost, W, b, train],
63+
feed_dict={X: [1, 2, 3, 4, 5],
64+
Y: [2.1, 3.1, 4.1, 5.1, 6.1]})
65+
if step % 20 == 0:
66+
print(step, cost_val, W_val, b_val)
67+
68+
# Testing our model
69+
print(sess.run(hypothesis, feed_dict={X: [5]}))
70+
print(sess.run(hypothesis, feed_dict={X: [2.5]}))
71+
print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]}))
72+
73+
'''
74+
1960 3.32396e-07 [ 1.00037301] [ 1.09865296]
75+
1980 2.90429e-07 [ 1.00034881] [ 1.09874094]
76+
2000 2.5373e-07 [ 1.00032604] [ 1.09882331]
77+
[ 6.10045338]
78+
[ 3.59963846]
79+
[ 2.59931231 4.59996414]
5680
'''

0 commit comments

Comments
 (0)