From fb8d20b498e0594286540d340d3c8d134d3a3e30 Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Thu, 1 Aug 2019 06:03:44 +0000 Subject: [PATCH] adding GPU automatic mixed precision training support --- gcn/models.py | 23 +++++++++++++++++++++-- gcn/train.py | 2 ++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/gcn/models.py b/gcn/models.py index 68c814b..35a0e9c 100644 --- a/gcn/models.py +++ b/gcn/models.py @@ -1,3 +1,5 @@ +import os + from gcn.layers import * from gcn.metrics import * @@ -93,7 +95,16 @@ def __init__(self, placeholders, input_dim, **kwargs): self.placeholders = placeholders self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) - + + if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1' or \ + ('gpu_auto_mixed_precision' in FLAGS and FLAGS.gpu_auto_mixed_precision): + tf_version_list = tf.__version__.split(".") + if int(tf_version_list[0]) < 2: + if int(tf_version_list[1]) < 14: + raise RuntimeError("TensorFlow>=1.14 is required for automatic precision.") + print("=============Enabling GPU Automatic Mixed Precision=============") + self.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(self.optimizer) + self.build() def _loss(self): @@ -140,7 +151,15 @@ def __init__(self, placeholders, input_dim, **kwargs): self.placeholders = placeholders self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) - + if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1' or \ + ('gpu_auto_mixed_precision' in FLAGS and FLAGS.gpu_auto_mixed_precision): + tf_version_list = tf.__version__.split(".") + if int(tf_version_list[0]) < 2: + if int(tf_version_list[1]) < 14: + raise RuntimeError("TensorFlow>=1.14 is required for automatic precision.") + print("=============Enabling GPU Automatic Mixed Precision=============") + self.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(self.optimizer) + self.build() def _loss(self): diff --git a/gcn/train.py b/gcn/train.py index 3c9bb2d..ab29ca8 100644 --- a/gcn/train.py +++ b/gcn/train.py @@ -24,6 +24,8 @@ flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.') flags.DEFINE_integer('early_stopping', 10, 'Tolerance for early stopping (# of epochs).') flags.DEFINE_integer('max_degree', 3, 'Maximum Chebyshev polynomial degree.') +flags.DEFINE_bool("gpu_auto_mixed_precision", default=False, + help="Enabling GPU automatic mixed precision training.") # Load data adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset)