diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 431098a07f..d1a426e4ff 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -32,6 +32,7 @@ data TransformationId | SimplifyIfs | SpecializeArgs | CaseFolding + | CasePermutation | FilterUnreachable | OptPhaseEval | OptPhaseExec diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs index 37803a7add..9da953abac 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs @@ -92,6 +92,7 @@ transformationText = \case SimplifyIfs -> strSimplifyIfs SpecializeArgs -> strSpecializeArgs CaseFolding -> strCaseFolding + CasePermutation -> strCasePermutation FilterUnreachable -> strFilterUnreachable OptPhaseEval -> strOptPhaseEval OptPhaseExec -> strOptPhaseExec @@ -210,6 +211,9 @@ strSpecializeArgs = "specialize-args" strCaseFolding :: Text strCaseFolding = "case-folding" +strCasePermutation :: Text +strCasePermutation = "case-permutation" + strFilterUnreachable :: Text strFilterUnreachable = "filter-unreachable" diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 7dbcda3b84..3dacaaaa77 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -100,6 +100,11 @@ isImmediate tab = \case NVar {} -> True NIdt {} -> True NCst {} -> True + NCtr Constr {..} + | Just ci <- lookupConstructorInfo' tab _constrTag -> + let paramsNum = length (takeWhile (isTypeConstr tab) (typeArgs (ci ^. constructorType))) + in length _constrArgs <= paramsNum + | otherwise -> all (isType tab mempty) _constrArgs node@(NApp {}) -> let (h, args) = unfoldApps' node in case h of @@ -125,6 +130,16 @@ isFailNode = \case NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True _ -> False +isTrueConstr :: Node -> Bool +isTrueConstr = \case + NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True + _ -> False + +isFalseConstr :: Node -> Bool +isFalseConstr = \case + NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True + _ -> False + freeVarsSortedMany :: [Node] -> Set Var freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars) @@ -356,26 +371,29 @@ builtinOpArgTypes = \case OpFail -> [mkTypeString'] translateCase :: (Node -> Node -> Node -> a) -> a -> Case -> a -translateCase translateIf dflt Case {..} = case _caseBranches of +translateCase translateIfFun dflt Case {..} = case _caseBranches of [br@CaseBranch {..}] | _caseBranchTag == BuiltinTag TagTrue -> - translateIf _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault) + translateIfFun _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault) [br@CaseBranch {..}] | _caseBranchTag == BuiltinTag TagFalse -> - translateIf _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody) + translateIfFun _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody) [br1, br2] | br1 ^. caseBranchTag == BuiltinTag TagTrue && br2 ^. caseBranchTag == BuiltinTag TagFalse -> - translateIf _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody) + translateIfFun _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody) | br1 ^. caseBranchTag == BuiltinTag TagFalse && br2 ^. caseBranchTag == BuiltinTag TagTrue -> - translateIf _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody) + translateIfFun _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody) _ -> dflt where branchFailure :: Node branchFailure = mkBuiltinApp' OpFail [mkConstant' (ConstString "illegal `if` branch")] +translateCaseIf :: (Node -> Node -> Node -> a) -> Case -> a +translateCaseIf f = translateCase f impossible + checkDepth :: InfoTable -> BinderList Binder -> Int -> Node -> Bool checkDepth tab bl 0 node = isType tab bl node checkDepth tab bl d node = case node of diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index dd223b2004..c615d46bb8 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -31,6 +31,7 @@ import Juvix.Compiler.Core.Transformation.NatToPrimInt import Juvix.Compiler.Core.Transformation.Normalize import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable) import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding @@ -81,6 +82,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts SimplifyIfs -> return . simplifyIfs SpecializeArgs -> return . specializeArgs CaseFolding -> return . caseFolding + CasePermutation -> return . casePermutation FilterUnreachable -> return . filterUnreachable OptPhaseEval -> Phase.Eval.optimize OptPhaseExec -> Phase.Exec.optimize diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs new file mode 100644 index 0000000000..e961ff7b43 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs @@ -0,0 +1,70 @@ +module Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) where + +import Data.HashMap.Strict qualified as HashMap +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +isConstructorTree :: InfoTable -> Case -> Node -> Bool +isConstructorTree tab c node = case run $ runFail $ go mempty node of + Just ctrsMap -> + all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault) + Nothing -> False + where + tags = map (^. caseBranchTag) (c ^. caseBranches) + tagMap = HashMap.fromList (map (\br -> (br ^. caseBranchTag, br ^. caseBranchBody)) (c ^. caseBranches)) + + checkOne :: HashMap Tag Int -> Tag -> Bool + checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of + Just 1 -> True + Nothing -> True + _ -> isImmediate tab (fromJust $ HashMap.lookup tag tagMap) + + checkDefault :: HashMap Tag Int -> Maybe Node -> Bool + checkDefault ctrsMap = \case + Just d -> + sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1 + || isImmediate tab d + where + tags' = HashSet.fromList tags + Nothing -> True + + go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int) + go ctrs = \case + NCtr Constr {..} -> + return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs + NCase Case {..} -> do + ctrs' <- maybe (return ctrs) (go ctrs) _caseDefault + foldM go ctrs' (map (^. caseBranchBody) _caseBranches) + _ -> + fail + +convertNode :: InfoTable -> Node -> Node +convertNode tab = dmap go + where + go :: Node -> Node + go node = case node of + NCase c@Case {..} -> case _caseValue of + NCase c' + | isConstructorTree tab c _caseValue -> + NCase + c' + { _caseBranches = map permuteBranch (c' ^. caseBranches), + _caseDefault = fmap (mkBody c) (c' ^. caseDefault) + } + where + permuteBranch :: CaseBranch -> CaseBranch + permuteBranch br@CaseBranch {..} = + case shift _caseBranchBindersNum (NCase c {_caseValue = mkBottom'}) of + NCase cs -> + over caseBranchBody (mkBody cs) br + _ -> impossible + + mkBody :: Case -> Node -> Node + mkBody cs n = NCase cs {_caseValue = n} + _ -> + node + _ -> node + +casePermutation :: InfoTable -> InfoTable +casePermutation tab = mapAllNodes (convertNode tab) tab diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 97915b50b0..c88b3e7ff5 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -4,10 +4,13 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding +import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons +import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs optimize' :: CoreOptions -> InfoTable -> InfoTable @@ -18,7 +21,10 @@ optimize' CoreOptions {..} tab = ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) . lambdaFolding . doInlining + . simplifyIfs' (_optOptimizationLevel <= 1) + . simplifyComparisons . caseFolding + . casePermutation . letFolding' (isInlineableLambda _optInliningDepth) . lambdaFolding . specializeArgs diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs new file mode 100644 index 0000000000..78ca2d25e8 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs @@ -0,0 +1,95 @@ +module Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +convertNode :: InfoTable -> Node -> Node +convertNode tab = dmap go + where + boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive + + go :: Node -> Node + go node = case node of + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf goIf c + _ -> node + + goIf :: Node -> Node -> Node -> Node + goIf v b1 b2 = case v of + NBlt blt@BuiltinApp {..} + | OpEq <- _builtinAppOp -> + case b2 of + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf (goCmp v b1) c + NBlt blt' + | OpIntLt <- blt' ^. builtinAppOp, + blt ^. builtinAppArgs == blt' ^. builtinAppArgs -> + if + | isFalseConstr b1 -> + b2 + | isTrueConstr b1 -> + NBlt blt {_builtinAppOp = OpIntLe} + | otherwise -> + mkIf' boolSym v b1 b2 + _ -> + mkIf' boolSym v b1 b2 + | OpIntLt <- _builtinAppOp, + isFalseConstr b1 && isTrueConstr b2 -> + NBlt + blt + { _builtinAppOp = OpIntLe, + _builtinAppArgs = reverse _builtinAppArgs + } + | OpIntLe <- _builtinAppOp, + isFalseConstr b1 && isTrueConstr b2 -> + NBlt + blt + { _builtinAppOp = OpIntLt, + _builtinAppArgs = reverse _builtinAppArgs + } + _ -> + mkIf' boolSym v b1 b2 + + goCmp :: Node -> Node -> Node -> Node -> Node -> Node + goCmp v b1 v' b1' b2' = case (v, v') of + (NBlt blt, NBlt blt') + | (OpEq, OpIntLt) <- (blt ^. builtinAppOp, blt' ^. builtinAppOp), + blt ^. builtinAppArgs == blt' ^. builtinAppArgs -> + if + | isFalseConstr b1 && isTrueConstr b1' && isFalseConstr b2' -> + v' + | isTrueConstr b1 && isFalseConstr b1' && isFalseConstr b2' -> + v + | isTrueConstr b1 && isTrueConstr b1' && isFalseConstr b2' -> + NBlt blt {_builtinAppOp = OpIntLe} + | isFalseConstr b1 && isFalseConstr b1' && isTrueConstr b2' -> + NBlt + blt + { _builtinAppOp = OpIntLt, + _builtinAppArgs = reverse (blt ^. builtinAppArgs) + } + | isTrueConstr b1 && isFalseConstr b1' && isTrueConstr b2' -> + NBlt + blt + { _builtinAppOp = OpIntLe, + _builtinAppArgs = reverse (blt ^. builtinAppArgs) + } + | isFalseConstr b1 && isTrueConstr b1' && isTrueConstr b2' -> + mkIf' boolSym v b1 b1' + | b1 == b2' -> + mkIf' boolSym v' b1' b2' + | b1' == b2' -> + mkIf' boolSym v b1 b1' + | b1 == b1' -> + mkIf' boolSym (NBlt blt {_builtinAppOp = OpIntLe}) b1 b2' + | otherwise -> + theIfs + _ -> + theIfs + where + theIfs = mkIf' boolSym v b1 (mkIf' boolSym v' b1' b2') + +simplifyComparisons :: InfoTable -> InfoTable +simplifyComparisons tab = mapAllNodes (convertNode tab) tab diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs index 8b9249bef3..bb01f8c333 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs @@ -1,21 +1,30 @@ -module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where +module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs, simplifyIfs') where -import Data.List qualified as List import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Transformation.Base -convertNode :: Node -> Node -convertNode = umap go +convertNode :: Bool -> InfoTable -> Node -> Node +convertNode bFast tab = umap go where + boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive + go :: Node -> Node go node = case node of - NCase Case {..} - | isCaseBoolean _caseBranches - && all (== List.head bodies) (List.tail bodies) -> - List.head bodies - where - bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf goIf c _ -> node + goIf :: Node -> Node -> Node -> Node + goIf v b1 b2 + | isTrueConstr b1 && isFalseConstr b2 = v + | bFast && isTrueConstr b1 && isTrueConstr b2 = b1 + | bFast && isFalseConstr b1 && isFalseConstr b2 = b1 + | not bFast && b1 == b2 = b1 + | otherwise = mkIf' boolSym v b1 b2 + +simplifyIfs' :: Bool -> InfoTable -> InfoTable +simplifyIfs' bFast tab = mapAllNodes (convertNode bFast tab) tab + simplifyIfs :: InfoTable -> InfoTable -simplifyIfs = mapAllNodes convertNode +simplifyIfs = simplifyIfs' False