Skip to content

Commit

Permalink
Add samples for ONNX mobilenetv2-7 fp16 and fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Nov 3, 2022
1 parent d3a1463 commit 7d41ac7
Show file tree
Hide file tree
Showing 218 changed files with 295 additions and 22 deletions.
119 changes: 113 additions & 6 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ export async function buildConstantByNpy(builder, url) {
const typedArray = new TypedArrayConstructor(sizeOfShape(dimensions));
const dataView = new DataView(npArray.data.buffer);
const littleEndian = npArray.byteOrder === '<';
let getFuncName = `get` + type[0].toUpperCase() + type.substr(1);
if (type == 'float16') {
getFuncName = `getUint16`;
}
for (let i = 0; i < sizeOfShape(dimensions); ++i) {
typedArray[i] = dataView[`get` + type[0].toUpperCase() + type.substr(1)](
typedArray[i] = dataView[getFuncName](
i * TypedArrayConstructor.BYTES_PER_ELEMENT, littleEndian);
}
return builder.constant({type, dimensions}, typedArray);
Expand All @@ -70,6 +74,102 @@ export function getVideoFrame(videoElement) {
return canvasElement;
}

// ref: http://stackoverflow.com/questions/32633585/how-do-you-convert-to-half-floats-in-javascript
const toHalf = (function() {

var floatView = new Float32Array(1);
var int32View = new Int32Array(floatView.buffer);

/* This method is faster than the OpenEXR implementation (very often
* used, eg. in Ogre), with the additional benefit of rounding, inspired
* by James Tursa?s half-precision code. */
return function toHalf(val) {

floatView[0] = val;
var x = int32View[0];

var bits = (x >> 16) & 0x8000; /* Get the sign */
var m = (x >> 12) & 0x07ff; /* Keep one extra bit for rounding */
var e = (x >> 23) & 0xff; /* Using int is faster here */

/* If zero, or denormal, or exponent underflows too much for a denormal
* half, return signed zero. */
if (e < 103) {
return bits;
}

/* If NaN, return NaN. If Inf or exponent overflow, return Inf. */
if (e > 142) {
bits |= 0x7c00;
/* If exponent was 0xff and one mantissa bit was set, it means NaN,
* not Inf, so make sure we set one mantissa bit too. */
bits |= ((e == 255) ? 0 : 1) && (x & 0x007fffff);
return bits;
}

/* If exponent underflows but not too much, return a denormal */
if (e < 113) {
m |= 0x0800;
/* Extra rounding may overflow and set mantissa to 0 and exponent
* to 1, which is OK. */
bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
return bits;
}

bits |= ((e - 112) << 10) | (m >> 1);
/* Extra rounding. An overflow will set mantissa to 0 and increment
* the exponent, which is OK. */
bits += m & 1;
return bits;
};

})();

// This function converts a Float16 stored as the bits of a Uint16 into a Javascript Number.
// Adapted from: https://gist.github.com/martinkallman/5049614
// input is a Uint16 (eg, new Uint16Array([value])[0])

export function float16ToNumber(input) {
// Create a 32 bit DataView to store the input
const arr = new ArrayBuffer(4);
const dv = new DataView(arr);

// Set the Float16 into the last 16 bits of the dataview
// So our dataView is [00xx]
dv.setUint16(2, input, false);

// Get all 32 bits as a 32 bit integer
// (JS bitwise operations are performed on 32 bit signed integers)
const asInt32 = dv.getInt32(0, false);

// All bits aside from the sign
let rest = asInt32 & 0x7fff;
// Sign bit
let sign = asInt32 & 0x8000;
// Exponent bits
const exponent = asInt32 & 0x7c00;

// Shift the non-sign bits into place for a 32 bit Float
rest <<= 13;
// Shift the sign bit into place for a 32 bit Float
sign <<= 16;

// Adjust bias
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Exponent_encoding
rest += 0x38000000;
// Denormals-as-zero
rest = (exponent === 0 ? 0 : rest);
// Re-insert sign bit
rest |= sign;

// Set the adjusted float32 (stored as int32) back into the dataview
dv.setInt32(0, rest, false);

// Get it back out as a float32 (which js will convert to a Number)
const asFloat32 = dv.getFloat32(0, false);

return asFloat32;
}
/**
* This method is used to covert input element to tensor data.
* @param {Object} inputElement, an object of HTML [<img> | <video>] element.
Expand Down Expand Up @@ -105,9 +205,16 @@ export function getVideoFrame(videoElement) {
* @return {Object} tensor, an object of input tensor.
*/
export function getInputTensor(inputElement, inputOptions) {
const dataType = inputOptions.dataType || 'float32';
const inputDimensions = inputOptions.inputDimensions;
const tensor = new Float32Array(
let tensor;
if (dataType === 'float16') {
tensor = new Uint16Array(
inputDimensions.slice(1).reduce((a, b) => a * b));
} else {
tensor = new Float32Array(
inputDimensions.slice(1).reduce((a, b) => a * b));
}

inputElement.width = inputElement.videoWidth ||
inputElement.naturalWidth;
Expand Down Expand Up @@ -164,11 +271,11 @@ export function getInputTensor(inputElement, inputOptions) {
value = pixels[h * width * imageChannels + w * imageChannels + c];
}
if (inputLayout === 'nchw') {
tensor[c * width * height + h * width + w] =
(value - mean[c]) / std[c];
tensor[c * width * height + h * width + w] = dataType === 'float16' ?
toHalf((value - mean[c]) / std[c]) : (value - mean[c]) / std[c];
} else {
tensor[h * width * channels + w * channels + c] =
(value - mean[c]) / std[c];
tensor[h * width * channels + w * channels + c] = dataType === 'float16' ?
toHalf((value - mean[c]) / std[c]) : (value - mean[c]) / std[c];
}
}
}
Expand Down
20 changes: 13 additions & 7 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<label class="btn btn-outline-info custom" name="polyfill">
<!-- <label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
</label>
</label> -->
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off" checked>WebNN (GPU)
</label>
</div>
</div>
Expand All @@ -55,9 +55,9 @@
<label class="btn btn-outline-info active" id='nchw-label'>
<input type="radio" name="layout" id="nchw" autocomplete="off" checked>NCHW
</label>
<label class="btn btn-outline-info btn-sm">
<!-- <label class="btn btn-outline-info btn-sm">
<input type="radio" name="layout" id="nhwc" autocomplete="off">NHWC
</label>
</label> -->
</div>
</div>
</div>
Expand All @@ -67,15 +67,21 @@
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="modelBtns">
<label class="btn btn-outline-info active">
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenetv27fp32" autocomplete="off">MobileNet V2 7 FP32
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenetv27fp16" autocomplete="off">MobileNet V2 7 FP16
</label>
<!-- <label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50" autocomplete="off">ResNet V2 50
</label>
</label> -->
</div>
</div>
</div>
Expand Down
30 changes: 21 additions & 9 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {ResNet50V2Nchw} from './resnet50v2_nchw.js';
import {ResNet50V2Nhwc} from './resnet50v2_nhwc.js';
import * as ui from '../common/ui.js';
import * as utils from '../common/utils.js';
import { MobileNetV27Nchw } from './mobilenetv2_7_nchw.js';

const maxWidth = 380;
const maxHeight = 380;
Expand Down Expand Up @@ -43,7 +44,7 @@ async function fetchLabels(url) {
$(document).ready(() => {
$('.icdisplay').hide();
if (utils.isWebNN()) {
$('#webnn_cpu').click();
$('#webnn_gpu').click();
} else {
$('#polyfill_gpu').click();
}
Expand Down Expand Up @@ -123,7 +124,7 @@ async function renderCamStream() {
const inputCanvas = utils.getVideoFrame(camElement);
console.log('- Computing... ');
const start = performance.now();
netInstance.compute(inputBuffer, outputBuffer);
await netInstance.compute(inputBuffer, outputBuffer);
computeTime = (performance.now() - start).toFixed(2);
console.log(` done in ${computeTime} ms.`);
drawInput(inputCanvas, 'camInCanvas');
Expand All @@ -147,7 +148,10 @@ function getTopClasses(buffer, labels) {
const classes = [];

for (let i = 0; i < 3; ++i) {
const prob = sorted[i][0];
let prob = sorted[i][0];
if (inputOptions.dataType === 'float16') {
prob = utils.float16ToNumber(prob);
}
const index = sorted[i][1];
const c = {
label: labels[index],
Expand Down Expand Up @@ -197,6 +201,8 @@ function showPerfResult(medianComputeTime = undefined) {

function constructNetObject(type) {
const netObject = {
'mobilenetv27fp32nchw': new MobileNetV27Nchw('float32'),
'mobilenetv27fp16nchw': new MobileNetV27Nchw('float16'),
'mobilenetnchw': new MobileNetV2Nchw(),
'mobilenetnhwc': new MobileNetV2Nhwc(),
'squeezenetnchw': new SqueezeNetNchw(),
Expand Down Expand Up @@ -237,14 +243,20 @@ async function main() {
netInstance = constructNetObject(instanceType);
inputOptions = netInstance.inputOptions;
labels = await fetchLabels(inputOptions.labelUrl);
outputBuffer =
new Float32Array(utils.sizeOfShape(netInstance.outputDimensions));
if (inputOptions.dataType === 'float16') {
outputBuffer =
new Uint16Array(utils.sizeOfShape(netInstance.outputDimensions));
} else {
outputBuffer =
new Float32Array(utils.sizeOfShape(netInstance.outputDimensions));
}

isFirstTimeLoad = false;
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
// UI shows model loading progress
await ui.showProgressComponent('current', 'pending', 'pending');
console.log('- Loading weights... ');
const contextOptions = {devicePreference};
const contextOptions = {type: 'webnn', devicePreference: devicePreference};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
}
Expand All @@ -256,7 +268,7 @@ async function main() {
await ui.showProgressComponent('done', 'current', 'pending');
console.log('- Building... ');
start = performance.now();
netInstance.build(outputOperand);
await netInstance.build(outputOperand);
buildTime = (performance.now() - start).toFixed(2);
console.log(` done in ${buildTime} ms.`);
}
Expand All @@ -269,11 +281,11 @@ async function main() {
let medianComputeTime;

// Do warm up
netInstance.compute(inputBuffer, outputBuffer);
await netInstance.compute(inputBuffer, outputBuffer);

for (let i = 0; i < numRuns; i++) {
start = performance.now();
netInstance.compute(inputBuffer, outputBuffer);
await netInstance.compute(inputBuffer, outputBuffer);
computeTime = (performance.now() - start).toFixed(2);
console.log(` compute time ${i+1}: ${computeTime} ms`);
computeTimeArray.push(Number(computeTime));
Expand Down
Loading

0 comments on commit 7d41ac7

Please sign in to comment.