diff --git a/musco/tf/compressor/layers/group_conv_2d.py b/musco/tf/compressor/layers/group_conv_2d.py index d3c4d19..c4893c9 100644 --- a/musco/tf/compressor/layers/group_conv_2d.py +++ b/musco/tf/compressor/layers/group_conv_2d.py @@ -5,7 +5,10 @@ from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec +if tf.__version__.startswith('2.'): + from tensorflow.keras.layers import InputSpec +else: + from tensorflow.python.keras.engine import InputSpec from tensorflow.python.keras.utils import conv_utils from tensorflow.python.ops import nn diff --git a/musco/tf/optimizer/trt.py b/musco/tf/optimizer/trt.py index c87147d..5bc0f38 100644 --- a/musco/tf/optimizer/trt.py +++ b/musco/tf/optimizer/trt.py @@ -1,7 +1,10 @@ import os import gc import tensorflow as tf -import tensorflow.contrib.tensorrt as trt +if tf.__version__.startswith('2.'): + import tensorflow.experimental.tensorrt as trt +else: + import tensorflow.contrib.tensorrt as trt from tensorflow.python.framework import graph_io