Skip to content

Commit

Permalink
Bugfix: compiler looping with the specialize pragma (#2899)
Browse files Browse the repository at this point in the history
* Closes #2884
  • Loading branch information
lukaszcz authored Jul 15, 2024
1 parent 7d2a59c commit 5a76e5d
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 36 deletions.
2 changes: 1 addition & 1 deletion juvix-stdlib
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ nonRecursiveIdents' tab =
HashSet.difference
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
(recursiveIdentsClosure tab)

nonRecursiveIdents :: Module -> HashSet Symbol
nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data TransformationId
| SpecializeArgs
| CaseFolding
| CasePermutation
| ConstantFolding
| FilterUnreachable
| OptPhaseEval
| OptPhaseExec
Expand Down Expand Up @@ -113,6 +114,7 @@ instance TransformationId' TransformationId where
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
ConstantFolding -> strConstantFolding
FilterUnreachable -> strFilterUnreachable
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ strCaseFolding = "case-folding"
strCasePermutation :: Text
strCasePermutation = "case-permutation"

strConstantFolding :: Text
strConstantFolding = "constant-folding"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
import Juvix.Compiler.Core.Transformation.Optimize.ConstantFolding
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
Expand Down Expand Up @@ -96,6 +97,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation
ConstantFolding -> constantFolding
FilterUnreachable -> return . filterUnreachable
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize
Expand Down
12 changes: 6 additions & 6 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ isInlineableLambda inlineDepth md bl node = case node of
False

convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth recSyms md = dmapL go
convertNode inlineDepth nonRecSyms md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
Expand All @@ -37,7 +37,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineNever ->
node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isInlineableLambda inlineDepth md bl def
&& length args >= argsNum ->
mkApps def args
Expand All @@ -57,7 +57,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineAlways -> def
Just InlineNever -> node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
def
| otherwise ->
Expand All @@ -76,7 +76,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Nothing
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
&& checkDepth md bl inlineDepth def ->
NCase cs {_caseValue = mkApps def args}
Expand All @@ -92,9 +92,9 @@ convertNode inlineDepth recSyms md = dmapL go
node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth recSyms md = mapT (const (convertNode inliningDepth recSyms md)) md
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md

inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (recursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) md
11 changes: 4 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ optimize' opts@CoreOptions {..} md =
tab :: InfoTable
tab = computeCombinedInfoTable md

recs :: HashSet Symbol
recs = recursiveIdents' tab

nonRecs :: HashSet Symbol
nonRecs = nonRecursiveIdents' tab

Expand All @@ -48,12 +45,12 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecs

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth recs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' md'
where
recs' =
nonRecs' =
if
| _optOptimizationLevel > 1 -> recursiveIdents md'
| otherwise -> recs
| _optOptimizationLevel > 1 -> nonRecursiveIdents md'
| otherwise -> nonRecs

doSimplification :: Int -> Module -> Module
doSimplification n =
Expand Down
35 changes: 13 additions & 22 deletions tests/Compilation/positive/test056.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,12 @@ mymap {A B} (f : A -> B) : List A -> List B
| (x :: xs) := f x :: mymap f xs;

{-# specialize: [2, 5], inline: false #-}
myf
: {A B : Type}
-> A
-> (A -> A -> B)
-> A
-> B
-> Bool
-> B
myf : {A B : Type} -> A -> (A -> A -> B) -> A -> B -> Bool -> B
| a0 f a b true := f a0 a
| a0 f a b false := b;

{-# inline: false #-}
myf'
: {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
myf' : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
| a0 f a b := myf a0 (f a0) a b true;

sum : List Nat -> Nat
Expand All @@ -40,29 +32,28 @@ funa : {A : Type} -> (A -> A) -> A -> A
{-# specialize: true #-}
type Additive A := mkAdditive {add : A -> A -> A};

type Multiplicative A :=
mkMultiplicative {mul : A -> A -> A};
type Multiplicative A := mkMultiplicative {mul : A -> A -> A};

addNat : Additive Nat := mkAdditive (+);

{-# specialize: true #-}
mulNat : Multiplicative Nat := mkMultiplicative (*);

{-# inline: false #-}
fadd {A} (a : Additive A) (x y : A) : A :=
Additive.add a x y;
fadd {A} (a : Additive A) (x y : A) : A := Additive.add a x y;

{-# inline: false #-}
fmul {A} (m : Multiplicative A) (x y : A) : A :=
Multiplicative.mul m x y;
fmul {A} (m : Multiplicative A) (x y : A) : A := Multiplicative.mul m x y;

{-# specialize: [1] #-}
myfilter {A} (f : A → Bool) : List A → List A
| nil := nil
| (h :: hs) := ite (f h) (h :: myfilter f hs) (myfilter f hs);

main : Nat :=
sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum
(flatten
(mymap
(mymap λ {x := x + 2})
((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
sum (myfilter (const false) [])
+ sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum (flatten (mymap (mymap λ {x := x + 2}) ((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
+ myf 3 (*) 2 5 true
+ myf 1 (+) 2 0 false
+ myf' 7 (const (+)) 2 0
Expand Down

0 comments on commit 5a76e5d

Please sign in to comment.