Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement scatterElements #105

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,44 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}

export function validateScatterElementsParams(input, indices, updates, {axis = 0} = {}) {
const inputRank = input.rank;
const indicesRank = indices.rank;
const updatesRank = updates.rank;
if (inputRank < 1) {
throw new Error(`The input should be at least a 1-D tensor.`);
}
if (indicesRank < 1) {
throw new Error(`The indices should be at least a 1-D tensor.`);
}
if (updatesRank < 1) {
throw new Error(`The updates should be at least a 1-D tensor.`);
}
if (indicesRank !== inputRank) {
throw new Error(`Invalid indices value - indices rank should be equal to input rank.`);
}
if (updatesRank !== indicesRank) {
throw new Error(`Invalid updates value - updates rank should be equal to indices rank.`);
}
if (!updates.shape.every((value, index) => value === indices.shape[index])) {
throw new Error(`Invalid updates value - updates shape should be same as indices shape.`);
}
if (axis !== undefined) {
if (!Number.isInteger(axis) || axis < 0 || axis >= inputRank) {
throw new Error(
`The axis ${axis} should be an unsigned integer in the interval [0, ${inputRank}).`);
}
}
huningxin marked this conversation as resolved.
Show resolved Hide resolved
const axisSize = input.shape[axis];
for (let i = 0; i < sizeOfShape(indices.shape); ++i) {
const index = indices.getValueByIndex(i);
if (!Number.isInteger(index) || index < -axisSize || index >= axisSize) {
throw new Error(`Invalid indices value - it should be an integer in the interval ` +
`[${-axisSize}, ${axisSize - 1}]`);
}
}
}

export function validateTriangularParams(input, {diagonal = 0} = {}) {
const inputRank = input.rank;
if (inputRank < 2) {
Expand Down
33 changes: 33 additions & 0 deletions src/scatter_elements.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'use strict';

import {sizeOfShape} from './lib/tensor.js';
import {validateScatterElementsParams} from './lib/validate-input.js';
import {identity} from './unary.js';

/**
* Create a copy of the input data, and then update its value to values specified by updates at
* specific index positions specified by indices.
* @param {Tensor} input
* @param {Tensor} indices
* @param {Tensor} updates
* @param {MLScatterOptions} [options]
* @return {Tensor}
*/
export function scatterElements(input, indices, updates, {axis = 0} = {}) {
validateScatterElementsParams(...arguments);
const output = identity(input);

for (let indicesIndex = 0; indicesIndex < sizeOfShape(indices.shape); ++indicesIndex) {
// output[indices[i, j, k, ...], j, k, ...] = updates[i, j, k, ...] // if axis == 0
// output[i, indices[i, j, k, ...], k, ...] = updates[i, j, k, ...] // if axis == 1
// output[i, j, indices[i, j, k, ...], ...] = updates[i, j, k, ...] // if axis == 2
const indicesLocation = indices.locationFromIndex(indicesIndex);
let indiceValue = indices.getValueByIndex(indicesIndex);
indiceValue = indiceValue < 0 ? indiceValue + input.shape[axis] : indiceValue;
const outputLocation = indicesLocation.slice();
outputLocation[axis] = indiceValue;
BruceDai marked this conversation as resolved.
Show resolved Hide resolved
output.setValueByLocation(outputLocation, updates.getValueByIndex(indicesIndex));
}

return output;
}
2 changes: 1 addition & 1 deletion src/unary.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export const log = (input) => unary(input, Math.log);
export const neg = (input) => unary(input, (x) => -1 * x);
export const sin = (input) => unary(input, Math.sin);
export const tan = (input) => unary(input, Math.tan);
export const copy = (input) => unary(input, (x) => x);
export const identity = (input) => unary(input, (x) => x);
export const reciprocal = (input) => unary(input, (x) => 1 / x);
export const sqrt = (input) => unary(input, Math.sqrt);
export const erf = (input) => unary(input, erfKernel);
Expand Down
124 changes: 124 additions & 0 deletions test/scatter_elements_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
'use strict';

import {scatterElements} from '../src/scatter_elements.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test scatterElements', function() {
function testScatterElements(input, indices, updates, expected, options = {}) {
const inputTensor = new Tensor(input.shape, input.data);
const indicesTensor = new Tensor(indices.shape, indices.data);
const updatesTensor = new Tensor(updates.shape, updates.data);
const outputTensor = scatterElements(inputTensor, indicesTensor, updatesTensor, options);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('scatterElements 2D default', function() {
const input = {
shape: [3, 3],
data: [
0, 0, 0,
0, 0, 0,
0, 0, 0,
],
};
const indices = {
shape: [2, 3],
data: [
1, 0, 2,
0, 2, 1,
],
};
const updates = {
shape: [2, 3],
data: [
1, 1.1, 1.2,
2, 2.1, 2.2,
],
};
const expected = {
shape: [3, 3],
data: [
2, 1.1, 0,
1, 0, 2.2,
0, 2.1, 1.2,
],
};
testScatterElements(input, indices, updates, expected);
});

it('scatterElements 2D, explicit axis=0', function() {
const input = {
shape: [3, 3],
data: [
0, 0, 0,
0, 0, 0,
0, 0, 0,
],
};
const indices = {
shape: [2, 3],
data: [
1, 0, 2,
0, 2, 1,
],
};
const updates = {
shape: [2, 3],
data: [
1, 1.1, 1.2,
2, 2.1, 2.2,
],
};
const expected = {
shape: [3, 3],
data: [
2, 1.1, 0,
1, 0, 2.2,
0, 2.1, 1.2,
],
};
testScatterElements(input, indices, updates, expected, {axis: 0});
});

it('scatterElements 2D, axis=1', function() {
const input = {
shape: [1, 5],
data: [1, 2, 3, 4, 5],
};
const indices = {
shape: [1, 2],
data: [1, 3],
};
const updates = {
shape: [1, 2],
data: [1.1, 2.1],
};
const expected = {
shape: [1, 5],
data: [1, 1.1, 3, 2.1, 5],
};
testScatterElements(input, indices, updates, expected, {axis: 1});
});

it('scatterElements 2D negative indices, axis=1', function() {
const input = {
shape: [1, 5],
data: [1, 2, 3, 4, 5],
};
const indices = {
shape: [1, 2],
data: [1, -2],
};
const updates = {
shape: [1, 2],
data: [1.1, 2.1],
};
const expected = {
shape: [1, 5],
data: [1, 1.1, 3, 2.1, 5],
};
testScatterElements(input, indices, updates, expected, {axis: 1});
});
});
12 changes: 6 additions & 6 deletions test/unary_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -764,16 +764,16 @@ describe('test unary', function() {
[3, 2, 2, 1]);
});

it('copy', function() {
it('identity', function() {
// 0D scalar
testUnary('copy', [1.4124068], [1.4124068], []);
testUnary('identity', [1.4124068], [1.4124068], []);
testUnary(
'copy',
'identity',
[1.4124068, 1.9740626, -0.06506752, 0.73539704],
[1.4124068, 1.9740626, -0.06506752, 0.73539704],
[4]);
testUnary(
'copy',
'identity',
[
1.4124068, 1.9740626, -0.06506752, 0.73539704,
-0.56439203, 0.89806247, 0.12939146, -0.34816208,
Expand All @@ -786,7 +786,7 @@ describe('test unary', function() {
],
[3, 4]);
testUnary(
'copy',
'identity',
[
1.4124068, 1.9740626,
-0.06506752, 0.73539704,
Expand All @@ -805,7 +805,7 @@ describe('test unary', function() {
],
[3, 2, 2]);
testUnary(
'copy',
'identity',
[
1.4124068, 1.9740626,
-0.06506752, 0.73539704,
Expand Down
Loading