Skip to content

Commit

Permalink
Numeric, ordering and equality traits (#2433)
Browse files Browse the repository at this point in the history
* Adapts to anoma/juvix-stdlib#86
* Adds a pass in `toEvalTransformations` to automatically inline all
record projection functions, regardless of the optimization level. This
is necessary to ensure that arithmetic operations and comparisons on
`Nat` or `Int` are always represented directly with the corresponding
built-in Core functions. This is generally highly desirable and required
for the Geb target.
* Adds the `inline: always` pragma which indicates that a function
should always be inlined during the mandatory inlining phase, regardless
of optimization level.
  • Loading branch information
lukaszcz authored Oct 9, 2023
1 parent 0e4c27b commit 60a191b
Show file tree
Hide file tree
Showing 79 changed files with 170 additions and 141 deletions.
17 changes: 7 additions & 10 deletions examples/demo/Demo.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ module Demo;

-- standard library prelude
import Stdlib.Prelude open;
-- for comparisons on natural numbers
import Stdlib.Data.Nat.Ord open;
-- for Ordering

even : Nat → Bool
| zero := true
Expand Down Expand Up @@ -35,13 +32,13 @@ preorder : {A : Type} → Tree A → List A
| (node x l r) := x :: nil ++ preorder l ++ preorder r;

terminating
sort : {A : Type} → (A → A → Ordering) → List A → List A
| _ nil := nil
| _ xs@(_ :: nil) := xs
| {A} cmp xs :=
sort {A} {{Ord A}} : List A → List A
| nil := nil
| xs@(_ :: nil) := xs
| xs :=
uncurry
(merge {{mkOrd cmp}})
(both (sort cmp) (splitAt (div (length xs) 2) xs));
merge
(both sort (splitAt (div (length xs) 2) xs));

printNatListLn : List Nat → IO
| nil := printStringLn "nil"
Expand All @@ -51,6 +48,6 @@ printNatListLn : List Nat → IO
main : IO :=
printStringLn "Hello!"
>> printNatListLn (preorder (mirror tree))
>> printNatListLn (sort compare (preorder (mirror tree)))
>> printNatListLn (sort (preorder (mirror tree)))
>> printNatLn (log2 3)
>> printNatLn (log2 130);
1 change: 0 additions & 1 deletion examples/midsquare/MidSquareHash.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
module MidSquareHash;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

--- `pow N` is 2 ^ N
pow : Nat -> Nat
Expand Down
1 change: 0 additions & 1 deletion examples/midsquare/MidSquareHashUnrolled.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
module MidSquareHashUnrolled;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

--- `powN` is 2 ^ N
pow0 : Nat := 1;
Expand Down
2 changes: 0 additions & 2 deletions examples/milestone/Bank/Bank.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ module Bank;
import Stdlib.Prelude open;
import Stdlib.Debug.Fail open;

import Stdlib.Data.Nat.Ord open;

import Stdlib.Data.Nat as Nat;

Address : Type := Nat;
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/Collatz/Collatz.juvix
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Collatz;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

collatzNext (n : Nat) : Nat :=
if (mod n 2 == 0) (div n 2) (3 * n + 1);
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/CLI/TicTacToe.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
--- The module Logic.Game contains the game logic.
module CLI.TicTacToe;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;
import Logic.Game open;

Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/Logic/Extra.juvix
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
--- Some generic helper definitions.
module Logic.Extra;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;

--- Concatenates a list of strings
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/Logic/Game.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
--- diagonal row is the winner. It is a solved game, with a forced draw assuming best play from both players.
module Logic.Game;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;
import Logic.Extra open public;
import Logic.Board open public;
Expand Down
7 changes: 1 addition & 6 deletions examples/milestone/TicTacToe/Logic/Square.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Logic.Square;

import Stdlib.Prelude open;
import Logic.Symbol open;
import Stdlib.Data.Nat.Ord open;
import Logic.Extra open;

--- A square is each of the holes in a board
Expand All @@ -24,9 +23,5 @@ showSquare : Square → String
| (occupied s) := " " ++str showSymbol s ++str " ";

replace (player : Symbol) (k : Nat) : Square → Square
| (empty n) :=
if
(n Stdlib.Data.Nat.Ord.== k)
(occupied player)
(empty n)
| (empty n) := if (n == k) (occupied player) (empty n)
| s := s;
2 changes: 0 additions & 2 deletions examples/milestone/Tutorial/Tutorial.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@ module Tutorial;

-- import the standard library prelude and bring it into scope
import Stdlib.Prelude open;
-- bring comparison operators on Nat into scope
import Stdlib.Data.Nat.Ord open;

main : IO := printStringLn "Hello world!";
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Backend/Geb/Translation/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fromCore :: Core.InfoTable -> (Morphism, Object)
fromCore tab = case tab ^. Core.infoMain of
Just sym ->
let node = Core.lookupIdentifierNode tab sym
syms = reverse $ filter (/= sym) $ Core.createIdentDependencyInfo tab ^. Core.depInfoTopSort
syms = reverse $ filter (/= sym) $ Core.createCallGraph tab ^. Core.depInfoTopSort
idents = map (Core.lookupIdentifierInfo tab) syms
morph = run . runReader emptyEnv $ goIdents node idents
obj = convertType $ Info.getNodeType node
Expand Down
30 changes: 27 additions & 3 deletions src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Data.IdentDependencyInfo where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Extra.Utils
Expand All @@ -9,8 +10,8 @@ import Juvix.Compiler.Core.Language
type IdentDependencyInfo = DependencyInfo Symbol

-- | Compute the call graph
createIdentDependencyInfo :: InfoTable -> IdentDependencyInfo
createIdentDependencyInfo tab = createDependencyInfo graph startVertices
createCallGraph :: InfoTable -> IdentDependencyInfo
createCallGraph tab = createDependencyInfo graph startVertices
where
graph :: HashMap Symbol (HashSet Symbol)
graph =
Expand All @@ -27,5 +28,28 @@ createIdentDependencyInfo tab = createDependencyInfo graph startVertices
syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMain)

createSymbolDependencyInfo :: InfoTable -> IdentDependencyInfo
createSymbolDependencyInfo tab = createDependencyInfo graph startVertices
where
graph :: HashMap Symbol (HashSet Symbol)
graph =
fmap
( \IdentifierInfo {..} ->
getSymbols tab (lookupIdentifierNode tab _identifierSymbol)
)
(tab ^. infoIdentifiers)
<> foldr
( \ConstructorInfo {..} ->
HashMap.insert _constructorInductive (getSymbols tab _constructorType)
)
mempty
(tab ^. infoConstructors)

startVertices :: HashSet Symbol
startVertices = HashSet.fromList syms

syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMain)

recursiveIdents :: InfoTable -> HashSet Symbol
recursiveIdents = nodesOnCycles . createIdentDependencyInfo
recursiveIdents = nodesOnCycles . createCallGraph
35 changes: 24 additions & 11 deletions src/Juvix/Compiler/Core/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,33 @@ filterByFile f t =
matchesLocation :: Maybe Location -> Bool
matchesLocation l = l ^? _Just . intervalFile == Just f

-- | Prunes the orphaned entries of identMap and indentContext, i.e., ones that
-- have no corresponding entries in infoIdentifiers or infoInductives
-- | Prunes the orphaned entries of identMap, indentContext and
-- infoConstructors, i.e., ones that have no corresponding entries in
-- infoIdentifiers or infoInductives
pruneInfoTable :: InfoTable -> InfoTable
pruneInfoTable tab =
over
identMap
( HashMap.filter
( \case
IdentFun s -> HashMap.member s (tab ^. infoIdentifiers)
IdentInd s -> HashMap.member s (tab ^. infoInductives)
IdentConstr tag -> HashMap.member tag (tab ^. infoConstructors)
)
)
pruneIdentMap
$ over
infoConstructors
( HashMap.filter
( \ConstructorInfo {..} ->
HashMap.member _constructorInductive (tab ^. infoInductives)
)
)
$ over
identContext
(HashMap.filterWithKey (\s _ -> HashMap.member s (tab ^. infoIdentifiers)))
tab
where
pruneIdentMap :: InfoTable -> InfoTable
pruneIdentMap tab' =
over
identMap
( HashMap.filter
( \case
IdentFun s -> HashMap.member s (tab' ^. infoIdentifiers)
IdentInd s -> HashMap.member s (tab' ^. infoInductives)
IdentConstr tag -> HashMap.member tag (tab' ^. infoConstructors)
)
)
tab'
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ data TransformationId
| LambdaFolding
| LetHoisting
| Inlining
| MandatoryInlining
| FoldTypeSynonyms
| CaseCallLifting
| SimplifyIfs
Expand Down Expand Up @@ -75,7 +76,7 @@ toNormalizeTransformations :: [TransformationId]
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]

toVampIRTransformations :: [TransformationId]
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]
toVampIRTransformations = toEvalTransformations ++ [FilterUnreachable, CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]

toStrippedTransformations :: [TransformationId]
toStrippedTransformations =
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ transformationText = \case
LambdaFolding -> strLambdaFolding
LetHoisting -> strLetHoisting
Inlining -> strInlining
MandatoryInlining -> strMandatoryInlining
FoldTypeSynonyms -> strFoldTypeSynonyms
CaseCallLifting -> strCaseCallLifting
SimplifyIfs -> strSimplifyIfs
Expand Down Expand Up @@ -191,6 +192,9 @@ strLambdaFolding = "lambda-folding"
strInlining :: Text
strInlining = "inlining"

strMandatoryInlining :: Text
strMandatoryInlining = "mandatory-inlining"

strFoldTypeSynonyms :: Text
strFoldTypeSynonyms = "fold-type-synonyms"

Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ geval opts herr ctx env0 = eval' env0
Closure env' (NLam (Lambda i' bi b)) ->
let !v = eval' env r in evalBody i' bi env' v b
lv
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
let !v = eval' env r in goNormApp i lv v
| otherwise ->
evalError "invalid application" (mkApp i lv (substEnv env r))
Expand All @@ -106,7 +106,7 @@ geval opts herr ctx env0 = eval' env0
NCtr (Constr _ tag args) ->
branch n env args tag def bs
v'
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
goNormCase env i sym v' bs def
| otherwise ->
evalError "matching on non-data" (substEnv env (mkCase i sym v' bs def))
Expand Down Expand Up @@ -214,7 +214,7 @@ geval opts herr ctx env0 = eval' env0
(Just v1, Just v2) ->
toNode (v1 `op` v2)
_
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
mkBuiltinApp' opcode [vl, vr]
| otherwise ->
evalError "wrong operand type" n
Expand Down
16 changes: 16 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,29 @@ nodeIdents f = ufoldA reassemble go
NIdt i -> NIdt <$> f i
n -> pure n

getInductives :: Node -> HashSet Symbol
getInductives n = HashSet.fromList (n ^.. nodeInductives)

nodeInductives :: Traversal' Node Symbol
nodeInductives f = ufoldA reassemble go
where
go = \case
NTyp ty -> NTyp <$> traverseOf typeConstrSymbol f ty
n -> pure n

getSymbols :: InfoTable -> Node -> HashSet Symbol
getSymbols tab = gather go mempty
where
go :: HashSet Symbol -> Node -> HashSet Symbol
go acc = \case
NTyp TypeConstr {..} -> HashSet.insert _typeConstrSymbol acc
NIdt Ident {..} -> HashSet.insert _identSymbol acc
NCase Case {..} -> HashSet.insert _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' tab _constrTag ->
HashSet.insert (ci ^. constructorInductive) acc
_ -> acc

-- | Prism for NRec
_NRec :: SimpleFold Node LetRec
_NRec f = \case
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnre
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
Expand Down Expand Up @@ -74,6 +75,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
LambdaFolding -> return . lambdaFolding
LetHoisting -> return . letHoisting
Inlining -> inlining
MandatoryInlining -> return . mandatoryInlining
FoldTypeSynonyms -> return . foldTypeSynonyms
CaseCallLifting -> return . caseCallLifting
SimplifyIfs -> return . simplifyIfs
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Transformation/LetHoisting.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- Moves al let expressions at the top, just after the top lambdas. This
-- Moves all let expressions at the top, just after the top lambdas. This
-- transformation assumes:
-- - There are no LetRecs, Lambdas (other than the ones at the top), nor Match.
-- - Case nodes do not have binders.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Transformation.Base

filterUnreachable :: InfoTable -> InfoTable
filterUnreachable tab = pruneInfoTable $ over infoIdentifiers goFilter tab
filterUnreachable tab =
pruneInfoTable $
over infoInductives goFilter $
over infoIdentifiers goFilter tab
where
depInfo = createIdentDependencyInfo tab
depInfo = createSymbolDependencyInfo tab

goFilter :: HashMap Symbol IdentifierInfo -> HashMap Symbol IdentifierInfo
goFilter idents =
HashMap.filterWithKey (\sym _ -> isReachable depInfo sym) idents
goFilter :: HashMap Symbol a -> HashMap Symbol a
goFilter =
HashMap.filterWithKey (\sym _ -> isReachable depInfo sym)
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ convertNode inlineDepth recSyms tab = dmapL go
Just (InlinePartiallyApplied k)
| length args >= k ->
mkApps def args
Just InlineAlways ->
mkApps def args
Just InlineNever ->
node
_
Expand All @@ -48,6 +50,17 @@ convertNode inlineDepth recSyms tab = dmapL go
def = lookupIdentifierNode tab _identSymbol
_ ->
node
NIdt Ident {..} ->
case pi of
Just InlineFullyApplied | argsNum == 0 -> def
Just (InlinePartiallyApplied 0) -> def
Just InlineAlways -> def
_ -> node
where
ii = lookupIdentifierInfo tab _identSymbol
pi = ii ^. identifierPragmas . pragmasInline
argsNum = ii ^. identifierArgsNum
def = lookupIdentifierNode tab _identSymbol
-- inline zero-argument definitions automatically if inlining would result
-- in case reduction
NCase cs@Case {..} -> case _caseValue of
Expand Down
Loading

0 comments on commit 60a191b

Please sign in to comment.