Skip to content

Commit

Permalink
Implement triangular
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Dec 18, 2023
1 parent 058aa77 commit 446e033
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,17 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}
}

export function validateTriangularParams(input, {diagonal = 0} = {}) {
const inputRank = input.rank;
if (inputRank !== 2) {
throw new Error('The input should be a 2-D tensor.');
}
const [i, j] = input.shape;
if (i !== j) {
throw new Error('The input should be a 2-D tensor of [N, N] shape.');
}
if (diagonal >= i || diagonal <= -i) {
throw new Error(`The diagonal should be in in the range [${1 - i}, ${i - 1}].`);
}
}
38 changes: 38 additions & 0 deletions src/triangular.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateTriangularParams} from './lib/validate-input.js';

/**
* Get retained boolean flag.
* @param {Array} location
* @param {Boolean} upper
* @param {Number} diagonal
* @return {Boolean}
*/
function isRetainedValue(location, upper, diagonal) {
const [i, j] = location;
return upper ? j >= i + diagonal : j <= i + diagonal;
}

/**
* Given a 2-D tensor (matrix), return a 2-D tensor containing either the upper or lower triangular
* part of the input tensor.
* @param {Tensor} input
* @param {MLTriangularOptions} [options]
* @return {Tensor}
*/
export function triangular(input, {upper = true, diagonal = 0} = {}) {
validateTriangularParams(...arguments);
const shapeOutput = input.shape.slice();
const output = new Tensor(shapeOutput);

for (let outputIndex = 0; outputIndex < sizeOfShape(shapeOutput); ++outputIndex) {
const outputLoc = output.locationFromIndex(outputIndex);
const retainedFlag = isRetainedValue(outputLoc, upper, diagonal);
const inputValue = retainedFlag ? input.getValueByLocation(outputLoc) : 0;
output.setValueByLocation(outputLoc, inputValue);
}

return output;
}
157 changes: 157 additions & 0 deletions test/triangular_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
'use strict';

import {triangular} from '../src/triangular.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test triangular', function() {
function testTriangular(input, expected, options = {}) {
const x = new Tensor(input.shape, input.data);
const y = triangular(x, options);
utils.checkShape(y, expected.shape);
utils.checkValue(y, expected.data);
}

it('triangular default', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
7, 1, 2,
0, 4, 8,
0, 0, 3,
],
},
);
});

it('triangular diagonal=1', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
0, 1, 2,
0, 0, 8,
0, 0, 0,
],
},
{
diagonal: 1,
},
);
});

it('triangular diagonal=-1', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
0, 6, 3,
],
},
{
diagonal: -1,
},
);
});

it('triangular upper=false', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
7, 0, 0,
9, 4, 0,
2, 6, 3,
],
},
{
upper: false,
},
);
});

it('triangular upper=false diagonal=1', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
7, 1, 0,
9, 4, 8,
2, 6, 3,
],
},
{
upper: false,
diagonal: 1,
},
);
});

it('triangular upper=false diagonal=-1', function() {
testTriangular(
{
shape: [3, 3],
data: [
7, 1, 2,
9, 4, 8,
2, 6, 3,
],
},
{
shape: [3, 3],
data: [
0, 0, 0,
9, 0, 0,
2, 6, 0,
],
},
{
upper: false,
diagonal: -1,
},
);
});
});

0 comments on commit 446e033

Please sign in to comment.