diff --git a/src/expand.js b/src/expand.js index cd3759a..2319adf 100644 --- a/src/expand.js +++ b/src/expand.js @@ -1,7 +1,7 @@ 'use strict'; import {broadcast, getBroadcastShape} from './lib/broadcast.js'; -import {Tensor} from '../src/lib/tensor.js'; +import {Scalar} from '../src/lib/tensor.js'; /** * Expand any dimension of size 1 of the input tensor to a @@ -10,14 +10,10 @@ import {Tensor} from '../src/lib/tensor.js'; * @param {Array} newShape * @return {Tensor} */ + + export function expand(input, newShape) { - if (input.shape.length === 0) { - const inputReshape = new Tensor([1], input.data); - const outputShape = getBroadcastShape(inputReshape.shape, newShape); - return broadcast(inputReshape, outputShape); - } else { - const inputReshape = new Tensor(input.shape, input.data); - const outputShape = getBroadcastShape(inputReshape.shape, newShape); - return broadcast(inputReshape, outputShape); - } + const inputReshape = input.shape.length === 0 ? new Scalar(input.data) : input; + const outputShape = getBroadcastShape(inputReshape.shape, newShape); + return broadcast(inputReshape, outputShape); }