Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numeric, ordering and equality traits #2433

Merged
merged 12 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions examples/demo/Demo.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ module Demo;

-- standard library prelude
import Stdlib.Prelude open;
-- for comparisons on natural numbers
import Stdlib.Data.Nat.Ord open;
-- for Ordering

even : Nat → Bool
| zero := true
Expand Down Expand Up @@ -35,13 +32,13 @@ preorder : {A : Type} → Tree A → List A
| (node x l r) := x :: nil ++ preorder l ++ preorder r;

terminating
sort : {A : Type} → (A → A → Ordering) → List A → List A
| _ nil := nil
| _ xs@(_ :: nil) := xs
| {A} cmp xs :=
sort {A} {{Ord A}} : List A → List A
| nil := nil
| xs@(_ :: nil) := xs
| xs :=
uncurry
(merge {{mkOrd cmp}})
(both (sort cmp) (splitAt (div (length xs) 2) xs));
merge
(both sort (splitAt (div (length xs) 2) xs));

printNatListLn : List Nat → IO
| nil := printStringLn "nil"
Expand All @@ -51,6 +48,6 @@ printNatListLn : List Nat → IO
main : IO :=
printStringLn "Hello!"
>> printNatListLn (preorder (mirror tree))
>> printNatListLn (sort compare (preorder (mirror tree)))
>> printNatListLn (sort (preorder (mirror tree)))
>> printNatLn (log2 3)
>> printNatLn (log2 130);
1 change: 0 additions & 1 deletion examples/midsquare/MidSquareHash.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
module MidSquareHash;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

--- `pow N` is 2 ^ N
pow : Nat -> Nat
Expand Down
1 change: 0 additions & 1 deletion examples/midsquare/MidSquareHashUnrolled.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
module MidSquareHashUnrolled;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

--- `powN` is 2 ^ N
pow0 : Nat := 1;
Expand Down
2 changes: 0 additions & 2 deletions examples/milestone/Bank/Bank.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ module Bank;
import Stdlib.Prelude open;
import Stdlib.Debug.Fail open;

import Stdlib.Data.Nat.Ord open;

import Stdlib.Data.Nat as Nat;

Address : Type := Nat;
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/Collatz/Collatz.juvix
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Collatz;

import Stdlib.Prelude open;
import Stdlib.Data.Nat.Ord open;

collatzNext (n : Nat) : Nat :=
if (mod n 2 == 0) (div n 2) (3 * n + 1);
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/CLI/TicTacToe.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
--- The module Logic.Game contains the game logic.
module CLI.TicTacToe;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;
import Logic.Game open;

Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/Logic/Extra.juvix
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
--- Some generic helper definitions.
module Logic.Extra;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;

--- Concatenates a list of strings
Expand Down
1 change: 0 additions & 1 deletion examples/milestone/TicTacToe/Logic/Game.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
--- diagonal row is the winner. It is a solved game, with a forced draw assuming best play from both players.
module Logic.Game;

import Stdlib.Data.Nat.Ord open;
import Stdlib.Prelude open;
import Logic.Extra open public;
import Logic.Board open public;
Expand Down
7 changes: 1 addition & 6 deletions examples/milestone/TicTacToe/Logic/Square.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Logic.Square;

import Stdlib.Prelude open;
import Logic.Symbol open;
import Stdlib.Data.Nat.Ord open;
import Logic.Extra open;

--- A square is each of the holes in a board
Expand All @@ -24,9 +23,5 @@ showSquare : Square → String
| (occupied s) := " " ++str showSymbol s ++str " ";

replace (player : Symbol) (k : Nat) : Square → Square
| (empty n) :=
if
(n Stdlib.Data.Nat.Ord.== k)
(occupied player)
(empty n)
| (empty n) := if (n == k) (occupied player) (empty n)
| s := s;
2 changes: 0 additions & 2 deletions examples/milestone/Tutorial/Tutorial.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@ module Tutorial;

-- import the standard library prelude and bring it into scope
import Stdlib.Prelude open;
-- bring comparison operators on Nat into scope
import Stdlib.Data.Nat.Ord open;

main : IO := printStringLn "Hello world!";
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Backend/Geb/Translation/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fromCore :: Core.InfoTable -> (Morphism, Object)
fromCore tab = case tab ^. Core.infoMain of
Just sym ->
let node = Core.lookupIdentifierNode tab sym
syms = reverse $ filter (/= sym) $ Core.createIdentDependencyInfo tab ^. Core.depInfoTopSort
syms = reverse $ filter (/= sym) $ Core.createCallGraph tab ^. Core.depInfoTopSort
idents = map (Core.lookupIdentifierInfo tab) syms
morph = run . runReader emptyEnv $ goIdents node idents
obj = convertType $ Info.getNodeType node
Expand Down
30 changes: 27 additions & 3 deletions src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Data.IdentDependencyInfo where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Extra.Utils
Expand All @@ -9,8 +10,8 @@ import Juvix.Compiler.Core.Language
type IdentDependencyInfo = DependencyInfo Symbol

-- | Compute the call graph
createIdentDependencyInfo :: InfoTable -> IdentDependencyInfo
createIdentDependencyInfo tab = createDependencyInfo graph startVertices
createCallGraph :: InfoTable -> IdentDependencyInfo
createCallGraph tab = createDependencyInfo graph startVertices
where
graph :: HashMap Symbol (HashSet Symbol)
graph =
Expand All @@ -27,5 +28,28 @@ createIdentDependencyInfo tab = createDependencyInfo graph startVertices
syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMain)

createSymbolDependencyInfo :: InfoTable -> IdentDependencyInfo
createSymbolDependencyInfo tab = createDependencyInfo graph startVertices
where
graph :: HashMap Symbol (HashSet Symbol)
graph =
fmap
( \IdentifierInfo {..} ->
getSymbols tab (lookupIdentifierNode tab _identifierSymbol)
)
(tab ^. infoIdentifiers)
<> foldr
( \ConstructorInfo {..} ->
HashMap.insert _constructorInductive (getSymbols tab _constructorType)
)
mempty
(tab ^. infoConstructors)

startVertices :: HashSet Symbol
startVertices = HashSet.fromList syms

syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMain)

recursiveIdents :: InfoTable -> HashSet Symbol
recursiveIdents = nodesOnCycles . createIdentDependencyInfo
recursiveIdents = nodesOnCycles . createCallGraph
35 changes: 24 additions & 11 deletions src/Juvix/Compiler/Core/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,33 @@ filterByFile f t =
matchesLocation :: Maybe Location -> Bool
matchesLocation l = l ^? _Just . intervalFile == Just f

-- | Prunes the orphaned entries of identMap and indentContext, i.e., ones that
-- have no corresponding entries in infoIdentifiers or infoInductives
-- | Prunes the orphaned entries of identMap, indentContext and
-- infoConstructors, i.e., ones that have no corresponding entries in
-- infoIdentifiers or infoInductives
pruneInfoTable :: InfoTable -> InfoTable
pruneInfoTable tab =
over
identMap
( HashMap.filter
( \case
IdentFun s -> HashMap.member s (tab ^. infoIdentifiers)
IdentInd s -> HashMap.member s (tab ^. infoInductives)
IdentConstr tag -> HashMap.member tag (tab ^. infoConstructors)
)
)
pruneIdentMap
$ over
infoConstructors
( HashMap.filter
( \ConstructorInfo {..} ->
HashMap.member _constructorInductive (tab ^. infoInductives)
)
)
$ over
identContext
(HashMap.filterWithKey (\s _ -> HashMap.member s (tab ^. infoIdentifiers)))
tab
where
pruneIdentMap :: InfoTable -> InfoTable
pruneIdentMap tab' =
over
identMap
( HashMap.filter
( \case
IdentFun s -> HashMap.member s (tab' ^. infoIdentifiers)
IdentInd s -> HashMap.member s (tab' ^. infoInductives)
IdentConstr tag -> HashMap.member tag (tab' ^. infoConstructors)
)
)
tab'
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ data TransformationId
| LambdaFolding
| LetHoisting
| Inlining
| MandatoryInlining
| FoldTypeSynonyms
| CaseCallLifting
| SimplifyIfs
Expand Down Expand Up @@ -75,7 +76,7 @@ toNormalizeTransformations :: [TransformationId]
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]

toVampIRTransformations :: [TransformationId]
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]
toVampIRTransformations = toEvalTransformations ++ [FilterUnreachable, CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]

toStrippedTransformations :: [TransformationId]
toStrippedTransformations =
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ transformationText = \case
LambdaFolding -> strLambdaFolding
LetHoisting -> strLetHoisting
Inlining -> strInlining
MandatoryInlining -> strMandatoryInlining
FoldTypeSynonyms -> strFoldTypeSynonyms
CaseCallLifting -> strCaseCallLifting
SimplifyIfs -> strSimplifyIfs
Expand Down Expand Up @@ -191,6 +192,9 @@ strLambdaFolding = "lambda-folding"
strInlining :: Text
strInlining = "inlining"

strMandatoryInlining :: Text
strMandatoryInlining = "mandatory-inlining"

strFoldTypeSynonyms :: Text
strFoldTypeSynonyms = "fold-type-synonyms"

Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ geval opts herr ctx env0 = eval' env0
Closure env' (NLam (Lambda i' bi b)) ->
let !v = eval' env r in evalBody i' bi env' v b
lv
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
let !v = eval' env r in goNormApp i lv v
| otherwise ->
evalError "invalid application" (mkApp i lv (substEnv env r))
Expand All @@ -106,7 +106,7 @@ geval opts herr ctx env0 = eval' env0
NCtr (Constr _ tag args) ->
branch n env args tag def bs
v'
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
lukaszcz marked this conversation as resolved.
Show resolved Hide resolved
goNormCase env i sym v' bs def
| otherwise ->
evalError "matching on non-data" (substEnv env (mkCase i sym v' bs def))
Expand Down Expand Up @@ -214,7 +214,7 @@ geval opts herr ctx env0 = eval' env0
(Just v1, Just v2) ->
toNode (v1 `op` v2)
_
| opts ^. evalOptionsNormalize ->
| opts ^. evalOptionsNormalize || opts ^. evalOptionsNoFailure ->
mkBuiltinApp' opcode [vl, vr]
| otherwise ->
evalError "wrong operand type" n
Expand Down
16 changes: 16 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,29 @@ nodeIdents f = ufoldA reassemble go
NIdt i -> NIdt <$> f i
n -> pure n

getInductives :: Node -> HashSet Symbol
getInductives n = HashSet.fromList (n ^.. nodeInductives)

nodeInductives :: Traversal' Node Symbol
nodeInductives f = ufoldA reassemble go
where
go = \case
NTyp ty -> NTyp <$> traverseOf typeConstrSymbol f ty
n -> pure n

getSymbols :: InfoTable -> Node -> HashSet Symbol
getSymbols tab = gather go mempty
where
go :: HashSet Symbol -> Node -> HashSet Symbol
go acc = \case
NTyp TypeConstr {..} -> HashSet.insert _typeConstrSymbol acc
NIdt Ident {..} -> HashSet.insert _identSymbol acc
NCase Case {..} -> HashSet.insert _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' tab _constrTag ->
HashSet.insert (ci ^. constructorInductive) acc
_ -> acc

-- | Prism for NRec
_NRec :: SimpleFold Node LetRec
_NRec f = \case
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnre
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
Expand Down Expand Up @@ -74,6 +75,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
LambdaFolding -> return . lambdaFolding
LetHoisting -> return . letHoisting
Inlining -> inlining
MandatoryInlining -> return . mandatoryInlining
FoldTypeSynonyms -> return . foldTypeSynonyms
CaseCallLifting -> return . caseCallLifting
SimplifyIfs -> return . simplifyIfs
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Transformation/LetHoisting.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- Moves al let expressions at the top, just after the top lambdas. This
-- Moves all let expressions at the top, just after the top lambdas. This
-- transformation assumes:
-- - There are no LetRecs, Lambdas (other than the ones at the top), nor Match.
-- - Case nodes do not have binders.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Transformation.Base

filterUnreachable :: InfoTable -> InfoTable
filterUnreachable tab = pruneInfoTable $ over infoIdentifiers goFilter tab
filterUnreachable tab =
pruneInfoTable $
over infoInductives goFilter $
over infoIdentifiers goFilter tab
where
depInfo = createIdentDependencyInfo tab
depInfo = createSymbolDependencyInfo tab

goFilter :: HashMap Symbol IdentifierInfo -> HashMap Symbol IdentifierInfo
goFilter idents =
HashMap.filterWithKey (\sym _ -> isReachable depInfo sym) idents
goFilter :: HashMap Symbol a -> HashMap Symbol a
goFilter =
HashMap.filterWithKey (\sym _ -> isReachable depInfo sym)
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ convertNode inlineDepth recSyms tab = dmapL go
Just (InlinePartiallyApplied k)
| length args >= k ->
mkApps def args
Just InlineAlways ->
mkApps def args
Just InlineNever ->
node
_
Expand All @@ -48,6 +50,17 @@ convertNode inlineDepth recSyms tab = dmapL go
def = lookupIdentifierNode tab _identSymbol
_ ->
node
NIdt Ident {..} ->
case pi of
Just InlineFullyApplied | argsNum == 0 -> def
Just (InlinePartiallyApplied 0) -> def
Just InlineAlways -> def
_ -> node
where
ii = lookupIdentifierInfo tab _identSymbol
pi = ii ^. identifierPragmas . pragmasInline
argsNum = ii ^. identifierArgsNum
def = lookupIdentifierNode tab _identSymbol
-- inline zero-argument definitions automatically if inlining would result
-- in case reduction
NCase cs@Case {..} -> case _caseValue of
Expand Down
Loading