Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cast #63

Merged
merged 14 commits into from
Dec 21, 2023
1 change: 1 addition & 0 deletions .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module.exports = {
'chai': 'readonly',
'BigInt': 'readonly',
'BigInt64Array': 'readonly',
'BigUint64Array': 'readonly',
},
rules: {
'semi': 'error',
Expand Down
46 changes: 46 additions & 0 deletions src/cast.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
'use strict';

import {Tensor} from '../src/lib/tensor.js';

/**
* Cast each element in the input tensor to the target data type.
* @param {Tensor} input
* @return {Tensor}
*/

export function cast(input, type) {
let outputArray;
switch (type) {
case 'int8':
outputArray = new Int8Array(Array.from(input.data, (num) => (Math.round(num))));
break;
case 'uint8':
outputArray = new Uint8Array(Array.from(input.data, (num) => (Math.round(num))));
break;
case 'int32':
outputArray = new Int32Array(Array.from(input.data, (num) => (Math.round(num))));
break;
case 'uint32':
outputArray = new Uint32Array(Array.from(input.data, (num) => (Math.round(num))));
Copy link

@fdwr fdwr Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncation toward zero is the expected casting behavior for float to int, not rounding to nearest. That's consistent with existing ML libraries I know of. e.g. float32 3.7 -> uint32 3. Isn't the default behavior of Uint32Array already the expected truncation toward zero behavior? Otherwise we need Math.trunc().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry for the misunderstanding. I've revised the code :)

break;
case 'int64':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this error "TypeError: Cannot convert a BigInt value to a number" of CI.

outputArray = new BigInt64Array(Array.from(input.data, (num) => BigInt(Math.round(num))));
break;
case 'float32':
outputArray = new Float32Array(input.data);
break;
case 'float64':
outputArray = new Float64Array(input.data);
break;
case 'float16':
// todo
huningxin marked this conversation as resolved.
Show resolved Hide resolved
throw new Error('Unsupported output type: float16' );
case 'uint64':
// todo
huningxin marked this conversation as resolved.
Show resolved Hide resolved
throw new Error('Unsupported output type: uint64' );
default:
throw new Error('Unsupported output type: ' + type);
}
const output = new Tensor(input.shape, outputArray);
return output;
}
175 changes: 175 additions & 0 deletions test/cast_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
'use strict';

import {Tensor} from '../src/lib/tensor.js';
import {cast} from '../src/cast.js';
import * as utils from './utils.js';

describe('test cast', function() {
const InputDataType = {
int8: Int8Array,
uint8: Uint8Array,
int32: Int32Array,
uint32: Uint32Array,
int64: BigInt64Array,
float32: Float32Array,
float64: Float64Array,
};
function testCast(input, type, expected) {
let tensorInput;
if (input.type) {
tensorInput = new Tensor(input.shape, new InputDataType[input.type](input.data));
} else {
tensorInput = new Tensor(input.shape, input.data);
}
const outputTensor = cast(tensorInput, type);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('cast float64 to int8', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.75, 14, -14,
],
};
const expected = {
shape: [5],
data: [
0, 0, 4, 14, -14,
],
};
testCast(input, 'int8', expected);
});

it('cast float64 to uint8', function() {
const input = {
shape: [5],
data: [
0.25, 0.75, 3.75, 14, 15,
],
};
const expected = {
shape: [5],
data: [
0, 1, 4, 14, 15,
],
};
testCast(input, 'uint8', expected);
});

it('cast float64 to int32', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
0, 0, 3, 1234, -1234,
],
};
testCast(input, 'int32', expected);
});

it('cast float64 to uint32', function() {
const input = {
shape: [5],
data: [
0.75, 0.25, 3.21, 14, 15,
],
};
const expected = {
shape: [5],
data: [
1, 0, 3, 14, 15,
],
};
testCast(input, 'uint32', expected);
});

it('cast float64 to int64', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
0n, 0n, 3n, 1234n, -1234n,
],
};
testCast(input, 'int64', expected);
});

it('cast float64 to float32', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
-0.25, 0.25, 3.2100000381469727, 1234, -1234,
],
};
testCast(input, 'float32', expected);
});

it('cast int32 to float32', function() {
const input = {
shape: [5],
data: [
0, 1, -2, -3, 3,
],
type: 'int32',
};
const expected = {
shape: [5],
data: [
0, 1, -2, -3, 3,
],
};
testCast(input, 'float32', expected);
});

it('cast uint32 to float64', function() {
const input = {
shape: [5],
data: [
0, 1, 22, 33, 33,
],
type: 'uint32',
};
const expected = {
shape: [5],
data: [
0, 1, 22, 33, 33,
],
};
testCast(input, 'float64', expected);
});

it('cast float32 to float64', function() {
const input = {
shape: [5],
data: [
0, 0.1, 0.2, -300, 993,
],
type: 'float32',
};
const expected = {
shape: [5],
data: [
0, 0.10000000149011612, 0.20000000298023224, -300, 993,
],
};
testCast(input, 'float64', expected);
});
});
16 changes: 11 additions & 5 deletions test/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ function getBitwise(value) {
* false: The distance between a and b is far away from the given ULP distance.
*/
assert.isAlmostEqualUlp = function(a, b, nulp, message) {
const aBitwise = getBitwise(a);
const bBitwise = getBitwise(b);
let distance = aBitwise - bBitwise;
distance = distance >= 0 ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
if (typeof(a) == 'number') {
const aBitwise = getBitwise(a);
const bBitwise = getBitwise(b);
let distance = aBitwise - bBitwise;
distance = distance >= 0 ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
} else {
let distance = a - b;
distance = distance >= 0n ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
}
};

export function checkValue(tensor, expected, nulp = 0) {
Expand Down