diff --git a/tfjs-tflite/README.md b/tfjs-tflite/README.md index e3ee1638cca..391c00b2806 100644 --- a/tfjs-tflite/README.md +++ b/tfjs-tflite/README.md @@ -93,10 +93,9 @@ if SIMD is not enabled). Without multi-threading support, certain models might not achieve the best performance. See [here][cross origin setup steps] for the high-level steps to set up the cross-origin isolation. -Setting the number of threads when calling `loadTFLiteModel` can also help with -the performance. In most cases, the threads count should be the same as the -number of physical cores, which is half of `navigator.hardwareConcurrency` on -many x86-64 processors. +By default, the runtime uses the number of physical cores as the thread count. +You can tune this number by setting the `numThreads` option when loading the +TFLite model: ```js const tfliteModel = await tflite.loadTFLiteModel( @@ -104,6 +103,25 @@ const tfliteModel = await tflite.loadTFLiteModel( {numThreads: navigator.hardwareConcurrency / 2}); ``` +# Profiling + +Profiling can be enabled by setting the `enableProfiling` option to true when +loading the TFLite model: + +```js +const tfliteModel = await tflite.loadTFLiteModel( + 'path/to/your/my_model.tflite', + {enableProfiling: true}); +``` + +Once it is enabled, the runtime will record per-op latency data when the +`predict` method is called. The profiling results can be retrieved in two ways: + +- `tfliteModel.getProfilingResults()`: this method will return an array of + `{nodeType, nodeName, execTimeInMs}`. +- `tfliteModel.getProfilingSummary()`: this method will return a human-readable + profiling result summary that looks like [this][profiling summary]. + # Development ## Building @@ -130,3 +148,4 @@ $ yarn build-npm [xnnpack]: https://github.com/google/XNNPACK [xnnpack doc]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#limitations-and-supported-operators [cross origin setup steps]: https://github.com/tensorflow/tfjs/tree/master/tfjs-backend-wasm#setting-up-cross-origin-isolation +[profiling summary]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/benchmark/README.md#profiling-model-operators diff --git a/tfjs-tflite/scripts/download-tflite-web-api.sh b/tfjs-tflite/scripts/download-tflite-web-api.sh index 53609852375..4a2fadd2237 100755 --- a/tfjs-tflite/scripts/download-tflite-web-api.sh +++ b/tfjs-tflite/scripts/download-tflite-web-api.sh @@ -22,7 +22,7 @@ set -e OUTPUT_DIR="$1" # The default version. -CURRENT_VERSION=0.0.5 +CURRENT_VERSION=0.0.6 # Get the version from the second parameter. # Default to the value in CURRENT_VERSION. diff --git a/tfjs-tflite/src/tflite_model.ts b/tfjs-tflite/src/tflite_model.ts index 6dadd4649f4..f45379babdc 100644 --- a/tfjs-tflite/src/tflite_model.ts +++ b/tfjs-tflite/src/tflite_model.ts @@ -17,9 +17,8 @@ import {DataType, InferenceModel, ModelPredictConfig, ModelTensorInfo, NamedTensorMap, tensor, Tensor} from '@tensorflow/tfjs-core'; -import {getDefaultNumThreads} from './tflite_task_library_client/common'; import * as tfliteWebAPIClient from './tflite_web_api_client'; -import {TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner'; +import {ProfileItem, TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner'; const TFHUB_SEARCH_PARAM = '?lite-format=tflite'; @@ -178,6 +177,14 @@ export class TFLiteModel implements InferenceModel { throw new Error('execute() of TFLiteModel is not supported yet.'); } + getProfilingResults(): ProfileItem[] { + return this.modelRunner.getProfilingResults(); + } + + getProfilingSummary(): string { + return this.modelRunner.getProfilingSummary(); + } + private setModelInputFromTensor( modelInput: TFLiteWebModelRunnerTensorInfo, tensor: Tensor) { // String and complex tensors are not supported. @@ -317,17 +324,9 @@ export async function loadTFLiteModel( model = `${model}${TFHUB_SEARCH_PARAM}`; } - // Process options. - const curOptions: TFLiteWebModelRunnerOptions = {}; - if (options && options.numThreads !== undefined) { - curOptions.numThreads = options.numThreads; - } else { - curOptions.numThreads = await getDefaultNumThreads(); - } - const tfliteModelRunner = await tfliteWebAPIClient.tfweb.TFLiteWebModelRunner.create( - model, curOptions); + model, options); return new TFLiteModel(tfliteModelRunner); } diff --git a/tfjs-tflite/src/tflite_model_test.ts b/tfjs-tflite/src/tflite_model_test.ts index 30c9e3fd43f..e7d566013d0 100644 --- a/tfjs-tflite/src/tflite_model_test.ts +++ b/tfjs-tflite/src/tflite_model_test.ts @@ -19,7 +19,7 @@ import * as tf from '@tensorflow/tfjs-core'; import {DataType, NamedTensorMap} from '@tensorflow/tfjs-core'; import {TFLiteModel} from './tflite_model'; -import {TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner'; +import {ProfileItem, TFLiteDataType, TFLiteWebModelRunner, TFLiteWebModelRunnerOptions, TFLiteWebModelRunnerTensorInfo} from './types/tflite_web_model_runner'; // A mock TFLiteWebModelRunner that doubles the data from input tensors to // output tensors during inference. @@ -61,6 +61,14 @@ class MockModelRunner implements TFLiteWebModelRunner { cleanUp() {} + getProfilingResults(): ProfileItem[] { + return []; + } + + getProfilingSummary(): string { + return ''; + } + private getTensorInfos(firstTensorType: TFLiteDataType = 'int32'): TFLiteWebModelRunnerTensorInfo[] { const shape0 = [1, 2, 3]; diff --git a/tfjs-tflite/src/types/tflite_web_model_runner.ts b/tfjs-tflite/src/types/tflite_web_model_runner.ts index 820348b5f51..8060973d7f3 100644 --- a/tfjs-tflite/src/types/tflite_web_model_runner.ts +++ b/tfjs-tflite/src/types/tflite_web_model_runner.ts @@ -50,6 +50,31 @@ export declare interface TFLiteWebModelRunner extends BaseTaskLibrary { * @return Whether the inference is successful or not. */ infer(): boolean; + + /** + * Gets per-node profiling results. + * + * This is only useful when TFLiteWebModelRunnerOptions.enableProfiling is + * set to true. + */ + getProfilingResults(): ProfileItem[]; + + /** + * Gets the profiling summary. + * + * This is only useful when TFLiteWebModelRunnerOptions.enableProfiling is + * set to true. + */ + getProfilingSummary(): string; +} + +export declare interface ProfileItem { + /** The type of the node, e.g. "CONV_2D". */ + nodeType: string; + /** The name of the node, e.g. "MobilenetV1/MobilenetV1/Conv2d_0/Relu6". */ + nodeName: string; + /** The execution time (in ms) of the node. */ + nodeExecMs: number; } /** Options for TFLiteWebModelRunner. */ @@ -61,6 +86,21 @@ export declare interface TFLiteWebModelRunnerOptions { * not supported by user's browser. */ numThreads?: number; + /** + * Whether to enable profiling. + * + * Default to false. After it is enabled, the profiling results can be + * retrieved by calling TFLiteWebModelRunner.getProfilingResults or + * TFLiteWebModelRunner.getProfilingSummary. See their comments for more + * details. + */ + enableProfiling?: boolean; + /** + * Maximum nmber of entries that the profiler can keep. + * + * Default to 1024. + */ + maxProfilingBufferEntries?: number; } /** Types of TFLite tensor data. */