Skip to content

Commit

Permalink
case permutation (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Oct 11, 2023
1 parent a5516a5 commit 7714635
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

isConstructorTree :: Case -> Node -> Bool
isConstructorTree c node = case run $ runFail $ go mempty node of
Just ctrsMap ->
all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault)
Nothing -> False
where
tags = map (^. caseBranchTag) (c ^. caseBranches)

checkOne :: HashMap Tag Int -> Tag -> Bool
checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of
Just 1 -> True
Nothing -> True
_ -> {- isImmediate -} False

checkDefault :: HashMap Tag Int -> Maybe Node -> Bool
checkDefault ctrsMap = \case
Just {} ->
-- or isImmediate
sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1
where
tags' = HashSet.fromList tags
Nothing -> True

go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int)
go ctrs = \case
NCtr Constr {..} -> return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs
NCase Case {..} -> foldM go ctrs (map (^. caseBranchBody) _caseBranches)
_ -> fail

convertNode :: Node -> Node
convertNode = dmap go
where
go :: Node -> Node
go node = case node of
NCase c@Case {..} -> case _caseValue of
NCase c'
| isConstructorTree c _caseValue ->
NCase
c'
{ _caseBranches = map permuteBranch (c' ^. caseBranches),
_caseDefault = fmap (mkBody c) (c' ^. caseDefault)
}
where
permuteBranch :: CaseBranch -> CaseBranch
permuteBranch br@CaseBranch {..} =
case shift _caseBranchBindersNum (NCase c {_caseValue = mkBottom'}) of
NCase cs ->
over caseBranchBody (mkBody cs) br
_ -> impossible

mkBody :: Case -> Node -> Node
mkBody cs n = NCase cs {_caseValue = n}
_ ->
node
_ -> node

casePermutation :: InfoTable -> InfoTable
casePermutation = mapAllNodes convertNode
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
Expand All @@ -19,6 +20,7 @@ optimize' CoreOptions {..} tab =
. lambdaFolding
. doInlining
. caseFolding
. casePermutation
. letFolding' (isInlineableLambda _optInliningDepth)
. lambdaFolding
. specializeArgs
Expand Down

0 comments on commit 7714635

Please sign in to comment.