forked from webmachinelearning/webnn-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2e28fac
commit ae2c0a5
Showing
592 changed files
with
733 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
'use strict'; | ||
|
||
import {buildConstantByNpy} from '../common/utils.js'; | ||
|
||
// ResNet50 V1 model with 'nchw' input layout | ||
export class EfficientNetFP16Nchw { | ||
constructor() { | ||
this.context_ = null; | ||
this.builder_ = null; | ||
this.graph_ = null; | ||
this.weightsUrl_ = './weights/efficientnet_fp16_nchw_optimized/'; | ||
this.inputOptions = { | ||
mean: [0.485, 0.456, 0.406], | ||
std: [0.229, 0.224, 0.225], | ||
norm: true, | ||
inputLayout: 'nchw', | ||
labelUrl: './labels/labels1000.txt', | ||
inputDimensions: [1, 3, 224, 224], | ||
dataType: 'float16', | ||
}; | ||
this.outputDimensions = [1, 1000]; | ||
} | ||
|
||
async buildConv_(input, name, blockName, clip = false, options = {}) { | ||
let prefix = ''; | ||
if (blockName !== '') { | ||
prefix = this.weightsUrl_ + 'block' + blockName + '_conv' + | ||
name; | ||
} else { | ||
prefix = this.weightsUrl_ + 'conv' + name; | ||
} | ||
const weightName = prefix + '_w.npy'; | ||
const weight = await buildConstantByNpy(this.builder_, weightName); | ||
|
||
const biasName = prefix + '_b.npy'; | ||
options.bias = await buildConstantByNpy(this.builder_, biasName); | ||
if (clip){ | ||
return this.builder_.clamp( | ||
this.builder_.conv2d(input, weight, options), | ||
{minValue: 0, maxValue: 6}); | ||
} | ||
return this.builder_.conv2d(input, weight, options); | ||
} | ||
|
||
async buildGemm_(input, name) { | ||
const prefix = this.weightsUrl_ + 'dense' + name; | ||
const weightName = prefix + '_w.npy'; | ||
const weight = await buildConstantByNpy(this.builder_, weightName); | ||
const biasName = prefix + '_b.npy'; | ||
const bias = await buildConstantByNpy(this.builder_, biasName); | ||
const options = | ||
{c: this.builder_.reshape(bias, [1, 1000]), bTranspose: true}; | ||
return this.builder_.gemm(input, weight, options); | ||
} | ||
|
||
async buildBottlenect_(input, blockName, group, pad = 1) { | ||
const conv1 = await this.buildConv_(input, '0', blockName, true); | ||
const conv2 = await this.buildConv_( | ||
conv1, '1', blockName, true, { groups: group, padding: [pad, pad, pad, pad] }); | ||
const conv3 = await this.buildConv_(conv2, '2', blockName); | ||
return this.builder_.add(conv3, input); | ||
} | ||
|
||
async load(contextOptions) { | ||
this.context_ = await navigator.ml.createContext(contextOptions); | ||
this.builder_ = new MLGraphBuilder(this.context_); | ||
const data = this.builder_.input('input', { dataType: this.inputOptions.dataType, dimensions: this.inputOptions.inputDimensions }); | ||
// Block 0 | ||
const conv1 = await this.buildConv_( | ||
data, '0', '0', true, { padding: [0, 1, 0, 1], strides: [2, 2] }); | ||
const conv2 = await this.buildConv_(conv1, '1', '0', true, { groups: 32, padding: [1, 1, 1, 1] }); | ||
const conv3 = await this.buildConv_(conv2, '2', '0'); | ||
|
||
// Block 1 | ||
const conv4 = await this.buildConv_(conv3, '0', '1', true); | ||
const conv5 = await this.buildConv_(conv4, '1', '1', true, { groups: 144, padding: [0, 1, 0, 1], strides: [2, 2] }); | ||
const conv6 = this.buildConv_(conv5, '2', '1'); | ||
|
||
// Block 2~4 | ||
const bottleneck2 = await this.buildBottlenect_(conv6, '2', 192); | ||
const bottleneck3 = await this.buildBottlenect_(bottleneck2, '3', 192); | ||
const bottleneck4 = await this.buildBottlenect_(bottleneck3, '4', 192); | ||
|
||
// Block 5 | ||
const conv7 = await this.buildConv_(bottleneck4, '0', '5', true); | ||
const conv8 = await this.buildConv_(conv7, '1', '5', true, { groups: 192, padding: [1, 2, 1, 2], strides: [2, 2] }); | ||
const conv9 = await this.buildConv_(conv8, '2', '5'); | ||
|
||
// Block 6~8 | ||
const bottleneck6 = await this.buildBottlenect_(conv9, '6', 336, 2); | ||
const bottleneck7 = await this.buildBottlenect_(bottleneck6, '7', 336, 2); | ||
const bottleneck8 = await this.buildBottlenect_(bottleneck7, '8', 336, 2); | ||
|
||
// Block 9 | ||
const conv10 = await this.buildConv_(bottleneck8, '0', '9', true); | ||
const conv11 = await this.buildConv_(conv10, '1', '9', true, { groups: 336, padding: [0, 1, 0, 1], strides: [2, 2] }); | ||
const conv12 = this.buildConv_(conv11, '2', '9'); | ||
|
||
// Block 10~14 | ||
const bottleneck10 = await this.buildBottlenect_(conv12, '10', 672); | ||
const bottleneck11 = await this.buildBottlenect_(bottleneck10, '11', 672); | ||
const bottleneck12 = await this.buildBottlenect_(bottleneck11, '12', 672); | ||
const bottleneck13 = await this.buildBottlenect_(bottleneck12, '13', 672); | ||
const bottleneck14 = await this.buildBottlenect_(bottleneck13, '14', 672); | ||
|
||
// Block 15 | ||
const conv13 = await this.buildConv_(bottleneck14, '0', '15', true); | ||
const conv14 = await this.buildConv_(conv13, '1', '15', true, { groups: 672, padding: [2, 2, 2, 2] }); | ||
const conv15 = this.buildConv_(conv14, '2', '15'); | ||
|
||
// Block 16~20 | ||
const bottleneck16 = await this.buildBottlenect_(conv15, '16', 960, 2); | ||
const bottleneck17 = await this.buildBottlenect_(bottleneck16, '17', 960, 2); | ||
const bottleneck18 = await this.buildBottlenect_(bottleneck17, '18', 960, 2); | ||
const bottleneck19 = await this.buildBottlenect_(bottleneck18, '19', 960, 2); | ||
const bottleneck20 = await this.buildBottlenect_(bottleneck19, '20', 960, 2); | ||
|
||
// Block 21 | ||
const conv16 = await this.buildConv_(bottleneck20, '0', '21', true); | ||
const conv17 = await this.buildConv_(conv16, '1', '21', true, { groups: 960, padding: [1, 2, 1, 2], strides: [2, 2] }); | ||
const conv18 = await this.buildConv_(conv17, '2', '21'); | ||
|
||
// Block 22~28 | ||
const bottleneck22 = await this.buildBottlenect_(conv18, '22', 1632, 2); | ||
const bottleneck23 = await this.buildBottlenect_(bottleneck22, '23', 1632, 2); | ||
const bottleneck24 = await this.buildBottlenect_(bottleneck23, '24', 1632, 2); | ||
const bottleneck25 = await this.buildBottlenect_(bottleneck24, '25', 1632, 2); | ||
const bottleneck26 = await this.buildBottlenect_(bottleneck25, '26', 1632, 2); | ||
const bottleneck27 = await this.buildBottlenect_(bottleneck26, '27', 1632, 2); | ||
const bottleneck28 = await this.buildBottlenect_(bottleneck27, '28', 1632, 2); | ||
|
||
// Block 29 | ||
const conv19 = await this.buildConv_(bottleneck28, '0', '29', true); | ||
const conv20 = await this.buildConv_(conv19, '1', '29', true, { groups: 1632, padding: [1, 1, 1, 1] }); | ||
const conv21 = await this.buildConv_(conv20, '2', '29'); | ||
|
||
const conv22 = this.buildConv_(conv21, '2', '', true); | ||
const pool1 = await this.builder_.averagePool2d(conv22); | ||
const reshape = await this.builder_.reshape(pool1, [1, 1028]); | ||
return this.buildGemm_(reshape, '0'); | ||
} | ||
|
||
async build(outputOperand) { | ||
this.graph_ = await this.builder_.build({'output': outputOperand}); | ||
} | ||
|
||
// Release the constant tensors of a model | ||
dispose() { | ||
// dispose() is only available in webnn-polyfill | ||
if (this.graph_ !== null && 'dispose' in this.graph_) { | ||
this.graph_.dispose(); | ||
} | ||
} | ||
|
||
async compute(inputBuffer, outputBuffer) { | ||
const inputs = {'input': inputBuffer}; | ||
const outputs = {'output': outputBuffer}; | ||
const results = await this.context_.compute(this.graph_, inputs, outputs); | ||
return results; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.