Skip to content

Commit

Permalink
use for-loop
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Jan 29, 2025
1 parent a65cd9e commit 888d015
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 18 deletions.
35 changes: 17 additions & 18 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,31 +998,30 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
auto lhs = gen(bop->lhs());

if (print_inline_) {
if (exponent == 1) {
for (int i = 0; i < exponent; ++i) {
if (i != 0) {
code_ << " * ";
}
code_ << lhs;
} else if (exponent == 2) {
code_ << lhs << " * " << lhs;
} else if (exponent == 3) {
code_ << lhs << " * " << lhs << " * " << lhs;
}
} else {
indent() << gen(bop->out());
if (bop->out()->isScalar()) {
if (exponent == 1) {
code_ << " = " << lhs;
} else if (exponent == 2) {
code_ << " = " << lhs << " * " << lhs;
} else if (exponent == 3) {
code_ << " = " << lhs << " * " << lhs << " * " << lhs;
for (int i = 0; i < exponent; ++i) {
if (i == 0) {
code_ << " = " << lhs;
} else {
code_ << " * " << lhs;
}
}
} else {
code_ << "\n";
if (exponent == 1) {
indent() << kTab << "= " << lhs;
} else if (exponent == 2) {
indent() << kTab << "= " << lhs << "\n * " << lhs;
} else if (exponent == 3) {
indent() << kTab << "= " << lhs << "\n * " << lhs << "\n * " << lhs;
for (int i = 0; i < exponent; ++i) {
if (i == 0) {
code_ << "\n";
indent() << kTab << "= " << lhs;
} else {
indent() << "\n" << kTab << "* " << lhs;
}
}
}
}
Expand Down
59 changes: 59 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9413,6 +9413,65 @@ TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) {
}
}


TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({16, 32});
fusion.addInput(tv0);
auto tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
auto tv2 = makeSymbolicTensor(1);
fusion.addInput(tv2);

auto tv3 = set(tv2);
auto tv4 = broadcast(tv3, {false, true});
auto tv5 = add(tv1, tv4);
auto tv6 = add(tv0, tv5);
fusion.addOutput(tv6);

// Make the loop domains of tv3 and tv4 exact mapped with tv1's loop
// domain
scheduler_tools::scheduleLoopDomainsLike({tv3, tv4}, tv1->getLoopDomain());

EXPECT_TRUE(fusion.hasRegisteredExactMappings());

// tv3 and tv4 should have new cloned IDs that are exact mapped with
// tv1
auto registered_mappings = fusion.registeredExactMappings();
auto registered_mappings_it = registered_mappings.find(tv3->axis(1));
EXPECT_NE(registered_mappings_it, registered_mappings.end());
const auto& registered_ids = registered_mappings_it->second;
EXPECT_TRUE(registered_ids->has(tv4->axis(1)));
EXPECT_TRUE(registered_ids->has(tv1->axis(1)));

{
IdModel id_model(&fusion, /*build_graphs=*/false);
const auto& exact_graph = id_model.buildExactGraph();
for (auto tv : {tv3, tv4}) {
EXPECT_EQ(
exact_graph.toGroups(tv->getLoopDomain()),
exact_graph.toGroups(tv1->getLoopDomain()));
}
}

// tv0 and tv1 are exact mapped. Since tv0 has static extents,
// replaceSymbolicSizes will replace the symbolic extents of tv1 and tv2
// with the static extents of tv0.
replaceSymbolicSizes(&fusion);

// Check if the exact mapping is still alive
{
IdModel id_model(&fusion, /*build_graphs=*/false);
const auto& exact_graph = id_model.buildExactGraph();
for (auto tv : {tv3, tv4}) {
EXPECT_EQ(
exact_graph.toGroups(tv->getLoopDomain()),
exact_graph.toGroups(tv1->getLoopDomain()));
}
}
}
// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser

0 comments on commit 888d015

Please sign in to comment.