diff --git a/models_generation/distilbert.py b/models_generation/distilbert.py index e761dc5..5a70841 100644 --- a/models_generation/distilbert.py +++ b/models_generation/distilbert.py @@ -6,6 +6,14 @@ input_spec = tf.TensorSpec([1, 384], tf.int32) model._set_inputs(input_spec, training=False) + +# For tensorflow>2.2.0, set inputs in the following way. +# Otherwise, the model.inputs and model.outputs will be None. +# keras_input = tf.keras.Input([384], batch_size=1, dtype=tf.int32) +# keras_output = model(keras_input, training=False) +# model = tf.keras.Model(keras_input, keras_output) + + print(model.inputs) print(model.outputs) diff --git a/models_generation/gpt2.py b/models_generation/gpt2.py index 60d267c..6d2d2e4 100644 --- a/models_generation/gpt2.py +++ b/models_generation/gpt2.py @@ -6,6 +6,12 @@ input_spec = tf.TensorSpec([1, 64], tf.int32) model._set_inputs(input_spec, training=False) +# For tensorflow>2.2.0, set inputs in the following way. +# Otherwise, the model.inputs and model.outputs will be None. +# keras_input = tf.keras.Input([64], batch_size=1, dtype=tf.int32) +# keras_output = model(keras_input, training=False) +# model = tf.keras.Model(keras_input, keras_output) + print(model.inputs) print(model.outputs)