Skip to content

Commit

Permalink
[tflite] Add profiling related APIs to tfjs-tflite (#5911)
Browse files Browse the repository at this point in the history
* [tflite] Add profiling related APIs to tfjs-tflite

* test

* fix

* address comments
  • Loading branch information
jinjingforever authored Dec 2, 2021
1 parent a2b5ccf commit 23f6ebd
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 17 deletions.
27 changes: 23 additions & 4 deletions tfjs-tflite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,35 @@ if SIMD is not enabled). Without multi-threading support, certain models might
not achieve the best performance. See [here][cross origin setup steps] for the
high-level steps to set up the cross-origin isolation.

Setting the number of threads when calling `loadTFLiteModel` can also help with
the performance. In most cases, the threads count should be the same as the
number of physical cores, which is half of `navigator.hardwareConcurrency` on
many x86-64 processors.
By default, the runtime uses the number of physical cores as the thread count.
You can tune this number by setting the `numThreads` option when loading the
TFLite model:

```js
const tfliteModel = await tflite.loadTFLiteModel(
'path/to/your/my_model.tflite',
{numThreads: navigator.hardwareConcurrency / 2});
```

# Profiling

Profiling can be enabled by setting the `enableProfiling` option to true when
loading the TFLite model:

```js
const tfliteModel = await tflite.loadTFLiteModel(
'path/to/your/my_model.tflite',
{enableProfiling: true});
```

Once it is enabled, the runtime will record per-op latency data when the
`predict` method is called. The profiling results can be retrieved in two ways:

- `tfliteModel.getProfilingResults()`: this method will return an array of
`{nodeType, nodeName, execTimeInMs}`.
- `tfliteModel.getProfilingSummary()`: this method will return a human-readable
profiling result summary that looks like [this][profiling summary].

# Development

## Building
Expand All @@ -130,3 +148,4 @@ $ yarn build-npm
[xnnpack]: https://github.com/google/XNNPACK
[xnnpack doc]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#limitations-and-supported-operators
[cross origin setup steps]: https://github.com/tensorflow/tfjs/tree/master/tfjs-backend-wasm#setting-up-cross-origin-isolation
[profiling summary]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/benchmark/README.md#profiling-model-operators
2 changes: 1 addition & 1 deletion tfjs-tflite/scripts/download-tflite-web-api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set -e
OUTPUT_DIR="$1"

# The default version.
CURRENT_VERSION=0.0.5
CURRENT_VERSION=0.0.6

# Get the version from the second parameter.
# Default to the value in CURRENT_VERSION.
Expand Down
21 changes: 10 additions & 11 deletions tfjs-tflite/src/tflite_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

import {DataType, InferenceModel, ModelPredictConfig, ModelTensorInfo, NamedTensorMap, tensor, Tensor} from '@tensorflow/tfjs-core';

import {getDefaultNumThreads} from './tflite_task_library_client/common';
import * as tfliteWebAPIClient from './tflite_web_api_client';
import {TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner';
import {ProfileItem, TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner';

const TFHUB_SEARCH_PARAM = '?lite-format=tflite';

Expand Down Expand Up @@ -178,6 +177,14 @@ export class TFLiteModel implements InferenceModel {
throw new Error('execute() of TFLiteModel is not supported yet.');
}

getProfilingResults(): ProfileItem[] {
return this.modelRunner.getProfilingResults();
}

getProfilingSummary(): string {
return this.modelRunner.getProfilingSummary();
}

private setModelInputFromTensor(
modelInput: TFLiteWebModelRunnerTensorInfo, tensor: Tensor) {
// String and complex tensors are not supported.
Expand Down Expand Up @@ -317,17 +324,9 @@ export async function loadTFLiteModel(
model = `${model}${TFHUB_SEARCH_PARAM}`;
}

// Process options.
const curOptions: TFLiteWebModelRunnerOptions = {};
if (options && options.numThreads !== undefined) {
curOptions.numThreads = options.numThreads;
} else {
curOptions.numThreads = await getDefaultNumThreads();
}

const tfliteModelRunner =
await tfliteWebAPIClient.tfweb.TFLiteWebModelRunner.create(
model, curOptions);
model, options);
return new TFLiteModel(tfliteModelRunner);
}

Expand Down
10 changes: 9 additions & 1 deletion tfjs-tflite/src/tflite_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import * as tf from '@tensorflow/tfjs-core';
import {DataType, NamedTensorMap} from '@tensorflow/tfjs-core';

import {TFLiteModel} from './tflite_model';
import {TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner';
import {ProfileItem, TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner';

// A mock TFLiteWebModelRunner that doubles the data from input tensors to
// output tensors during inference.
Expand Down Expand Up @@ -61,6 +61,14 @@ class MockModelRunner implements TFLiteWebModelRunner {

cleanUp() {}

getProfilingResults(): ProfileItem[] {
return [];
}

getProfilingSummary(): string {
return '';
}

private getTensorInfos(firstTensorType: TFLiteDataType = 'int32'):
TFLiteWebModelRunnerTensorInfo[] {
const shape0 = [1, 2, 3];
Expand Down
40 changes: 40 additions & 0 deletions tfjs-tflite/src/types/tflite_web_model_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,31 @@ export declare interface TFLiteWebModelRunner extends BaseTaskLibrary {
* @return Whether the inference is successful or not.
*/
infer(): boolean;

/**
* Gets per-node profiling results.
*
* This is only useful when TFLiteWebModelRunnerOptions.enableProfiling is
* set to true.
*/
getProfilingResults(): ProfileItem[];

/**
* Gets the profiling summary.
*
* This is only useful when TFLiteWebModelRunnerOptions.enableProfiling is
* set to true.
*/
getProfilingSummary(): string;
}

export declare interface ProfileItem {
/** The type of the node, e.g. "CONV_2D". */
nodeType: string;
/** The name of the node, e.g. "MobilenetV1/MobilenetV1/Conv2d_0/Relu6". */
nodeName: string;
/** The execution time (in ms) of the node. */
nodeExecMs: number;
}

/** Options for TFLiteWebModelRunner. */
Expand All @@ -61,6 +86,21 @@ export declare interface TFLiteWebModelRunnerOptions {
* not supported by user's browser.
*/
numThreads?: number;
/**
* Whether to enable profiling.
*
* Default to false. After it is enabled, the profiling results can be
* retrieved by calling TFLiteWebModelRunner.getProfilingResults or
* TFLiteWebModelRunner.getProfilingSummary. See their comments for more
* details.
*/
enableProfiling?: boolean;
/**
* Maximum nmber of entries that the profiler can keep.
*
* Default to 1024.
*/
maxProfilingBufferEntries?: number;
}

/** Types of TFLite tensor data. */
Expand Down

0 comments on commit 23f6ebd

Please sign in to comment.