Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison optimization #2443

Merged
merged 6 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ data TransformationId
| SimplifyIfs
| SpecializeArgs
| CaseFolding
| CasePermutation
| FilterUnreachable
| OptPhaseEval
| OptPhaseExec
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ transformationText = \case
SimplifyIfs -> strSimplifyIfs
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
FilterUnreachable -> strFilterUnreachable
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec
Expand Down Expand Up @@ -210,6 +211,9 @@ strSpecializeArgs = "specialize-args"
strCaseFolding :: Text
strCaseFolding = "case-folding"

strCasePermutation :: Text
strCasePermutation = "case-permutation"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

Expand Down
28 changes: 23 additions & 5 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ isImmediate tab = \case
NVar {} -> True
NIdt {} -> True
NCst {} -> True
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' tab _constrTag ->
let paramsNum = length (takeWhile (isTypeConstr tab) (typeArgs (ci ^. constructorType)))
in length _constrArgs <= paramsNum
| otherwise -> all (isType tab mempty) _constrArgs
node@(NApp {}) ->
let (h, args) = unfoldApps' node
in case h of
Expand All @@ -125,6 +130,16 @@ isFailNode = \case
NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True
_ -> False

isTrueConstr :: Node -> Bool
isTrueConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True
_ -> False

isFalseConstr :: Node -> Bool
isFalseConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
_ -> False

freeVarsSortedMany :: [Node] -> Set Var
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)

Expand Down Expand Up @@ -356,26 +371,29 @@ builtinOpArgTypes = \case
OpFail -> [mkTypeString']

translateCase :: (Node -> Node -> Node -> a) -> a -> Case -> a
translateCase translateIf dflt Case {..} = case _caseBranches of
translateCase translateIfFun dflt Case {..} = case _caseBranches of
[br@CaseBranch {..}]
| _caseBranchTag == BuiltinTag TagTrue ->
translateIf _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault)
translateIfFun _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault)
[br@CaseBranch {..}]
| _caseBranchTag == BuiltinTag TagFalse ->
translateIf _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody)
translateIfFun _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody)
[br1, br2]
| br1 ^. caseBranchTag == BuiltinTag TagTrue
&& br2 ^. caseBranchTag == BuiltinTag TagFalse ->
translateIf _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody)
translateIfFun _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody)
| br1 ^. caseBranchTag == BuiltinTag TagFalse
&& br2 ^. caseBranchTag == BuiltinTag TagTrue ->
translateIf _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody)
translateIfFun _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody)
_ ->
dflt
where
branchFailure :: Node
branchFailure = mkBuiltinApp' OpFail [mkConstant' (ConstString "illegal `if` branch")]

translateCaseIf :: (Node -> Node -> Node -> a) -> Case -> a
translateCaseIf f = translateCase f impossible

checkDepth :: InfoTable -> BinderList Binder -> Int -> Node -> Bool
checkDepth tab bl 0 node = isType tab bl node
checkDepth tab bl d node = case node of
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Juvix.Compiler.Core.Transformation.NatToPrimInt
import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
Expand Down Expand Up @@ -81,6 +82,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SimplifyIfs -> return . simplifyIfs
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation
FilterUnreachable -> return . filterUnreachable
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize
Expand Down
70 changes: 70 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
module Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

isConstructorTree :: InfoTable -> Case -> Node -> Bool
isConstructorTree tab c node = case run $ runFail $ go mempty node of
Just ctrsMap ->
all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault)
Nothing -> False
where
tags = map (^. caseBranchTag) (c ^. caseBranches)
tagMap = HashMap.fromList (map (\br -> (br ^. caseBranchTag, br ^. caseBranchBody)) (c ^. caseBranches))

checkOne :: HashMap Tag Int -> Tag -> Bool
checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of
Just 1 -> True
Nothing -> True
_ -> isImmediate tab (fromJust $ HashMap.lookup tag tagMap)

checkDefault :: HashMap Tag Int -> Maybe Node -> Bool
checkDefault ctrsMap = \case
Just d ->
sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1
|| isImmediate tab d
where
tags' = HashSet.fromList tags
Nothing -> True

go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int)
go ctrs = \case
NCtr Constr {..} ->
return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs
NCase Case {..} -> do
ctrs' <- maybe (return ctrs) (go ctrs) _caseDefault
foldM go ctrs' (map (^. caseBranchBody) _caseBranches)
_ ->
fail

convertNode :: InfoTable -> Node -> Node
convertNode tab = dmap go
where
go :: Node -> Node
go node = case node of
NCase c@Case {..} -> case _caseValue of
NCase c'
| isConstructorTree tab c _caseValue ->
NCase
c'
{ _caseBranches = map permuteBranch (c' ^. caseBranches),
_caseDefault = fmap (mkBody c) (c' ^. caseDefault)
}
where
permuteBranch :: CaseBranch -> CaseBranch
permuteBranch br@CaseBranch {..} =
case shift _caseBranchBindersNum (NCase c {_caseValue = mkBottom'}) of
NCase cs ->
over caseBranchBody (mkBody cs) br
_ -> impossible

mkBody :: Case -> Node -> Node
mkBody cs n = NCase cs {_caseValue = n}
_ ->
node
_ -> node

casePermutation :: InfoTable -> InfoTable
casePermutation tab = mapAllNodes (convertNode tab) tab
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs

optimize' :: CoreOptions -> InfoTable -> InfoTable
Expand All @@ -18,7 +21,10 @@ optimize' CoreOptions {..} tab =
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. doInlining
. simplifyIfs' (_optOptimizationLevel <= 1)
. simplifyComparisons
. caseFolding
. casePermutation
. letFolding' (isInlineableLambda _optInliningDepth)
. lambdaFolding
. specializeArgs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) where

import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

convertNode :: InfoTable -> Node -> Node
convertNode tab = dmap go
where
boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive

go :: Node -> Node
go node = case node of
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf goIf c
_ -> node

goIf :: Node -> Node -> Node -> Node
goIf v b1 b2 = case v of
NBlt blt@BuiltinApp {..}
| OpEq <- _builtinAppOp ->
case b2 of
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf (goCmp v b1) c
NBlt blt'
| OpIntLt <- blt' ^. builtinAppOp,
blt ^. builtinAppArgs == blt' ^. builtinAppArgs ->
if
| isFalseConstr b1 ->
b2
| isTrueConstr b1 ->
NBlt blt {_builtinAppOp = OpIntLe}
| otherwise ->
mkIf' boolSym v b1 b2
_ ->
mkIf' boolSym v b1 b2
| OpIntLt <- _builtinAppOp,
isFalseConstr b1 && isTrueConstr b2 ->
NBlt
blt
{ _builtinAppOp = OpIntLe,
_builtinAppArgs = reverse _builtinAppArgs
}
| OpIntLe <- _builtinAppOp,
isFalseConstr b1 && isTrueConstr b2 ->
NBlt
blt
{ _builtinAppOp = OpIntLt,
_builtinAppArgs = reverse _builtinAppArgs
}
_ ->
mkIf' boolSym v b1 b2

goCmp :: Node -> Node -> Node -> Node -> Node -> Node
goCmp v b1 v' b1' b2' = case (v, v') of
(NBlt blt, NBlt blt')
| (OpEq, OpIntLt) <- (blt ^. builtinAppOp, blt' ^. builtinAppOp),
blt ^. builtinAppArgs == blt' ^. builtinAppArgs ->
if
| isFalseConstr b1 && isTrueConstr b1' && isFalseConstr b2' ->
v'
| isTrueConstr b1 && isFalseConstr b1' && isFalseConstr b2' ->
v
| isTrueConstr b1 && isTrueConstr b1' && isFalseConstr b2' ->
NBlt blt {_builtinAppOp = OpIntLe}
| isFalseConstr b1 && isFalseConstr b1' && isTrueConstr b2' ->
NBlt
blt
{ _builtinAppOp = OpIntLt,
_builtinAppArgs = reverse (blt ^. builtinAppArgs)
}
| isTrueConstr b1 && isFalseConstr b1' && isTrueConstr b2' ->
NBlt
blt
{ _builtinAppOp = OpIntLe,
_builtinAppArgs = reverse (blt ^. builtinAppArgs)
}
| isFalseConstr b1 && isTrueConstr b1' && isTrueConstr b2' ->
mkIf' boolSym v b1 b1'
| b1 == b2' ->
mkIf' boolSym v' b1' b2'
| b1' == b2' ->
mkIf' boolSym v b1 b1'
| b1 == b1' ->
mkIf' boolSym (NBlt blt {_builtinAppOp = OpIntLe}) b1 b2'
| otherwise ->
theIfs
_ ->
theIfs
where
theIfs = mkIf' boolSym v b1 (mkIf' boolSym v' b1' b2')

simplifyComparisons :: InfoTable -> InfoTable
simplifyComparisons tab = mapAllNodes (convertNode tab) tab
31 changes: 20 additions & 11 deletions src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs, simplifyIfs') where

import Data.List qualified as List
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

convertNode :: Node -> Node
convertNode = umap go
convertNode :: Bool -> InfoTable -> Node -> Node
convertNode bFast tab = umap go
where
boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive

go :: Node -> Node
go node = case node of
NCase Case {..}
| isCaseBoolean _caseBranches
&& all (== List.head bodies) (List.tail bodies) ->
List.head bodies
where
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
NCase c@Case {..}
| isCaseBoolean _caseBranches ->
translateCaseIf goIf c
_ -> node

goIf :: Node -> Node -> Node -> Node
goIf v b1 b2
| isTrueConstr b1 && isFalseConstr b2 = v
| bFast && isTrueConstr b1 && isTrueConstr b2 = b1
| bFast && isFalseConstr b1 && isFalseConstr b2 = b1
| not bFast && b1 == b2 = b1
| otherwise = mkIf' boolSym v b1 b2

simplifyIfs' :: Bool -> InfoTable -> InfoTable
simplifyIfs' bFast tab = mapAllNodes (convertNode bFast tab) tab

simplifyIfs :: InfoTable -> InfoTable
simplifyIfs = mapAllNodes convertNode
simplifyIfs = simplifyIfs' False