-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement layerNormalization #60
implement layerNormalization #60
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mei1127, LGTM with some nits.
src/layer_normalization.js
Outdated
|
||
/** | ||
* Normalize the tensor values of input features using | ||
* [layer-Normalization] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update [layer-Normalization] with link, and combine two lines into just one line
* Normalize the tensor values of input features using
* [layer-Normalization]
to
* Normalize the tensor values of input features using [layer-Normalization](https://arxiv.org/abs/1607.06450)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,thanks
src/layer_normalization.js
Outdated
*/ | ||
export function layerNormalization(input, {scale, bias, axes, epsilon=1e-5}) { | ||
validateLayerNormalizationParams(...arguments); | ||
console.log('axes :', axes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove these debugging console.log
code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
export function layerNormalization(input, {scale, bias, axes, epsilon=1e-5}) { | ||
validateLayerNormalizationParams(...arguments); | ||
console.log('axes :', axes); | ||
if (axes === undefined) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced with @shiyi9801, here default axes should be [1, ..., N-1] which N means the rank of input.
src/layer_normalization.js
Outdated
const mean = reduceMean(input, reduceOptions); | ||
const variance = reduceMean(pow(sub(input, mean), new Scalar(2)), reduceOptions); | ||
output = div(sub(input, mean), | ||
pow(add(variance, new Scalar(epsilon)), new Scalar(0.5))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: sqrt op was added, so here we could invoke sqrt for simple usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
export function validateLayerNormalizationParams(input, {axes, scale, bias} = {}) { | ||
if (scale && axes ) { | ||
if (scale.rank !== axes.length) { | ||
throw new Error('DataError: the rank of scale is not equal to the size of axes.'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These error messages are totally following Algorithm part of layerNormalization op, it would be a todo enhancement (low priority) to update validation checking for others ops. @huningxin WDYT, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me, thanks!
src/layer_normalization.js
Outdated
output = div(sub(input, mean), | ||
pow(add(variance, new Scalar(epsilon)), new Scalar(0.5))); | ||
if (scale) { | ||
output = mul(output, reshape(scale, shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, scale
and bias
should not be directly reshape to shape
.
The reason is that axes
can be out-of-order.
Let's say an input shape = [1, 2, 3, 4], with axes = [1, 3], scale shape = [2, 4], then scale can be directly reshaped to [1, 2, 1, 4], and it's broadcast compatible to [1, 2, 3, 4].
But with axes = [3, 1], scale shape = [4, 2], if it's directly reshaped to [1, 2, 1, 4], then the broadcast to [1, 2, 3, 4] will work incorrectly.
So we should transpose the scale
and bias
from [4, 2] to [2, 4] before reshape them, and the transpose should follow how the axes
being sorted to ascending order.
For example, with axes = [2, 0, 1] and scale shape = [4, 1, 2], we should transpose the scale following the permutation [1, 2, 0], so scale shape will be transposed to [1, 2, 4], and then we can reshape it to [1, 2, 1, 4].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, let me think about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One important comment and some nits, but otherwise LGTM.
src/layer_normalization.js
Outdated
// The output tensor has the same shape as the input tensor. | ||
let output = new Tensor(input.shape); | ||
const inputShape = input.shape; | ||
const shape = new Array(input.rank).fill(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about shape
-> compatibleShape
or broadcastableShape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/layer_normalization.js
Outdated
const mean = reduceMean(input, reduceOptions); | ||
const variance = reduceMean(pow(sub(input, mean), new Scalar(2)), reduceOptions); | ||
output = div(sub(input, mean), | ||
pow(add(variance, new Scalar(epsilon)), new Scalar(0.5))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just use the sqrt
operator now? I think it's implemented now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/lib/validate-input.js
Outdated
throw new Error('DataError: the rank of scale is not equal to the size of axes.'); | ||
} | ||
} | ||
if (bias && axes ) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (bias && axes ) { | |
if (bias && axes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/lib/validate-input.js
Outdated
for (let i = 0; i < axes.length; i++) { | ||
const axis = axes[i]; | ||
if (axis >= input.rank) { | ||
throw new Error('DataError:the value of axis in axes should be smaller than input.rank'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new Error('DataError:the value of axis in axes should be smaller than input.rank'); | |
throw new Error('DataError: the value of axis in axes should be smaller than input.rank'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/lib/validate-input.js
Outdated
} | ||
const dim = input.shape[axis]; | ||
if (scale) { | ||
if (scale.shape[i] == !dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (scale.shape[i] == !dim) { | |
if (scale.shape[i] !== dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
test/layer_normalization_test.js
Outdated
shape: [2, 2, 3], | ||
value: [1, 2, 3, 6, 5, 4, 3, 6, 24, -10, 0, 5], | ||
}, | ||
[-2.4494713718167804, 0, 4.898942743633561, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[-2.4494713718167804, 0, 4.898942743633561, | |
[ | |
-2.4494713718167804, 0, 4.898942743633561, |
New line for consistency with elsewhere, like below:
[
-1.4494713718167804,
2,
7.898942743633561,
3.4494713718167804,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
test/layer_normalization_test.js
Outdated
1.4638475999719223, 0.8783085599831534, 0.29276951999438444, | ||
-0.1645769966453613, 0.131661597316289, 1.9090931610861905, | ||
-1.4482775704791793, -0.46081559060701155, 0.032915399329072226, | ||
[-1.4638475999719223, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I wasn't talking about the wrapping of the numbers. Was just requesting a new line, like other places :b.
[ <---
-1.4638475999719223,
0.8783085599831534,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,thanks! I got it
src/layer_normalization.js
Outdated
const reduceOptions = {axes, keepDimensions: true}; | ||
const mean = reduceMean(input, reduceOptions); | ||
const variance = reduceMean(pow(sub(input, mean), new Scalar(2)), reduceOptions); | ||
output = div(sub(input, mean), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these two lines can be in one .
src/layer_normalization.js
Outdated
* @param {MLBatchNormalizationOptions} [options] | ||
* @return {Tensor} | ||
*/ | ||
export function sortByValue(axes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function description isn't align with the below sortByValue function. Please add each description for each function.
And I suggest renaming sortByValue
by getIndexOfSortedValue
or getOriginalIndexOfSortedValue
, it would be understandable.
src/layer_normalization.js
Outdated
* Normalize the tensor values of input features using | ||
* [layer-Normalization](https://arxiv.org/abs/1607.06450) | ||
* @param {Tensor} input | ||
* @param {Array} axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this line description
src/layer_normalization.js
Outdated
* [layer-Normalization](https://arxiv.org/abs/1607.06450) | ||
* @param {Tensor} input | ||
* @param {Array} axes | ||
* @param {MLBatchNormalizationOptions} [options] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need modify MLBatchNormalizationOptions to MLLayerNormalizationOptions
Thanks for addressing these review comments. |
e6b9d71
to
5e5f729
Compare
); | ||
}); | ||
|
||
it('layerNormalization Ascending order axis', function() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lowercase Ascending to ascending
Don't you just squash merge the end result anyway, which results in a clean merge history to the target branch? Unfortunately forced pushes onto the active CR branch breaks this very useful functionality where reviewers can quickly see what's new since the previous time they reviewed it. 😢 |
@BruceDai @huningxin @fdwr @shiyi9801 .PTAL, thanks!