Skip to content

Commit

Permalink
Address wanming's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mingmingtasd committed May 7, 2024
1 parent f353896 commit c53f0d3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
14 changes: 6 additions & 8 deletions image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ export class EfficientNetFP16Nchw {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.targetDataType_ = 'float16';
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/efficientnet_fp16_nchw_optimized/weights/';
this.inputOptions = {
Expand All @@ -18,7 +17,6 @@ export class EfficientNetFP16Nchw {
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
dataType: 'float32',
};
this.outputDimensions = [1, 1000];
}
Expand All @@ -32,9 +30,9 @@ export class EfficientNetFP16Nchw {
prefix = this.weightsUrl_ + 'conv' + name;
}
const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
this.targetDataType_ = 'float16');
'float16');
options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
this.targetDataType_ = 'float16');
'float16');
if (clip) {
return this.builder_.clamp(
this.builder_.conv2d(await input, await weight, options),
Expand All @@ -47,13 +45,13 @@ export class EfficientNetFP16Nchw {
const prefix = this.weightsUrl_ + 'dense' + name;
const weightName = prefix + '_w.npy';
const weight = buildConstantByNpy(this.builder_, weightName,
this.targetDataType_ = 'float16');
'float16');
const biasName = prefix + '_b.npy';
const bias = buildConstantByNpy(this.builder_, biasName,
this.targetDataType_ = 'float16');
'float16');
const options =
{c: this.builder_.reshape(await bias, [1, 1000])};
return this.builder_.gemm(await input, await weight, options);
return await this.builder_.gemm(await input, await weight, options);
}

async buildBottleneck_(input, blockName, group, pad = 1) {
Expand All @@ -78,7 +76,7 @@ export class EfficientNetFP16Nchw {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
let data = this.builder_.input('input', {
dataType: this.inputOptions.dataType,
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
data = this.builder_.cast(data, 'float16');
Expand Down
31 changes: 15 additions & 16 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';

// MobileNet V2 model with 'nchw' input layout
export class MobileNetV2Nchw {
constructor(targetDataType = 'float32') {
constructor(dataType = 'float32') {
this.context_ = null;
this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
this.targetDataType_ = targetDataType;
this.dataType_ = dataType;
this.weightsUrl_ = weightsOrigin();
if (this.targetDataType_ === 'float32') {
if (this.dataType_ === 'float32') {
this.weightsUrl_ += '/test-data/models/mobilenetv2_nchw/weights/';
} else if (this.targetDataType_ === 'float16') {
} else if (this.dataType_ === 'float16') {
this.weightsUrl_ +=
'/test-data/models/mobilenetv2_fp16_nchw_optimized/weights/';
} else {
throw new Error(`Unsupported dataType: ${this.targetDataType_}`);
throw new Error(`Unsupported dataType: ${this.dataType_}`);
}
this.inputOptions = {
mean: [0.485, 0.456, 0.406],
Expand All @@ -26,25 +26,24 @@ export class MobileNetV2Nchw {
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
dataType: 'float32',
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, relu6 = true, options = {}) {
let weights;
if (this.targetDataType_==='float32') {
if (this.dataType_ === 'float32') {
weights = buildConstantByNpy(this.builder_,
`${this.weightsUrl_}conv_${name}_weight.npy`, this.targetDataType_);
`${this.weightsUrl_}conv_${name}_weight.npy`);
options.bias = await buildConstantByNpy(this.builder_,
`${this.weightsUrl_}conv_${name}_bias.npy`, this.targetDataType_);
`${this.weightsUrl_}conv_${name}_bias.npy`);
} else {
weights = buildConstantByNpy(this.builder_,
`${this.weightsUrl_}w${name}.npy`, this.targetDataType_);
`${this.weightsUrl_}w${name}.npy`, this.dataType_);
// Only node 97 has no bias input
if (name !== '97') {
options.bias = await buildConstantByNpy(this.builder_,
`${this.weightsUrl_}b${name}.npy`, this.targetDataType_);
`${this.weightsUrl_}b${name}.npy`, this.dataType_);
}
}

Expand All @@ -67,10 +66,10 @@ export class MobileNetV2Nchw {
const prefix = this.weightsUrl_ + 'gemm_' + name;
const weightsName = prefix + '_weight.npy';
const weights = buildConstantByNpy(this.builder_, weightsName,
this.targetDataType_);
this.dataType_);
const biasName = prefix + '_bias.npy';
const bias = buildConstantByNpy(this.builder_, biasName,
this.targetDataType_);
this.dataType_);
const options = {c: await bias, bTranspose: true};
return this.builder_.gemm(await input, await weights, options);
}
Expand Down Expand Up @@ -99,10 +98,10 @@ export class MobileNetV2Nchw {
this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
let data = this.builder_.input('input', {
dataType: this.inputOptions.dataType,
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
if (this.targetDataType_ === 'float16') {
if (this.dataType_ === 'float16') {
data = this.builder_.cast(data, 'float16');
}
const conv0 = this.buildConv_(
Expand Down Expand Up @@ -144,7 +143,7 @@ export class MobileNetV2Nchw {
bottleneck14, ['90', '92', '94'], 960, 1, false);

const conv3 = this.buildConv_(bottleneck15, '95', true);
if (this.targetDataType_ == 'float32') {
if (this.dataType_ == 'float32') {
const pool = this.builder_.averagePool2d(await conv3);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = this.buildGemm_(reshape, '104');
Expand Down
12 changes: 5 additions & 7 deletions image_classification/resnet50v1_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ export class ResNet50V1FP16Nchw {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.targetDataType_ = 'float16';
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/resnet50v1_fp16_nchw_optimized/weights/';
this.inputOptions = {
Expand All @@ -18,7 +17,6 @@ export class ResNet50V1FP16Nchw {
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
dataType: 'float32',
};
this.outputDimensions = [1, 1000];
}
Expand All @@ -32,9 +30,9 @@ export class ResNet50V1FP16Nchw {
prefix = this.weightsUrl_ + 'conv' + name;
}
const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
this.targetDataType_);
'float16');
options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
this.targetDataType_);
'float16');
if (relu) {
options.activation = this.builder_.relu();
}
Expand All @@ -46,10 +44,10 @@ export class ResNet50V1FP16Nchw {
const prefix = this.weightsUrl_ + 'dense' + name;
const weightName = prefix + '_w.npy';
const weight = buildConstantByNpy(this.builder_, weightName,
this.targetDataType_);
'float16');
const biasName = prefix + '_b.npy';
const bias = buildConstantByNpy(this.builder_, biasName,
this.targetDataType_);
'float16');
const options =
{c: this.builder_.reshape(await bias, [1, 1000]), bTranspose: true};
return this.builder_.gemm(await input, await weight, options);
Expand Down Expand Up @@ -81,7 +79,7 @@ export class ResNet50V1FP16Nchw {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
let data = this.builder_.input('input', {
dataType: this.inputOptions.dataType,
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
data = this.builder_.cast(data, 'float16');
Expand Down

0 comments on commit c53f0d3

Please sign in to comment.