Skip to content

Commit

Permalink
Comparison optimization (#2443)
Browse files Browse the repository at this point in the history
* Closes #2440
  • Loading branch information
lukaszcz authored Oct 12, 2023
1 parent 81f8339 commit 9e3e07d
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ data TransformationId
| SimplifyIfs
| SpecializeArgs
| CaseFolding
| CasePermutation
| FilterUnreachable
| OptPhaseEval
| OptPhaseExec
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ transformationText = \case
SimplifyIfs -> strSimplifyIfs
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
FilterUnreachable -> strFilterUnreachable
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec
Expand Down Expand Up @@ -210,6 +211,9 @@ strSpecializeArgs = "specialize-args"
strCaseFolding :: Text
strCaseFolding = "case-folding"

strCasePermutation :: Text
strCasePermutation = "case-permutation"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

Expand Down
28 changes: 23 additions & 5 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ isImmediate tab = \case
NVar {} -> True
NIdt {} -> True
NCst {} -> True
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' tab _constrTag ->
let paramsNum = length (takeWhile (isTypeConstr tab) (typeArgs (ci ^. constructorType)))
in length _constrArgs <= paramsNum
| otherwise -> all (isType tab mempty) _constrArgs
node@(NApp {}) ->
let (h, args) = unfoldApps' node
in case h of
Expand All @@ -125,6 +130,16 @@ isFailNode = \case
NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True
_ -> False

isTrueConstr :: Node -> Bool
isTrueConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True
_ -> False

isFalseConstr :: Node -> Bool
isFalseConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
_ -> False

freeVarsSortedMany :: [Node] -> Set Var
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)

Expand Down Expand Up @@ -356,26 +371,29 @@ builtinOpArgTypes = \case
OpFail -> [mkTypeString']

translateCase :: (Node -> Node -> Node -> a) -> a -> Case -> a
translateCase translateIf dflt Case {..} = case _caseBranches of
translateCase translateIfFun dflt Case {..} = case _caseBranches of
[br@CaseBranch {..}]
| _caseBranchTag == BuiltinTag TagTrue ->
translateIf _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault)
translateIfFun _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault)
[br@CaseBranch {..}]
| _caseBranchTag == BuiltinTag TagFalse ->
translateIf _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody)
translateIfFun _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody)
[br1, br2]
| br1 ^. caseBranchTag == BuiltinTag TagTrue
&& br2 ^. caseBranchTag == BuiltinTag TagFalse ->
translateIf _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody)
translateIfFun _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody)
| br1 ^. caseBranchTag == BuiltinTag TagFalse
&& br2 ^. caseBranchTag == BuiltinTag TagTrue ->
translateIf _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody)
translateIfFun _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody)
_ ->
dflt
where
branchFailure :: Node
branchFailure = mkBuiltinApp' OpFail [mkConstant' (ConstString "illegal `if` branch")]

translateCaseIf :: (Node -> Node -> Node -> a) -> Case -> a
translateCaseIf f = translateCase f impossible

checkDepth :: InfoTable -> BinderList Binder -> Int -> Node -> Bool
checkDepth tab bl 0 node = isType tab bl node
checkDepth tab bl d node = case node of
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Juvix.Compiler.Core.Transformation.NatToPrimInt
import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
Expand Down Expand Up @@ -81,6 +82,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SimplifyIfs -> return . simplifyIfs
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation
FilterUnreachable -> return . filterUnreachable
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize
Expand Down
70 changes: 70 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
module Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

isConstructorTree :: InfoTable -> Case -> Node -> Bool
isConstructorTree tab c node = case run $ runFail $ go mempty node of
Just ctrsMap ->
all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault)
Nothing -> False
where
tags = map (^. caseBranchTag) (c ^. caseBranches)
tagMap = HashMap.fromList (map (\br -> (br ^. caseBranchTag, br ^. caseBranchBody)) (c ^. caseBranches))

checkOne :: HashMap Tag Int -> Tag -> Bool
checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of
Just 1 -> True
Nothing -> True
_ -> isImmediate tab (fromJust $ HashMap.lookup tag tagMap)

checkDefault :: HashMap Tag Int -> Maybe Node -> Bool
checkDefault ctrsMap = \case
Just d ->
sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1
|| isImmediate tab d
where
tags' = HashSet.fromList tags
Nothing -> True

go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int)
go ctrs = \case
NCtr Constr {..} ->
return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs
NCase Case {..} -> do
ctrs' <- maybe (return ctrs) (go ctrs) _caseDefault
foldM go ctrs' (map (^. caseBranchBody) _caseBranches)
_ ->
fail

convertNode :: InfoTable -> Node -> Node
convertNode tab = dmap go
where
go :: Node -> Node
go node = case node of
NCase c@Case {..} -> case _caseValue of
NCase c'
| isConstructorTree tab c _caseValue ->
NCase
c'
{ _caseBranches = map permuteBranch (c' ^. caseBranches),
_caseDefault = fmap (mkBody c) (c' ^. caseDefault)
}
where
permuteBranch :: CaseBranch -> CaseBranch
permuteBranch br@CaseBranch {..} =
case shift _caseBranchBindersNum (NCase c {_caseValue = mkBottom'}) of
NCase cs ->
over caseBranchBody (mkBody cs) br
_ -> impossible

mkBody :: Case -> Node -> Node
mkBody cs n = NCase cs {_caseValue = n}
_ ->
node
_ -> node

casePermutation :: InfoTable -> InfoTable
casePermutation tab = mapAllNodes (convertNode tab) tab
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs

optimize' :: CoreOptions -> InfoTable -> InfoTable
Expand All @@ -18,7 +21,10 @@ optimize' CoreOptions {..} tab =
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. doInlining
. simplifyIfs' (_optOptimizationLevel <= 1)
. simplifyComparisons
. caseFolding
. casePermutation
. letFolding' (isInlineableLambda _optInliningDepth)
. lambdaFolding
. specializeArgs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) where

import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

convertNode :: InfoTable -> Node -> Node
convertNode tab = dmap go
where
boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive

go :: Node -> Node
go node = case node of
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf goIf c
_ -> node

goIf :: Node -> Node -> Node -> Node
goIf v b1 b2 = case v of
NBlt blt@BuiltinApp {..}
| OpEq <- _builtinAppOp ->
case b2 of
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf (goCmp v b1) c
NBlt blt'
| OpIntLt <- blt' ^. builtinAppOp,
blt ^. builtinAppArgs == blt' ^. builtinAppArgs ->
if
| isFalseConstr b1 ->
b2
| isTrueConstr b1 ->
NBlt blt {_builtinAppOp = OpIntLe}
| otherwise ->
mkIf' boolSym v b1 b2
_ ->
mkIf' boolSym v b1 b2
| OpIntLt <- _builtinAppOp,
isFalseConstr b1 && isTrueConstr b2 ->
NBlt
blt
{ _builtinAppOp = OpIntLe,
_builtinAppArgs = reverse _builtinAppArgs
}
| OpIntLe <- _builtinAppOp,
isFalseConstr b1 && isTrueConstr b2 ->
NBlt
blt
{ _builtinAppOp = OpIntLt,
_builtinAppArgs = reverse _builtinAppArgs
}
_ ->
mkIf' boolSym v b1 b2

goCmp :: Node -> Node -> Node -> Node -> Node -> Node
goCmp v b1 v' b1' b2' = case (v, v') of
(NBlt blt, NBlt blt')
| (OpEq, OpIntLt) <- (blt ^. builtinAppOp, blt' ^. builtinAppOp),
blt ^. builtinAppArgs == blt' ^. builtinAppArgs ->
if
| isFalseConstr b1 && isTrueConstr b1' && isFalseConstr b2' ->
v'
| isTrueConstr b1 && isFalseConstr b1' && isFalseConstr b2' ->
v
| isTrueConstr b1 && isTrueConstr b1' && isFalseConstr b2' ->
NBlt blt {_builtinAppOp = OpIntLe}
| isFalseConstr b1 && isFalseConstr b1' && isTrueConstr b2' ->
NBlt
blt
{ _builtinAppOp = OpIntLt,
_builtinAppArgs = reverse (blt ^. builtinAppArgs)
}
| isTrueConstr b1 && isFalseConstr b1' && isTrueConstr b2' ->
NBlt
blt
{ _builtinAppOp = OpIntLe,
_builtinAppArgs = reverse (blt ^. builtinAppArgs)
}
| isFalseConstr b1 && isTrueConstr b1' && isTrueConstr b2' ->
mkIf' boolSym v b1 b1'
| b1 == b2' ->
mkIf' boolSym v' b1' b2'
| b1' == b2' ->
mkIf' boolSym v b1 b1'
| b1 == b1' ->
mkIf' boolSym (NBlt blt {_builtinAppOp = OpIntLe}) b1 b2'
| otherwise ->
theIfs
_ ->
theIfs
where
theIfs = mkIf' boolSym v b1 (mkIf' boolSym v' b1' b2')

simplifyComparisons :: InfoTable -> InfoTable
simplifyComparisons tab = mapAllNodes (convertNode tab) tab
31 changes: 20 additions & 11 deletions src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs, simplifyIfs') where

import Data.List qualified as List
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

convertNode :: Node -> Node
convertNode = umap go
convertNode :: Bool -> InfoTable -> Node -> Node
convertNode bFast tab = umap go
where
boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive

go :: Node -> Node
go node = case node of
NCase Case {..}
| isCaseBoolean _caseBranches
&& all (== List.head bodies) (List.tail bodies) ->
List.head bodies
where
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf goIf c
_ -> node

goIf :: Node -> Node -> Node -> Node
goIf v b1 b2
| isTrueConstr b1 && isFalseConstr b2 = v
| bFast && isTrueConstr b1 && isTrueConstr b2 = b1
| bFast && isFalseConstr b1 && isFalseConstr b2 = b1
| not bFast && b1 == b2 = b1
| otherwise = mkIf' boolSym v b1 b2

simplifyIfs' :: Bool -> InfoTable -> InfoTable
simplifyIfs' bFast tab = mapAllNodes (convertNode bFast tab) tab

simplifyIfs :: InfoTable -> InfoTable
simplifyIfs = mapAllNodes convertNode
simplifyIfs = simplifyIfs' False

0 comments on commit 9e3e07d

Please sign in to comment.