diff --git a/jax/experimental/jax2tf/examples/tflite/README.md b/jax/experimental/jax2tf/examples/tflite/README.md deleted file mode 100644 index e8ef038ac377..000000000000 --- a/jax/experimental/jax2tf/examples/tflite/README.md +++ /dev/null @@ -1,3 +0,0 @@ -This directory contains examples of using the jax2tf converter to produce -models that can be used with TensorFlow Lite. -Note that this is still highly experimental. diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/README.md b/jax/experimental/jax2tf/examples/tflite/mnist/README.md deleted file mode 100644 index f39bd9c7ea9f..000000000000 --- a/jax/experimental/jax2tf/examples/tflite/mnist/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Image classification on the MNIST dataset for on-device machine learning inference - -This directory contains the code that demonstrates how to: - -1. Train a simple convolutional neural network on the MNIST dataset; and -2. Use jax2tf — that helps stage JAX programs out as TensorFlow graphs — into a - TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model) that can be used - again for conversion into a [TensorFlow Lite (TF Lite)](https://www.tensorflow.org/lite/) - format for on-device machine learning inference. - -This example is based on the [TensorFlow Lite demo](https://developer.android.com/codelabs/digit-classifier-tflite) -of building a handwritten digit classifier app for Android with TensorFlow Lite. -The model training code is based on the [Flax Linen example](https://github.com/google/flax/tree/master/linen_examples/mnist) -for classification on the MNIST dataset. - -## Requirements - -* This example uses [Flax](http://github.com/google/flax) and - [TensorFlow Datasets](https://www.tensorflow.org/datasets). - The code downloads the MNIST dataset from TensorFlow Datasets and preprocesses it - before feeding the data into the neural network as input. - -## Training the model - -To train the model in JAX with the Flax library, convert it into a -TensorFlow SavedModel, and export it into the TF Lite format, -run the following command: - -```shell -python mnist.py -``` - -The training itself should take about 1 minute, assuming the -dataset has already been downloaded. -The dataset is loaded directly into a `data/` directory -under `/tmp/jax2tf/mnist`. - -This example's training loop runs for 10 epochs for -demonstration purposes. After training, the model is converted to -SavedModels and then — TF Lite. -In the end, it is saved as `mnist.tflite` under `/tmp/jax2tf/mnist/`. - -### Convert to TensorFlow function using the jax2tf converter - -You can write a prediction function from the trained model and use -the jax2tf converter to convert it to a TF function as follows: - -```python -def predict(image): - return CNN().apply({'params': optimizer.target}, image) -``` -```python -# Convert your Flax model to TF function. -tf_predict = tf.function( - jax2tf.convert(predict, enable_xla=False), - input_signature=[ - tf.TensorSpec(shape=[1, 28, 28, 1], dtype=tf.float32, name='input') - ], - autograph=False) -``` - -Note the `enable_xla=False` parameter to `jax2tf.convert`. -This is used to instruct the converter to avoid using a few special -TensorFlow ops that are only available with the XLA compiler, and which -are not understood (yet) by the TFLite converter to be used below. - - -Check out [more details about this limitation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), -including to which JAX primitives it applies. - -### Convert the trained model to the TF Lite format - -The [TF Lite converter](https://www.tensorflow.org/lite/convert#python_api_) -provides a number of ways to convert a TensorFlow model to the TF format. -For example, you can use the `from_concrete_functions` API as follows: - -```python -# Convert your TF function to the TF Lite format. -converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_predict.get_concrete_function()], tf_predict) - -converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. -] -tflite_float_model = converter.convert() -``` - -Note that the use of `tf.lite.OpsSet.SELECT_TF_OPS` in supported -ops is required here to run regular TensorFlow ops (as compared to -TensorFlow Lite ops) in the TF Lite runtime. - -### Apply quantization - -The next step is to perform [quantization](https://www.tensorflow.org/lite/performance/post_training_quantization). -Because the converted format is no different than the one converted -from a regular TensorFlow SavedModel, you can apply quantization: - -```python -# Re-convert the model to TF Lite using quantization. -converter.optimizations = [tf.lite.Optimize.DEFAULT] -tflite_quantized_model = converter.convert() -``` - -## Deploy the TF Lite model to your Android app - -To deploy the TF Lite model to your Android app, follow the -instructions in the [Build a handwritten digit classifier app](https://developer.android.com/codelabs/digit-classifier-tflite) -codelab. - -When using a TF Lite model that has been converted with the -support for select TensorFlow ops, the client must also use the TF Lite -runtime that includes the necessary library of TensorFlow ops. -In this example, you can use the standard -TensorFlow [AAR](https://developer.android.com/studio/projects/android-library#aar-contents) file, -by specifying the [nightly dependencies](https://www.tensorflow.org/lite/guide/ops_select#android_aar). - -``` -dependencies { - implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' - // This dependency adds the necessary TF op support. - implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly' -} -``` - -You can also specify `abiFilters` to reduce the size of the TensorFlow op dependencies. - -``` -android { - defaultConfig { - ndk { - abiFilters 'armeabi-v7a', 'arm64-v8a' - } - } -} -``` - - diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py deleted file mode 100644 index 71f2eebee2a4..000000000000 --- a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from absl import app -from absl import flags - -from jax.experimental import jax2tf -from jax.experimental.jax2tf.examples import mnist_lib - -import numpy as np - -import tensorflow as tf -import tensorflow_datasets as tfds # type: ignore[import-not-found] - -_TFLITE_FILE_PATH = flags.DEFINE_string( - 'tflite_file_path', - '/tmp/mnist.tflite', - 'Path where to save the TensorFlow Lite file.', -) -_SERVING_BATCH_SIZE = flags.DEFINE_integer( - 'serving_batch_size', - 4, - 'For what batch size to prepare the serving signature. ', -) -_NUM_EPOCHS = flags.DEFINE_integer( - 'num_epochs', 10, 'For how many epochs to train.' -) - - -# A helper function to evaluate the TF Lite model using "test" dataset. -def evaluate_tflite_model(tflite_model, test_ds): - # Initialize TFLite interpreter using the model. - interpreter = tf.lite.Interpreter(model_content=tflite_model) - interpreter.allocate_tensors() - input_tensor_index = interpreter.get_input_details()[0]['index'] - output = interpreter.tensor(interpreter.get_output_details()[0]['index']) - - # Run predictions on every image in the "test" dataset. - prediction_digits = [] - labels = [] - for image, one_hot_label in test_ds: - interpreter.set_tensor(input_tensor_index, image) - - # Run inference. - interpreter.invoke() - - # Post-processing: for each batch dimension and find the digit with highest - # probability. - digits = np.argmax(output(), axis=1) - prediction_digits.extend(digits) - labels.extend(np.argmax(one_hot_label, axis=1)) - - # Compare prediction results with ground truth labels to calculate accuracy. - accurate_count = 0 - for index in range(len(prediction_digits)): - if prediction_digits[index] == labels[index]: - accurate_count += 1 - accuracy = accurate_count * 1.0 / len(prediction_digits) - return accuracy - - -def main(_): - logging.info('Loading the MNIST TensorFlow dataset') - train_ds = mnist_lib.load_mnist( - tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) - test_ds = mnist_lib.load_mnist( - tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE) - - (flax_predict, flax_params) = mnist_lib.FlaxMNIST.train( - train_ds, test_ds, _NUM_EPOCHS.value - ) - - def predict(image): - return flax_predict(flax_params, image) - - # Convert Flax model to TF function. - tf_predict = tf.function( - jax2tf.convert(predict, enable_xla=False), - input_signature=[ - tf.TensorSpec( - shape=[_SERVING_BATCH_SIZE, 28, 28, 1], - dtype=tf.float32, - name='input') - ], - autograph=False) - - # Convert TF function to TF Lite format. - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_predict.get_concrete_function()], tf_predict) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. - ] - tflite_float_model = converter.convert() - - # Show model size in KBs. - float_model_size = len(tflite_float_model) / 1024 - print('Float model size = %dKBs.' % float_model_size) - - # Re-convert the model to TF Lite using quantization. - converter.optimizations = [tf.lite.Optimize.DEFAULT] - tflite_quantized_model = converter.convert() - - # Show model size in KBs. - quantized_model_size = len(tflite_quantized_model) / 1024 - print('Quantized model size = %dKBs,' % quantized_model_size) - print('which is about %d%% of the float model size.' % - (quantized_model_size * 100 / float_model_size)) - - # Evaluate the TF Lite float model. You'll find that its accuracy is identical - # to the original Flax model because they are essentially the same model - # stored in different format. - float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds) - print('Float model accuracy = %.4f' % float_accuracy) - - # Evalualte the TF Lite quantized model. - # Don't be surprised if you see quantized model accuracy is higher than - # the original float model. It happens sometimes :) - quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds) - print('Quantized model accuracy = %.4f' % quantized_accuracy) - print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy)) - - f = open(_TFLITE_FILE_PATH.value, 'wb') - f.write(tflite_quantized_model) - f.close() - - -if __name__ == '__main__': - app.run(main)