Skip to content

Commit

Permalink
optimized axis validation and outputLocation setting
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Oct 8, 2024
1 parent d0807b9 commit 8b20da9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
8 changes: 3 additions & 5 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,9 @@ export function validateScatterElementsParams(input, indices, updates, {axis = 0
if (!updates.shape.every((value, index) => value === indices.shape[index])) {
throw new Error(`Invalid updates value - updates shape should be same as indices shape.`);
}
if (axis !== undefined) {
if (!Number.isInteger(axis) || axis < 0 || axis >= inputRank) {
throw new Error(
`The axis ${axis} should be an unsigned integer in the interval [0, ${inputRank}).`);
}
if (!Number.isInteger(axis) || axis < 0 || axis >= inputRank) {
throw new Error(
`The axis ${axis} should be an unsigned integer in the interval [0, ${inputRank}).`);
}
const axisSize = input.shape[axis];
for (let i = 0; i < sizeOfShape(indices.shape); ++i) {
Expand Down
8 changes: 4 additions & 4 deletions src/scatter_elements.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ export function scatterElements(input, indices, updates, {axis = 0} = {}) {
// output[i, indices[i, j, k, ...], k, ...] = updates[i, j, k, ...] // if axis == 1
// output[i, j, indices[i, j, k, ...], ...] = updates[i, j, k, ...] // if axis == 2
const indicesLocation = indices.locationFromIndex(indicesIndex);
let indiceValue = indices.getValueByIndex(indicesIndex);
indiceValue = indiceValue < 0 ? indiceValue + input.shape[axis] : indiceValue;
const outputLocation = indicesLocation.slice();
outputLocation[axis] = indiceValue;
let indicesValue = indices.getValueByIndex(indicesIndex);
indicesValue = indicesValue < 0 ? indicesValue + input.shape[axis] : indicesValue;
const outputLocation =
[...indicesLocation.slice(0, axis), indicesValue, ...indicesLocation .slice(axis + 1)];
output.setValueByLocation(outputLocation, updates.getValueByIndex(indicesIndex));
}

Expand Down

0 comments on commit 8b20da9

Please sign in to comment.