Skip to content

Commit

Permalink
Update OpenMP reduction detection for new ops
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii committed Feb 7, 2025
1 parent 56c00d3 commit b58b1ee
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
9 changes: 5 additions & 4 deletions codon/cir/transform/parallel/openmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
static void extractAssociativeOpChain(Value *v, const std::string &op,
types::Type *type,
std::vector<Value *> &result) {
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
auto *call = cast<CallInstr>(v);
extractAssociativeOpChain(call->front(), op, type, result);
extractAssociativeOpChain(call->back(), op, type, result);
Expand Down Expand Up @@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {

for (auto &rf : reductionFunctions) {
if (rf.method) {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
continue;
} else {
if (!util::isCallOf(item, rf.name,
Expand All @@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {

if (rf.method) {
std::vector<Value *> opChain;
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(),
opChain);
extractAssociativeOpChain(callRHS, rf.name, type, opChain);
if (opChain.size() < 2)
continue;

Expand Down
13 changes: 9 additions & 4 deletions codon/cir/util/irtools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,

unsigned i = 0;
for (auto *arg : *call) {
if (!arg->getType()->is(inputs[i++]))
if (inputs[i] && !arg->getType()->is(inputs[i]))
return false;
++i;
}

if (output && !value->getType()->is(output))
return false;

if (method &&
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0])))
return false;
if (method) {
if (inputs.empty() || !fn->getParentType())
return false;

if (inputs[0] && !fn->getParentType()->is(inputs[0]))
return false;
}

return true;
}
Expand Down
24 changes: 24 additions & 0 deletions test/transform/omp.codon
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,18 @@ def test_omp_reductions():
c = min(b, c)
assert c == -1.

c = 0.
@par
for i in L:
c += i # float-int op
assert c == expected(N, 0., float.__add__)

c = 0.
@par
for i in L:
c = i + c # int-float op
assert c == expected(N, 0., float.__add__)

# float32s
c = f32(0.)
# this one can give different results due to
Expand Down Expand Up @@ -479,6 +491,18 @@ def test_omp_reductions():
c = min(b, c)
assert c == f32(-1.)

c = f32(0.)
@par
for i in L[:12]:
c += i # float-int op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)

c = f32(0.)
@par
for i in L[:12]:
c = i + c # int-float op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)

x_add = 10.
x_min = inf
x_max = -inf
Expand Down

0 comments on commit b58b1ee

Please sign in to comment.