Skip to content

Commit

Permalink
[benchmark] Support collecting and showing per-op stats (latency for …
Browse files Browse the repository at this point in the history
…now) for tflite models in benchmark tool (#5917)

* support per-op stats for tflite models

* fix

* fix
  • Loading branch information
jinjingforever authored Dec 3, 2021
1 parent 23f6ebd commit 9dd2e8d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
37 changes: 27 additions & 10 deletions e2e/benchmarks/benchmark_util.js
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,12 @@ async function downloadValuesFromTensorContainer(tensorContainer) {
* @param model An instance of tf.GraphModel or tf.LayersModel for profiling
* memory usage in the inference process.
* @param input The input tensor container for model inference.
* @param isTflite Whether a TFLite model is being profiled or not.
*/
async function profileModelInference(model, input) {
const predict = getPredictFnForModel(model, input);
return profileInference(predict);
async function profileModelInference(model, input, isTflite = false) {
const predict = isTflite ? () => tfliteModel.predict(input) :
getPredictFnForModel(model, input);
return profileInference(predict, isTflite);
}

/**
Expand Down Expand Up @@ -369,20 +371,35 @@ async function profileModelInference(model, input) {
* ```
*
* @param predict The predict function to execute for profiling memory usage.
* @param isTflite Whether a TFLite model is being profiled or not.
*/
async function profileInference(predict) {
async function profileInference(predict, isTflite = false) {
if (typeof predict !== 'function') {
throw new Error(
'The first parameter should be a function, while ' +
`a(n) ${typeof predict} is found.`);
}

const kernelInfo = await tf.profile(async () => {
const res = await predict();
await downloadValuesFromTensorContainer(res);
tf.dispose(res);
});

let kernelInfo = {};
if (isTflite) {
await predict();
const profileItems = tfliteModel.getProfilingResults();
kernelInfo.kernels = profileItems.map(item => {
return {
name: item.nodeType,
kernelTimeMs: item.nodeExecMs,
// TODO: Shapes are not supported yet.
inputShapes: [],
outputShapes: [],
};
});
} else {
kernelInfo = await tf.profile(async () => {
const res = await predict();
await downloadValuesFromTensorContainer(res);
tf.dispose(res);
});
}
kernelInfo.kernels =
kernelInfo.kernels.sort((a, b) => b.kernelTimeMs - a.kernelTimeMs);
kernelInfo.aggregatedKernels = aggregateKernelTime(kernelInfo.kernels);
Expand Down
49 changes: 32 additions & 17 deletions e2e/benchmarks/local-benchmark/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
let start = performance.now();
const inputSize = parseFloat(state.inputSize);
if (isTflite()) {
if (tfliteModel) {
tfliteModel.modelRunner.cleanUp();
}
tfliteModel = await benchmark.loadTflite();
await loadTfliteModel();
} else {
model = await benchmark.load(inputSize, state.architecture, state.inputType);
}
Expand Down Expand Up @@ -584,34 +581,41 @@ <h2>TensorFlow.js Model Benchmark</h2>
}

async function profileMemoryAndKernelTime() {
await showMsg('Profile memory and kernels');

// Reload tflite model with profiling enabled.
if (isTflite()) {
await showMsg(null);
const tbody = document.querySelector('#kernels tbody');
const nameSpan = document.createElement('span');
nameSpan.textContent = '[TODO] Per-kernel stats not supported yet for TFLite models';
appendRow(tbody, nameSpan, '');
return;
await loadTfliteModel(true);
// This will make sure the model will be reloaded (with profiling disabled) next time
// when "Run benchmark" is clicked.
//
// Without this, the model with profiling *enabled* will continued to be used to measure
// the total run time when "Run benchmark" is clicked, which is not accurate.
state.isModelChanged = true;
state.isModelLoaded = false;
}

await showMsg('Profile memory and kernels');
const start = performance.now();

let profileInfo;
if (state.benchmark === 'custom') {
const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
try {
profileInfo = await profileModelInference(model, input);
profileInfo = await profileModelInference(model, input, isTflite());
} finally {
tf.dispose(input);
}
} else {
profileInfo = await profileInference(() => predict(model));
profileInfo = await profileInference(() => predict(model), isTflite());
}

const elapsed = performance.now() - start;
await showMsg(null);
appendRow(timeTable, 'Peak memory', printMemory(profileInfo.peakBytes));
appendRow(timeTable, 'Leaked tensors', profileInfo.newTensors);
// TODO: These are not supported by tflite models yet.
if (!isTflite()) {
appendRow(timeTable, 'Peak memory', printMemory(profileInfo.peakBytes));
appendRow(timeTable, 'Leaked tensors', profileInfo.newTensors);
}
appendRow(timeTable, 'Profile time', printTime(elapsed));

if (state.backend === 'webgl' && !queryTimerIsEnabled()) {
Expand Down Expand Up @@ -648,7 +652,10 @@ <h2>TensorFlow.js Model Benchmark</h2>
inputInfo += `input${index}: ${inputShape.length}D[${inputShape}]`;
}
});
appendRow(tbody, nameSpan, kernel.kernelTimeMs.toFixed(2), inputInfo, kernel.outputShapes, kernel.extraInfo);
appendRow(tbody, nameSpan, kernel.kernelTimeMs.toFixed(2),
inputInfo || '-',
kernel.outputShapes.length !== 0 ? kernel.outputShapes : '-',
kernel.extraInfo || '-');
});
} else {
profileInfo.aggregatedKernels.forEach(r => {
Expand Down Expand Up @@ -784,7 +791,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
'https://cdn.jsdelivr.net/npm/@tensorflow-models/pose-detection',
// Load tfjs-tflite from jsdelivr because it correctly sets the
// "cross-origin-resource-policy" header.
'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite/dist/tf-tflite.js',
'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@latest/dist/tf-tflite.js',
'../model_config.js',
'../benchmark_util.js',
'./util.js',
Expand Down Expand Up @@ -972,6 +979,14 @@ <h2>TensorFlow.js Model Benchmark</h2>
return state.backend === 'tflite';
}

async function loadTfliteModel(enableProfiling = false) {
if (tfliteModel) {
tfliteModel.modelRunner.cleanUp();
}
const benchmark = benchmarks[state.benchmark];
tfliteModel = await benchmark.loadTflite(enableProfiling);
}

function updateModelsDropdown(newValues) {
const tfliteBenchmarksHtml = newValues.map(name => `<option value="${name}">${name}</option>`);
modelController.domElement.children[0].innerHTML = tfliteBenchmarksHtml;
Expand Down
8 changes: 4 additions & 4 deletions e2e/benchmarks/model_config.js
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ const benchmarks = {
'https://storage.googleapis.com/learnjs-data/mobilenet_v2_100_fused/model.json';
return tf.loadGraphModel(url);
},
loadTflite: async () => {
loadTflite: async (enableProfiling = false) => {
const url =
'https://tfhub.dev/tensorflow/lite-model/mobilenet_v2_1.0_224/1/metadata/1';
return tflite.loadTFLiteModel(url);
return tflite.loadTFLiteModel(url, {enableProfiling});
},
predictFunc: () => {
const input = tf.randomNormal([1, 224, 224, 3]);
Expand Down Expand Up @@ -403,8 +403,8 @@ const benchmarks = {
load: async () => {
return loadModelByUrlWithState(state.modelUrl, {}, state);
},
loadTflite: async () => {
return tflite.loadTFLiteModel(state.modelUrl);
loadTflite: async (enableProfiling = false) => {
return tflite.loadTFLiteModel(state.modelUrl, {enableProfiling});
},
predictFunc: () => {
return async (model, customInput) => {
Expand Down

0 comments on commit 9dd2e8d

Please sign in to comment.