diff --git a/src/Juvix/Compiler/Asm/Pipeline.hs b/src/Juvix/Compiler/Asm/Pipeline.hs index 321a056da2..e124a84f56 100644 --- a/src/Juvix/Compiler/Asm/Pipeline.hs +++ b/src/Juvix/Compiler/Asm/Pipeline.hs @@ -20,7 +20,7 @@ toReg' = validate >=> filterUnreachable >=> computeStackUsage >=> computePreallo -- | Perform transformations on JuvixAsm necessary before the translation to -- Nockma toNockma' :: (Members '[Error AsmError, Reader Options] r) => InfoTable -> Sem r InfoTable -toNockma' = validate >=> filterUnreachable >=> computeTempHeight +toNockma' = validate >=> filterUnreachable toReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => InfoTable -> Sem r InfoTable toReg = mapReader fromEntryPoint . mapError (JuvixError @AsmError) . toReg' diff --git a/src/Juvix/Compiler/Asm/Transformation.hs b/src/Juvix/Compiler/Asm/Transformation.hs index b0771b3359..be382194d7 100644 --- a/src/Juvix/Compiler/Asm/Transformation.hs +++ b/src/Juvix/Compiler/Asm/Transformation.hs @@ -3,12 +3,10 @@ module Juvix.Compiler.Asm.Transformation module Juvix.Compiler.Asm.Transformation.Prealloc, module Juvix.Compiler.Asm.Transformation.Validate, module Juvix.Compiler.Asm.Transformation.FilterUnreachable, - module Juvix.Compiler.Asm.Transformation.TempHeight, ) where import Juvix.Compiler.Asm.Transformation.FilterUnreachable import Juvix.Compiler.Asm.Transformation.Prealloc import Juvix.Compiler.Asm.Transformation.StackUsage -import Juvix.Compiler.Asm.Transformation.TempHeight import Juvix.Compiler.Asm.Transformation.Validate diff --git a/src/Juvix/Compiler/Asm/Transformation/TempHeight.hs b/src/Juvix/Compiler/Asm/Transformation/TempHeight.hs deleted file mode 100644 index 1eab2d8909..0000000000 --- a/src/Juvix/Compiler/Asm/Transformation/TempHeight.hs +++ /dev/null @@ -1,75 +0,0 @@ -module Juvix.Compiler.Asm.Transformation.TempHeight where - -import Juvix.Compiler.Asm.Transformation.Base - -computeFunctionTempHeight :: - (Member (Error AsmError) r) => - InfoTable -> - FunctionInfo -> - Sem r FunctionInfo -computeFunctionTempHeight tab fi = do - ps :: [Command] <- recurseS sig (fi ^. functionCode) - return (set functionCode ps fi) - where - sig :: RecursorSig StackInfo r Command - sig = - RecursorSig - { _recursorInfoTable = tab, - _recurseInstr = goInstr, - _recurseBranch = goBranch, - _recurseCase = goCase, - _recurseSave = goSave - } - - goInstr :: StackInfo -> CmdInstr -> Sem r Command - goInstr si cmd@(CmdInstr _ instr) = case instr of - Push (Ref (DRef (TempRef r))) -> - let h = si ^. stackInfoTempStackHeight - r' = set refTempTempHeight (Just h) r - instr' = Push (Ref (DRef (TempRef r'))) - in return (Instr (set cmdInstrInstruction instr' cmd)) - Push (Ref (ConstrRef field@Field {_fieldRef = TempRef r})) -> - let h = si ^. stackInfoTempStackHeight - r' = set refTempTempHeight (Just h) r - instr' = - Push - ( Ref - ( ConstrRef - field - { _fieldRef = TempRef r' - } - ) - ) - in return (Instr (set cmdInstrInstruction instr' cmd)) - _ -> return (Instr cmd) - - goCase :: StackInfo -> CmdCase -> [Code] -> Maybe Code -> Sem r Command - goCase _ cmd brs mdef = - return - ( Case - cmd - { _cmdCaseBranches = branches', - _cmdCaseDefault = mdef - } - ) - where - branches' :: [CaseBranch] - branches' = - [ set caseBranchCode newCode oldBr - | (oldBr, newCode) <- zipExact (cmd ^. cmdCaseBranches) brs - ] - - goBranch :: StackInfo -> CmdBranch -> Code -> Code -> Sem r Command - goBranch _ cmd t f = - return - ( Branch - cmd - { _cmdBranchTrue = t, - _cmdBranchFalse = f - } - ) - goSave :: StackInfo -> CmdSave -> Code -> Sem r Command - goSave _ cmd code = return (Save (set cmdSaveCode code cmd)) - -computeTempHeight :: (Member (Error AsmError) r) => InfoTable -> Sem r InfoTable -computeTempHeight tab = liftFunctionTransformation (computeFunctionTempHeight tab) tab diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId.hs b/src/Juvix/Compiler/Tree/Data/TransformationId.hs index 2e1b139721..2ead0ee432 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId.hs @@ -9,6 +9,7 @@ data TransformationId | IdentityU | IdentityD | Apply + | TempHeight deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -19,7 +20,7 @@ data PipelineId type TransformationLikeId = TransformationLikeId' TransformationId PipelineId toNockmaTransformations :: [TransformationId] -toNockmaTransformations = [Apply] +toNockmaTransformations = [Apply, TempHeight] toAsmTransformations :: [TransformationId] toAsmTransformations = [] @@ -31,6 +32,7 @@ instance TransformationId' TransformationId where IdentityU -> strIdentityU IdentityD -> strIdentityD Apply -> strApply + TempHeight -> strTempHeight instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs index fa473dc2ec..d5512d024a 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs @@ -19,3 +19,6 @@ strIdentityD = "identity-dmap" strApply :: Text strApply = "apply" + +strTempHeight :: Text +strTempHeight = "temp-height" diff --git a/src/Juvix/Compiler/Tree/Transformation.hs b/src/Juvix/Compiler/Tree/Transformation.hs index d4a30f0a76..714fe1a82c 100644 --- a/src/Juvix/Compiler/Tree/Transformation.hs +++ b/src/Juvix/Compiler/Tree/Transformation.hs @@ -9,6 +9,7 @@ import Juvix.Compiler.Tree.Data.TransformationId import Juvix.Compiler.Tree.Transformation.Apply import Juvix.Compiler.Tree.Transformation.Base import Juvix.Compiler.Tree.Transformation.Identity +import Juvix.Compiler.Tree.Transformation.TempHeight applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable applyTransformations ts tbl = foldM (flip appTrans) tbl ts @@ -19,3 +20,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts IdentityU -> return . identityU IdentityD -> return . identityD Apply -> return . computeApply + TempHeight -> return . computeTempHeight diff --git a/src/Juvix/Compiler/Tree/Transformation/TempHeight.hs b/src/Juvix/Compiler/Tree/Transformation/TempHeight.hs new file mode 100644 index 0000000000..15fd9d0f7a --- /dev/null +++ b/src/Juvix/Compiler/Tree/Transformation/TempHeight.hs @@ -0,0 +1,25 @@ +module Juvix.Compiler.Tree.Transformation.TempHeight where + +import Juvix.Compiler.Tree.Extra.Recursors +import Juvix.Compiler.Tree.Transformation.Base + +computeFunctionTempHeight :: Node -> Node +computeFunctionTempHeight = umapN go + where + go :: Int -> Node -> Node + go k = \case + MemRef (DRef (TempRef r)) -> + let r' = set refTempTempHeight (Just k) r + in MemRef $ DRef (TempRef r') + MemRef (ConstrRef field@Field {_fieldRef = TempRef r}) -> + let r' = set refTempTempHeight (Just k) r + in MemRef + ( ConstrRef + field + { _fieldRef = TempRef r' + } + ) + node -> node + +computeTempHeight :: InfoTable -> InfoTable +computeTempHeight = mapT (const computeFunctionTempHeight)