From 4586e4f62a3c6022fb900d7ec6ffc5f2a29792a1 Mon Sep 17 00:00:00 2001 From: mingmingtasd Date: Thu, 22 Feb 2024 15:19:09 +0800 Subject: [PATCH] Enable mobilenetv2_7 model on NPU --- common/utils.js | 2 +- image_classification/index.html | 6 ++++ image_classification/main.js | 42 +++++++++++++++------- image_classification/mobilenetv2_7_nchw.js | 10 +++--- test-data | 2 +- 5 files changed, 41 insertions(+), 21 deletions(-) diff --git a/common/utils.js b/common/utils.js index 8b861a8e..1541e704 100644 --- a/common/utils.js +++ b/common/utils.js @@ -60,7 +60,7 @@ export async function buildConstantByNpy(builder, url) { typedArray[i] = dataView[getFuncName]( i * TypedArrayConstructor.BYTES_PER_ELEMENT, littleEndian); } - return builder.constant({type, dimensions}, typedArray); + return builder.constant({dataType: type, type, dimensions}, typedArray); } // Convert video frame to a canvas element diff --git a/image_classification/index.html b/image_classification/index.html index fc62aeb2..0d099769 100644 --- a/image_classification/index.html +++ b/image_classification/index.html @@ -43,6 +43,9 @@ + @@ -220,6 +223,9 @@

No model selected

+ diff --git a/image_classification/main.js b/image_classification/main.js index 912cdca7..a3efe078 100644 --- a/image_classification/main.js +++ b/image_classification/main.js @@ -29,8 +29,8 @@ let buildTime = 0; let computeTime = 0; let inputOptions; let outputBuffer; -let devicePreference = ''; -let lastDevicePreference = ''; +let deviceType = ''; +let lastdeviceType = ''; let backend = ''; let lastBackend = ''; const disabledSelectors = ['#tabs > li', '.btn']; @@ -137,6 +137,25 @@ async function renderCamStream() { // Get top 3 classes of labels from output buffer function getTopClasses(buffer, labels) { + // Convert output buffer from float16 to float32, because tf.tensor/tf.softmax doesn't + // support float16 data type according to https://js.tensorflow.org/api/latest/#tensor. + if (inputOptions.dataType === 'float16') { + let elements_count = utils.sizeOfShape(netInstance.outputDimensions); + let float32Buffer = new Float32Array(elements_count); + for (let i = 0; i < elements_count; ++i) { + float32Buffer[i] = utils.float16ToNumber(buffer[i]); + } + buffer = float32Buffer; + } + + // Softmax + buffer = tf.tidy(() => { + const a = + tf.tensor(buffer, netInstance.outputDimensions, 'float32'); + const b = tf.softmax(a); + return b.dataSync(); + }); + const probs = Array.from(buffer); const indexes = probs.map((prob, index) => [prob, index]); const sorted = indexes.sort((a, b) => { @@ -150,9 +169,6 @@ function getTopClasses(buffer, labels) { for (let i = 0; i < 3; ++i) { let prob = sorted[i][0]; - if (inputOptions.dataType === 'float16') { - prob = utils.float16ToNumber(prob); - } const index = sorted[i][1]; const c = { label: labels[index], @@ -218,7 +234,7 @@ function constructNetObject(type) { async function main() { try { if (modelName === '') return; - [backend, devicePreference] = + [backend, deviceType] = $('input[name="backend"]:checked').attr('id').split('_'); ui.handleClick(disabledSelectors, true); if (isFirstTimeLoad) $('#hint').hide(); @@ -228,12 +244,12 @@ async function main() { // Only do load() and build() when model first time loads, // there's new model choosed, backend changed or device changed if (isFirstTimeLoad || instanceType !== modelName + layout || - lastDevicePreference != devicePreference || lastBackend != backend) { - if (lastDevicePreference != devicePreference || lastBackend != backend) { + lastdeviceType != deviceType || lastBackend != backend) { + if (lastdeviceType != deviceType || lastBackend != backend) { // Set backend and device - await utils.setBackend(backend, devicePreference); - lastDevicePreference = lastDevicePreference != devicePreference ? - devicePreference : lastDevicePreference; + await utils.setBackend(backend, deviceType); + lastdeviceType = lastdeviceType != deviceType ? + deviceType : lastdeviceType; lastBackend = lastBackend != backend ? backend : lastBackend; } if (netInstance !== null) { @@ -257,7 +273,7 @@ async function main() { // UI shows model loading progress await ui.showProgressComponent('current', 'pending', 'pending'); console.log('- Loading weights... '); - const contextOptions = {type: 'webnn', deviceType: devicePreference}; + const contextOptions = {type: 'webnn', deviceType: deviceType}; if (powerPreference) { contextOptions['powerPreference'] = powerPreference; } @@ -318,4 +334,4 @@ async function main() { ui.addAlert(error.message); } ui.handleClick(disabledSelectors, false); -} +} \ No newline at end of file diff --git a/image_classification/mobilenetv2_7_nchw.js b/image_classification/mobilenetv2_7_nchw.js index cb20a8ed..fdd34cdb 100644 --- a/image_classification/mobilenetv2_7_nchw.js +++ b/image_classification/mobilenetv2_7_nchw.js @@ -80,7 +80,7 @@ export class MobileNetV27Nchw { this.context_ = await navigator.ml.createContext(contextOptions); this.builder_ = new MLGraphBuilder(this.context_); const data = this.builder_.input('input', - {type: this.dataType_, dimensions: this.inputOptions.inputDimensions}); + {dataType: this.dataType_, dimensions: this.inputOptions.inputDimensions}); const conv0 = await this.buildConv_( data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]}); const conv1 = await this.buildConv_( @@ -122,10 +122,8 @@ export class MobileNetV27Nchw { const conv3 = await this.buildConv_(bottleneck15, '95', true); const conv4 = await this.buildConv_(conv3, '97', false, {groups: 1280, strides: [7, 7]}); const conv5 = await this.buildConv_(conv4, '104', false); - const reshape = this.builder_.reshape(conv5, [1, null]); - // return reshape; - // const gemm = await this.buildGemm_(reshape, '104'); - return this.builder_.softmax(reshape); + return this.builder_.reshape(conv5, [1, 1000]); + // return this.builder_.softmax(reshape); } async build(outputOperand) { @@ -146,4 +144,4 @@ export class MobileNetV27Nchw { const results = await this.context_.compute(this.graph_, inputs, outputs); return results; } -} +} \ No newline at end of file diff --git a/test-data b/test-data index 045017d3..ea628619 160000 --- a/test-data +++ b/test-data @@ -1 +1 @@ -Subproject commit 045017d38ea0133807fa26af9e5b030147cb2314 +Subproject commit ea628619df93141fa82535b063e6dde8aa7a5c9a