Skip to content

Commit

Permalink
Compilation of side conditions in pattern matches (#2984)
Browse files Browse the repository at this point in the history
* Closes #2804 
* Requires #3003
* Front-end syntax for side conditions was implemented in #2852. This PR
implements compilation of side conditions.
* Adds side-conditions to `Match` nodes in Core. Updates Core parsing,
printing and the evaluator.
* Only side-conditions without an `else` branch are allowed in Core. If
there is an `else` branch, the side conditions are translated in
`fromInternal` into nested ifs. Because with `else` the conditions are
exhaustive, there are no implications for pattern exhaustiveness
checking.
* Adjusts the "wildcard row" case in the pattern matching compilation
algorithm to take into account the side conditions.
  • Loading branch information
lukaszcz authored Sep 9, 2024
1 parent 453afff commit ab2d31a
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 43 deletions.
14 changes: 13 additions & 1 deletion src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ geval opts herr tab env0 = eval' env0
match n env vs = \case
br : brs ->
case matchPatterns [] vs (toList (br ^. matchBranchPatterns)) of
Just args -> eval' (args ++ env) (br ^. matchBranchBody)
Just args -> matchRhs (args ++ env) (br ^. matchBranchRhs)
Nothing -> match n env vs brs
where
matchPatterns :: [Node] -> [Node] -> [Pattern] -> Maybe [Node]
Expand All @@ -169,6 +169,18 @@ geval opts herr tab env0 = eval' env0
| tag == _patternConstrTag =
matchPatterns (v : acc) args _patternConstrArgs
patmatch _ _ _ = Nothing

matchRhs :: [Node] -> MatchBranchRhs -> Node
matchRhs env' = \case
MatchBranchRhsExpression e -> eval' env' e
MatchBranchRhsIfs ifs -> matchIfs env' (toList ifs)

matchIfs :: [Node] -> [SideIfBranch] -> Node
matchIfs env' = \case
SideIfBranch {..} : ifs -> case eval' env' _sideIfBranchCondition of
NCtr (Constr _ (BuiltinTag TagTrue) []) -> eval' env' _sideIfBranchBody
_ -> matchIfs env' ifs
[] -> match n env vs brs
[] ->
evalError "no matching pattern" (substEnv env n)

Expand Down
54 changes: 41 additions & 13 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,6 @@ mkMatch i vtys rty vs bs = NMatch (Match i vtys rty vs bs)
mkMatch' :: NonEmpty Type -> Type -> NonEmpty Node -> [MatchBranch] -> Node
mkMatch' = mkMatch Info.empty

mkMatchBranch :: Info -> NonEmpty Pattern -> Node -> MatchBranch
mkMatchBranch = MatchBranch

mkMatchBranch' :: NonEmpty Pattern -> Node -> MatchBranch
mkMatchBranch' = MatchBranch mempty

mkIf :: Info -> Symbol -> Node -> Node -> Node -> Node
mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2)
where
Expand All @@ -122,6 +116,14 @@ mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2)
mkIf' :: Symbol -> Node -> Node -> Node -> Node
mkIf' = mkIf Info.empty

mkIfs :: Symbol -> [(Info, Node, Node)] -> Node -> Node
mkIfs sym = \case
[] -> id
((i, v, b) : rest) -> mkIf i sym v b . mkIfs sym rest

mkIfs' :: Symbol -> [(Node, Node)] -> Node -> Node
mkIfs' sym = mkIfs sym . map (\(v, b) -> (Info.empty, v, b))

mkBinder :: Text -> Type -> Binder
mkBinder name ty = Binder name Nothing ty

Expand Down Expand Up @@ -641,18 +643,27 @@ destruct = \case
: map noBinders (toList vtys)
++ map noBinders (toList vs)
++ concat
[ br
: reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis)
| (bis, br) <- branchChildren
[ brs
++ reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis)
| (bis, brs) <- branchChildren
]
where
branchChildren :: [([Binder], NodeChild)]
branchChildren :: [([Binder], [NodeChild])]
branchChildren =
[ (binders, manyBinders binders (br ^. matchBranchBody))
[ (binders, map (manyBinders binders) (branchRhsChildren (br ^. matchBranchRhs)))
| br <- branches,
let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns))
]

branchRhsChildren :: MatchBranchRhs -> [Node]
branchRhsChildren = \case
MatchBranchRhsExpression e -> [e]
MatchBranchRhsIfs ifs -> concatMap sideIfBranchChildren ifs

sideIfBranchChildren :: SideIfBranch -> [Node]
sideIfBranchChildren SideIfBranch {..} =
[_sideIfBranchCondition, _sideIfBranchBody]

branchInfos :: [Info]
branchInfos =
concat
Expand Down Expand Up @@ -684,14 +695,31 @@ destruct = \case
let mkBranch :: MatchBranch -> Sem '[Input Node, Input Info] MatchBranch
mkBranch br = do
bi' <- inputJust
b' <- inputJust
b' <- mkBranchRhs (br ^. matchBranchRhs)
pats' <- setPatternsInfos (br ^. matchBranchPatterns)
return
br
{ _matchBranchInfo = bi',
_matchBranchPatterns = pats',
_matchBranchBody = b'
_matchBranchRhs = b'
}
mkBranchRhs :: MatchBranchRhs -> Sem '[Input Node, Input Info] MatchBranchRhs
mkBranchRhs = \case
MatchBranchRhsExpression _ -> do
e' <- inputJust
return (MatchBranchRhsExpression e')
MatchBranchRhsIfs ifs -> do
ifs' <- mkSideIfs ifs
return (MatchBranchRhsIfs ifs')
mkSideIfs :: NonEmpty SideIfBranch -> Sem '[Input Node, Input Info] (NonEmpty SideIfBranch)
mkSideIfs brs =
mapM mkSideIfBranch brs
mkSideIfBranch :: SideIfBranch -> Sem '[Input Node, Input Info] SideIfBranch
mkSideIfBranch _ = do
_sideIfBranchInfo <- inputJust
_sideIfBranchCondition <- inputJust
_sideIfBranchBody <- inputJust
return SideIfBranch {..}
numVals = length vs
values' :: NonEmpty Node
valueTypes' :: NonEmpty Node
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
_ ->
recur [] node
where
cont :: Level -> [BinderChange] -> Node -> Node
cont :: [BinderChange] -> Node -> Node
cont bcs = go (recur . (bcs ++)) (k + bindersNumFromBinderChange bcs)
```
produces
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ type PatternWildcard = PatternWildcard' Info Node

type PatternConstr = PatternConstr' Info Node

type MatchBranchRhs = MatchBranchRhs' Info Node

type SideIfBranch = SideIfBranch' Info Node

type Pattern = Pattern' Info Node

type PiLhs = PiLhs' Info Node
Expand Down
22 changes: 21 additions & 1 deletion src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ data Match' i a = Match
data MatchBranch' i a = MatchBranch
{ _matchBranchInfo :: i,
_matchBranchPatterns :: !(NonEmpty (Pattern' i a)),
_matchBranchBody :: !a
_matchBranchRhs :: !(MatchBranchRhs' i a)
}

data Pattern' i a
Expand All @@ -202,6 +202,16 @@ data PatternConstr' i a = PatternConstr
_patternConstrArgs :: ![Pattern' i a]
}

data MatchBranchRhs' i a
= MatchBranchRhsExpression !a
| MatchBranchRhsIfs !(NonEmpty (SideIfBranch' i a))

data SideIfBranch' i a = SideIfBranch
{ _sideIfBranchInfo :: i,
_sideIfBranchCondition :: !a,
_sideIfBranchBody :: !a
}

-- | Useful for unfolding Pi
data PiLhs' i a = PiLhs
{ _piLhsInfo :: i,
Expand Down Expand Up @@ -437,8 +447,10 @@ makeLenses ''Case'
makeLenses ''CaseBranch'
makeLenses ''Match'
makeLenses ''MatchBranch'
makeLenses ''MatchBranchRhs'
makeLenses ''PatternWildcard'
makeLenses ''PatternConstr'
makeLenses ''SideIfBranch'
makeLenses ''Pi'
makeLenses ''Lambda'
makeLenses ''Univ'
Expand Down Expand Up @@ -528,12 +540,20 @@ instance (Eq a) => Eq (Pi' i a) where
eqOn (^. piBinder . binderType)
..&&.. eqOn (^. piBody)

instance (Eq a) => Eq (MatchBranchRhs' i a) where
(MatchBranchRhsExpression e1) == (MatchBranchRhsExpression e2) = e1 == e2
(MatchBranchRhsIfs ifs1) == (MatchBranchRhsIfs ifs2) = ifs1 == ifs2
_ == _ = False

instance (Eq a) => Eq (MatchBranch' i a) where
(MatchBranch _ pats1 b1) == (MatchBranch _ pats2 b2) = pats1 == pats2 && b1 == b2

instance (Eq a) => Eq (PatternConstr' i a) where
(PatternConstr _ _ tag1 ps1) == (PatternConstr _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2

instance (Eq a) => Eq (SideIfBranch' i a) where
(SideIfBranch _ c1 b1) == (SideIfBranch _ c2 b2) = c1 == c2 && b1 == b2

instance Hashable (Ident' i) where
hashWithSalt s = hashWithSalt s . (^. identSymbol)

Expand Down
21 changes: 19 additions & 2 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,23 @@ instance PrettyCode Bottom where
ty' <- ppCode _bottomType
return (parens (kwBottom <+> kwColon <+> ty'))

instance PrettyCode SideIfBranch where
ppCode :: (Member (Reader Options) r) => SideIfBranch -> Sem r (Doc Ann)
ppCode SideIfBranch {..} = do
cond <- ppCode _sideIfBranchCondition
body <- ppCode _sideIfBranchBody
return $ kwIf <+> cond <+> kwAssign <+> body

instance PrettyCode MatchBranchRhs where
ppCode :: (Member (Reader Options) r) => MatchBranchRhs -> Sem r (Doc Ann)
ppCode = \case
MatchBranchRhsExpression x -> do
e <- ppCode x
return $ kwAssign <+> e
MatchBranchRhsIfs x -> do
brs <- mapM ppCode x
return $ vsep brs

instance PrettyCode Node where
ppCode :: forall r. (Member (Reader Options) r) => Node -> Sem r (Doc Ann)
ppCode node = case node of
Expand All @@ -394,11 +411,11 @@ instance PrettyCode Node where
ppCodeCase' branchBinderNames branchBinderTypes branchTagNames x
NMatch Match {..} -> do
let branchPatterns = map (^. matchBranchPatterns) _matchBranches
branchBodies = map (^. matchBranchBody) _matchBranches
branchRhs = map (^. matchBranchRhs) _matchBranches
pats <- mapM ppPatterns branchPatterns
vs <- mapM ppCode _matchValues
vs' <- zipWithM ppWithType (toList vs) (toList _matchValueTypes)
bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> kwAssign <+> br') pats branchBodies
bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> br') pats branchRhs
let bss = bracesIndent $ align $ concatWith (\a b -> a <> kwSemicolon <> line <> b) bs
rty <- ppTypeAnnot _matchReturnType
return $ kwMatch <+> hsep (punctuate comma vs') <+> kwWith <> rty <+> bss
Expand Down
24 changes: 17 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/MatchToCase.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Juvix.Compiler.Core.Transformation.Base

data PatternRow = PatternRow
{ _patternRowPatterns :: [Pattern],
_patternRowBody :: Node,
_patternRowRhs :: MatchBranchRhs,
-- | The number of initial wildcard binders in `_patternRowPatterns` which
-- don't originate from the input
_patternRowIgnoredPatternsNum :: Int,
Expand Down Expand Up @@ -58,7 +58,7 @@ goMatchToCase recur node = case node of
matchBranchToPatternRow MatchBranch {..} =
PatternRow
{ _patternRowPatterns = toList _matchBranchPatterns,
_patternRowBody = _matchBranchBody,
_patternRowRhs = _matchBranchRhs,
_patternRowIgnoredPatternsNum = 0,
_patternRowBinderChangesRev = [BCAdd n]
}
Expand Down Expand Up @@ -104,10 +104,10 @@ goMatchToCase recur node = case node of
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docValueSequence pat
mockFile = $(mkAbsFile "/match-to-case")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
r@PatternRow {..} : _
r@PatternRow {..} : matrix'
| all isPatWildcard _patternRowPatterns ->
-- The first row matches all values (Section 4, case 2)
compileMatchingRow bindersNum vs r
compileMatchingRow err bindersNum vs matrix' r
_ -> do
-- Section 4, case 3
-- Select the first column
Expand Down Expand Up @@ -181,9 +181,19 @@ goMatchToCase recur node = case node of
where
ii = lookupInductiveInfo md ind

compileMatchingRow :: Level -> [Level] -> PatternRow -> Sem r Node
compileMatchingRow bindersNum vs PatternRow {..} =
goMatchToCase (recur . (bcs ++)) _patternRowBody
compileMatchingRow :: ([Value] -> [Value]) -> Level -> [Level] -> PatternMatrix -> PatternRow -> Sem r Node
compileMatchingRow err bindersNum vs matrix PatternRow {..} =
case _patternRowRhs of
MatchBranchRhsExpression body ->
goMatchToCase (recur . (bcs ++)) body
MatchBranchRhsIfs ifs -> do
-- If the branch has side-conditions, then we need to continue pattern
-- matching when none of the conditions is satisfied.
body <- compile err bindersNum vs matrix
md <- ask
let boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive
ifs' = map (\(SideIfBranch i c b) -> (i, c, b)) (toList ifs)
return $ mkIfs boolSym ifs' body
where
bcs =
reverse $
Expand Down
50 changes: 36 additions & 14 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -492,41 +492,63 @@ goCase c = do
rty <- goType (fromJust $ c ^. Internal.caseExpressionWholeType)
return (mkMatch i (pure ty) rty (pure expr) branches)
_ ->
-- If the type of the value matched on is not an inductive type, then the
-- case expression has one branch with a variable pattern.
case c ^. Internal.caseBranches of
Internal.CaseBranch {..} :| _ ->
case _caseBranchPattern ^. Internal.patternArgPattern of
Internal.PatternVariable name -> do
vars <- asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
let vars' = addPatternVariableNames _caseBranchPattern varsNum vars
body <-
rhs <-
local
(set indexTableVars vars')
(underBinders 1 (goCaseBranchRhs _caseBranchRhs))
return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body
case rhs of
MatchBranchRhsExpression body ->
return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body
_ ->
impossible
_ ->
impossible
where
goCaseBranch :: Type -> Internal.CaseBranch -> Sem r MatchBranch
goCaseBranch ty b = goPatternArgs 0 (b ^. Internal.caseBranchRhs) [b ^. Internal.caseBranchPattern] [ty]

-- | FIXME Fix this as soon as side if conditions are implemented in Core. This
-- is needed so that we can test typechecking without a crash.
todoSideIfs ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) =>
Internal.SideIfs ->
Sem r Node
todoSideIfs s = goExpression (s ^. Internal.sideIfBranches . _head1 . Internal.sideIfBranchBody)

goCaseBranchRhs ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) =>
Internal.CaseBranchRhs ->
Sem r Node
Sem r MatchBranchRhs
goCaseBranchRhs = \case
Internal.CaseBranchRhsExpression e -> goExpression e
Internal.CaseBranchRhsIf i -> todoSideIfs i
Internal.CaseBranchRhsExpression e -> MatchBranchRhsExpression <$> goExpression e
Internal.CaseBranchRhsIf Internal.SideIfs {..} -> case _sideIfElse of
Just elseBranch -> do
branches <- toList <$> mapM goSideIfBranch _sideIfBranches
elseBranch' <- goExpression elseBranch
boolSym <- getBoolSymbol
return $ MatchBranchRhsExpression $ mkIfs' boolSym branches elseBranch'
where
goSideIfBranch :: Internal.SideIfBranch -> Sem r (Node, Node)
goSideIfBranch Internal.SideIfBranch {..} = do
cond <- goExpression _sideIfBranchCondition
body <- goExpression _sideIfBranchBody
return (cond, body)
Nothing -> do
branches <- mapM goSideIfBranch _sideIfBranches
return $ MatchBranchRhsIfs branches
where
goSideIfBranch :: Internal.SideIfBranch -> Sem r SideIfBranch
goSideIfBranch Internal.SideIfBranch {..} = do
cond <- goExpression _sideIfBranchCondition
body <- goExpression _sideIfBranchBody
return $
SideIfBranch
{ _sideIfBranchInfo = setInfoLocation (getLoc _sideIfBranchCondition) mempty,
_sideIfBranchCondition = cond,
_sideIfBranchBody = body
}

goLambda ::
forall r.
Expand Down
Loading

0 comments on commit ab2d31a

Please sign in to comment.