Skip to content

Commit

Permalink
named patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Aug 22, 2024
1 parent 0e4581a commit 0759a45
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 70 deletions.
94 changes: 62 additions & 32 deletions src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.List.NonEmpty.Extra qualified as NonEmpty
import Data.Text qualified as T
import Data.Text qualified as Text
import Juvix.Compiler.Backend.Isabelle.Data.Result
Expand Down Expand Up @@ -193,7 +194,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
oneClause expr =
nonEmpty'
[ Clause
{ _clausePatterns = nonEmpty' $ map PatVar argnames,
{ _clausePatterns = nonEmpty' (map PatVar argnames),
_clauseBody = expr
}
]
Expand All @@ -210,13 +211,28 @@ goModule onlyTypes infoTable Internal.Module {..} =
: goClauses cls
Nested pats npats ->
let rhs = goExpression'' nset' nmap' _lambdaBody
vnames = map (overNameText (disambiguate (nset' ^. nameSet))) argnames
argnames' = fmap getPatternArgName _lambdaPatterns
vnames =
fmap
( \(idx :: Int, mname) ->
maybe
( defaultName
( disambiguate
(nset' ^. nameSet)
("v_" <> show idx)
)
)
(overNameText (disambiguate (nset' ^. nameSet)))
mname
)
(NonEmpty.zip (nonEmpty' [0 ..]) argnames')
nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames
remainingBranches = goLambdaClauses'' nset'' nmap' cls
valTuple = ExprTuple (Tuple (nonEmpty' (map ExprIden vnames)))
brs = goNestedBranches valTuple rhs remainingBranches (PatTuple (Tuple (nonEmpty' pats))) (nonEmpty' npats)
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
patTuple = PatTuple (Tuple (nonEmpty' pats))
brs = goNestedBranches valTuple rhs remainingBranches patTuple (nonEmpty' npats)
in [ Clause
{ _clausePatterns = nonEmpty' $ map PatVar vnames,
{ _clausePatterns = fmap PatVar vnames,
_clauseBody =
ExprCase
Case
Expand Down Expand Up @@ -890,6 +906,18 @@ goModule onlyTypes infoTable Internal.Module {..} =
]
[] -> return []

isPatternArgVar :: Internal.PatternArg -> Bool
isPatternArgVar Internal.PatternArg {..} =
case _patternArgPattern of
Internal.PatternVariable {} -> True
_ -> False

getPatternArgName :: Internal.PatternArg -> Maybe Name
getPatternArgName Internal.PatternArg {..} =
case _patternArgPattern of
Internal.PatternVariable name -> Just name
_ -> _patternArgName

goPatternArgsTop :: [Internal.PatternArg] -> (Nested [Pattern], NameSet, NameMap)
goPatternArgsTop pats =
(Nested pats' npats, nset, nmap)
Expand All @@ -914,18 +942,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
goPatternArgs isTop = mapM (goPatternArg isTop)

goPatternArg :: forall r. (Members '[State NameSet, State NameMap, Output (Expression, Nested Pattern)] r) => Bool -> Internal.PatternArg -> Sem r Pattern
goPatternArg isTop Internal.PatternArg {..} =
-- TODO: named patterns
goPattern isTop _patternArgPattern

goNestedPatternArg :: forall r. (Members '[State NameSet, State NameMap, Output (Expression, Nested Pattern)] r) => Bool -> Internal.PatternArg -> Sem r (Nested Pattern)
goNestedPatternArg isTop Internal.PatternArg {..} =
-- TODO: named patterns
goNestedPattern isTop _patternArgPattern

goNestedPattern :: forall r. (Members '[State NameSet, State NameMap, Output (Expression, Nested Pattern)] r) => Bool -> Internal.Pattern -> Sem r (Nested Pattern)
goNestedPattern isTop pat = do
(npats, pat') <- runOutputList $ goPattern isTop pat
goPatternArg isTop Internal.PatternArg {..}
| Just name <- _patternArgName = do
binders <- gets (^. nameSet)
let name' = overNameText (disambiguate binders) name
modify' (over nameSet (HashSet.insert (name' ^. namePretty)))
modify' (over nameMap (HashMap.insert name (ExprIden name')))
npat <- goNestedPattern _patternArgPattern
output (ExprIden name', npat)
return $ PatVar name'
| otherwise =
goPattern isTop _patternArgPattern

goNestedPatternArg :: forall r. (Members '[State NameSet, State NameMap] r) => Internal.PatternArg -> Sem r (Nested Pattern)
goNestedPatternArg Internal.PatternArg {..}
| Just name <- _patternArgName = do
binders <- gets (^. nameSet)
let name' = overNameText (disambiguate binders) name
modify' (over nameSet (HashSet.insert (name' ^. namePretty)))
modify' (over nameMap (HashMap.insert name (ExprIden name')))
npat <- goNestedPattern _patternArgPattern
return $ Nested (PatVar name') [(ExprIden name', npat)]
| otherwise =
goNestedPattern _patternArgPattern

goNestedPattern :: forall r. (Members '[State NameSet, State NameMap] r) => Internal.Pattern -> Sem r (Nested Pattern)
goNestedPattern pat = do
(npats, pat') <- runOutputList $ goPattern False pat
return $ Nested pat' npats

goPattern :: forall r. (Members '[State NameSet, State NameMap, Output (Expression, Nested Pattern)] r) => Bool -> Internal.Pattern -> Sem r Pattern
Expand Down Expand Up @@ -960,13 +1003,13 @@ goModule onlyTypes infoTable Internal.Module {..} =
let name' = overNameText (\n -> indName ^. nameText <> "." <> n) name
in ExprApp (Application (ExprIden name') (ExprIden vname))
vname = defaultName (disambiguate binders "v")
fieldsVars = map (second getPatternArgVar) $ map (first adjustName) $ filter (isPatternArgVar . snd) fields
fieldsVars = map (second (fromJust . getPatternArgName)) $ map (first adjustName) $ filter (isPatternArgVar . snd) fields
fieldsNonVars = map (first adjustName) $ filter (not . isPatternArgVar . snd) fields
modify' (over nameSet (HashSet.insert (vname ^. namePretty)))
forM fieldsVars $ \(e, fname) -> do
modify' (over nameSet (HashSet.insert (fname ^. namePretty)))
modify' (over nameMap (HashMap.insert fname e))
fieldsNonVars' <- mapM (secondM (goNestedPatternArg False)) fieldsNonVars
fieldsNonVars' <- mapM (secondM goNestedPatternArg) fieldsNonVars
forM fieldsNonVars' output
return (PatVar vname)
| Just (x, y) <- getPairPat _constrAppConstructor _constrAppParameters = do
Expand All @@ -988,19 +1031,6 @@ goModule onlyTypes infoTable Internal.Module {..} =
_constrAppArgs = args
}

isPatternArgVar :: Internal.PatternArg -> Bool
isPatternArgVar Internal.PatternArg {..} =
isNothing _patternArgName
&& case _patternArgPattern of
Internal.PatternVariable {} -> True
_ -> False

getPatternArgVar :: Internal.PatternArg -> Name
getPatternArgVar Internal.PatternArg {..} =
case _patternArgPattern of
Internal.PatternVariable name -> name
_ -> impossible

-- This function cannot be simply merged with `getList` because in patterns
-- the constructors don't get the type arguments.
getListPat :: Name -> [Internal.PatternArg] -> Maybe [Internal.PatternArg]
Expand Down
3 changes: 3 additions & 0 deletions tests/positive/Isabelle/Program.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ type R := mkR {
r2 : Nat;
};

r : R := mkR 0 1;
v : Nat := 0;

funR (r : R) : R :=
case r of
| mkR@{r1; r2} := r@R{r1 := r1 + r2};
Expand Down
87 changes: 49 additions & 38 deletions tests/positive/Isabelle/isabelle/Program.thy
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ datatype 'A Queue
= queue "'A list" "'A list"

fun qfst :: "'A Queue \<Rightarrow> 'A list" where
"qfst (queue x v) = x"
"qfst (queue x v') = x"

fun qsnd :: "'A Queue \<Rightarrow> 'A list" where
"qsnd (queue v v') = v'"
"qsnd (queue v' v'0) = v'0"

fun pop_front :: "'A Queue \<Rightarrow> 'A Queue" where
"pop_front q =
(let
q' = queue (tl (qfst q)) (qsnd q)
in case qfst q' of
[] \<Rightarrow> queue (rev (qsnd q')) [] |
v \<Rightarrow> q')"
v' \<Rightarrow> q')"

fun push_back :: "'A Queue \<Rightarrow> 'A \<Rightarrow> 'A Queue" where
"push_back q x =
Expand All @@ -80,8 +80,8 @@ fun is_empty :: "'A Queue \<Rightarrow> bool" where
[] \<Rightarrow>
(case qsnd q of
[] \<Rightarrow> True |
v \<Rightarrow> False) |
v \<Rightarrow> False)"
v' \<Rightarrow> False) |
v' \<Rightarrow> False)"

definition empty :: "'A Queue" where
"empty = queue [] []"
Expand Down Expand Up @@ -109,17 +109,28 @@ fun r1 :: "R \<Rightarrow> nat" where
fun r2 :: "R \<Rightarrow> nat" where
"r2 (| R.r1 = r1'0, R.r2 = r2'0 |) = r2'0"

definition r :: R where
"r = (| R.r1 = 0, R.r2 = 1 |)"

definition v :: nat where
"v = 0"

fun funR :: "R \<Rightarrow> R" where
"funR r =
(case r of
v \<Rightarrow>
"funR r' =
(case r' of
v' \<Rightarrow>
(\<lambda> x0 . case x0 of
v' \<Rightarrow> (| R.r1 = R.r1 v' + R.r2 v', R.r2 = R.r2 v' |)) r)"
v'0 \<Rightarrow> (| R.r1 = R.r1 v'0 + R.r2 v'0, R.r2 = R.r2 v'0 |)) r')"

fun funRR :: "R \<Rightarrow> R" where
"funRR (| R.r1 = r1'0, R.r2 = r2'0 |) =
((\<lambda> x0 . case x0 of
v \<Rightarrow> (| R.r1 = R.r1 v + R.r2 v, R.r2 = R.r2 v |)) r)"
"funRR r'0 =
(case (r'0) of
(r') \<Rightarrow>
(case (r') of
(v') \<Rightarrow>
(\<lambda> x0 . case x0 of
v'0 \<Rightarrow>
(| R.r1 = R.r1 v'0 + R.r2 v'0, R.r2 = R.r2 v'0 |)) r'))"

fun funR' :: "R \<Rightarrow> R" where
"funR' (| R.r1 = rr1, R.r2 = rr2 |) =
Expand All @@ -141,50 +152,50 @@ fun funR1 :: "R \<Rightarrow> R" where
in (| R.r1 = r1'0, R.r2 = r2'0 |))"

fun funR2 :: "R \<Rightarrow> R" where
"funR2 r =
(case r of
v' \<Rightarrow>
(case v' of
v \<Rightarrow>
(case (R.r1 v) of
"funR2 r' =
(case r' of
v'0 \<Rightarrow>
(case v'0 of
v' \<Rightarrow>
(case (R.r1 v') of
(0) \<Rightarrow>
let
r1'0 = R.r2 v;
r2'0 = R.r2 v
r1'0 = R.r2 v';
r2'0 = R.r2 v'
in (| R.r1 = r1'0, R.r2 = r2'0 |) |
_ \<Rightarrow>
(case v' of
v'0 \<Rightarrow>
(case v'0 of
v'1 \<Rightarrow>
let
r1'0 = R.r2 v'0;
r2'0 = R.r1 v'0
r1'0 = R.r2 v'1;
r2'0 = R.r1 v'1
in (| R.r1 = r1'0, R.r2 = r2'0 |)))))"

fun funR3 :: "(R, R) Either' \<Rightarrow> R" where
"funR3 er =
(case er of
v' \<Rightarrow>
(case v' of
(Left' v) \<Rightarrow>
(case (R.r1 v) of
v'0 \<Rightarrow>
(case v'0 of
(Left' v') \<Rightarrow>
(case (R.r1 v') of
(0) \<Rightarrow>
let
r1'0 = R.r2 v;
r2'0 = R.r2 v
r1'0 = R.r2 v';
r2'0 = R.r2 v'
in (| R.r1 = r1'0, R.r2 = r2'0 |) |
_ \<Rightarrow>
(case v' of
(Left' v'0) \<Rightarrow>
(case v'0 of
(Left' v'1) \<Rightarrow>
let
r1'0 = R.r2 v'0;
r2'0 = R.r1 v'0
r1'0 = R.r2 v'1;
r2'0 = R.r1 v'1
in (| R.r1 = r1'0, R.r2 = r2'0 |) |
(Right' r) \<Rightarrow>
(Right' r') \<Rightarrow>
(\<lambda> x0 . case x0 of
v'0 \<Rightarrow> (| R.r1 = 2, R.r2 = R.r2 v'0 |)) r)) |
(Right' r) \<Rightarrow>
v'1 \<Rightarrow> (| R.r1 = 2, R.r2 = R.r2 v'1 |)) r')) |
(Right' r') \<Rightarrow>
(\<lambda> x0 . case x0 of
v'0 \<Rightarrow> (| R.r1 = 2, R.r2 = R.r2 v'0 |)) r))"
v'1 \<Rightarrow> (| R.r1 = 2, R.r2 = R.r2 v'1 |)) r'))"

fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
"bf b1 b2 = (\<not> (b1 \<and> b2))"
Expand Down

0 comments on commit 0759a45

Please sign in to comment.