Skip to content

Commit

Permalink
modified expand.js
Browse files Browse the repository at this point in the history
  • Loading branch information
mei1127 committed Nov 27, 2023
1 parent 3feda15 commit 6026dba
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/expand.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'use strict';

import {broadcast, getBroadcastShape} from './lib/broadcast.js';
import {Tensor} from '../src/lib/tensor.js';
import {Scalar} from '../src/lib/tensor.js';

/**
* Expand any dimension of size 1 of the input tensor to a
Expand All @@ -10,14 +10,10 @@ import {Tensor} from '../src/lib/tensor.js';
* @param {Array} newShape
* @return {Tensor}
*/


export function expand(input, newShape) {
if (input.shape.length === 0) {
const inputReshape = new Tensor([1], input.data);
const outputShape = getBroadcastShape(inputReshape.shape, newShape);
return broadcast(inputReshape, outputShape);
} else {
const inputReshape = new Tensor(input.shape, input.data);
const outputShape = getBroadcastShape(inputReshape.shape, newShape);
return broadcast(inputReshape, outputShape);
}
const inputReshape = input.shape.length === 0 ? new Scalar(input.data) : input;
const outputShape = getBroadcastShape(inputReshape.shape, newShape);
return broadcast(inputReshape, outputShape);
}

0 comments on commit 6026dba

Please sign in to comment.