This document walks through the code of a simple Android mobile application that demonstrates image classification using the device camera.
We're now going to walk through the most important parts of the sample code.
This mobile application gets the camera input using the functions defined in the
file
CameraActivity.java
.
This file depends on
AndroidManifest.xml
to set the camera orientation.
CameraActivity
also contains code to capture user preferences from the UI and
make them available to other classes via convenience methods.
model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase());
device = Device.valueOf(deviceSpinner.getSelectedItem().toString());
numThreads = Integer.parseInt(threadsTextView.getText().toString().trim());
This Image Classification Android reference app demonstrates two implementation
solutions,
lib_task_api
that leverages the out-of-box API from the
TensorFlow Lite Task Library,
and
lib_support
that creates the custom inference pipleline using the
TensorFlow Lite Support Library.
Both solutions implement the file Classifier.java
(see
the one in lib_task_api
and
the one in lib_support)
that contains most of the complex logic for processing the camera input and
running inference.
Two subclasses of the Classifier
exist, as in ClassifierFloatMobileNet.java
and ClassifierQuantizedMobileNet.java
, which contain settings for both
floating point and
quantized
models.
The Classifier
class implements a static method, create
, which is used to
instantiate the appropriate subclass based on the supplied model type (quantized
vs floating point).
Inference can be done using just a few lines of code with the
ImageClassifier
in the TensorFlow Lite Task Library.
ImageClassifier
expects a model populated with the
model metadata and the label
file. See the
model compatibility requirements
for more details.
ImageClassifierOptions
allows manipulation on various inference options, such
as setting the maximum number of top scored results to return using
setMaxResults(MAX_RESULTS)
, and setting the score threshold using
setScoreThreshold(scoreThreshold)
.
// Create the ImageClassifier instance.
ImageClassifierOptions options =
ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build();
imageClassifier = ImageClassifier.createFromFileAndOptions(activity,
getModelPath(), options);
ImageClassifier
currently does not support configuring delegates and
multithread, but those are on our roadmap. Please stay tuned!
ImageClassifier
contains builtin logic to preprocess the input image, such as
rotating and resizing an image. Processing options can be configured through
ImageProcessingOptions
. In the following example, input images are rotated to
the up-right angle and cropped to the center as the model expects a square input
(224x224
). See the
Java doc of ImageClassifier
for more details about how the underlying image processing is performed.
TensorImage inputImage = TensorImage.fromBitmap(bitmap);
int width = bitmap.getWidth();
int height = bitmap.getHeight();
int cropSize = min(width, height);
ImageProcessingOptions imageOptions =
ImageProcessingOptions.builder()
.setOrientation(getOrientation(sensorOrientation))
// Set the ROI to the center of the image.
.setRoi(
new Rect(
/*left=*/ (width - cropSize) / 2,
/*top=*/ (height - cropSize) / 2,
/*right=*/ (width + cropSize) / 2,
/*bottom=*/ (height + cropSize) / 2))
.build();
List<Classifications> results = imageClassifier.classify(inputImage,
imageOptions);
The output of ImageClassifier
is a list of Classifications
instance, where
each Classifications
element is a single head classification result. All the
demo models are single head models, therefore, results
only contains one
Classifications
object. Use Classifications.getCategories()
to get a list of
top-k categories as specified with MAX_RESULTS
. Each Category
object
contains the srting label and the score of that category.
To match the implementation of
lib_support
,
results
is converted into List<Recognition>
in the method,
getRecognitions
.
To perform inference, we need to load a model file and instantiate an
Interpreter
. This happens in the constructor of the Classifier
class, along
with loading the list of class labels. Information about the device type and
number of threads is used to configure the Interpreter
via the
Interpreter.Options
instance passed into its constructor. Note that if a GPU,
DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a
Delegate
can be used
to take full advantage of these hardware.
Please note that there are performance edge cases and developers are adviced to test with a representative set of devices prior to production.
protected Classifier(Activity activity, Device device, int numThreads) throws
IOException {
tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
switch (device) {
case NNAPI:
nnApiDelegate = new NnApiDelegate();
tfliteOptions.addDelegate(nnApiDelegate);
break;
case GPU:
gpuDelegate = new GpuDelegate();
tfliteOptions.addDelegate(gpuDelegate);
break;
case CPU:
break;
}
tfliteOptions.setNumThreads(numThreads);
tflite = new Interpreter(tfliteModel, tfliteOptions);
labels = FileUtil.loadLabels(activity, getLabelPath());
...
For Android devices, we recommend pre-loading and memory mapping the model file
to offer faster load times and reduce the dirty pages in memory. The method
FileUtil.loadMappedFile
does this, returning a MappedByteBuffer
containing
the model.
The MappedByteBuffer
is passed into the Interpreter
constructor, along with
an Interpreter.Options
object. This object can be used to configure the
interpreter, for example by setting the number of threads (.setNumThreads(1)
)
or enabling NNAPI
(.addDelegate(nnApiDelegate)
).
Next in the Classifier
constructor, we take the input camera bitmap image,
convert it to a TensorImage
format for efficient processing and pre-process
it. The steps are shown in the private 'loadImage' method:
/** Loads input image, and applys preprocessing. */
private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) {
// Loads bitmap into a TensorImage.
image.load(bitmap);
// Creates processor for the TensorImage.
int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
int numRoration = sensorOrientation / 90;
ImageProcessor imageProcessor =
new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
.add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR))
.add(new Rot90Op(numRoration))
.add(getPreprocessNormalizeOp())
.build();
return imageProcessor.process(inputImageBuffer);
}
The pre-processing is largely the same for quantized and float models with one exception: Normalization.
In ClassifierFloatMobileNet
, the normalization parameters are defined as:
private static final float IMAGE_MEAN = 127.5f;
private static final float IMAGE_STD = 127.5f;
In ClassifierQuantizedMobileNet
, normalization is not required. Thus the
nomalization parameters are defined as:
private static final float IMAGE_MEAN = 0.0f;
private static final float IMAGE_STD = 1.0f;
Initiate the output TensorBuffer
for the output of the model.
/** Output probability TensorBuffer. */
private final TensorBuffer outputProbabilityBuffer;
//...
// Get the array size for the output buffer from the TensorFlow Lite model file
int probabilityTensorIndex = 0;
int[] probabilityShape =
tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001}
DataType probabilityDataType =
tflite.getOutputTensor(probabilityTensorIndex).dataType();
// Creates the output tensor and its processor.
outputProbabilityBuffer =
TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
// Creates the post processor for the output probability.
probabilityProcessor =
new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build();
For quantized models, we need to de-quantize the prediction with the NormalizeOp (as they are all essentially linear transformation). For float model, de-quantize is not required. But to uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more specific,
In ClassifierQuantizedMobileNet
, the normalized parameters are defined as:
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 255.0f;
In ClassifierFloatMobileNet
, the normalized parameters are defined as:
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 1.0f;
Inference is performed using the following in Classifier
class:
tflite.run(inputImageBuffer.getBuffer(),
outputProbabilityBuffer.getBuffer().rewind());
Rather than call run
directly, the method recognizeImage
is used. It accepts
a bitmap and sensor orientation, runs inference, and returns a sorted List
of
Recognition
instances, each corresponding to a label. The method will return a
number of results bounded by MAX_RESULTS
, which is 3 by default.
Recognition
is a simple class that contains information about a specific
recognition result, including its title
and confidence
. Using the
post-processing normalization method specified, the confidence is converted to
between 0 and 1 of a given class being represented by the image.
/** Gets the label to probability map. */
Map<String, Float> labeledProbability =
new TensorLabel(labels,
probabilityProcessor.process(outputProbabilityBuffer))
.getMapWithFloatValue();
A PriorityQueue
is used for sorting.
/** Gets the top-k results. */
private static List<Recognition> getTopKProbability(
Map<String, Float> labelProb) {
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of
// the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (Map.Entry<String, Float> entry : labelProb.entrySet()) {
pq.add(new Recognition("" + entry.getKey(), entry.getKey(),
entry.getValue(), null));
}
final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
return recognitions;
}
The classifier is invoked and inference results are displayed by the
processImage()
function in
ClassifierActivity.java
.
ClassifierActivity
is a subclass of CameraActivity
that contains method
implementations that render the camera image, run classification, and display
the results. The method processImage()
runs classification on a background
thread as fast as possible, rendering information on the UI thread to avoid
blocking inference and creating latency.
@Override
protected void processImage() {
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth,
previewHeight);
final int imageSizeX = classifier.getImageSizeX();
final int imageSizeY = classifier.getImageSizeY();
runInBackground(
new Runnable() {
@Override
public void run() {
if (classifier != null) {
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results =
classifier.recognizeImage(rgbFrameBitmap, sensorOrientation);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
LOGGER.v("Detect: %s", results);
runOnUiThread(
new Runnable() {
@Override
public void run() {
showResultsInBottomSheet(results);
showFrameInfo(previewWidth + "x" + previewHeight);
showCropInfo(imageSizeX + "x" + imageSizeY);
showCameraResolution(imageSizeX + "x" + imageSizeY);
showRotationInfo(String.valueOf(sensorOrientation));
showInference(lastProcessingTimeMs + "ms");
}
});
}
readyForNextImage();
}
});
}
Another important role of ClassifierActivity
is to determine user preferences
(by interrogating CameraActivity
), and instantiate the appropriately
configured Classifier
subclass. This happens when the video feed begins (via
onPreviewSizeChosen()
) and when options are changed in the UI (via
onInferenceConfigurationChanged()
).
private void recreateClassifier(Model model, Device device, int numThreads) {
if (classifier != null) {
LOGGER.d("Closing classifier.");
classifier.close();
classifier = null;
}
if (device == Device.GPU && model == Model.QUANTIZED) {
LOGGER.d("Not creating classifier: GPU doesn't support quantized models.");
runOnUiThread(
() -> {
Toast.makeText(this, "GPU does not yet supported quantized models.",
Toast.LENGTH_LONG)
.show();
});
return;
}
try {
LOGGER.d(
"Creating classifier (model=%s, device=%s, numThreads=%d)", model,
device, numThreads);
classifier = Classifier.create(this, model, device, numThreads);
} catch (IOException e) {
LOGGER.e(e, "Failed to create classifier.");
}
}