Skip to content

Commit

Permalink
record pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Aug 20, 2024
1 parent 600a9c6 commit 23585b0
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 39 deletions.
5 changes: 3 additions & 2 deletions src/Juvix/Compiler/Backend/Isabelle/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ data Cons a = Cons
_consTail :: a
}

newtype Record a = Record
{ _recordFields :: [(Name, a)]
data Record a = Record
{ _recordName :: Name,
_recordFields :: [(Name, a)]
}

data ConstrApp = ConstrApp
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Backend/Isabelle/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ instance (PrettyCode a) => PrettyCode (List a) where

instance (PrettyCode a) => PrettyCode (Record a) where
ppCode Record {..} = do
recName <- ppCode _recordName
names <- mapM (ppCode . fst) _recordFields
elems <- mapM (ppCode . snd) _recordFields
let fields = zipWithExact (\n e -> n <+> "=" <+> e) names elems
let names' = map (\n -> recName <> "." <> n) names
fields = zipWithExact (\n e -> n <+> "=" <+> e) names' elems
return $ "(|" <+> hsep (punctuate comma fields) <+> "|)"

instance (PrettyCode a, HasAtomicity a) => PrettyCode (Cons a) where
Expand Down
77 changes: 44 additions & 33 deletions src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
_clauseBody = goExpression'' nset' nmap' _lambdaBody
}
where
(pats, nset', nmap') = goPatternArgs'' (filterTypeArgs 0 ty (toList _lambdaPatterns))
(pats, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))

goFunctionDef :: Internal.FunctionDef -> Statement
goFunctionDef Internal.FunctionDef {..} = goDef _funDefName _funDefType _funDefArgsInfo (Just _funDefBody)
Expand Down Expand Up @@ -389,9 +389,9 @@ goModule onlyTypes infoTable Internal.Module {..} =
x' <- goExpression x
y' <- goExpression y
return $ ExprTuple (Tuple (x' :| [y']))
| Just fields <- getRecordCreation app = do
| Just (name, fields) <- getRecordCreation app = do
fields' <- mapM (secondM goExpression) fields
return $ ExprRecord (Record fields')
return $ ExprRecord (Record name fields')
| Just (fn, args) <- getIdentApp app = do
fn' <- goExpression fn
args' <- mapM goExpression args
Expand Down Expand Up @@ -546,12 +546,15 @@ goModule onlyTypes infoTable Internal.Module {..} =
where
(fn, args) = Internal.unfoldApplication app

getRecordCreation :: Internal.Application -> Maybe [(Name, Internal.Expression)]
getRecordCreation :: Internal.Application -> Maybe (Name, [(Name, Internal.Expression)])
getRecordCreation app = case fn of
Internal.ExpressionIden (Internal.IdenConstructor name) ->
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
Just ctrInfo
| ctrInfo ^. Internal.constructorInfoRecord -> Just (goRecordFields (getArgtys ctrInfo) (toList args))
| ctrInfo ^. Internal.constructorInfoRecord ->
Just (indName, goRecordFields (getArgtys ctrInfo) (toList args))
where
indName = ctrInfo ^. Internal.constructorInfoInductive
_ -> Nothing
_ -> Nothing
where
Expand Down Expand Up @@ -702,9 +705,9 @@ goModule onlyTypes infoTable Internal.Module {..} =
goClause :: Internal.LambdaClause -> Sem r CaseBranch
goClause Internal.LambdaClause {..} = do
(pat, nset, nmap) <- case _lambdaPatterns of
p :| [] -> goPatternArg' p
p :| [] -> goPatternArgCase p
_ -> do
(pats, nset, nmap) <- goPatternArgs' (toList _lambdaPatterns)
(pats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
let pat =
PatTuple
Tuple
Expand All @@ -731,7 +734,7 @@ goModule onlyTypes infoTable Internal.Module {..} =

goCaseBranch :: Internal.CaseBranch -> Sem r CaseBranch
goCaseBranch Internal.CaseBranch {..} = do
(pat, nset, nmap) <- goPatternArg' _caseBranchPattern
(pat, nset, nmap) <- goPatternArgCase _caseBranchPattern
rhs <- withLocalNames nset nmap $ goCaseBranchRhs _caseBranchRhs
return $
CaseBranch
Expand All @@ -744,32 +747,32 @@ goModule onlyTypes infoTable Internal.Module {..} =
Internal.CaseBranchRhsExpression e -> goExpression e
Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions"

goPatternArgs'' :: [Internal.PatternArg] -> ([Pattern], NameSet, NameMap)
goPatternArgs'' pats =
goPatternArgsTop :: [Internal.PatternArg] -> ([Pattern], NameSet, NameMap)
goPatternArgsTop pats =
(pats', nset, nmap)
where
(nset, (nmap, pats')) = run $ runState (NameSet mempty) $ runState (NameMap mempty) $ goPatternArgs pats
(nset, (nmap, pats')) = run $ runState (NameSet mempty) $ runState (NameMap mempty) $ goPatternArgs True pats

goPatternArg' :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Internal.PatternArg -> Sem r (Pattern, NameSet, NameMap)
goPatternArg' pat = do
goPatternArgCase :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Internal.PatternArg -> Sem r (Pattern, NameSet, NameMap)
goPatternArgCase pat = do
nset <- ask @NameSet
nmap <- ask @NameMap
let (nmap', (nset', pat')) = run $ runState nmap $ runState nset $ goPatternArg pat
let (nmap', (nset', pat')) = run $ runState nmap $ runState nset $ goPatternArg False pat
return (pat', nset', nmap')

goPatternArgs' :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.PatternArg] -> Sem r ([Pattern], NameSet, NameMap)
goPatternArgs' pats = do
goPatternArgsCase :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.PatternArg] -> Sem r ([Pattern], NameSet, NameMap)
goPatternArgsCase pats = do
nset <- ask @NameSet
nmap <- ask @NameMap
let (nmap', (nset', pats')) = run $ runState nmap $ runState nset $ goPatternArgs pats
let (nmap', (nset', pats')) = run $ runState nmap $ runState nset $ goPatternArgs False pats
return (pats', nset', nmap')

goPatternArgs :: forall r. (Members '[State NameSet, State NameMap] r) => [Internal.PatternArg] -> Sem r [Pattern]
goPatternArgs = mapM goPatternArg
goPatternArgs :: forall r. (Members '[State NameSet, State NameMap] r) => Bool -> [Internal.PatternArg] -> Sem r [Pattern]
goPatternArgs isTop = mapM (goPatternArg isTop)

-- TODO: named patterns (`_patternArgName`) are not handled properly
goPatternArg :: forall r. (Members '[State NameSet, State NameMap] r) => Internal.PatternArg -> Sem r Pattern
goPatternArg Internal.PatternArg {..} =
goPatternArg :: forall r. (Members '[State NameSet, State NameMap] r) => Bool -> Internal.PatternArg -> Sem r Pattern
goPatternArg isTop Internal.PatternArg {..} =
goPattern _patternArgPattern
where
goPattern :: Internal.Pattern -> Sem r Pattern
Expand All @@ -786,27 +789,32 @@ goModule onlyTypes infoTable Internal.Module {..} =
goPatternConstructorApp :: Internal.ConstructorApp -> Sem r Pattern
goPatternConstructorApp Internal.ConstructorApp {..}
| Just lst <- getListPat _constrAppConstructor _constrAppParameters = do
pats <- goPatternArgs lst
pats <- goPatternArgs False lst
return $ PatList (List pats)
| Just (x, y) <- getConsPat _constrAppConstructor _constrAppParameters = do
x' <- goPatternArg x
y' <- goPatternArg y
x' <- goPatternArg False x
y' <- goPatternArg False y
return $ PatCons (Cons x' y')
| Just fields <- getRecordPat _constrAppConstructor _constrAppParameters = do
fields' <- mapM (secondM goPatternArg) fields
return $ PatRecord (Record fields')
| Just (name, fields) <- getRecordPat _constrAppConstructor _constrAppParameters =
if
| isTop -> do
fields' <- mapM (secondM (goPatternArg False)) fields
return $ PatRecord (Record name fields')
| otherwise ->
-- TODO: record patterns are not supported in non-top-level patterns
return $ PatVar (defaultName "_")
| Just (x, y) <- getPairPat _constrAppConstructor _constrAppParameters = do
x' <- goPatternArg x
y' <- goPatternArg y
x' <- goPatternArg False x
y' <- goPatternArg False y
return $ PatTuple (Tuple (x' :| [y']))
| Just p <- getNatPat _constrAppConstructor _constrAppParameters =
case p of
Left zero -> return zero
Right arg -> do
arg' <- goPatternArg arg
arg' <- goPatternArg False arg
return (PatConstrApp (ConstrApp (goConstrName _constrAppConstructor) [arg']))
| otherwise = do
args <- mapM goPatternArg _constrAppParameters
args <- mapM (goPatternArg False) _constrAppParameters
return $
PatConstrApp
ConstrApp
Expand Down Expand Up @@ -841,11 +849,14 @@ goModule onlyTypes infoTable Internal.Module {..} =
_ -> Nothing
Nothing -> Nothing

getRecordPat :: Name -> [Internal.PatternArg] -> Maybe [(Name, Internal.PatternArg)]
getRecordPat :: Name -> [Internal.PatternArg] -> Maybe (Name, [(Name, Internal.PatternArg)])
getRecordPat name args =
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
Just ctrInfo
| ctrInfo ^. Internal.constructorInfoRecord -> Just (goRecordFields (getArgtys ctrInfo) args)
| ctrInfo ^. Internal.constructorInfoRecord ->
Just (indName, goRecordFields (getArgtys ctrInfo) args)
where
indName = ctrInfo ^. Internal.constructorInfoInductive
_ -> Nothing

getNatPat :: Name -> [Internal.PatternArg] -> Maybe (Either Pattern Internal.PatternArg)
Expand Down
18 changes: 15 additions & 3 deletions tests/positive/Isabelle/isabelle/Program.thy
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,25 @@ record R =
r1 :: nat
r2 :: nat

fun r1 :: "R \<Rightarrow> nat" where
"r1 \<lparr> r1 = r1', r2 = _ \<rparr> = r1'"
fun r1 :: "R \<Rightarrow> nat" where
"r1 (| R.r1 = r1'0, R.r2 = r2'0 |) = r1'0"

fun r2 :: "R \<Rightarrow> nat" where
"r2 (| R.r1 = r1'0, R.r2 = r2'0 |) = r2'0"

fun funR :: "R \<Rightarrow> R" where
"funR r =
(case r of
_ \<Rightarrow> r\<lparr>r1 := R.r1 r + R.r2 r\<rparr>)"
_ \<Rightarrow>
(\<lambda> x0 . case x0 of
_ \<Rightarrow> (| R.r1 = r1 + r2, R.r2 = r2 |)) r)"

fun funR' :: "R \<Rightarrow> R" where
"funR' (| R.r1 = rr1, R.r2 = rr2 |) =
(let
r1'0 = rr1 + rr2;
r2'0 = rr2
in (| R.r1 = r1'0, R.r2 = r2'0 |))"

fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
"bf b1 b2 = (\<not> (b1 \<and> b2))"
Expand Down

0 comments on commit 23585b0

Please sign in to comment.