This project demonstrates how to use TensorFlow Mobile on Android for handwritten digits classification from MNIST.
Prebuilt APK can be downloaded from here.
If you are interested in a TensorFlow Lite version, please refer to tflite-mnist-android.
- Python 3.6, TensorFlow 1.8.0
- Android Studio 3.0, Gradle 4.1
The model is defined in mnist.py, run the following command to train the model.
python train.py --model_dir ./saved_model --iterations 10000
train.py uses a simple convontional neural network. train_bn.py provides a bigger network with batch normalization, which hopefully would achieve 99.5% accuracy on validation set within 10000 iterations as shown below.
After training, a collection of checkpoint files and a frozen GraphDef file mnist.pb
will be generated in ./saved_model
.
You can test the model on test set using the command below.
python test.py --model_dir ./saved_model
A pre-trained model can be downloaded from here.
TensorFlow provides optimize_for_inference.py to optimize the model by removing parts of a graph that are only needed for training.
Navigate to the TensorFlow repository directory, run the following command to optimize the model.
python tensorflow/python/tools/optimize_for_inference.py \
--input=model_path/mnist.pb \
--output=output_path/mnist_optimized.pb \
--input_names=x \
--output_names=output
The input
argument should point to the TensorFlow GraphDef file (mnist.pb
) trained in Step 1. The output
argument specifies the location for the optimized model.
Notice that the mnist.pb
generated by train.py is already frozen, otherwise we will have to freeze the graph first by using freeze_graph.py before optimization.
A optimized model file can be downloaded from here.
Copy the mnist_optimized.pb
generated in Step 2 to /android/app/src/main/assets
, then build and run the app.
The Classifer creates a TensorFlowInferenceInterface from mnist_optimized.pb
. The TensorFlowInferenceInterface provides an interface for inference and performance summarization, which is included in the following library.
implementation "org.tensorflow:tensorflow-android:1.8.0"
- The basic model architecture comes from tensorflow-mnist-tutorial.
- The official TensorFlow Mobile Android demo.
- The FingerPaint from Android API demo.