Skip to content

Commit

Permalink
LLVMCodeBuilder: Add list range checks
Browse files Browse the repository at this point in the history
  • Loading branch information
adazem009 committed Jan 12, 2025
1 parent a3de164 commit 953b0d9
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 20 deletions.
63 changes: 58 additions & 5 deletions src/dev/engine/internal/llvm/llvmcodebuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,25 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
assert(step.args.size() == 1);
const auto &arg = step.args[0];
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());

// Range check
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
llvm::Value *index = castValue(arg.second, arg.first);
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
llvm::BasicBlock *removeBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(inRange, removeBlock, nextBlock);

// Remove
m_builder.SetInsertPoint(removeBlock);
index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
m_builder.CreateCall(resolve_list_remove(), { listPtr.ptr, index });
// NOTE: Removing doesn't deallocate (see List::removeAt()), so there's no need to update the data pointer
m_builder.CreateBr(nextBlock);

m_builder.SetInsertPoint(nextBlock);
break;
}

Expand Down Expand Up @@ -733,11 +749,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
m_builder.CreateStore(m_builder.CreateOr(dataPtrDirty, m_builder.CreateICmpEQ(allocatedSize, size)), listPtr.dataPtrDirty);

// Range check
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
llvm::Value *index = castValue(indexArg.second, indexArg.first);
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLE(index, size));
llvm::BasicBlock *insertBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(inRange, insertBlock, nextBlock);

// Insert
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
m_builder.SetInsertPoint(insertBlock);
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
llvm::Value *itemPtr = m_builder.CreateCall(resolve_list_insert_empty(), { listPtr.ptr, index });
createReusedValueStore(valueArg.second, itemPtr, type, listPtr.type);
m_builder.CreateBr(nextBlock);

m_builder.SetInsertPoint(nextBlock);
break;
}

Expand All @@ -747,9 +775,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
const auto &valueArg = step.args[1];
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());

// Range check
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
llvm::Value *index = castValue(indexArg.second, indexArg.first);
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
llvm::BasicBlock *replaceBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(inRange, replaceBlock, nextBlock);

// Replace
m_builder.SetInsertPoint(replaceBlock);
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
llvm::Value *itemPtr = getListItem(listPtr, index);
createValueStore(valueArg.second, itemPtr, type, listPtr.type);
m_builder.CreateBr(nextBlock);

auto &typeMap = m_scopeLists.back();

Expand All @@ -761,6 +803,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
typeMap[&listPtr] = listPtr.type;
}

m_builder.SetInsertPoint(nextBlock);
break;
}

Expand All @@ -777,8 +820,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
assert(step.args.size() == 1);
const auto &arg = step.args[0];
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
step.functionReturnReg->value = getListItem(listPtr, index);

llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
llvm::Value *index = castValue(arg.second, arg.first);
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));

LLVMConstantRegister nullReg(listPtr.type == Compiler::StaticType::Unknown ? Compiler::StaticType::Number : listPtr.type, Value());
llvm::Value *null = createValue(static_cast<LLVMRegister *>(static_cast<CompilerValue *>(&nullReg)));

index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
step.functionReturnReg->value = m_builder.CreateSelect(inRange, getListItem(listPtr, index), null);
step.functionReturnReg->setType(listPtr.type);
break;
}
Expand Down
117 changes: 102 additions & 15 deletions test/dev/llvm/llvmcodebuilder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1913,9 +1913,21 @@ TEST_F(LLVMCodeBuilderTest, RemoveFromList)
CompilerValue *v = m_builder->addConstValue(1);
m_builder->createListRemove(globalList.get(), v);

v = m_builder->addConstValue(-1);
m_builder->createListRemove(globalList.get(), v);

v = m_builder->addConstValue(3);
m_builder->createListRemove(globalList.get(), v);

v = m_builder->addConstValue(3);
m_builder->createListRemove(localList.get(), v);

v = m_builder->addConstValue(-1);
m_builder->createListRemove(localList.get(), v);

v = m_builder->addConstValue(4);
m_builder->createListRemove(localList.get(), v);

auto code = m_builder->finalize();
Script script(&sprite, nullptr, nullptr);
script.setCode(code);
Expand Down Expand Up @@ -2040,6 +2052,18 @@ TEST_F(LLVMCodeBuilderTest, InsertToList)
v2 = m_builder->addConstValue("hello world");
m_builder->createListInsert(localList.get(), v1, v2);

v1 = m_builder->addConstValue(3);
v2 = m_builder->addConstValue("test");
m_builder->createListInsert(localList.get(), v1, v2);

v1 = m_builder->addConstValue(-1);
v2 = m_builder->addConstValue(123);
m_builder->createListInsert(localList.get(), v1, v2);

v1 = m_builder->addConstValue(6);
v2 = m_builder->addConstValue(123);
m_builder->createListInsert(localList.get(), v1, v2);

auto code = m_builder->finalize();
Script script(&sprite, nullptr, nullptr);
script.setCode(code);
Expand All @@ -2049,7 +2073,7 @@ TEST_F(LLVMCodeBuilderTest, InsertToList)
code->run(ctx.get());

ASSERT_EQ(globalList->toString(), "1 2 1 test 3");
ASSERT_EQ(localList->toString(), "false hello world true");
ASSERT_EQ(localList->toString(), "false hello world true test");
}

TEST_F(LLVMCodeBuilderTest, ListReplace)
Expand Down Expand Up @@ -2099,6 +2123,14 @@ TEST_F(LLVMCodeBuilderTest, ListReplace)
v2 = m_builder->addConstValue("hello world");
m_builder->createListReplace(localList.get(), v1, v2);

v1 = m_builder->addConstValue(-1);
v2 = m_builder->addConstValue(123);
m_builder->createListReplace(localList.get(), v1, v2);

v1 = m_builder->addConstValue(5);
v2 = m_builder->addConstValue(123);
m_builder->createListReplace(localList.get(), v1, v2);

auto code = m_builder->finalize();
Script script(&sprite, nullptr, nullptr);
script.setCode(code);
Expand Down Expand Up @@ -2169,26 +2201,36 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
sprite.setEngine(&m_engine);
EXPECT_CALL(m_engine, stage()).WillRepeatedly(Return(&stage));

std::unordered_map<List *, std::string> strings;

auto globalList = std::make_shared<List>("", "");
stage.addList(globalList);

auto localList = std::make_shared<List>("", "");
sprite.addList(localList);
auto localList1 = std::make_shared<List>("", "");
sprite.addList(localList1);

auto localList2 = std::make_shared<List>("", "");
sprite.addList(localList2);

auto localList3 = std::make_shared<List>("", "");
sprite.addList(localList3);

globalList->append(1);
globalList->append(2);
globalList->append(3);

localList->append("Lorem");
localList->append("ipsum");
localList->append("dolor");
localList->append("sit");
strings[localList.get()] = localList->toString();
localList1->append("Lorem");
localList1->append("ipsum");
localList1->append("dolor");
localList1->append("sit");

localList2->append(-564.121);
localList2->append(4257.4);

localList3->append(true);
localList3->append(false);

createBuilder(&sprite, true);

// Global
CompilerValue *v = m_builder->addConstValue(2);
v = m_builder->addListItem(globalList.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
Expand All @@ -2201,24 +2243,67 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
v = m_builder->addListItem(globalList.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(-1);
v = m_builder->addListItem(globalList.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(3);
v = m_builder->addListItem(globalList.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

// Local 1
v = m_builder->addConstValue(0);
v = m_builder->addListItem(localList.get(), v);
v = m_builder->addListItem(localList1.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(2);
v = m_builder->addListItem(localList.get(), v);
v = m_builder->addListItem(localList1.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(3);
v = m_builder->addListItem(localList.get(), v);
v = m_builder->addListItem(localList1.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(-1);
v = m_builder->addListItem(localList1.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

v = m_builder->addConstValue(4);
v = m_builder->addListItem(localList1.get(), v);
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });

// Local 2
v = m_builder->addConstValue(-1);
v = m_builder->addListItem(localList2.get(), v);
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });

v = m_builder->addConstValue(2);
v = m_builder->addListItem(localList2.get(), v);
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });

// Local 3
v = m_builder->addConstValue(-1);
v = m_builder->addListItem(localList3.get(), v);
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });

v = m_builder->addConstValue(2);
v = m_builder->addListItem(localList3.get(), v);
m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v });

static const std::string expected =
"3\n"
"1\n"
"0\n"
"0\n"
"Lorem\n"
"dolor\n"
"sit\n";
"sit\n"
"0\n"
"0\n"
"0\n"
"0\n"
"0\n"
"0\n";

auto code = m_builder->finalize();
Script script(&sprite, nullptr, nullptr);
Expand All @@ -2231,7 +2316,9 @@ TEST_F(LLVMCodeBuilderTest, GetListItem)
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);

ASSERT_EQ(globalList->toString(), "1 test 3");
ASSERT_EQ(localList->toString(), "Lorem ipsum dolor sit");
ASSERT_EQ(localList1->toString(), "Lorem ipsum dolor sit");
ASSERT_EQ(localList2->toString(), "-564.121 4257.4");
ASSERT_EQ(localList3->toString(), "true false");
}

TEST_F(LLVMCodeBuilderTest, GetListSize)
Expand Down

0 comments on commit 953b0d9

Please sign in to comment.