Skip to content

Commit

Permalink
Enable mobilenetv2_7 model on NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
mingmingtasd committed Feb 22, 2024
1 parent 01a0801 commit 4586e4f
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 21 deletions.
2 changes: 1 addition & 1 deletion common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export async function buildConstantByNpy(builder, url) {
typedArray[i] = dataView[getFuncName](
i * TypedArrayConstructor.BYTES_PER_ELEMENT, littleEndian);
}
return builder.constant({type, dimensions}, typedArray);
return builder.constant({dataType: type, type, dimensions}, typedArray);
}

// Convert video frame to a canvas element
Expand Down
6 changes: 6 additions & 0 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off" checked>WebNN (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU)
</label>
</div>
</div>
</div>
Expand Down Expand Up @@ -220,6 +223,9 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<script src="https://cdn.jsdelivr.net/npm/popper.js@1.16.1/dist/umd/popper.min.js"
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.9.0/dist/tf.min.js"
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8="
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV"
crossorigin="anonymous"></script>
Expand Down
42 changes: 29 additions & 13 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ let buildTime = 0;
let computeTime = 0;
let inputOptions;
let outputBuffer;
let devicePreference = '';
let lastDevicePreference = '';
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
const disabledSelectors = ['#tabs > li', '.btn'];
Expand Down Expand Up @@ -137,6 +137,25 @@ async function renderCamStream() {

// Get top 3 classes of labels from output buffer
function getTopClasses(buffer, labels) {
// Convert output buffer from float16 to float32, because tf.tensor/tf.softmax doesn't
// support float16 data type according to https://js.tensorflow.org/api/latest/#tensor.
if (inputOptions.dataType === 'float16') {
let elements_count = utils.sizeOfShape(netInstance.outputDimensions);
let float32Buffer = new Float32Array(elements_count);
for (let i = 0; i < elements_count; ++i) {
float32Buffer[i] = utils.float16ToNumber(buffer[i]);
}
buffer = float32Buffer;
}

// Softmax
buffer = tf.tidy(() => {
const a =
tf.tensor(buffer, netInstance.outputDimensions, 'float32');
const b = tf.softmax(a);
return b.dataSync();
});

const probs = Array.from(buffer);
const indexes = probs.map((prob, index) => [prob, index]);
const sorted = indexes.sort((a, b) => {
Expand All @@ -150,9 +169,6 @@ function getTopClasses(buffer, labels) {

for (let i = 0; i < 3; ++i) {
let prob = sorted[i][0];
if (inputOptions.dataType === 'float16') {
prob = utils.float16ToNumber(prob);
}
const index = sorted[i][1];
const c = {
label: labels[index],
Expand Down Expand Up @@ -218,7 +234,7 @@ function constructNetObject(type) {
async function main() {
try {
if (modelName === '') return;
[backend, devicePreference] =
[backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
Expand All @@ -228,12 +244,12 @@ async function main() {
// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
if (isFirstTimeLoad || instanceType !== modelName + layout ||
lastDevicePreference != devicePreference || lastBackend != backend) {
if (lastDevicePreference != devicePreference || lastBackend != backend) {
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
await utils.setBackend(backend, devicePreference);
lastDevicePreference = lastDevicePreference != devicePreference ?
devicePreference : lastDevicePreference;
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
}
if (netInstance !== null) {
Expand All @@ -257,7 +273,7 @@ async function main() {
// UI shows model loading progress
await ui.showProgressComponent('current', 'pending', 'pending');
console.log('- Loading weights... ');
const contextOptions = {type: 'webnn', deviceType: devicePreference};
const contextOptions = {type: 'webnn', deviceType: deviceType};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
}
Expand Down Expand Up @@ -318,4 +334,4 @@ async function main() {
ui.addAlert(error.message);
}
ui.handleClick(disabledSelectors, false);
}
}
10 changes: 4 additions & 6 deletions image_classification/mobilenetv2_7_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export class MobileNetV27Nchw {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
const data = this.builder_.input('input',
{type: this.dataType_, dimensions: this.inputOptions.inputDimensions});
{dataType: this.dataType_, dimensions: this.inputOptions.inputDimensions});
const conv0 = await this.buildConv_(
data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]});
const conv1 = await this.buildConv_(
Expand Down Expand Up @@ -122,10 +122,8 @@ export class MobileNetV27Nchw {
const conv3 = await this.buildConv_(bottleneck15, '95', true);
const conv4 = await this.buildConv_(conv3, '97', false, {groups: 1280, strides: [7, 7]});
const conv5 = await this.buildConv_(conv4, '104', false);
const reshape = this.builder_.reshape(conv5, [1, null]);
// return reshape;
// const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(reshape);
return this.builder_.reshape(conv5, [1, 1000]);
// return this.builder_.softmax(reshape);
}

async build(outputOperand) {
Expand All @@ -146,4 +144,4 @@ export class MobileNetV27Nchw {
const results = await this.context_.compute(this.graph_, inputs, outputs);
return results;
}
}
}
2 changes: 1 addition & 1 deletion test-data
Submodule test-data updated 1026 files

0 comments on commit 4586e4f

Please sign in to comment.