Skip to content

Commit

Permalink
Add a unit test checking proper custom pattern matching
Browse files Browse the repository at this point in the history
The test should covers basic pattern matching, including the
skipping of module instructions not present in the pattern.
  • Loading branch information
zacikpa committed Nov 9, 2023
1 parent 1e4f4d1 commit 07b6004
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions tests/unit_tests/simpll/DifferentialFunctionComparatorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include <Config.h>
#include <CustomPatternSet.h>
#include <DebugInfo.h>
#include <DifferentialFunctionComparator.h>
#include <ModuleComparator.h>
Expand Down Expand Up @@ -2051,3 +2052,93 @@ TEST_F(DifferentialFunctionComparatorTest, ReorderedBinaryOperationNeedLeaf) {

ASSERT_EQ(DiffComp->compare(), 0);
}

TEST_F(DifferentialFunctionComparatorTest, CustomPatternSkippingInstruction) {
// Test custom pattern matching and skipping of instructions therein.
//
// ; Old side of the pattern:
// define void @diffkemp.old.pattern() {
// %1 = sub i8 0, 1
// ret void
// }
//
// ; New side of the pattern:
// define void @diffkemp.new.pattern() {
// %1 = sub i8 1, 0
// %2 = sdiv i8 %1, %1
// ret void
// }
//
// ; Old compared function:
// define void @old.function() {
// %1 = sub i8 0, 1 ; matched
// %2 = add i8 0, 1 ; skipped
// ret void
// }
//
// ; New compared function:
// define void @new.function() {
// %1 = sub i8 1, 0 ; matched
// %2 = add i8 0, 1 ; skipped
// %3 = sdiv i8 %1, %1 ; matched
// ret void
// }

LLVMContext PatCtx;
auto PatMod = std::make_unique<Module>("PatternMod", PatCtx);

auto PatFL = Function::Create(
FunctionType::get(Type::getVoidTy(PatCtx), {}, false),
GlobalValue::ExternalLinkage,
"diffkemp.old.pattern",
PatMod.get());
auto PatFR = Function::Create(
FunctionType::get(Type::getVoidTy(PatCtx), {}, false),
GlobalValue::ExternalLinkage,
"diffkemp.new.pattern",
PatMod.get());

BasicBlock *PatBBL = BasicBlock::Create(PatCtx, "", PatFL);
BasicBlock *PatBBR = BasicBlock::Create(PatCtx, "", PatFR);

Constant *PatConstL1 = ConstantInt::get(Type::getInt8Ty(PatCtx), 0);
Constant *PatConstL2 = ConstantInt::get(Type::getInt8Ty(PatCtx), 1);
Constant *PatConstR1 = ConstantInt::get(Type::getInt8Ty(PatCtx), 0);
Constant *PatConstR2 = ConstantInt::get(Type::getInt8Ty(PatCtx), 1);

BinaryOperator::Create(
BinaryOperator::Sub, PatConstL1, PatConstL2, "", PatBBL);
auto PatSubR = BinaryOperator::Create(
BinaryOperator::Sub, PatConstR2, PatConstR1, "", PatBBR);

BinaryOperator::Create(BinaryOperator::SDiv, PatSubR, PatSubR, "", PatBBR);

ReturnInst::Create(PatCtx, nullptr, PatBBL);
ReturnInst::Create(PatCtx, nullptr, PatBBR);

CustomPatternSet PatSet;
PatSet.addPattern(std::move(PatMod));
DiffComp->addCustomPatternSet(&PatSet);

BasicBlock *BBL = BasicBlock::Create(CtxL, "", FL);
BasicBlock *BBR = BasicBlock::Create(CtxR, "", FR);

Constant *ConstL1 = ConstantInt::get(Type::getInt8Ty(CtxL), 0);
Constant *ConstL2 = ConstantInt::get(Type::getInt8Ty(CtxL), 1);
Constant *ConstR1 = ConstantInt::get(Type::getInt8Ty(CtxR), 0);
Constant *ConstR2 = ConstantInt::get(Type::getInt8Ty(CtxR), 1);

BinaryOperator::Create(BinaryOperator::Sub, ConstL1, ConstL2, "", BBL);
auto SubR = BinaryOperator::Create(
BinaryOperator::Sub, ConstR2, ConstR1, "", BBR);

BinaryOperator::Create(BinaryOperator::Add, ConstL1, ConstL2, "", BBL);
BinaryOperator::Create(BinaryOperator::Add, ConstR1, ConstR2, "", BBR);

BinaryOperator::Create(BinaryOperator::SDiv, SubR, SubR, "", BBR);

ReturnInst::Create(CtxL, nullptr, BBL);
ReturnInst::Create(CtxR, nullptr, BBR);

ASSERT_EQ(DiffComp->compare(), 0);
}

0 comments on commit 07b6004

Please sign in to comment.