|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import random |
| 4 | +import csv |
| 5 | +import Model |
| 6 | +from TrajectoryLoader import TrajectoryLoader |
| 7 | + |
| 8 | + |
| 9 | +# parameters for traning |
| 10 | +learnig_rate = 0.001 |
| 11 | +num_batches = 100000 |
| 12 | +batch_size = 256 |
| 13 | +display_step = 50 |
| 14 | +# parameters for seq2seq model |
| 15 | +n_lstm = 128 |
| 16 | +encoder_length = 20 |
| 17 | +decoder_length = 10 |
| 18 | + |
| 19 | +attention_func1 = 'dot' |
| 20 | +attention_func2 = 'general' |
| 21 | +attention_func3 = 'concat' |
| 22 | + |
| 23 | +# Choose Adam optimizer. |
| 24 | +optimizer = tf.keras.optimizers.Adam(learnig_rate) |
| 25 | + |
| 26 | +# Create and build encoder and decoder. |
| 27 | +encoder_a = Model.Encoder(n_lstm, batch_size) |
| 28 | +decoder_a = Model.DecoderAttention(n_lstm, batch_size, attention_func2) |
| 29 | + |
| 30 | +x = np.zeros((batch_size, 1, 5), dtype=np.float32) |
| 31 | +output = encoder_a(x) |
| 32 | +decoder_a(x, output[1:], output[0]) |
| 33 | +encoder_a.summary() |
| 34 | +decoder_a.summary() |
| 35 | + |
| 36 | +#tensorboard |
| 37 | +summary_writer = tf.summary.create_file_writer('tensorboard') |
| 38 | +tf.summary.trace_on(profiler=True) |
| 39 | +# checkpoint |
| 40 | +checkpoint1 = tf.train.Checkpoint(EncoderAttention = encoder_a) |
| 41 | +manager1 = tf.train.CheckpointManager(checkpoint1, directory = './SaveEncoderAttention', checkpoint_name = 'EncoderAttention.ckpt', max_to_keep = 10) |
| 42 | +checkpoint2 = tf.train.Checkpoint(DecoderAttention = decoder_a) |
| 43 | +manager2 = tf.train.CheckpointManager(checkpoint2, directory = './SaveDecoderAttention', checkpoint_name = 'DecoderAttention.ckpt', max_to_keep = 10) |
| 44 | + |
| 45 | + |
| 46 | +def RunOptimization(source_seq, target_seq_in, target_seq_out, step): |
| 47 | + loss = 0 |
| 48 | + decoder_length = target_seq_out.shape[1] |
| 49 | + with tf.GradientTape() as tape: |
| 50 | + encoder_outputs = encoder_a(source_seq) |
| 51 | + states = encoder_outputs[1:] |
| 52 | + y_sample = 0 |
| 53 | + for t in range(decoder_length): |
| 54 | + ''' |
| 55 | + if t == 0 or random.randint(0,1) == 0: |
| 56 | + decoder_in = tf.expand_dims(target_seq_in[:, t], 1) |
| 57 | + else: |
| 58 | + decoder_in = tf.expand_dims(y_sample, 1) |
| 59 | + ''' |
| 60 | + decoder_in = tf.expand_dims(target_seq_in[:, t], 1) |
| 61 | + logit, de_state_h, de_state_c, _= decoder_a(decoder_in, states, encoder_outputs[0]) |
| 62 | + # TODO scheduled sampling |
| 63 | + y_sample = logit |
| 64 | + states = de_state_h, de_state_c |
| 65 | + # loss function : RSME TODO |
| 66 | + loss_0 = tf.keras.losses.MSE(target_seq_out[:, t, 1:3], logit[:, 1:3]) |
| 67 | + loss += tf.sqrt(loss_0)# TODO |
| 68 | + |
| 69 | + variables = encoder_a.trainable_variables + decoder_a.trainable_variables |
| 70 | + gradients = tape.gradient(loss, variables) |
| 71 | + optimizer.apply_gradients(zip(gradients, variables)) |
| 72 | + |
| 73 | + loss = tf.reduce_mean(loss) |
| 74 | + loss = loss / decoder_length |
| 75 | + with summary_writer.as_default(): |
| 76 | + tf.summary.scalar("loss", loss.numpy(), step = step) |
| 77 | + |
| 78 | + return loss |
| 79 | + |
| 80 | +# Load trajectory data. |
| 81 | +seq2seq_loader = TrajectoryLoader() |
| 82 | +seq2seq_loader.loadTrajectoryData("./DataSet/TrajectoryMillion.csv") |
| 83 | + |
| 84 | + |
| 85 | +for batch_index in range(1, num_batches+1): |
| 86 | + seq_encoder, seq_decoder = seq2seq_loader.getBatchSeq2Seq(batch_size, encoder_length, decoder_length) |
| 87 | + seq_decoder_in = seq_decoder[:, :decoder_length, :] |
| 88 | + seq_decoder_out = seq_decoder[:, 1:decoder_length+1, :] |
| 89 | + loss = RunOptimization(seq_encoder, seq_decoder_in, seq_decoder_out, batch_index) |
| 90 | + |
| 91 | + if batch_index % display_step == 0: |
| 92 | + print("batch %d: loss %f" % (batch_index, loss.numpy())) |
| 93 | + path1 = manager1.save(checkpoint_number = batch_index) |
| 94 | + path2 = manager2.save(checkpoint_number = batch_index) |
| 95 | + |
| 96 | + |
| 97 | +with summary_writer.as_default(): |
| 98 | + tf.summary.trace_export(name = "model_trace", step = 0, profiler_outdir = 'tensorboard') |
| 99 | + |
0 commit comments