From 953b0d9e2e50acd338d4688b359139a73a9d4528 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sun, 12 Jan 2025 01:00:46 +0100 Subject: [PATCH] LLVMCodeBuilder: Add list range checks --- .../engine/internal/llvm/llvmcodebuilder.cpp | 63 +++++++++- test/dev/llvm/llvmcodebuilder_test.cpp | 117 +++++++++++++++--- 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp index 0404f448..8a8f9fde 100644 --- a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp @@ -660,9 +660,25 @@ std::shared_ptr LLVMCodeBuilder::finalize() assert(step.args.size() == 1); const auto &arg = step.args[0]; const LLVMListPtr &listPtr = m_listPtrs[step.workList]; - llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty()); + + // Range check + llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0)); + llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); + size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); + llvm::Value *index = castValue(arg.second, arg.first); + llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + llvm::BasicBlock *removeBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(inRange, removeBlock, nextBlock); + + // Remove + m_builder.SetInsertPoint(removeBlock); + index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty()); m_builder.CreateCall(resolve_list_remove(), { listPtr.ptr, index }); // NOTE: Removing doesn't deallocate (see List::removeAt()), so there's no need to update the data pointer + m_builder.CreateBr(nextBlock); + + m_builder.SetInsertPoint(nextBlock); break; } @@ -733,11 +749,23 @@ std::shared_ptr LLVMCodeBuilder::finalize() llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); m_builder.CreateStore(m_builder.CreateOr(dataPtrDirty, m_builder.CreateICmpEQ(allocatedSize, size)), listPtr.dataPtrDirty); + // Range check + llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0)); + size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); + llvm::Value *index = castValue(indexArg.second, indexArg.first); + llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLE(index, size)); + llvm::BasicBlock *insertBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(inRange, insertBlock, nextBlock); + // Insert - llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty()); + m_builder.SetInsertPoint(insertBlock); + index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); llvm::Value *itemPtr = m_builder.CreateCall(resolve_list_insert_empty(), { listPtr.ptr, index }); createReusedValueStore(valueArg.second, itemPtr, type, listPtr.type); + m_builder.CreateBr(nextBlock); + m_builder.SetInsertPoint(nextBlock); break; } @@ -747,9 +775,23 @@ std::shared_ptr LLVMCodeBuilder::finalize() const auto &valueArg = step.args[1]; Compiler::StaticType type = optimizeRegisterType(valueArg.second); LLVMListPtr &listPtr = m_listPtrs[step.workList]; - llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty()); + + // Range check + llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0)); + llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); + size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); + llvm::Value *index = castValue(indexArg.second, indexArg.first); + llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + llvm::BasicBlock *replaceBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function); + m_builder.CreateCondBr(inRange, replaceBlock, nextBlock); + + // Replace + m_builder.SetInsertPoint(replaceBlock); + index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); llvm::Value *itemPtr = getListItem(listPtr, index); createValueStore(valueArg.second, itemPtr, type, listPtr.type); + m_builder.CreateBr(nextBlock); auto &typeMap = m_scopeLists.back(); @@ -761,6 +803,7 @@ std::shared_ptr LLVMCodeBuilder::finalize() typeMap[&listPtr] = listPtr.type; } + m_builder.SetInsertPoint(nextBlock); break; } @@ -777,8 +820,18 @@ std::shared_ptr LLVMCodeBuilder::finalize() assert(step.args.size() == 1); const auto &arg = step.args[0]; const LLVMListPtr &listPtr = m_listPtrs[step.workList]; - llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty()); - step.functionReturnReg->value = getListItem(listPtr, index); + + llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0)); + llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr); + size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy()); + llvm::Value *index = castValue(arg.second, arg.first); + llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size)); + + LLVMConstantRegister nullReg(listPtr.type == Compiler::StaticType::Unknown ? Compiler::StaticType::Number : listPtr.type, Value()); + llvm::Value *null = createValue(static_cast(static_cast(&nullReg))); + + index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty()); + step.functionReturnReg->value = m_builder.CreateSelect(inRange, getListItem(listPtr, index), null); step.functionReturnReg->setType(listPtr.type); break; } diff --git a/test/dev/llvm/llvmcodebuilder_test.cpp b/test/dev/llvm/llvmcodebuilder_test.cpp index c8b9c916..4ff82788 100644 --- a/test/dev/llvm/llvmcodebuilder_test.cpp +++ b/test/dev/llvm/llvmcodebuilder_test.cpp @@ -1913,9 +1913,21 @@ TEST_F(LLVMCodeBuilderTest, RemoveFromList) CompilerValue *v = m_builder->addConstValue(1); m_builder->createListRemove(globalList.get(), v); + v = m_builder->addConstValue(-1); + m_builder->createListRemove(globalList.get(), v); + + v = m_builder->addConstValue(3); + m_builder->createListRemove(globalList.get(), v); + v = m_builder->addConstValue(3); m_builder->createListRemove(localList.get(), v); + v = m_builder->addConstValue(-1); + m_builder->createListRemove(localList.get(), v); + + v = m_builder->addConstValue(4); + m_builder->createListRemove(localList.get(), v); + auto code = m_builder->finalize(); Script script(&sprite, nullptr, nullptr); script.setCode(code); @@ -2040,6 +2052,18 @@ TEST_F(LLVMCodeBuilderTest, InsertToList) v2 = m_builder->addConstValue("hello world"); m_builder->createListInsert(localList.get(), v1, v2); + v1 = m_builder->addConstValue(3); + v2 = m_builder->addConstValue("test"); + m_builder->createListInsert(localList.get(), v1, v2); + + v1 = m_builder->addConstValue(-1); + v2 = m_builder->addConstValue(123); + m_builder->createListInsert(localList.get(), v1, v2); + + v1 = m_builder->addConstValue(6); + v2 = m_builder->addConstValue(123); + m_builder->createListInsert(localList.get(), v1, v2); + auto code = m_builder->finalize(); Script script(&sprite, nullptr, nullptr); script.setCode(code); @@ -2049,7 +2073,7 @@ TEST_F(LLVMCodeBuilderTest, InsertToList) code->run(ctx.get()); ASSERT_EQ(globalList->toString(), "1 2 1 test 3"); - ASSERT_EQ(localList->toString(), "false hello world true"); + ASSERT_EQ(localList->toString(), "false hello world true test"); } TEST_F(LLVMCodeBuilderTest, ListReplace) @@ -2099,6 +2123,14 @@ TEST_F(LLVMCodeBuilderTest, ListReplace) v2 = m_builder->addConstValue("hello world"); m_builder->createListReplace(localList.get(), v1, v2); + v1 = m_builder->addConstValue(-1); + v2 = m_builder->addConstValue(123); + m_builder->createListReplace(localList.get(), v1, v2); + + v1 = m_builder->addConstValue(5); + v2 = m_builder->addConstValue(123); + m_builder->createListReplace(localList.get(), v1, v2); + auto code = m_builder->finalize(); Script script(&sprite, nullptr, nullptr); script.setCode(code); @@ -2169,26 +2201,36 @@ TEST_F(LLVMCodeBuilderTest, GetListItem) sprite.setEngine(&m_engine); EXPECT_CALL(m_engine, stage()).WillRepeatedly(Return(&stage)); - std::unordered_map strings; - auto globalList = std::make_shared("", ""); stage.addList(globalList); - auto localList = std::make_shared("", ""); - sprite.addList(localList); + auto localList1 = std::make_shared("", ""); + sprite.addList(localList1); + + auto localList2 = std::make_shared("", ""); + sprite.addList(localList2); + + auto localList3 = std::make_shared("", ""); + sprite.addList(localList3); globalList->append(1); globalList->append(2); globalList->append(3); - localList->append("Lorem"); - localList->append("ipsum"); - localList->append("dolor"); - localList->append("sit"); - strings[localList.get()] = localList->toString(); + localList1->append("Lorem"); + localList1->append("ipsum"); + localList1->append("dolor"); + localList1->append("sit"); + + localList2->append(-564.121); + localList2->append(4257.4); + + localList3->append(true); + localList3->append(false); createBuilder(&sprite, true); + // Global CompilerValue *v = m_builder->addConstValue(2); v = m_builder->addListItem(globalList.get(), v); m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); @@ -2201,24 +2243,67 @@ TEST_F(LLVMCodeBuilderTest, GetListItem) v = m_builder->addListItem(globalList.get(), v); m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + v = m_builder->addConstValue(-1); + v = m_builder->addListItem(globalList.get(), v); + m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + v = m_builder->addConstValue(3); + v = m_builder->addListItem(globalList.get(), v); + m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + // Local 1 v = m_builder->addConstValue(0); - v = m_builder->addListItem(localList.get(), v); + v = m_builder->addListItem(localList1.get(), v); m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); v = m_builder->addConstValue(2); - v = m_builder->addListItem(localList.get(), v); + v = m_builder->addListItem(localList1.get(), v); m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); v = m_builder->addConstValue(3); - v = m_builder->addListItem(localList.get(), v); + v = m_builder->addListItem(localList1.get(), v); + m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + v = m_builder->addConstValue(-1); + v = m_builder->addListItem(localList1.get(), v); m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + v = m_builder->addConstValue(4); + v = m_builder->addListItem(localList1.get(), v); + m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + // Local 2 + v = m_builder->addConstValue(-1); + v = m_builder->addListItem(localList2.get(), v); + m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v }); + + v = m_builder->addConstValue(2); + v = m_builder->addListItem(localList2.get(), v); + m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v }); + + // Local 3 + v = m_builder->addConstValue(-1); + v = m_builder->addListItem(localList3.get(), v); + m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v }); + + v = m_builder->addConstValue(2); + v = m_builder->addListItem(localList3.get(), v); + m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }, { v }); + static const std::string expected = "3\n" "1\n" + "0\n" + "0\n" "Lorem\n" "dolor\n" - "sit\n"; + "sit\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n"; auto code = m_builder->finalize(); Script script(&sprite, nullptr, nullptr); @@ -2231,7 +2316,9 @@ TEST_F(LLVMCodeBuilderTest, GetListItem) ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); ASSERT_EQ(globalList->toString(), "1 test 3"); - ASSERT_EQ(localList->toString(), "Lorem ipsum dolor sit"); + ASSERT_EQ(localList1->toString(), "Lorem ipsum dolor sit"); + ASSERT_EQ(localList2->toString(), "-564.121 4257.4"); + ASSERT_EQ(localList3->toString(), "true false"); } TEST_F(LLVMCodeBuilderTest, GetListSize)