Skip to content

Commit

Permalink
Support simple tuple patterns which is good for making examples
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Dec 26, 2023
1 parent dbf627e commit 9cce624
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 92 deletions.
11 changes: 6 additions & 5 deletions shared/src/main/scala/mlscript/pretyper/PreTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class PreTyper(override val debugTopics: Opt[Set[Str]]) extends Traceable with D
trace(s"extractParameters <== ${inspect.deep(fields)}") {
fields match {
case Tup(arguments) =>
arguments.map {
case (S(nme: Var), Fld(_, _)) => new ValueSymbol(nme, false)
case (_, Fld(_, nme: Var)) => new ValueSymbol(nme, false)
case (_, Fld(_, Bra(false, nme: Var))) => new ValueSymbol(nme, false)
case (_, _) => ???
arguments.flatMap {
case (S(nme: Var), Fld(_, _)) => new ValueSymbol(nme, false) :: Nil
case (_, Fld(_, nme: Var)) => new ValueSymbol(nme, false) :: Nil
case (_, Fld(_, Bra(false, nme: Var))) => new ValueSymbol(nme, false) :: Nil
case (_, Fld(_, tuple @ Tup(_))) => extractParameters(tuple)
case (_, Fld(_, _)) => ???
}
case PlainTup(arguments @ _*) =>
arguments.map {
Expand Down
12 changes: 12 additions & 0 deletions shared/src/main/scala/mlscript/pretyper/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ package object symbol {
// Urgh, let's do this in the next refactor.
// I really should move these imperative and stateful functions to a
// separate class!
val tupleSubScrutineeMap: MutMap[Int, MutMap[Var, ValueSymbol]] = MutMap.empty
// Note that the innermost map is a map from variable names to symbols.
// Sometimes a class parameter may have many names. We maintain the
// uniqueness of the symbol for now.

def getSubScrutineeSymbolOrElse(
classLikeSymbol: TypeSymbol,
Expand All @@ -95,6 +99,14 @@ package object symbol {
subScrutineeMap.getOrElseUpdate(classLikeSymbol, MutMap.empty)
.getOrElseUpdate(index, MutMap.empty)
.getOrElseUpdate(name, default)

def getTupleSubScrutineeSymbolOrElse(
index: Int,
name: Var, // <-- Remove this parameter after we remove `ScrutineeSymbol`.
default: => ValueSymbol
): ValueSymbol =
tupleSubScrutineeMap.getOrElseUpdate(index, MutMap.empty)
.getOrElseUpdate(name, default)

def addMatchedClass(symbol: TypeSymbol, loc: Opt[Loc]): Unit = {
matchedClasses.getOrElseUpdate(symbol, Buffer.empty) ++= loc
Expand Down
123 changes: 93 additions & 30 deletions shared/src/main/scala/mlscript/ucs/stages/Desugaring.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait Desugaring { self: PreTyper =>
* `args_x$Cons`.
*/
private def makeUnappliedVar(scrutinee: Var, className: Var): Var =
Var(s"args_${scrutinee.name}$$${className.name}")
Var(s"$unappliedPrefix${scrutinee.name}$$${className.name}")

// I plan to integrate scrutinee symbols into a field of `ValueSymbol`.
// Because each `ValueSymbol` can be matched in multiple UCS expressions.
Expand Down Expand Up @@ -177,21 +177,21 @@ trait Desugaring { self: PreTyper =>
)

private def flattenClassParameters(parentScrutinee: Var, parentClassLikeSymbol: TypeSymbol, parameters: Ls[Opt[s.Pattern]]): Ls[Opt[Var -> Opt[s.Pattern]]] =
parameters.zipWithIndex.map {
parameters.iterator.zipWithIndex.map {
case (N, _) => N
case (S(s.NamePattern(name)), index) =>
val symbol = parentScrutinee.getScrutineeSymbol.getSubScrutineeSymbolOrElse(
parentClassLikeSymbol, index, name, new ValueSymbol(name, false)
)
S(name.withSymbol(symbol) -> N)
case (S(parameterPattern @ (s.ClassPattern(_, _) | s.LiteralPattern(_))), index) =>
case (S(parameterPattern @ (s.ClassPattern(_, _) | s.LiteralPattern(_) | s.TuplePattern(_))), index) =>
val scrutinee = freshScrutinee(parentScrutinee, parentClassLikeSymbol.name, index)
val symbol = parentScrutinee.getScrutineeSymbol.getSubScrutineeSymbolOrElse(
parentClassLikeSymbol, index, scrutinee, new ValueSymbol(scrutinee, false)
)
S(scrutinee.withSymbol(symbol) -> S(parameterPattern))
case _ => ??? // Other patterns are not implemented yet.
}
}.toList

/**
* Recursively decompose and flatten a possibly nested class pattern. Any
Expand All @@ -211,7 +211,7 @@ trait Desugaring { self: PreTyper =>
* @param initialScope the scope before flattening the class pattern
* @return a tuple of the augmented scope and a function that wrap a split
*/
private def flattenClassPattern(pattern: s.ClassPattern, scrutinee: Var, initialScope: Scope): (Scope, c.Split => c.Branch) = {
private def desugarClassPattern(pattern: s.ClassPattern, scrutinee: Var, initialScope: Scope): (Scope, c.Split => c.Branch) = {
val scrutineeSymbol = scrutinee.getScrutineeSymbol
val patternClassSymbol = pattern.nme.resolveTypeSymbol(initialScope)
// Most importantly, we need to add the class to the list of matched classes.
Expand All @@ -234,36 +234,86 @@ trait Desugaring { self: PreTyper =>
bindNextParameter.andThen { c.Split.Let(false, parameter, Sel(unapp, Var(index.toString)), _) }
}.andThen { c.Split.Let(false, unapp, makeUnapplyCall(scrutinee, pattern.nme), _): c.Split }
val scopeWithClassParameters = initialScope ++ (unapp.symbol :: nestedPatterns.flatMap(_.map(_._1.symbol)))
// Second, collect bindings from sub-patterns and accumulate a function
// that add let bindings to a split (we call it "binder").
nestedPatterns.foldLeft((scopeWithClassParameters, bindClassParameters)) {
// If this parameter is not matched with a sub-pattern, then we do
// nothing and pass on scope and binder.
case (acc, S(_ -> N)) => acc
// If this sub-pattern is a class pattern, we need to recursively flatten
// the class pattern. We will get a scope with all bindings and a function
// that adds all bindings to a split. The scope can be passed on to the
// next sub-pattern. The binder needs to be composed with the previous
// binder.
case ((scope, bindPrevious), S(nme -> S(pattern: s.ClassPattern))) =>
val (scopeWithNestedAll, bindNestedAll) = flattenClassPattern(pattern, nme, scope)
(scopeWithNestedAll, split => bindPrevious(bindNestedAll(split) :: c.Split.Nil))
case ((scope, bindPrevious), S(nme -> S(pattern: s.LiteralPattern))) =>
val test = freshTest().withFreshSymbol
(scope + test.symbol, makeLiteralTest(test, nme, pattern.literal)(scope).andThen(bindPrevious))
// Well, other patterns are not supported yet.
case (acc, S((nme, pattern))) => ???
// If this parameter is empty (e.g. produced by wildcard), then we do
// nothing and pass on scope and binder.
case (acc, N) => acc
}
desugarNestedPatterns(nestedPatterns, scopeWithClassParameters, bindClassParameters)
// If there is no parameter, then we are done.
case N => (initialScope, identity(_: c.Split))
}
// Last, return the scope with all bindings and a function that adds all matches and bindings to a split.
(scopeWithAll, split => c.Branch(scrutinee, c.Pattern.Class(pattern.nme), bindAll(split)))
}

/**
* This function collects bindings from nested patterns and accumulate a
* function that add let bindings to a split (we call such function a
* "binder"). This function is supposed to be called from pattern desugaring
* functions.
*
* @param nestedPatterns nested patterns are a list of sub-scrutinees and
* corresponding sub-patterns
* @param scopeWithScrutinees a scope with all sub-scrutinees
* @param bindScrutinees a function that adds all bindings to a split
*/
private def desugarNestedPatterns(
nestedPatterns: Ls[Opt[Var -> Opt[s.Pattern]]],
scopeWithScrutinees: Scope,
bindScrutinees: c.Split => c.Split
): (Scope, c.Split => c.Split) = {
nestedPatterns.foldLeft((scopeWithScrutinees, bindScrutinees)) {
// If this parameter is not matched with a sub-pattern, then we do
// nothing and pass on scope and binder.
case (acc, S(_ -> N)) => acc
// If this sub-pattern is a class pattern, we need to recursively flatten
// the class pattern. We will get a scope with all bindings and a function
// that adds all bindings to a split. The scope can be passed on to the
// next sub-pattern. The binder needs to be composed with the previous
// binder.
case ((scope, bindPrevious), S(nme -> S(pattern: s.ClassPattern))) =>
val (scopeWithNestedAll, bindNestedAll) = desugarClassPattern(pattern, nme, scope)
(scopeWithNestedAll, split => bindPrevious(bindNestedAll(split) :: c.Split.Nil))
case ((scope, bindPrevious), S(nme -> S(pattern: s.LiteralPattern))) =>
val test = freshTest().withFreshSymbol
(scope + test.symbol, makeLiteralTest(test, nme, pattern.literal)(scope).andThen(bindPrevious))
case ((scope, bindPrevious), S(nme -> S(s.TuplePattern(fields)))) =>
val (scopeWithNestedAll, bindNestedAll) = desugarTuplePattern(fields, nme, scope)
(scopeWithNestedAll, bindNestedAll.andThen(bindPrevious))
// Well, other patterns are not supported yet.
case (acc, S((nme, pattern))) => ???
// If this parameter is empty (e.g. produced by wildcard), then we do
// nothing and pass on scope and binder.
case (acc, N) => acc
}
}

private def flattenTupleFields(parentScrutinee: Var, fields: Ls[Opt[s.Pattern]]): Ls[Opt[Var -> Opt[s.Pattern]]] =
fields.iterator.zipWithIndex.map {
case (N, _) => N
case (S(s.NamePattern(name)), index) =>
val symbol = parentScrutinee.getScrutineeSymbol.getTupleSubScrutineeSymbolOrElse(
index, name, new ValueSymbol(name, false)
)
S(name.withSymbol(symbol) -> N)
case (S(parameterPattern @ (s.ClassPattern(_, _) | s.LiteralPattern(_) | s.TuplePattern(_))), index) =>
val scrutinee = freshScrutinee(parentScrutinee, "Tuple$2", index)
val symbol = parentScrutinee.getScrutineeSymbol.getTupleSubScrutineeSymbolOrElse(
index, scrutinee, new ValueSymbol(scrutinee, false)
)
S(scrutinee.withSymbol(symbol) -> S(parameterPattern))
case _ => ???
}.toList

private def desugarTuplePattern(fields: Ls[Opt[s.Pattern]], scrutinee: Var, initialScope: Scope): (Scope, c.Split => c.Split) = {
val scrutineeSymbol = scrutinee.getScrutineeSymbol
val nestedPatterns = flattenTupleFields(scrutinee, fields)
val bindTupleFields = nestedPatterns.iterator.zipWithIndex.foldRight[c.Split => c.Split](identity) {
case ((N, _), bindNextField) => bindNextField
case ((S(parameter -> _), index), bindNextField) =>
val indexVar = Var(index.toString).withLoc(parameter.toLoc)
bindNextField.andThen { c.Split.Let(false, parameter, Sel(scrutinee, indexVar), _) }
}
val scopeWithTupleFields = initialScope ++ nestedPatterns.flatMap(_.map(_._1.symbol))
desugarNestedPatterns(nestedPatterns, scopeWithTupleFields, bindTupleFields)
}

private def desugarPatternSplit(split: s.PatternSplit)(implicit scrutinee: Term, scope: Scope): c.Split = {
def rec(scrutinee: Var, split: s.PatternSplit)(implicit scope: Scope): c.Split = split match {
case s.Split.Cons(head, tail) =>
Expand Down Expand Up @@ -293,10 +343,19 @@ trait Desugaring { self: PreTyper =>
case pattern @ s.ClassPattern(nme, fields) =>
println(s"find term symbol of $scrutinee in ${scope.showLocalSymbols}")
scrutinee.symbol = scope.getTermSymbol(scrutinee.name).getOrElse(???)
val (scopeWithAll, bindAll) = flattenClassPattern(pattern, scrutinee, scope)
val (scopeWithAll, bindAll) = desugarClassPattern(pattern, scrutinee, scope)
val continuation = desugarTermSplit(head.continuation)(PartialTerm.Empty, scopeWithAll)
bindAll(continuation) :: rec(scrutinee, tail)
case s.TuplePattern(fields) => ???
case s.TuplePattern(fields) =>
scrutinee.symbol = scope.getTermSymbol(scrutinee.name).getOrElse(???)
val (scopeWithAll, bindAll) = desugarTuplePattern(fields, scrutinee, scope)
val continuation = desugarTermSplit(head.continuation)(PartialTerm.Empty, scopeWithAll)
val withBindings = bindAll(continuation)
if (withBindings.hasElse) {
withBindings
} else {
withBindings ++ rec(scrutinee, tail)
}
case s.RecordPattern(entries) => ???
}
case s.Split.Let(isRec, nme, rhs, tail) =>
Expand Down Expand Up @@ -327,8 +386,12 @@ object Desugaring {
val cachePrefix = "cache$"
val scrutineePrefix = "scrut$"
val testPrefix = "test$"
val unappliedPrefix = "args_"

def isCacheVar(nme: Var): Bool = nme.name.startsWith(cachePrefix)
def isScrutineeVar(nme: Var): Bool = nme.name.startsWith(scrutineePrefix)
def isTestVar(nme: Var): Bool = nme.name.startsWith(testPrefix)
def isUnappliedVar(nme: Var): Bool = nme.name.startsWith(unappliedPrefix)
def isGeneratedVar(nme: Var): Bool =
isCacheVar(nme) || isScrutineeVar(nme) || isTestVar(nme) || isUnappliedVar(nme)
}
Loading

0 comments on commit 9cce624

Please sign in to comment.