Skip to content

Commit

Permalink
Implement argMax and argMin
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Dec 12, 2023
1 parent a8798d2 commit 9856a81
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 27 deletions.
66 changes: 66 additions & 0 deletions src/arg_max_min.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js';
import {squeeze} from './squeeze.js';

/**
* Get the index location of the minimum or maxmium values of all the input values along the axes.
* @param {Tensor} input
* @param {Function} reduceFunc
* @param {MLArgMinMaxOptions} [options]
* @return {Tensor}
*/
export function argMaxMin(
input,
reduceFunc,
{
axes = null,
keepDimensions = false,
selectLastIndex = false,
} = {}) {
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const outputShape = input.shape.slice();

for (let i = 0; i < inpAxes.length; ++i) {
outputShape[inpAxes[i]] = 1;
}

let output = new Tensor(outputShape);
const tensor = reduceFunc(input, {axes: inpAxes, keepDimensions: true});

for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) {
const value = tensor.getValueByIndex(outputIndex);
const inputLocation = output.locationFromIndex(outputIndex);
const selectedArray = selectValuesToReduce(input, inpAxes, inputLocation);
const index =
selectLastIndex ? selectedArray.lastIndexOf(value) : selectedArray.indexOf(value);
output.setValueByIndex(outputIndex, index);
}

if (!keepDimensions) {
output = squeeze(output);
}

return output;
}

/**
* Get the index location of the maxmium values of all the input values along the axes.
* @param {Tensor} input
* @param {MLArgMinMaxOptions} [options]
* @return {Tensor}
*/
export function argMax(input, options = {}) {
return argMaxMin(input, reduceMax, options);
}

/**
* Get the index location of the minimum values of all the input values along the axes.
* @param {Tensor} input
* @param {MLArgMinMaxOptions} [options]
* @return {Tensor}
*/
export function argMin(input, options = {}) {
return argMaxMin(input, reduceMin, options);
}
2 changes: 1 addition & 1 deletion src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) {
}
}

export function validateReduceParams(input, _, {axes}) {
export function validateReduceParams(input, {axes}) {
if (axes.length > input.rank) {
throw new Error(`The length ${axes.length} of axes is bigger` +
`than input rank ${input.rank}.`);
Expand Down
62 changes: 36 additions & 26 deletions src/reduce.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,40 @@ import {abs, exp, log} from './unary.js';
import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js';
import {validateReduceParams} from './lib/validate-input.js';

export function selectValuesToReduce(input, axes, inputLocation) {
validateReduceParams(input, {axes});

const outputShape = input.shape.slice();
for (let i = 0; i < axes.length; ++i) {
outputShape[axes[i]] = 1;
}

// Calculate the "strides" across the reduction dimensions given in axes.
axes.sort((a, b) => a - b);
const reduceDims = axes.map((axis) => input.shape[axis]);
const reduceElements = sizeOfShape(reduceDims);
const reduceStrides = new Array(axes.length);
reduceStrides[reduceStrides.length - 1] = 1;
for (let i = reduceStrides.length - 2; i >= 0; --i) {
reduceStrides[i] = reduceStrides[i + 1] * reduceDims[i + 1];
}

const valuesToReduce = [];
// Find all values to reduce.
for (let reduceIndex = 0; reduceIndex < reduceElements; ++reduceIndex) {
// Calculate the input location given index of elements to reduce.
let remainingReduceIndex = reduceIndex;
for (let i = 0; i < axes.length; ++i) {
const axis = axes[i];
inputLocation[axis] = Math.floor(remainingReduceIndex / reduceStrides[i]);
remainingReduceIndex -= inputLocation[axis] * reduceStrides[i];
}
valuesToReduce.push(input.getValueByLocation(inputLocation));
}

return valuesToReduce;
}

/**
* Reduce the input along the dimensions given in axes.
* @param {Tensor} input
Expand All @@ -15,39 +49,15 @@ import {validateReduceParams} from './lib/validate-input.js';
*/
function reduce(input, reduceFunc, {keepDimensions = false, axes} = {}) {
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);

const outputShape = input.shape.slice();
for (let i = 0; i < inpAxes.length; ++i) {
outputShape[inpAxes[i]] = 1;
}

validateReduceParams(input, reduceFunc, {keepDimensions, axes: inpAxes});

// Calculate the "strides" across the reduction dimensions given in axes.
inpAxes.sort((a, b) => a - b);
const reduceDims = inpAxes.map((axis) => input.shape[axis]);
const reduceElements = sizeOfShape(reduceDims);
const reduceStrides = new Array(inpAxes.length);
reduceStrides[reduceStrides.length - 1] = 1;
for (let i = reduceStrides.length - 2; i >= 0; --i) {
reduceStrides[i] = reduceStrides[i + 1] * reduceDims[i + 1];
}

let output = new Tensor(outputShape);
for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) {
const valuesToReduce = [];
// Find all values to reduce.
for (let reduceIndex = 0; reduceIndex < reduceElements; ++reduceIndex) {
// Calculate the input location given index of elements to reduce.
const inputLocation = output.locationFromIndex(outputIndex);
let remainingReduceIndex = reduceIndex;
for (let i = 0; i < inpAxes.length; ++i) {
const axis = inpAxes[i];
inputLocation[axis] = Math.floor(remainingReduceIndex / reduceStrides[i]);
remainingReduceIndex -= inputLocation[axis] * reduceStrides[i];
}
valuesToReduce.push(input.getValueByLocation(inputLocation));
}
const inputLocation = output.locationFromIndex(outputIndex);
const valuesToReduce = selectValuesToReduce(input, inpAxes, inputLocation);
const outputValue = valuesToReduce.reduce(reduceFunc);
output.setValueByIndex(outputIndex, outputValue);
}
Expand Down
Loading

0 comments on commit 9856a81

Please sign in to comment.