Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement layerNormalization #60

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/layer_normalization.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
'use strict';

import {add, sub, mul, div, pow} from './binary.js';
import {reduceMean} from './reduce.js';
import {reshape} from './reshape.js';
import {sqrt} from './unary.js';
import {Tensor, Scalar} from './lib/tensor.js';
import {transpose} from './transpose.js';
import {validateLayerNormalizationParams} from './lib/validate-input.js';

/**
* Sort the indexes of the elements in the axes array
* based on their values and return the sorted index array
* @param {Array} axes
* @return {Array}
*/
export function getIndexOfSortedValue(axes) {
const sortedIndices = axes.map((_, index) => index);
sortedIndices.sort((a, b) => axes[a] - axes[b]);
return sortedIndices;
}

/**
* Normalize the tensor values of input features using
* [layer-Normalization](https://arxiv.org/abs/1607.06450)
* @param {Tensor} input
* @param {MLLayerNormalizationOptions} [options]
* @return {Tensor}
*/
export function layerNormalization(input, {scale, bias, axes, epsilon=1e-5}) {
validateLayerNormalizationParams(...arguments);
if (axes === undefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced with @shiyi9801, here default axes should be [1, ..., N-1] which N means the rank of input.

axes = Array.from({length: input.rank - 1}, (_, index) => index + 1);
}
const sortAxes = getIndexOfSortedValue(axes);
if (scale) {
scale = transpose(scale, {permutation: sortAxes});
}
if (bias) {
bias = transpose(bias, {permutation: sortAxes});
}
// The output tensor has the same shape as the input tensor.
let output = new Tensor(input.shape);
const inputShape = input.shape;
const compatibleShape = new Array(input.rank).fill(1);
for (let i = 0; i < axes.length; i++) {
const axis = axes[i];
compatibleShape[axis] = inputShape[axis];
}
const reduceOptions = {axes, keepDimensions: true};
const mean = reduceMean(input, reduceOptions);
const variance = reduceMean(pow(sub(input, mean), new Scalar(2)), reduceOptions);
output = div(sub(input, mean), sqrt(add(variance, new Scalar(epsilon))));
if (scale) {
output = mul(output, reshape(scale, compatibleShape));
}
if (bias) {
output = add(output, reshape(bias, compatibleShape));
}
return output;
}
34 changes: 34 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,40 @@ export function validateBatchNormalizationParams(input, mean, variance,
check1DTensorWithSize(bias, dim, 'bias');
}

export function validateLayerNormalizationParams(input, {axes, scale, bias} = {}) {
if (scale && axes) {
if (scale.rank !== axes.length) {
throw new Error('DataError: the rank of scale is not equal to the size of axes.');
Copy link
Contributor

@BruceDai BruceDai Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These error messages are totally following Algorithm part of layerNormalization op, it would be a todo enhancement (low priority) to update validation checking for others ops. @huningxin WDYT, thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me, thanks!

}
}
if (bias && axes) {
if (bias.rank !== axes.length) {
throw new Error('DataError: the rank of bias is not equal to the size of axes.');
}
}
if (axes) {
for (let i = 0; i < axes.length; i++) {
const axis = axes[i];
if (axis >= input.rank) {
throw new Error('DataError: the value of axis in axes should be smaller than input.rank');
}
const dim = input.shape[axis];
if (scale) {
if (scale.shape[i] !== dim) {
throw new Error(`The length ${scale.shape[i]} of the scale values is not equal to the ` +
`size ${dim} of the input dimension denoted by options.axis.`);
}
}
if (bias) {
if (bias.shape[i] !== dim) {
throw new Error(`The length ${bias.shape[i]} of the bias values is not equal to the ` +
`size ${dim} of the input dimension denoted by options.axis.`);
}
}
}
}
}

export function validateInstanceNormalizationParams(
input,
{
Expand Down
225 changes: 225 additions & 0 deletions test/layer_normalization_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
'use strict';

import {layerNormalization} from '../src/layer_normalization.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';


describe('test layerNormalization', function() {
function testLayerNorm(
input, expected, scale = undefined, bias = undefined, axes = undefined, options = {}) {
const inputTensor = new Tensor(input.shape, input.value);
if (scale) {
options.scale = new Tensor(scale.shape, scale.value);
}
if (bias) {
options.bias = new Tensor(bias.shape, bias.value);
}
if (axes) {
options.axes = axes;
}
const outputTensor = layerNormalization(inputTensor, options);
utils.checkShape(outputTensor, input.shape);
utils.checkValue(outputTensor, expected);
}

it('layerNormalization default 2D', function() {
testLayerNorm(
{
shape: [2, 3],
value: [1, 2, 3, 3, 6, 24],
},
[
-1.2247356859083902,
0,
1.2247356859083902,
-0.8626621354727335,
-0.5391638346704585,
1.4018259701431919,
],
);
});

it('layerNormalization default 3D', function() {
testLayerNorm(
{
shape: [2, 2, 3],
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5],
},
[
-1.4638475999719223,
-0.8783085599831534,
-0.29276951999438444,
1.4638475999719223,
0.8783085599831534,
0.29276951999438444,
-0.1645769966453613,
0.131661597316289,
1.9090931610861905,
-1.4482775704791793,
-0.46081559060701155,
0.032915399329072226,
],
);
});

it('layerNormalization default 3D with axes=[2]', function() {
testLayerNorm(
{
shape: [2, 2, 3],
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5],
},
[
-1.2247356859083902,
0,
1.2247356859083902,
1.2247356859083902,
0,
-1.2247356859083902,
-0.8626621354727335,
-0.5391638346704585,
1.4018259701431919,
-1.3363060377513567,
0.26726120755027133,
1.0690448302010853,
],
undefined,
undefined,
[2],
);
});

it('layerNormalization 3D with scale and axes=[2]', function() {
testLayerNorm(
{
shape: [2, 2, 3],
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5],
},
[
-2.4494713718167804,
0, 4.898942743633561,
2.4494713718167804,
0, -4.898942743633561,
-1.725324270945467,
-1.6174915040113755,
5.6073038805727675,
-2.6726120755027134,
0.8017836226508139,
4.276179320804341,
],
{
shape: [3],
value: [2, 3, 4],
},
undefined,
[2],
);
});

it('layerNormalization 3D with scale and bias and axes=[2]', function() {
testLayerNorm(
{
shape: [2, 2, 3],
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5],
},
[
-1.4494713718167804,
2,
7.898942743633561,
3.4494713718167804,
2,
-1.8989427436335609,
-0.725324270945467,
0.38250849598862446,
8.607303880572768,
-1.6726120755027134,
2.801783622650814,
7.276179320804341,
],
{
shape: [3],
value: [2, 3, 4],
},
{
shape: [3],
value: [1, 2, 3],
},
[2],
);
});

it('layerNormalization 3D with epsilon', function() {
testLayerNorm(
{
shape: [2, 2, 3],
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5],
},
[
-1.4635992282002273,
-0.8781595369201364,
-0.29271984564004544,
1.4635992282002273,
0.8781595369201364,
0.29271984564004544,
-0.16457620229526346,
0.1316609618362107,
1.9090839466250555,
-1.4482705801983182,
-0.46081336642673765,
0.032915240459052655,
],
undefined,
undefined,
undefined,
{
epsilon: 1e-3,
},
);
});

it('layerNormalization test descending order axes', function() {
testLayerNorm(
{
shape: [1, 2, 3],
value: [1, 2, 3, 4, 5, 6],
},
[
-1.4638475999719223,
-2.6349256799494603,
-1.4638475999719223,
0.5855390399887689,
3.5132342399326135,
8.783085599831534,
],
{
shape: [3, 2],
value: [1, 2, 3, 4, 5, 6],
},
undefined,
[2, 1],
);
});

it('layerNormalization Ascending order axis', function() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lowercase Ascending to ascending

testLayerNorm(
{
shape: [1, 2, 3],
value: [1, 2, 3, 4, 5, 6],
},
[
-1.4638475999719223,
-1.7566171199663068,
-0.8783085599831533,
1.1710780799775378,
4.391542799915767,
8.783085599831534,
],
{
shape: [2, 3],
value: [1, 2, 3, 4, 5, 6],
},
undefined,
[1, 2],
);
});
});