Skip to content

Commit

Permalink
Dependent pattern match prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Oct 22, 2024
1 parent ecc332f commit a42beba
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 14 deletions.
13 changes: 10 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,12 @@ object desugar {
for i <- List.range(0, arity)
selName = nme.selectorName(i)
if (selName ne caseParams(i).name) && !selectorNamesInBody.contains(selName)
yield syntheticProperty(selName, caseParams(i).tpt,
Select(This(EmptyTypeIdent), caseParams(i).name))
yield
val ptp =
if caseParams(i).mods.mods.exists(_.isInstanceOf[Mod.Var]) then caseParams(i).tpt
else SingletonTypeTree(Select(This(EmptyTypeIdent), caseParams(i).name))
syntheticProperty(selName, ptp,
Select(This(EmptyTypeIdent), caseParams(i).name))

def enumCaseMeths =
if isEnumCase then
Expand Down Expand Up @@ -918,7 +922,10 @@ object desugar {
if (arity == 0) Literal(Constant(true))
else if caseClassInScala2Library then scala2LibCompatUnapplyRhs(unapplyParam.name)
else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp =
if (arity == 0) Literal(Constant(true))
else if caseClassInScala2Library then TypeTree()
else SingletonTypeTree(Ident(unapplyParam.name))

DefDef(
methName,
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,8 @@ trait Applications extends Compatibility {
val patternBound = maximizeType(unapplyArgType, unapplyFn.span.endPos)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
unapplyArgType
selType & unapplyArgType
// unapplyArgType

val dummyArg = dummyTreeOfType(ownType)
val (newUnapplyFn, unapplyApp) =
Expand All @@ -1637,7 +1638,7 @@ trait Applications extends Compatibility {
.typedPatterns(qual, this)
val result = assignType(cpy.UnApply(tree)(newUnapplyFn, unapplyImplicits(dummyArg, unapplyApp), unapplyPatterns), ownType)
if (ownType.stripped eq selType.stripped) || ownType.isError then result
else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType)
else tryWithTypeTest(Typed(result, TypeTree(unapplyArgType)), selType)
case tp =>
val unapplyErr = if (tp.isError) unapplyFn else notAnExtractor(unapplyFn)
val typedArgsErr = unadaptedArgs.mapconserve(typed(_, defn.AnyType))
Expand Down
31 changes: 23 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typr.println(s"typed ident $kind$name in ${ctx.owner}")
if ctx.mode.is(Mode.Pattern) then
if name == nme.WILDCARD then
return tree.withType(pt)
val wpt = pt match
case pt: TermRef => pt.widen
case pt => pt
return tree.withType(wpt)
if name == tpnme.WILDCARD then
return tree.withType(defn.AnyType)
if untpd.isVarPattern(tree) && name.isTermName then
Expand Down Expand Up @@ -2031,7 +2034,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val rawSelectorTpe = fullyDefinedType(sel1.tpe, "pattern selector", tree.srcPos)
val selType = rawSelectorTpe match
case c: ConstantType if tree.isInline => c
case ref: TermRef => ref
case otherTpe => otherTpe.widen
val selTypeW = selType match
case tr: TermRef => tr.widen
case _ => selType


/** Does `tree` has the same shape as the given match type?
* We only support typed patterns with empty guards, but
Expand All @@ -2049,7 +2057,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// To check that pattern types correspond we need to type
// check `pat` here and throw away the result.
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
val pat1 = typedPattern(pat, selType)(using gadtCtx)
val pat1 = typedPattern(pat, selTypeW)(using gadtCtx)
val tpt = tpd.unbind(tpd.unsplice(pat1)) match
case Typed(_, tpt) => tpt
case UnApply(fun, _, p1 :: _) if fun.symbol == defn.TypeTest_unapply => p1
Expand All @@ -2067,7 +2075,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

val result = pt.underlyingNormalizable match {
case mt: MatchType if isMatchTypeShaped(mt) =>
typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
typedDependentMatchFinish(tree, sel1, selTypeW, tree.cases, mt)
case _ =>
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
}
Expand Down Expand Up @@ -2677,6 +2685,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedBind(tree: untpd.Bind, pt: Type)(using Context): Tree = {
if !isFullyDefined(pt, ForceDegree.all) then
return errorTree(tree, em"expected type of $tree is not fully defined")
val wpt = pt match
case pt: TermRef => pt.widen
case _ => pt
val body1 = typed(tree.body, pt)
body1 match {
case UnApply(fn, Nil, arg :: Nil)
Expand Down Expand Up @@ -2705,9 +2716,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val symTp =
if isStableIdentifierOrLiteral || pt.isNamedTupleType then pt
// need to combine tuple element types with expected named type
else if isWildcardStarArg(body1)
|| pt == defn.ImplicitScrutineeTypeRef
|| body1.tpe <:< pt // There is some strange interaction with gadt matching.
else if isWildcardStarArg(body1) || pt == defn.ImplicitScrutineeTypeRef
then body1.tpe
else if body1.tpe <:< wpt // There is some strange interaction with gadt matching.
// and implicit scopes.
// run/t2755.scala fails to compile if this subtype test is omitted
// and the else clause is changed to `body1.tpe & pt`. What
Expand All @@ -2717,14 +2728,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// it is Array[T] we get an implicit not found. To avoid fragility
// wrt to operand order for `&`, we include the explicit subtype test here.
// See also #5649.
then body1.tpe
then
if ctx.mode.is(Mode.Pattern) && (pt ne wpt) // && !(body1.tpe <:< pt)
then AndType(pt, body1.tpe)
// then pt & body1.tpe
else body1.tpe
else body1.tpe match
case btpe: TypeRef
if btpe.symbol == defn.TupleXXLClass && pt.tupleElementTypes.isDefined =>
// leave the original tuple type; don't mix with & TupleXXL which would only obscure things
pt
case _ =>
pt & body1.tpe
wpt & body1.tpe
val sym = newPatternBoundSymbol(name, symTp, tree.span)
if (pt == defn.ImplicitScrutineeTypeRef || tree.mods.is(Given)) sym.setFlag(Given)
if (ctx.mode.is(Mode.InPatternAlternative))
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/Tuple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ object Tuple {

/** Convert an array into a tuple of unknown arity and types */
def fromArray[T](xs: Array[T]): Tuple = {
val xs2 = xs match {
val xs2: Array[Object] = xs match {
case xs: Array[Object] => xs
case xs => xs.map(_.asInstanceOf[Object])
}
Expand Down
17 changes: 17 additions & 0 deletions tests/pos/dependent-pattern-match-type-test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//> using options -Xfatal-warnings

import scala.reflect.TypeTest

trait M:
type Tree <: AnyRef
type Apply <: Tree
given TypeTest[Tree, Apply] = ???
val Apply: ApplyModule
trait ApplyModule:
this: Apply.type =>
def unapply(x: Apply): (Tree, Tree) = ???

def quote(x: Tree) = x match
case Apply(f, args) =>
println(args)
case _ =>
59 changes: 59 additions & 0 deletions tests/pos/dependent-pattern-match.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
trait A:
val x: Int
type T
case class B(x: Int, y: Int) extends A:
type T = String
val z: T = "B"
case class C(x: Int) extends A:
type T = Int
val z: T = 1

object BParts:
def unapply(b: B): Option[(b.x.type, b.y.type)] = Some((b.x, b.y))

def test1(a: A) = a match
case b @ B(x, y) =>
// b: a.type & B
// x: (a.type & B).x.type
// y: (a.type & B).y.type
val e1: a.type = b
val e2: a.x.type = b.x
val e3: a.x.type = x
val e4: b.x.type = x
x + y
case C(x) =>
val e1: a.x.type = x
x
case BParts(x, y) =>
val e1: a.x.type = x
x + y
case _ => 0

def test2(a: A): a.T =
a match
case b: B =>
// b: a.type & B
// b.z: b.T = (a & B)#T = a.T & String
b.z
case c: C => c.z

def test3(a: A): a.T =
a match
case b: B =>
// b: a.type & B; hence b.T <:< a.T & String
// We don't have a: b.type in the body,
// so we can't prove String <:< a.T
val x: b.T = b.z + ""
x
case c: C =>
val x: c.T = c.z + 0
x

def test4(x: A, y: A) =
x match
case z: y.type =>
// if x.eq(y) then z = x = y,
// z: x.type & y.type
val a: x.type = z
val b: y.type = z
case _ =>

0 comments on commit a42beba

Please sign in to comment.