Skip to content

Commit

Permalink
[js/common] allows using Uint16Array as data for float16 tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Feb 26, 2025
1 parent 5be82eb commit 802ea65
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 17 deletions.
9 changes: 3 additions & 6 deletions js/common/lib/tensor-impl-type-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
[Uint32Array, 'uint32'],
]);

// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
}

// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
// polyfill if available.
Expand All @@ -59,6 +53,9 @@ export const checkTypedArray = () => {
isTypedArrayChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;

// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array = (globalThis as any).Float16Array;
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;

if (isBigInt64ArrayAvailable) {
Expand Down
7 changes: 7 additions & 0 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ export class Tensor implements TensorInterface {
} else {
throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`);
}
} else if (arg0 === 'float16' && arg1 instanceof Uint16Array && typedArrayConstructor !== Uint16Array) {
// when Float16Array is available and data is of type Uint16Array.
// We allow Uint16Array to be passed in as data for 'float16' tensor until Float16Array is generally
// supported in JavaScript environment.

// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = new (globalThis as any).Float16Array(arg1.buffer, arg1.byteOffset, arg1.length);
} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
Expand Down
4 changes: 3 additions & 1 deletion js/common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"build": "node ./build.js",
"prepare": "npm run build",
"pretest": "tsc --build ./test",
"test": "mocha ./test/**/*.js --timeout 30000"
"test": "npm run test:default && npm run test:f16",
"test:default": "mocha ./test/**/*.js --timeout 30000",
"test:f16": "node --js-float16array ../node_modules/mocha/bin/mocha.js ./test/**/*.js --timeout 30000"
},
"devDependencies": {
"typedoc": "^0.25.7"
Expand Down
5 changes: 3 additions & 2 deletions js/common/test/unit-tests/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ export const NUMBER_COMPATIBLE_NUMERICAL_TYPES = [
export const BIGINT_TYPES = [['int64', BigInt64Array, true] as const, ['uint64', BigUint64Array, true] as const];

/**
* float16 type, data represented by Uint16Array
* float16 type, data represented by Uint16Array/Float16Array
*/
export const FLOAT16_TYPE = ['float16', Uint16Array, false] as const;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const FLOAT16_TYPE = ['float16', (globalThis as any).Float16Array ?? Uint16Array, false] as const;

/**
* A list of all numerical types.
Expand Down
62 changes: 62 additions & 0 deletions js/common/test/unit-tests/tensor/constructor-f16.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import assert from 'assert/strict';
import { Tensor } from 'onnxruntime-common';

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const globalF16 = (globalThis as any).Float16Array;

(globalF16 ? describe : describe.skip)('Tensor Constructor Tests - check type float16 (Float16Array available)', () => {
it("[float16] new Tensor('float16', numbers, dims): allow number array when Float16Array is available", () => {
const tensor = new Tensor('float16', [1, 2, 3, 4], [2, 2]);
assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'");
assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'");
assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1');
assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2');
assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3');
assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4');
assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4');
});

it("[float16] new Tensor('float16', float16array, dims): allow Float16Array when Float16Array is available", () => {
const tensor = new Tensor('float16', new globalF16([1, 2, 3, 4]), [2, 2]);
assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'");
assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'");
assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1');
assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2');
assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3');
assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4');
assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4');
});

it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array when Float16Array is available", () => {
const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]);
assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'");
assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'");
assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1');
assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2');
assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3');
assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4');
assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4');
});
});

(globalF16 ? describe.skip : describe)(
'Tensor Constructor Tests - check type float16 (Float16Array not available)',
() => {
it(
"[float16] new Tensor('float16', numbers, dims): " +
"expect to throw because it's not allowed to construct 'float16' tensor from number array",
() => {
assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError);
},
);

it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array", () => {
const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]);
assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'");
assert(tensor.data instanceof Uint16Array, "tensor.data should be an instance of 'Uint16Array'");
});
},
);
8 changes: 0 additions & 8 deletions js/common/test/unit-tests/tensor/constructor-type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,6 @@ describe('Tensor Constructor Tests - check types', () => {
assert(tensor.data instanceof Uint8Array, "tensor.data should be an instance of 'Uint8Array'");
});

it(
"[float16] new Tensor('float16', numbers, dims): " +
"expect to throw because it's not allowed to construct 'float16' tensor from number array",
() => {
assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError);
},
);

it("[badtype] new Tensor('a', numbers, dims): expect to throw because 'a' is an invalid type", () => {
assert.throws(() => new TensorAny('a', [1, 2, 3, 4], [2, 2]), TypeError);
});
Expand Down

0 comments on commit 802ea65

Please sign in to comment.