|
2 | 2 | import tensorflow as tf
|
3 | 3 | tf.set_random_seed(777) # for reproducibility
|
4 | 4 |
|
5 |
| -x_train = [1, 2, 3] |
6 |
| -y_train = [1, 2, 3] |
7 |
| - |
8 | 5 | # Try to find values for W and b to compute y_data = W * x_data + b
|
9 | 6 | # We know that W should be 1 and b should be 0
|
10 | 7 | # But let's use TensorFlow to figure it out
|
|
14 | 11 | # Now we can use X and Y in place of x_data and y_data
|
15 | 12 | # # placeholders for a tensor that will be always fed using feed_dict
|
16 | 13 | # See http://stackoverflow.com/questions/36693740/
|
17 |
| -X = tf.placeholder(tf.float32) |
18 |
| -Y = tf.placeholder(tf.float32) |
| 14 | +X = tf.placeholder(tf.float32, shape=[None]) |
| 15 | +Y = tf.placeholder(tf.float32, shape=[None]) |
19 | 16 |
|
20 | 17 | # Our hypothesis XW+b
|
21 | 18 | hypothesis = X * W + b
|
|
34 | 31 |
|
35 | 32 | # Fit the line
|
36 | 33 | for step in range(2001):
|
37 |
| - sess.run(train, feed_dict={X: x_train, Y: y_train}) |
| 34 | + cost_val, W_val, b_val, _ = \ |
| 35 | + sess.run([cost, W, b, train], |
| 36 | + feed_dict={X: [1, 2, 3], Y: [1, 2, 3]}) |
38 | 37 | if step % 20 == 0:
|
39 |
| - print(step, sess.run(cost, feed_dict={ |
40 |
| - X: x_train, Y: y_train}), sess.run(W), sess.run(b)) |
| 38 | + print(step, cost_val, W_val, b_val) |
41 | 39 |
|
42 | 40 | # Learns best fit W:[ 1.], b:[ 0]
|
43 | 41 | '''
|
|
47 | 45 | '''
|
48 | 46 |
|
49 | 47 | # Testing our model
|
50 |
| -print(sess.run(hypothesis, feed_dict={X: 5})) |
51 |
| -print(sess.run(hypothesis, feed_dict={X: 2.5})) |
| 48 | +print(sess.run(hypothesis, feed_dict={X: [5]})) |
| 49 | +print(sess.run(hypothesis, feed_dict={X: [2.5]})) |
| 50 | +print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]})) |
52 | 51 |
|
53 | 52 | '''
|
54 | 53 | [ 5.0110054]
|
55 | 54 | [ 2.50091505]
|
| 55 | +[ 1.49687922 3.50495124] |
| 56 | +''' |
| 57 | + |
| 58 | + |
| 59 | +# Fit the line with new training data |
| 60 | +for step in range(2001): |
| 61 | + cost_val, W_val, b_val, _ = \ |
| 62 | + sess.run([cost, W, b, train], |
| 63 | + feed_dict={X: [1, 2, 3, 4, 5], |
| 64 | + Y: [2.1, 3.1, 4.1, 5.1, 6.1]}) |
| 65 | + if step % 20 == 0: |
| 66 | + print(step, cost_val, W_val, b_val) |
| 67 | + |
| 68 | +# Testing our model |
| 69 | +print(sess.run(hypothesis, feed_dict={X: [5]})) |
| 70 | +print(sess.run(hypothesis, feed_dict={X: [2.5]})) |
| 71 | +print(sess.run(hypothesis, feed_dict={X: [1.5, 3.5]})) |
| 72 | + |
| 73 | +''' |
| 74 | +1960 3.32396e-07 [ 1.00037301] [ 1.09865296] |
| 75 | +1980 2.90429e-07 [ 1.00034881] [ 1.09874094] |
| 76 | +2000 2.5373e-07 [ 1.00032604] [ 1.09882331] |
| 77 | +[ 6.10045338] |
| 78 | +[ 3.59963846] |
| 79 | +[ 2.59931231 4.59996414] |
56 | 80 | '''
|
0 commit comments