From 88e5c3850f3ec630973a06797c395dd7bfbd7a5d Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Tue, 10 Oct 2023 20:42:12 +0200 Subject: [PATCH 1/6] case permutation (wip) --- .../Optimize/CasePermutation.hs | 65 +++++++++++++++++++ .../Transformation/Optimize/Phase/Main.hs | 2 + 2 files changed, 67 insertions(+) create mode 100644 src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs new file mode 100644 index 0000000000..b234f940eb --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs @@ -0,0 +1,65 @@ +module Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) where + +import Data.HashMap.Strict qualified as HashMap +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +isConstructorTree :: Case -> Node -> Bool +isConstructorTree c node = case run $ runFail $ go mempty node of + Just ctrsMap -> + all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault) + Nothing -> False + where + tags = map (^. caseBranchTag) (c ^. caseBranches) + + checkOne :: HashMap Tag Int -> Tag -> Bool + checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of + Just 1 -> True + Nothing -> True + _ -> {- isImmediate -} False + + checkDefault :: HashMap Tag Int -> Maybe Node -> Bool + checkDefault ctrsMap = \case + Just {} -> + -- or isImmediate + sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1 + where + tags' = HashSet.fromList tags + Nothing -> True + + go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int) + go ctrs = \case + NCtr Constr {..} -> return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs + NCase Case {..} -> foldM go ctrs (map (^. caseBranchBody) _caseBranches) + _ -> fail + +convertNode :: Node -> Node +convertNode = dmap go + where + go :: Node -> Node + go node = case node of + NCase c@Case {..} -> case _caseValue of + NCase c' + | isConstructorTree c _caseValue -> + NCase + c' + { _caseBranches = map permuteBranch (c' ^. caseBranches), + _caseDefault = fmap (mkBody c) (c' ^. caseDefault) + } + where + permuteBranch :: CaseBranch -> CaseBranch + permuteBranch br@CaseBranch {..} = + case shift _caseBranchBindersNum (NCase c {_caseValue = mkBottom'}) of + NCase cs -> + over caseBranchBody (mkBody cs) br + _ -> impossible + + mkBody :: Case -> Node -> Node + mkBody cs n = NCase cs {_caseValue = n} + _ -> + node + _ -> node + +casePermutation :: InfoTable -> InfoTable +casePermutation = mapAllNodes convertNode diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 97915b50b0..ac6777971a 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -4,6 +4,7 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding @@ -19,6 +20,7 @@ optimize' CoreOptions {..} tab = . lambdaFolding . doInlining . caseFolding + . casePermutation . letFolding' (isInlineableLambda _optInliningDepth) . lambdaFolding . specializeArgs From fe72286e79afa14376b5823abee200e4fe4e0e54 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 11 Oct 2023 12:31:47 +0200 Subject: [PATCH 2/6] fix case permutation --- src/Juvix/Compiler/Core/Data/TransformationId.hs | 1 + .../Compiler/Core/Data/TransformationId/Parser.hs | 4 ++++ src/Juvix/Compiler/Core/Transformation.hs | 2 ++ .../Core/Transformation/Optimize/CasePermutation.hs | 10 +++++++--- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 431098a07f..d1a426e4ff 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -32,6 +32,7 @@ data TransformationId | SimplifyIfs | SpecializeArgs | CaseFolding + | CasePermutation | FilterUnreachable | OptPhaseEval | OptPhaseExec diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs index 37803a7add..9da953abac 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs @@ -92,6 +92,7 @@ transformationText = \case SimplifyIfs -> strSimplifyIfs SpecializeArgs -> strSpecializeArgs CaseFolding -> strCaseFolding + CasePermutation -> strCasePermutation FilterUnreachable -> strFilterUnreachable OptPhaseEval -> strOptPhaseEval OptPhaseExec -> strOptPhaseExec @@ -210,6 +211,9 @@ strSpecializeArgs = "specialize-args" strCaseFolding :: Text strCaseFolding = "case-folding" +strCasePermutation :: Text +strCasePermutation = "case-permutation" + strFilterUnreachable :: Text strFilterUnreachable = "filter-unreachable" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index dd223b2004..c615d46bb8 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -31,6 +31,7 @@ import Juvix.Compiler.Core.Transformation.NatToPrimInt import Juvix.Compiler.Core.Transformation.Normalize import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable) import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding @@ -81,6 +82,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts SimplifyIfs -> return . simplifyIfs SpecializeArgs -> return . specializeArgs CaseFolding -> return . caseFolding + CasePermutation -> return . casePermutation FilterUnreachable -> return . filterUnreachable OptPhaseEval -> Phase.Eval.optimize OptPhaseExec -> Phase.Exec.optimize diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs index b234f940eb..a8ef691521 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs @@ -30,9 +30,13 @@ isConstructorTree c node = case run $ runFail $ go mempty node of go :: (Member Fail r) => HashMap Tag Int -> Node -> Sem r (HashMap Tag Int) go ctrs = \case - NCtr Constr {..} -> return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs - NCase Case {..} -> foldM go ctrs (map (^. caseBranchBody) _caseBranches) - _ -> fail + NCtr Constr {..} -> + return $ HashMap.alter (Just . maybe 1 (+ 1)) _constrTag ctrs + NCase Case {..} -> do + ctrs' <- maybe (return ctrs) (go ctrs) _caseDefault + foldM go ctrs' (map (^. caseBranchBody) _caseBranches) + _ -> + fail convertNode :: Node -> Node convertNode = dmap go From 20748c18c1c7fe8b22350d9ccc9cb2a22ebf3ca7 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 11 Oct 2023 15:55:02 +0200 Subject: [PATCH 3/6] immediate --- src/Juvix/Compiler/Core/Extra/Utils.hs | 5 +++++ .../Optimize/CasePermutation.hs | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 7dbcda3b84..3295904fc2 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -100,6 +100,11 @@ isImmediate tab = \case NVar {} -> True NIdt {} -> True NCst {} -> True + NCtr Constr {..} + | Just ci <- lookupConstructorInfo' tab _constrTag -> + let paramsNum = length (takeWhile (isTypeConstr tab) (typeArgs (ci ^. constructorType))) + in length _constrArgs <= paramsNum + | otherwise -> all (isType tab mempty) _constrArgs node@(NApp {}) -> let (h, args) = unfoldApps' node in case h of diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs index a8ef691521..e961ff7b43 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/CasePermutation.hs @@ -5,25 +5,26 @@ import Data.HashSet qualified as HashSet import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Transformation.Base -isConstructorTree :: Case -> Node -> Bool -isConstructorTree c node = case run $ runFail $ go mempty node of +isConstructorTree :: InfoTable -> Case -> Node -> Bool +isConstructorTree tab c node = case run $ runFail $ go mempty node of Just ctrsMap -> all (checkOne ctrsMap) tags && checkDefault ctrsMap (c ^. caseDefault) Nothing -> False where tags = map (^. caseBranchTag) (c ^. caseBranches) + tagMap = HashMap.fromList (map (\br -> (br ^. caseBranchTag, br ^. caseBranchBody)) (c ^. caseBranches)) checkOne :: HashMap Tag Int -> Tag -> Bool checkOne ctrsMap tag = case HashMap.lookup tag ctrsMap of Just 1 -> True Nothing -> True - _ -> {- isImmediate -} False + _ -> isImmediate tab (fromJust $ HashMap.lookup tag tagMap) checkDefault :: HashMap Tag Int -> Maybe Node -> Bool checkDefault ctrsMap = \case - Just {} -> - -- or isImmediate + Just d -> sum (HashMap.filterWithKey (\k _ -> not (HashSet.member k tags')) ctrsMap) <= 1 + || isImmediate tab d where tags' = HashSet.fromList tags Nothing -> True @@ -38,14 +39,14 @@ isConstructorTree c node = case run $ runFail $ go mempty node of _ -> fail -convertNode :: Node -> Node -convertNode = dmap go +convertNode :: InfoTable -> Node -> Node +convertNode tab = dmap go where go :: Node -> Node go node = case node of NCase c@Case {..} -> case _caseValue of NCase c' - | isConstructorTree c _caseValue -> + | isConstructorTree tab c _caseValue -> NCase c' { _caseBranches = map permuteBranch (c' ^. caseBranches), @@ -66,4 +67,4 @@ convertNode = dmap go _ -> node casePermutation :: InfoTable -> InfoTable -casePermutation = mapAllNodes convertNode +casePermutation tab = mapAllNodes (convertNode tab) tab From f36520d3ac494d4487fb767d390eec2faa586cfd Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 11 Oct 2023 18:05:25 +0200 Subject: [PATCH 4/6] simplify comparisons --- src/Juvix/Compiler/Core/Extra/Utils.hs | 23 ++++++-- .../Transformation/Optimize/Phase/Main.hs | 2 + .../Optimize/SimplifyComparisons.hs | 57 +++++++++++++++++++ 3 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 3295904fc2..3dacaaaa77 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -130,6 +130,16 @@ isFailNode = \case NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True _ -> False +isTrueConstr :: Node -> Bool +isTrueConstr = \case + NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True + _ -> False + +isFalseConstr :: Node -> Bool +isFalseConstr = \case + NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True + _ -> False + freeVarsSortedMany :: [Node] -> Set Var freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars) @@ -361,26 +371,29 @@ builtinOpArgTypes = \case OpFail -> [mkTypeString'] translateCase :: (Node -> Node -> Node -> a) -> a -> Case -> a -translateCase translateIf dflt Case {..} = case _caseBranches of +translateCase translateIfFun dflt Case {..} = case _caseBranches of [br@CaseBranch {..}] | _caseBranchTag == BuiltinTag TagTrue -> - translateIf _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault) + translateIfFun _caseValue (br ^. caseBranchBody) (fromMaybe branchFailure _caseDefault) [br@CaseBranch {..}] | _caseBranchTag == BuiltinTag TagFalse -> - translateIf _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody) + translateIfFun _caseValue (fromMaybe branchFailure _caseDefault) (br ^. caseBranchBody) [br1, br2] | br1 ^. caseBranchTag == BuiltinTag TagTrue && br2 ^. caseBranchTag == BuiltinTag TagFalse -> - translateIf _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody) + translateIfFun _caseValue (br1 ^. caseBranchBody) (br2 ^. caseBranchBody) | br1 ^. caseBranchTag == BuiltinTag TagFalse && br2 ^. caseBranchTag == BuiltinTag TagTrue -> - translateIf _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody) + translateIfFun _caseValue (br2 ^. caseBranchBody) (br1 ^. caseBranchBody) _ -> dflt where branchFailure :: Node branchFailure = mkBuiltinApp' OpFail [mkConstant' (ConstString "illegal `if` branch")] +translateCaseIf :: (Node -> Node -> Node -> a) -> Case -> a +translateCaseIf f = translateCase f impossible + checkDepth :: InfoTable -> BinderList Binder -> Int -> Node -> Bool checkDepth tab bl 0 node = isType tab bl node checkDepth tab bl d node = case node of diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index ac6777971a..185210dc94 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -9,6 +9,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding +import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs optimize' :: CoreOptions -> InfoTable -> InfoTable @@ -19,6 +20,7 @@ optimize' CoreOptions {..} tab = ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) . lambdaFolding . doInlining + . simplifyComparisons . caseFolding . casePermutation . letFolding' (isInlineableLambda _optInliningDepth) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs new file mode 100644 index 0000000000..383b33ea7d --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs @@ -0,0 +1,57 @@ +module Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +convertNode :: InfoTable -> Node -> Node +convertNode tab = dmap go + where + boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive + + go :: Node -> Node + go node = case node of + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf goIf c + _ -> node + + goIf :: Node -> Node -> Node -> Node + goIf v b1 b2 = case v of + NBlt BuiltinApp {..} + | OpEq <- _builtinAppOp -> + case b2 of + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf (goCmp v b1) c + _ -> + mkIf' boolSym v b1 b2 + _ -> + mkIf' boolSym v b1 b2 + + goCmp :: Node -> Node -> Node -> Node -> Node -> Node + goCmp v b1 v' b1' b2' = case (v, v') of + (NBlt blt, NBlt blt') + | (OpEq, OpIntLt) <- (blt ^. builtinAppOp, blt' ^. builtinAppOp), + blt ^. builtinAppArgs == blt' ^. builtinAppArgs -> + if + | isFalseConstr b1 && isTrueConstr b1' && isFalseConstr b2' -> + v' + | isTrueConstr b1 && isFalseConstr b1' && isFalseConstr b2' -> + v + | isTrueConstr b1 && isTrueConstr b1' && isFalseConstr b2' -> + NBlt blt {_builtinAppOp = OpIntLe} + | b1 == b2' -> + mkIf' boolSym v' b1' b2' + | b1' == b2' -> + mkIf' boolSym v b1 b1' + | b1 == b1' -> + mkIf' boolSym (NBlt blt {_builtinAppOp = OpIntLe}) b1 b2' + | otherwise -> + theIfs + _ -> + theIfs + where + theIfs = mkIf' boolSym v b1 (mkIf' boolSym v' b1' b2') + +simplifyComparisons :: InfoTable -> InfoTable +simplifyComparisons tab = mapAllNodes (convertNode tab) tab From de8298d8c8a63eff1a1bff4c7c8ab52f182777d4 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 12 Oct 2023 11:09:54 +0200 Subject: [PATCH 5/6] simplify ifs --- .../Transformation/Optimize/Phase/Main.hs | 2 ++ .../Optimize/SimplifyComparisons.hs | 30 +++++++++++++++++- .../Transformation/Optimize/SimplifyIfs.hs | 31 ++++++++++++------- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 185210dc94..c88b3e7ff5 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -10,6 +10,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons +import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs optimize' :: CoreOptions -> InfoTable -> InfoTable @@ -20,6 +21,7 @@ optimize' CoreOptions {..} tab = ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) . lambdaFolding . doInlining + . simplifyIfs' (_optOptimizationLevel <= 1) . simplifyComparisons . caseFolding . casePermutation diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs index 383b33ea7d..dffbed3395 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs @@ -17,7 +17,7 @@ convertNode tab = dmap go goIf :: Node -> Node -> Node -> Node goIf v b1 b2 = case v of - NBlt BuiltinApp {..} + NBlt blt@BuiltinApp {..} | OpEq <- _builtinAppOp -> case b2 of NCase c@Case {..} @@ -25,6 +25,20 @@ convertNode tab = dmap go translateCaseIf (goCmp v b1) c _ -> mkIf' boolSym v b1 b2 + | OpIntLt <- _builtinAppOp, + isFalseConstr b1 && isTrueConstr b2 -> + NBlt + blt + { _builtinAppOp = OpIntLe, + _builtinAppArgs = reverse _builtinAppArgs + } + | OpIntLe <- _builtinAppOp, + isFalseConstr b1 && isTrueConstr b2 -> + NBlt + blt + { _builtinAppOp = OpIntLt, + _builtinAppArgs = reverse _builtinAppArgs + } _ -> mkIf' boolSym v b1 b2 @@ -40,6 +54,20 @@ convertNode tab = dmap go v | isTrueConstr b1 && isTrueConstr b1' && isFalseConstr b2' -> NBlt blt {_builtinAppOp = OpIntLe} + | isFalseConstr b1 && isFalseConstr b1' && isTrueConstr b2' -> + NBlt + blt + { _builtinAppOp = OpIntLt, + _builtinAppArgs = reverse (blt ^. builtinAppArgs) + } + | isTrueConstr b1 && isFalseConstr b1' && isTrueConstr b2' -> + NBlt + blt + { _builtinAppOp = OpIntLe, + _builtinAppArgs = reverse (blt ^. builtinAppArgs) + } + | isFalseConstr b1 && isTrueConstr b1' && isTrueConstr b2' -> + mkIf' boolSym v b1 b1' | b1 == b2' -> mkIf' boolSym v' b1' b2' | b1' == b2' -> diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs index 8b9249bef3..bb01f8c333 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyIfs.hs @@ -1,21 +1,30 @@ -module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where +module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs, simplifyIfs') where -import Data.List qualified as List import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Transformation.Base -convertNode :: Node -> Node -convertNode = umap go +convertNode :: Bool -> InfoTable -> Node -> Node +convertNode bFast tab = umap go where + boolSym = lookupConstructorInfo tab (BuiltinTag TagTrue) ^. constructorInductive + go :: Node -> Node go node = case node of - NCase Case {..} - | isCaseBoolean _caseBranches - && all (== List.head bodies) (List.tail bodies) -> - List.head bodies - where - bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault + NCase c@Case {..} + | isCaseBoolean _caseBranches -> + translateCaseIf goIf c _ -> node + goIf :: Node -> Node -> Node -> Node + goIf v b1 b2 + | isTrueConstr b1 && isFalseConstr b2 = v + | bFast && isTrueConstr b1 && isTrueConstr b2 = b1 + | bFast && isFalseConstr b1 && isFalseConstr b2 = b1 + | not bFast && b1 == b2 = b1 + | otherwise = mkIf' boolSym v b1 b2 + +simplifyIfs' :: Bool -> InfoTable -> InfoTable +simplifyIfs' bFast tab = mapAllNodes (convertNode bFast tab) tab + simplifyIfs :: InfoTable -> InfoTable -simplifyIfs = mapAllNodes convertNode +simplifyIfs = simplifyIfs' False From 8977c3fa7269bd77b82ea5a22890754b1befa8e4 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 12 Oct 2023 11:22:31 +0200 Subject: [PATCH 6/6] simplify comparisons --- .../Transformation/Optimize/SimplifyComparisons.hs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs index dffbed3395..78ca2d25e8 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SimplifyComparisons.hs @@ -23,6 +23,16 @@ convertNode tab = dmap go NCase c@Case {..} | isCaseBoolean _caseBranches -> translateCaseIf (goCmp v b1) c + NBlt blt' + | OpIntLt <- blt' ^. builtinAppOp, + blt ^. builtinAppArgs == blt' ^. builtinAppArgs -> + if + | isFalseConstr b1 -> + b2 + | isTrueConstr b1 -> + NBlt blt {_builtinAppOp = OpIntLe} + | otherwise -> + mkIf' boolSym v b1 b2 _ -> mkIf' boolSym v b1 b2 | OpIntLt <- _builtinAppOp,