Skip to content

Commit

Permalink
response to comments
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 3, 2024
1 parent dab7633 commit eeea5cd
Showing 1 changed file with 1 addition and 76 deletions.
77 changes: 1 addition & 76 deletions src/Conversion/ONNXToKrnl/Math/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc,
int64_t totVL =
computeSuitableUnrollFactor(inputType, collapsedInnermostLoops, mix,
canOverCompute, simdLoopStaticTripCount, simdOnly);
#if 1 // new approach that enables parallelism
// Test if loop trip count is long enough for a parallel execution.
if (enableParallel) {
int64_t parId;
Expand Down Expand Up @@ -454,6 +453,7 @@ bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc,
// where each (possibly virtual) thread is responsible for one chunk.
// Second round computes the final reduction done by one thread.

// TODO: this should not be hardwired but gotten from an option.
int64_t tNum = 8;

// Round 1.
Expand Down Expand Up @@ -506,81 +506,6 @@ bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc,
output2, tmp1, tmp2, alloc1, alloc2, divisorForMean);
}

#else // old approach
// Compute type of small temporary reduction vector.
MemRefType outputType = MemRefType::get({}, elementType);
MemRefType redType = MemRefType::get({totVL}, elementType);
VectorType vecType = VectorType::get({totVL}, elementType);
SmallVector<Value, 2> inputs, tmps, outputs, initVals;
SmallVector<DimsExpr, 2> inputAFs, tmpAFs, outputAFs;
DimsExpr emptyAF;
DimsExpr zeroAF(1, zero);
// Initialize data for input (same for both reduction).
inputs.emplace_back(flatInput);
inputAFs.emplace_back(zeroAF);
// Init data for 1st reduction: tmp (redAlloc), output (alloc), and init.
tmps.emplace_back(create.mem.alignedAlloc(redType));
tmpAFs.emplace_back(zeroAF);
/*output*/ alloc1 = create.mem.alloc(outputType);
outputs.emplace_back(alloc1);
outputAFs.emplace_back(emptyAF);
initVals.emplace_back(getIdentityValue<ONNXReductionOp1>(
rewriter, create.getLoc(), elementType));
// Init data for 2nd reduction.
alloc2 = nullptr;
if (hasTwoRed) {
tmps.emplace_back(create.mem.alignedAlloc(redType));
tmpAFs.emplace_back(zeroAF);
/*output*/ alloc2 = create.mem.alloc(outputType);
outputs.emplace_back(alloc2);
outputAFs.emplace_back(emptyAF);
initVals.emplace_back(getIdentityValue<ONNXReductionOp2>(
rewriter, create.getLoc(), elementType));
}
create.krnl.simdReduceIE(
lb, ub, totVL, simdOnly, inputs, inputAFs, tmps, tmpAFs, outputs,
outputAFs, initVals,
/* reduction function */
[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals,
ArrayRef<Value> tmpVals, llvm::SmallVectorImpl<Value> &resultVals,
int64_t VL) {
Type type = (VL > 1) ? vecType : elementType;
// First reduction, enqueue result.
Value accumulatedVec1 = emitScalarOpFor<ONNXReductionOp1>(
rewriter, create.getLoc(), op, type, {tmpVals[0], inputVals[0]});
resultVals.emplace_back(accumulatedVec1);
if (hasTwoRed) {
// Has a second reduction, also enqueue result.
Value accumulatedVec2 = emitScalarOpFor<ONNXReductionOp2>(
rewriter, create.getLoc(), op, type, {tmpVals[1], inputVals[0]});
resultVals.emplace_back(accumulatedVec2);
}
},
/* post reduction function*/
[&](const KrnlBuilder &kb, ArrayRef<Value> tmpVals,
llvm::SmallVectorImpl<Value> &scalarOutputs, int64_t VL) {
// Perform horizontal reductions.
Value res1 = create.vec.reduction(
getCombiningKind<ONNXReductionOp1>(), tmpVals[0]);
scalarOutputs.emplace_back(res1);
if (hasTwoRed) {
Value res2 = create.vec.reduction(
getCombiningKind<ONNXReductionOp2>(), tmpVals[1]);
scalarOutputs.emplace_back(res2);
}
// Handle means if any.
if (divideByMean<ONNXReductionOp1>())
scalarOutputs[0] = create.math.div(scalarOutputs[0], divisorForMean);
if (hasTwoRed && divideByMean<ONNXReductionOp2>())
scalarOutputs[1] = create.math.div(scalarOutputs[1], divisorForMean);
});
#endif

if (hasTwoRed)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "fused reduction to a scalar");
Expand Down

0 comments on commit eeea5cd

Please sign in to comment.