diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 63cf96a..f98929e 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -611,6 +611,42 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) { } } +export function validateScatterElementsParams(input, indices, updates, {axis = 0} = {}) { + const inputRank = input.rank; + const indicesRank = indices.rank; + const updatesRank = updates.rank; + if (inputRank < 1) { + throw new Error(`The input should be at least a 1-D tensor.`); + } + if (indicesRank < 1) { + throw new Error(`The indices should be at least a 1-D tensor.`); + } + if (updatesRank < 1) { + throw new Error(`The updates should be at least a 1-D tensor.`); + } + if (indicesRank !== inputRank) { + throw new Error(`Invalid indices value - indices rank should be equal to input rank.`); + } + if (updatesRank !== indicesRank) { + throw new Error(`Invalid updates value - updates rank should be equal to indices rank.`); + } + if (!updates.shape.every((value, index) => value === indices.shape[index])) { + throw new Error(`Invalid updates value - updates shape should be same as indices shape.`); + } + if (!Number.isInteger(axis) || axis < 0 || axis >= inputRank) { + throw new Error( + `The axis ${axis} should be an unsigned integer in the interval [0, ${inputRank}).`); + } + const axisSize = input.shape[axis]; + for (let i = 0; i < sizeOfShape(indices.shape); ++i) { + const index = indices.getValueByIndex(i); + if (!Number.isInteger(index) || index < -axisSize || index >= axisSize) { + throw new Error(`Invalid indices value - it should be an integer in the interval ` + + `[${-axisSize}, ${axisSize - 1}]`); + } + } +} + export function validateTriangularParams(input, {diagonal = 0} = {}) { const inputRank = input.rank; if (inputRank < 2) { diff --git a/src/scatter_elements.js b/src/scatter_elements.js new file mode 100644 index 0000000..61b484b --- /dev/null +++ b/src/scatter_elements.js @@ -0,0 +1,33 @@ +'use strict'; + +import {sizeOfShape} from './lib/tensor.js'; +import {validateScatterElementsParams} from './lib/validate-input.js'; +import {identity} from './unary.js'; + +/** + * Create a copy of the input data, and then update its value to values specified by updates at + * specific index positions specified by indices. + * @param {Tensor} input + * @param {Tensor} indices + * @param {Tensor} updates + * @param {MLScatterOptions} [options] + * @return {Tensor} + */ +export function scatterElements(input, indices, updates, {axis = 0} = {}) { + validateScatterElementsParams(...arguments); + const output = identity(input); + + for (let indicesIndex = 0; indicesIndex < sizeOfShape(indices.shape); ++indicesIndex) { + // output[indices[i, j, k, ...], j, k, ...] = updates[i, j, k, ...] // if axis == 0 + // output[i, indices[i, j, k, ...], k, ...] = updates[i, j, k, ...] // if axis == 1 + // output[i, j, indices[i, j, k, ...], ...] = updates[i, j, k, ...] // if axis == 2 + const indicesLocation = indices.locationFromIndex(indicesIndex); + let indicesValue = indices.getValueByIndex(indicesIndex); + indicesValue = indicesValue < 0 ? indicesValue + input.shape[axis] : indicesValue; + const outputLocation = + [...indicesLocation.slice(0, axis), indicesValue, ...indicesLocation .slice(axis + 1)]; + output.setValueByLocation(outputLocation, updates.getValueByIndex(indicesIndex)); + } + + return output; +} diff --git a/src/unary.js b/src/unary.js index a55ee2f..737d190 100644 --- a/src/unary.js +++ b/src/unary.js @@ -45,7 +45,7 @@ export const log = (input) => unary(input, Math.log); export const neg = (input) => unary(input, (x) => -1 * x); export const sin = (input) => unary(input, Math.sin); export const tan = (input) => unary(input, Math.tan); -export const copy = (input) => unary(input, (x) => x); +export const identity = (input) => unary(input, (x) => x); export const reciprocal = (input) => unary(input, (x) => 1 / x); export const sqrt = (input) => unary(input, Math.sqrt); export const erf = (input) => unary(input, erfKernel); diff --git a/test/scatter_elements_test.js b/test/scatter_elements_test.js new file mode 100644 index 0000000..871d47e --- /dev/null +++ b/test/scatter_elements_test.js @@ -0,0 +1,124 @@ +'use strict'; + +import {scatterElements} from '../src/scatter_elements.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test scatterElements', function() { + function testScatterElements(input, indices, updates, expected, options = {}) { + const inputTensor = new Tensor(input.shape, input.data); + const indicesTensor = new Tensor(indices.shape, indices.data); + const updatesTensor = new Tensor(updates.shape, updates.data); + const outputTensor = scatterElements(inputTensor, indicesTensor, updatesTensor, options); + utils.checkShape(outputTensor, expected.shape); + utils.checkValue(outputTensor, expected.data); + } + + it('scatterElements 2D default', function() { + const input = { + shape: [3, 3], + data: [ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + ], + }; + const indices = { + shape: [2, 3], + data: [ + 1, 0, 2, + 0, 2, 1, + ], + }; + const updates = { + shape: [2, 3], + data: [ + 1, 1.1, 1.2, + 2, 2.1, 2.2, + ], + }; + const expected = { + shape: [3, 3], + data: [ + 2, 1.1, 0, + 1, 0, 2.2, + 0, 2.1, 1.2, + ], + }; + testScatterElements(input, indices, updates, expected); + }); + + it('scatterElements 2D, explicit axis=0', function() { + const input = { + shape: [3, 3], + data: [ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + ], + }; + const indices = { + shape: [2, 3], + data: [ + 1, 0, 2, + 0, 2, 1, + ], + }; + const updates = { + shape: [2, 3], + data: [ + 1, 1.1, 1.2, + 2, 2.1, 2.2, + ], + }; + const expected = { + shape: [3, 3], + data: [ + 2, 1.1, 0, + 1, 0, 2.2, + 0, 2.1, 1.2, + ], + }; + testScatterElements(input, indices, updates, expected, {axis: 0}); + }); + + it('scatterElements 2D, axis=1', function() { + const input = { + shape: [1, 5], + data: [1, 2, 3, 4, 5], + }; + const indices = { + shape: [1, 2], + data: [1, 3], + }; + const updates = { + shape: [1, 2], + data: [1.1, 2.1], + }; + const expected = { + shape: [1, 5], + data: [1, 1.1, 3, 2.1, 5], + }; + testScatterElements(input, indices, updates, expected, {axis: 1}); + }); + + it('scatterElements 2D negative indices, axis=1', function() { + const input = { + shape: [1, 5], + data: [1, 2, 3, 4, 5], + }; + const indices = { + shape: [1, 2], + data: [1, -2], + }; + const updates = { + shape: [1, 2], + data: [1.1, 2.1], + }; + const expected = { + shape: [1, 5], + data: [1, 1.1, 3, 2.1, 5], + }; + testScatterElements(input, indices, updates, expected, {axis: 1}); + }); +}); diff --git a/test/unary_test.js b/test/unary_test.js index 6a0d9f3..d392d9d 100644 --- a/test/unary_test.js +++ b/test/unary_test.js @@ -764,16 +764,16 @@ describe('test unary', function() { [3, 2, 2, 1]); }); - it('copy', function() { + it('identity', function() { // 0D scalar - testUnary('copy', [1.4124068], [1.4124068], []); + testUnary('identity', [1.4124068], [1.4124068], []); testUnary( - 'copy', + 'identity', [1.4124068, 1.9740626, -0.06506752, 0.73539704], [1.4124068, 1.9740626, -0.06506752, 0.73539704], [4]); testUnary( - 'copy', + 'identity', [ 1.4124068, 1.9740626, -0.06506752, 0.73539704, -0.56439203, 0.89806247, 0.12939146, -0.34816208, @@ -786,7 +786,7 @@ describe('test unary', function() { ], [3, 4]); testUnary( - 'copy', + 'identity', [ 1.4124068, 1.9740626, -0.06506752, 0.73539704, @@ -805,7 +805,7 @@ describe('test unary', function() { ], [3, 2, 2]); testUnary( - 'copy', + 'identity', [ 1.4124068, 1.9740626, -0.06506752, 0.73539704,