From 01a0801e5b04623170cf66980aeff0278ae446fd Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 2 Nov 2023 11:25:30 +0800 Subject: [PATCH] Update to align with latest spec --- image_classification/main.js | 11 +++++++---- image_classification/mobilenetv2_7_nchw.js | 11 ++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/image_classification/main.js b/image_classification/main.js index db5c228a..912cdca7 100644 --- a/image_classification/main.js +++ b/image_classification/main.js @@ -124,8 +124,9 @@ async function renderCamStream() { const inputCanvas = utils.getVideoFrame(camElement); console.log('- Computing... '); const start = performance.now(); - await netInstance.compute(inputBuffer, outputBuffer); + const results = await netInstance.compute(inputBuffer, outputBuffer); computeTime = (performance.now() - start).toFixed(2); + outputBuffer = results.outputs.output; console.log(` done in ${computeTime} ms.`); drawInput(inputCanvas, 'camInCanvas'); showPerfResult(); @@ -256,7 +257,7 @@ async function main() { // UI shows model loading progress await ui.showProgressComponent('current', 'pending', 'pending'); console.log('- Loading weights... '); - const contextOptions = {type: 'webnn', devicePreference: devicePreference}; + const contextOptions = {type: 'webnn', deviceType: devicePreference}; if (powerPreference) { contextOptions['powerPreference'] = powerPreference; } @@ -281,11 +282,12 @@ async function main() { let medianComputeTime; // Do warm up - await netInstance.compute(inputBuffer, outputBuffer); + let results = await netInstance.compute(inputBuffer, outputBuffer); for (let i = 0; i < numRuns; i++) { start = performance.now(); - await netInstance.compute(inputBuffer, outputBuffer); + results = await netInstance.compute( + results.inputs.input, results.outputs.output); computeTime = (performance.now() - start).toFixed(2); console.log(` compute time ${i+1}: ${computeTime} ms`); computeTimeArray.push(Number(computeTime)); @@ -295,6 +297,7 @@ async function main() { medianComputeTime = medianComputeTime.toFixed(2); console.log(` median compute time: ${medianComputeTime} ms`); } + outputBuffer = results.outputs.output; console.log('outputBuffer: ', outputBuffer); await ui.showProgressComponent('done', 'done', 'done'); ui.readyShowResultComponents(); diff --git a/image_classification/mobilenetv2_7_nchw.js b/image_classification/mobilenetv2_7_nchw.js index 016601bd..cb20a8ed 100644 --- a/image_classification/mobilenetv2_7_nchw.js +++ b/image_classification/mobilenetv2_7_nchw.js @@ -40,11 +40,11 @@ export class MobileNetV27Nchw { } if (clip) { // implement `clip` by `clamp` of WebNN API - options.activation = this.builder_.clamp({minValue: 0, maxValue: 6}); + return this.builder_.clamp(this.builder_.conv2d(input, weights, options), {minValue: 0, maxValue: 6}); } else { options.activation = undefined; + return this.builder_.conv2d(input, weights, options); } - return this.builder_.conv2d(input, weights, options); } async buildGemm_(input, name) { @@ -122,14 +122,14 @@ export class MobileNetV27Nchw { const conv3 = await this.buildConv_(bottleneck15, '95', true); const conv4 = await this.buildConv_(conv3, '97', false, {groups: 1280, strides: [7, 7]}); const conv5 = await this.buildConv_(conv4, '104', false); - const reshape = this.builder_.reshape(conv5, [1, -1]); + const reshape = this.builder_.reshape(conv5, [1, null]); // return reshape; // const gemm = await this.buildGemm_(reshape, '104'); return this.builder_.softmax(reshape); } async build(outputOperand) { - this.graph_ = await this.builder_.buildAsync({'output': outputOperand}); + this.graph_ = await this.builder_.build({'output': outputOperand}); } // Release the constant tensors of a model @@ -143,6 +143,7 @@ export class MobileNetV27Nchw { async compute(inputBuffer, outputBuffer) { const inputs = {'input': inputBuffer}; const outputs = {'output': outputBuffer}; - await this.context_.compute(this.graph_, inputs, outputs); + const results = await this.context_.compute(this.graph_, inputs, outputs); + return results; } }