From 446e0333f97e208448407a82bcc35200d24b21d7 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Mon, 18 Dec 2023 16:11:07 +0800 Subject: [PATCH] Implement triangular --- src/lib/validate-input.js | 14 ++++ src/triangular.js | 38 +++++++++ test/triangular_test.js | 157 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 src/triangular.js create mode 100644 test/triangular_test.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 7442b8a..218a99e 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -444,3 +444,17 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) { } } } + +export function validateTriangularParams(input, {diagonal = 0} = {}) { + const inputRank = input.rank; + if (inputRank !== 2) { + throw new Error('The input should be a 2-D tensor.'); + } + const [i, j] = input.shape; + if (i !== j) { + throw new Error('The input should be a 2-D tensor of [N, N] shape.'); + } + if (diagonal >= i || diagonal <= -i) { + throw new Error(`The diagonal should be in in the range [${1 - i}, ${i - 1}].`); + } +} diff --git a/src/triangular.js b/src/triangular.js new file mode 100644 index 0000000..56fb81c --- /dev/null +++ b/src/triangular.js @@ -0,0 +1,38 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateTriangularParams} from './lib/validate-input.js'; + +/** + * Get retained boolean flag. + * @param {Array} location + * @param {Boolean} upper + * @param {Number} diagonal + * @return {Boolean} + */ +function isRetainedValue(location, upper, diagonal) { + const [i, j] = location; + return upper ? j >= i + diagonal : j <= i + diagonal; +} + +/** + * Given a 2-D tensor (matrix), return a 2-D tensor containing either the upper or lower triangular + * part of the input tensor. + * @param {Tensor} input + * @param {MLTriangularOptions} [options] + * @return {Tensor} + */ +export function triangular(input, {upper = true, diagonal = 0} = {}) { + validateTriangularParams(...arguments); + const shapeOutput = input.shape.slice(); + const output = new Tensor(shapeOutput); + + for (let outputIndex = 0; outputIndex < sizeOfShape(shapeOutput); ++outputIndex) { + const outputLoc = output.locationFromIndex(outputIndex); + const retainedFlag = isRetainedValue(outputLoc, upper, diagonal); + const inputValue = retainedFlag ? input.getValueByLocation(outputLoc) : 0; + output.setValueByLocation(outputLoc, inputValue); + } + + return output; +} diff --git a/test/triangular_test.js b/test/triangular_test.js new file mode 100644 index 0000000..4c29eb6 --- /dev/null +++ b/test/triangular_test.js @@ -0,0 +1,157 @@ +'use strict'; + +import {triangular} from '../src/triangular.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test triangular', function() { + function testTriangular(input, expected, options = {}) { + const x = new Tensor(input.shape, input.data); + const y = triangular(x, options); + utils.checkShape(y, expected.shape); + utils.checkValue(y, expected.data); + } + + it('triangular default', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 7, 1, 2, + 0, 4, 8, + 0, 0, 3, + ], + }, + ); + }); + + it('triangular diagonal=1', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 0, 1, 2, + 0, 0, 8, + 0, 0, 0, + ], + }, + { + diagonal: 1, + }, + ); + }); + + it('triangular diagonal=-1', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 0, 6, 3, + ], + }, + { + diagonal: -1, + }, + ); + }); + + it('triangular upper=false', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 7, 0, 0, + 9, 4, 0, + 2, 6, 3, + ], + }, + { + upper: false, + }, + ); + }); + + it('triangular upper=false diagonal=1', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 7, 1, 0, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + upper: false, + diagonal: 1, + }, + ); + }); + + it('triangular upper=false diagonal=-1', function() { + testTriangular( + { + shape: [3, 3], + data: [ + 7, 1, 2, + 9, 4, 8, + 2, 6, 3, + ], + }, + { + shape: [3, 3], + data: [ + 0, 0, 0, + 9, 0, 0, + 2, 6, 0, + ], + }, + { + upper: false, + diagonal: -1, + }, + ); + }); +});