From a384042105ff0ad6acb5ef2060f6205a5063fbf9 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 16 Nov 2020 16:58:12 +0100 Subject: [PATCH] added alpha IR refactoring --- fast_brainfuck.lua | 439 +++++++++++++++++++++++++++++++++------------ 1 file changed, 324 insertions(+), 115 deletions(-) diff --git a/fast_brainfuck.lua b/fast_brainfuck.lua index bfa7d17..6a31e13 100644 --- a/fast_brainfuck.lua +++ b/fast_brainfuck.lua @@ -1,11 +1,20 @@ --usage : luajit fast_brainfuck.lua mandelbrot.bf if jit then jit.opt.start("loopunroll=100") end -local STATS = false -- set to true to print optimizations count for each pass +local STATS = true -- set to true to print optimizations count for each pass local vmSettings = { ram = 32768, cellType = "char", } + +local shouldCreateSubFunctions = false -- only use for HUGE programs because it slows down the code +local subFunctionMinimumMatches = 50 +local subFunctionPrefix = "lf_" + + +local subFunctions = {} + + local artithmeticsIns = { ["+"] = 1, ["-"] = -1, @@ -50,10 +59,140 @@ local IRToCode = { [UNROLLED_ASSIGNATION] = "data[i+%i] = data[i+%i] + (-(data[i]/%i))*%i ", [IFSTART] = "if (data[i] ~= 0) then ", [IFEND] = "end ", - [FUNC_CALL] = "%s()" + [FUNC_CALL] = "%s() " } + + + +-- it's low likely we find duplicable nested loops se we only search for single ones +local function nextNonNestingWhileLoop(IRList, curPos, maxPos) + local loopStart + + while (curPos <= maxPos) do + --print("nextNonNestingWhileLoop", curPos, maxPos) + if IRList[curPos][1] == LOOPSTART then + loopStart = curPos + local i = 1 -- 1 because we directly check the next instruction + + while (loopStart + i <= maxPos) do + local curIR = IRList[loopStart + i][1] + + if curIR == LOOPSTART then + i = i - 1 -- -1 because we need to analyze it in the next outter Loop iteration + break; + elseif curIR == LOOPEND then + local loopEnd = loopStart + i + local IRListOUTPUT = {} + while loopStart <= loopEnd do + table.insert(IRListOUTPUT, IRList[loopStart]) + loopStart = loopStart + 1 + end + + return loopStart, IRListOUTPUT + end + + i = i + 1 + end + + + + end + + curPos = curPos + 1 + end + + return nil, nil +end + +local function IREqual(IR1, IR2) + if IR1[1] ~= IR2[1] then return false end + local i = 1 + while (true) do + if IR1[i] ~= IR2[i] then return false end + if IR1[i] == nil then return true end + i = i + 1 + end + return true +end + +assert(IREqual({INC, 1, 2, 3}, {INC, 1, 2, 3}) == true) +assert(IREqual({INC, 1, 2}, {INC, 1, 2, 3}) == false) +assert(IREqual({INC, 1, 2, 3}, {INC, 1, 2}) == false) +assert(IREqual({INC, 1, 2, 3}, {INC, 1, 2, 4}) == false) +assert(IREqual({INC, 1, 2, 3}, {MOVE, 1, 2, 4}) == false) +assert(IREqual({INC}, {INC}) == true) + +local function findIRMatches(haystack, needle, startPos) + local foundCount = 0 + local i = startPos + local max = #haystack + local needlesize = #needle + + while (i <= max) do + local needleI = 0 + + while (needleI < needlesize and (i + needleI) <= max and IREqual(haystack[i + needleI - 1], needle[needleI + 1])) == true do + needleI = needleI + 1 + end + + if needleI == needlesize then + i = i + needlesize + foundCount = foundCount + 1 + else + i = i + 1 + end + end + + return foundCount +end + + +local function replaceIRs(haystack, needle, replaceBy, startPos) + local replacmentCount = 0 + local i = startPos + local max = #haystack + local needlesize = #needle + local replacmentSize = #replaceBy + + while (i <= max) do + local needleI = 0 + + while (needleI < needlesize and (i + needleI) <= max and IREqual(haystack[i + needleI], needle[needleI + 1])) do + needleI = needleI + 1 + end + + if needleI == needlesize then + local replaceByI = 1 + + --remove needle IR + while (replaceByI <= needlesize) do + table.remove(haystack, i) + replaceByI = replaceByI + 1 + end + + --and insert new IR + replaceByI = 0 + + while (replaceByI < replacmentSize) do + -- here we can do a ref copy, not a real copy as we don't plan to edit the instructions/IR content later + table.insert(haystack, i + replaceByI, replaceBy[replaceByI + 1]) + replaceByI = replaceByI + 1 + end + + max = max - needlesize + replacmentSize + i = i + replacmentSize + replacmentCount = replacmentCount + 1 + else + i = i + 1 + end + end + + return replacmentCount +end + + local function firstPassOptimization(instList) --[[ while data[i] ~= 0 do @@ -92,6 +231,110 @@ local function firstPassOptimization(instList) end +local function secondPassMemset(instList) + if type(rawget(_G, "jit")) ~= "table" then + if STATS then + print("memset() pass is DISABLED because ffi.fill is not available on this platform.") + end + return + end + --[[ + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + i = i + 1 + data[i] = 0 + + vvvvvvvvvvvvvv + ffi.fill(data + i, 9, 0) + i = i + 9 + it also might automerge with second i+i instruction and remove if sum is zero +]] + local i = 1 + local minimumAssignations = 2 + local max = #instList + local currentFindSize = 0 + local currentAssignation = 0 + local optimizationCount = 0 + + while (i <= max - 2) do + if instList[i][1] == MOVE and instList[i][2] == 1 and instList[i + 1][1] == ASSIGNATION then + currentFindSize = 1 + currentAssignation = instList[i + 1][2] + local i2 = i + 2 + + while (i2 <= max) do + local ptsShiftCandidate = instList[i2] + local dataAssignationCandidate = instList[i2 + 1] + + if ptsShiftCandidate[1] ~= MOVE or ptsShiftCandidate[2] ~= 1 or dataAssignationCandidate[1] ~= ASSIGNATION or dataAssignationCandidate[2] ~= currentAssignation then + -- create memset instruction + if currentFindSize < minimumAssignations then + i = i + (currentFindSize * 2) - 1 -- -1 because right after this batch could be another one, don't skip the first member + goto doubleBreakMemset + end + + local i3 = 0 + -- clear the instruction so you can replace them by the memset one + while (i3 < (currentFindSize * 2)) do + table.remove(instList, i) + i3 = i3 + 1 + end + + -- the assignation row may not have started with a pointer shift for some reasons, so let's cover this case + -- we handle the possible ptr+1 or just ptr as starting mem pos + if instList[i - 1][1] == ASSIGNATION and instList[i - 1][2] == currentAssignation then + i = i - 1 + table.remove(instList, i) + table.insert(instList, i, {MEMSET, 0, currentFindSize + 1, currentAssignation}) + max = max - (currentFindSize + 1) * 2 + else + table.insert(instList, i, {MEMSET, 1, currentFindSize , currentAssignation}) + max = max - (currentFindSize * 2 - 1) + end + + local nextIns = instList[i + 1] + -- folding with next possible ptr ins + if nextIns[1] == MOVE then + if nextIns[2] + currentFindSize == 0 then + table.remove(instList, i + 1) + i = i - 1 + else + nextIns[2] = nextIns[2] + currentFindSize + end + else + table.insert(instList, i + 1, {MOVE, currentFindSize}) + i = i + 1 + end + optimizationCount = optimizationCount + 1 + goto doubleBreakMemset + else + currentFindSize = currentFindSize + 1 + end + + i2 = i2 + 2 + end + + ::doubleBreakMemset:: + end + + i = i + 1 + end + if STATS then print("memset() pass : ", optimizationCount ) end +end local function thirdPassUnRolledAssignation(instList) --[[ @@ -213,111 +456,6 @@ local function thirdPassUnRolledAssignation(instList) end -local function secondPassMemset(instList) - if type(rawget(_G, "jit")) ~= "table" then - if STATS then - print("memset() pass is DISABLED because ffi.fill is not available on this platform.") - end - return - end - --[[ - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - i = i + 1 - data[i] = 0 - - vvvvvvvvvvvvvv - ffi.fill(data + i, 9, 0) - i = i + 9 - it also might automerge with second i+i instruction and remove if sum is zero -]] - local i = 1 - local minimumAssignations = 2 - local max = #instList - local currentFindSize = 0 - local currentAssignation = 0 - local optimizationCount = 0 - - while (i <= max - 2) do - if instList[i][1] == MOVE and instList[i][2] == 1 and instList[i + 1][1] == ASSIGNATION then - currentFindSize = 1 - currentAssignation = instList[i + 1][2] - local i2 = i + 2 - - while (i2 <= max) do - local ptsShiftCandidate = instList[i2] - local dataAssignationCandidate = instList[i2 + 1] - - if ptsShiftCandidate[1] ~= MOVE or ptsShiftCandidate[2] ~= 1 or dataAssignationCandidate[1] ~= ASSIGNATION or dataAssignationCandidate[2] ~= currentAssignation then - -- create memset instruction - if currentFindSize < minimumAssignations then - i = i + (currentFindSize * 2) - 1 -- -1 because right after this batch could be another one, don't skip the first member - goto doubleBreakMemset - end - - local i3 = 0 - -- clear the instruction so you can replace them by the memset one - while (i3 < (currentFindSize * 2)) do - table.remove(instList, i) - i3 = i3 + 1 - end - - -- the assignation row may not have started with a pointer shift for some reasons, so let's cover this case - -- we handle the possible ptr+1 or just ptr as starting mem pos - if instList[i - 1][1] == ASSIGNATION and instList[i - 1][2] == currentAssignation then - i = i - 1 - table.remove(instList, i) - table.insert(instList, i, {MEMSET, 0, currentFindSize + 1, currentAssignation}) - max = max - (currentFindSize + 1) * 2 - else - table.insert(instList, i, {MEMSET, 1, currentFindSize , currentAssignation}) - max = max - (currentFindSize * 2 - 1) - end - - local nextIns = instList[i + 1] - -- folding with next possible ptr ins - if nextIns[1] == MOVE then - if nextIns[2] + currentFindSize == 0 then - table.remove(instList, i + 1) - i = i - 1 - else - nextIns[2] = nextIns[2] + currentFindSize - end - else - table.insert(instList, i + 1, {MOVE, currentFindSize}) - i = i + 1 - end - optimizationCount = optimizationCount + 1 - goto doubleBreakMemset - else - currentFindSize = currentFindSize + 1 - end - - i2 = i2 + 2 - end - - ::doubleBreakMemset:: - end - - i = i + 1 - end - if STATS then print("memset() pass : ", optimizationCount ) end -end - local brainfuck = function(s) s = s:gsub("[^%+%-<>%.,%[%]]+", "") -- remove new lines local instList = {} @@ -376,15 +514,11 @@ local brainfuck = function(s) secondPassMemset(instList) thirdPassUnRolledAssignation(instList) local insTableStr = {} - local i = 1 - local max = #instList + + -- lua 54 & jit compatiblity local unpack = unpack or table.unpack - while (i <= max) do - local IR = instList[i] - insTableStr[i] = string.format(IRToCode[IR[1]], select(2, unpack(IR))):gsub("%+%-", "-") - i = i + 1 - end + local code = [[local data; local ffi @@ -409,8 +543,83 @@ local r = function() return io.read(1):byte() end -]] .. table.concat(insTableStr) +]] + if shouldCreateSubFunctions then + + --remove the IR from main code + do + -- let's save the latest good candidate that could close/end the extraction chunk + + + local i = 1 + local max = #instList + local optReplaceCount = 0 + while (i <= max) do + local startPos, patternIRList = nextNonNestingWhileLoop(instList, i, max) + if startPos ~= nil then + local matches = findIRMatches(instList, patternIRList, startPos) + + local funcName = subFunctionPrefix .. tostring(patternIRList):sub(8) + + if matches >= subFunctionMinimumMatches then + subFunctions[funcName] = patternIRList + local replaceCount = replaceIRs(instList, patternIRList, {{FUNC_CALL, funcName}}, startPos) + assert(replaceCount == matches, "expected : " .. matches .. " got : " .. replaceCount) + optReplaceCount = optReplaceCount + replaceCount + max = max - ((replaceCount-1) * #patternIRList) + end + i = startPos + #patternIRList + 1 + else + break + end + end + if STATS then print("Refactoring pass : ", optReplaceCount) end + end + + + + + + + + + --output the extracted IR to Lua code + local subFunctionTableString = {} + + local subFunctionsNames = {} + for k, v in pairs(subFunctions) do + table.insert(subFunctionsNames, k) + end + + code = code .. "local " .. table.concat(subFunctionsNames, ", ") .. ";" + + for fName, IRtbl in pairs(subFunctions) do + local subFIR = {} + local i2 = 1 + local max = #IRtbl + while (i2 <= max) do + local IR = IRtbl[i2] + subFIR[i2] = string.format(IRToCode[IR[1]], select(2, unpack(IR))):gsub("%+%-", "-") + i2 = i2 + 1 + end + table.insert(subFunctionTableString, string.format(" %s = function() %s end ", fName, table.concat(subFIR))) + end + code = code .. table.concat(subFunctionTableString, "\n") + + end + + + i = 1 + local max = #instList + while (i <= max) do + local IR = instList[i] + insTableStr[i] = string.format(IRToCode[IR[1]], select(2, unpack(IR))):gsub("%+%-", "-") + i = i + 1 + end + + code = code .. table.concat(insTableStr, "\n") local loadstring = loadstring or load + --print(loadstring(code, string.format("Brainfuck Interpreter %p",instList ))) local status = loadstring(code, string.format("Brainfuck Interpreter %p",instList )) --print(code) do return end return status