From a42bebae0f8ce50018907f01f664e99154e55e54 Mon Sep 17 00:00:00 2001 From: noti0na1 Date: Tue, 22 Oct 2024 19:36:27 +0200 Subject: [PATCH] Dependent pattern match prototype --- .../src/dotty/tools/dotc/ast/Desugar.scala | 13 +++- .../dotty/tools/dotc/typer/Applications.scala | 5 +- .../src/dotty/tools/dotc/typer/Typer.scala | 31 +++++++--- library/src/scala/Tuple.scala | 2 +- .../dependent-pattern-match-type-test.scala | 17 ++++++ tests/pos/dependent-pattern-match.scala | 59 +++++++++++++++++++ 6 files changed, 113 insertions(+), 14 deletions(-) create mode 100644 tests/pos/dependent-pattern-match-type-test.scala create mode 100644 tests/pos/dependent-pattern-match.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index e66c71731b4f..0eba7fad00a4 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -810,8 +810,12 @@ object desugar { for i <- List.range(0, arity) selName = nme.selectorName(i) if (selName ne caseParams(i).name) && !selectorNamesInBody.contains(selName) - yield syntheticProperty(selName, caseParams(i).tpt, - Select(This(EmptyTypeIdent), caseParams(i).name)) + yield + val ptp = + if caseParams(i).mods.mods.exists(_.isInstanceOf[Mod.Var]) then caseParams(i).tpt + else SingletonTypeTree(Select(This(EmptyTypeIdent), caseParams(i).name)) + syntheticProperty(selName, ptp, + Select(This(EmptyTypeIdent), caseParams(i).name)) def enumCaseMeths = if isEnumCase then @@ -918,7 +922,10 @@ object desugar { if (arity == 0) Literal(Constant(true)) else if caseClassInScala2Library then scala2LibCompatUnapplyRhs(unapplyParam.name) else Ident(unapplyParam.name) - val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree() + val unapplyResTp = + if (arity == 0) Literal(Constant(true)) + else if caseClassInScala2Library then TypeTree() + else SingletonTypeTree(Ident(unapplyParam.name)) DefDef( methName, diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 17be2acc7378..a3a30cd96031 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1625,7 +1625,8 @@ trait Applications extends Compatibility { val patternBound = maximizeType(unapplyArgType, unapplyFn.span.endPos) if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound) unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}") - unapplyArgType + selType & unapplyArgType + // unapplyArgType val dummyArg = dummyTreeOfType(ownType) val (newUnapplyFn, unapplyApp) = @@ -1637,7 +1638,7 @@ trait Applications extends Compatibility { .typedPatterns(qual, this) val result = assignType(cpy.UnApply(tree)(newUnapplyFn, unapplyImplicits(dummyArg, unapplyApp), unapplyPatterns), ownType) if (ownType.stripped eq selType.stripped) || ownType.isError then result - else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType) + else tryWithTypeTest(Typed(result, TypeTree(unapplyArgType)), selType) case tp => val unapplyErr = if (tp.isError) unapplyFn else notAnExtractor(unapplyFn) val typedArgsErr = unadaptedArgs.mapconserve(typed(_, defn.AnyType)) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 3810bc66841e..35bb669bee98 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -593,7 +593,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typr.println(s"typed ident $kind$name in ${ctx.owner}") if ctx.mode.is(Mode.Pattern) then if name == nme.WILDCARD then - return tree.withType(pt) + val wpt = pt match + case pt: TermRef => pt.widen + case pt => pt + return tree.withType(wpt) if name == tpnme.WILDCARD then return tree.withType(defn.AnyType) if untpd.isVarPattern(tree) && name.isTermName then @@ -2031,7 +2034,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val rawSelectorTpe = fullyDefinedType(sel1.tpe, "pattern selector", tree.srcPos) val selType = rawSelectorTpe match case c: ConstantType if tree.isInline => c + case ref: TermRef => ref case otherTpe => otherTpe.widen + val selTypeW = selType match + case tr: TermRef => tr.widen + case _ => selType + /** Does `tree` has the same shape as the given match type? * We only support typed patterns with empty guards, but @@ -2049,7 +2057,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // To check that pattern types correspond we need to type // check `pat` here and throw away the result. val gadtCtx: Context = ctx.fresh.setFreshGADTBounds - val pat1 = typedPattern(pat, selType)(using gadtCtx) + val pat1 = typedPattern(pat, selTypeW)(using gadtCtx) val tpt = tpd.unbind(tpd.unsplice(pat1)) match case Typed(_, tpt) => tpt case UnApply(fun, _, p1 :: _) if fun.symbol == defn.TypeTest_unapply => p1 @@ -2067,7 +2075,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val result = pt.underlyingNormalizable match { case mt: MatchType if isMatchTypeShaped(mt) => - typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt) + typedDependentMatchFinish(tree, sel1, selTypeW, tree.cases, mt) case _ => typedMatchFinish(tree, sel1, selType, tree.cases, pt) } @@ -2677,6 +2685,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedBind(tree: untpd.Bind, pt: Type)(using Context): Tree = { if !isFullyDefined(pt, ForceDegree.all) then return errorTree(tree, em"expected type of $tree is not fully defined") + val wpt = pt match + case pt: TermRef => pt.widen + case _ => pt val body1 = typed(tree.body, pt) body1 match { case UnApply(fn, Nil, arg :: Nil) @@ -2705,9 +2716,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val symTp = if isStableIdentifierOrLiteral || pt.isNamedTupleType then pt // need to combine tuple element types with expected named type - else if isWildcardStarArg(body1) - || pt == defn.ImplicitScrutineeTypeRef - || body1.tpe <:< pt // There is some strange interaction with gadt matching. + else if isWildcardStarArg(body1) || pt == defn.ImplicitScrutineeTypeRef + then body1.tpe + else if body1.tpe <:< wpt // There is some strange interaction with gadt matching. // and implicit scopes. // run/t2755.scala fails to compile if this subtype test is omitted // and the else clause is changed to `body1.tpe & pt`. What @@ -2717,14 +2728,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer // it is Array[T] we get an implicit not found. To avoid fragility // wrt to operand order for `&`, we include the explicit subtype test here. // See also #5649. - then body1.tpe + then + if ctx.mode.is(Mode.Pattern) && (pt ne wpt) // && !(body1.tpe <:< pt) + then AndType(pt, body1.tpe) + // then pt & body1.tpe + else body1.tpe else body1.tpe match case btpe: TypeRef if btpe.symbol == defn.TupleXXLClass && pt.tupleElementTypes.isDefined => // leave the original tuple type; don't mix with & TupleXXL which would only obscure things pt case _ => - pt & body1.tpe + wpt & body1.tpe val sym = newPatternBoundSymbol(name, symTp, tree.span) if (pt == defn.ImplicitScrutineeTypeRef || tree.mods.is(Given)) sym.setFlag(Given) if (ctx.mode.is(Mode.InPatternAlternative)) diff --git a/library/src/scala/Tuple.scala b/library/src/scala/Tuple.scala index 8074fe3664e5..2f163d81e5e3 100644 --- a/library/src/scala/Tuple.scala +++ b/library/src/scala/Tuple.scala @@ -265,7 +265,7 @@ object Tuple { /** Convert an array into a tuple of unknown arity and types */ def fromArray[T](xs: Array[T]): Tuple = { - val xs2 = xs match { + val xs2: Array[Object] = xs match { case xs: Array[Object] => xs case xs => xs.map(_.asInstanceOf[Object]) } diff --git a/tests/pos/dependent-pattern-match-type-test.scala b/tests/pos/dependent-pattern-match-type-test.scala new file mode 100644 index 000000000000..c7b9036678d0 --- /dev/null +++ b/tests/pos/dependent-pattern-match-type-test.scala @@ -0,0 +1,17 @@ +//> using options -Xfatal-warnings + +import scala.reflect.TypeTest + +trait M: + type Tree <: AnyRef + type Apply <: Tree + given TypeTest[Tree, Apply] = ??? + val Apply: ApplyModule + trait ApplyModule: + this: Apply.type => + def unapply(x: Apply): (Tree, Tree) = ??? + + def quote(x: Tree) = x match + case Apply(f, args) => + println(args) + case _ => \ No newline at end of file diff --git a/tests/pos/dependent-pattern-match.scala b/tests/pos/dependent-pattern-match.scala new file mode 100644 index 000000000000..8c799b790a4b --- /dev/null +++ b/tests/pos/dependent-pattern-match.scala @@ -0,0 +1,59 @@ +trait A: + val x: Int + type T +case class B(x: Int, y: Int) extends A: + type T = String + val z: T = "B" +case class C(x: Int) extends A: + type T = Int + val z: T = 1 + +object BParts: + def unapply(b: B): Option[(b.x.type, b.y.type)] = Some((b.x, b.y)) + +def test1(a: A) = a match + case b @ B(x, y) => + // b: a.type & B + // x: (a.type & B).x.type + // y: (a.type & B).y.type + val e1: a.type = b + val e2: a.x.type = b.x + val e3: a.x.type = x + val e4: b.x.type = x + x + y + case C(x) => + val e1: a.x.type = x + x + case BParts(x, y) => + val e1: a.x.type = x + x + y + case _ => 0 + +def test2(a: A): a.T = + a match + case b: B => + // b: a.type & B + // b.z: b.T = (a & B)#T = a.T & String + b.z + case c: C => c.z + +def test3(a: A): a.T = + a match + case b: B => + // b: a.type & B; hence b.T <:< a.T & String + // We don't have a: b.type in the body, + // so we can't prove String <:< a.T + val x: b.T = b.z + "" + x + case c: C => + val x: c.T = c.z + 0 + x + +def test4(x: A, y: A) = + x match + case z: y.type => + // if x.eq(y) then z = x = y, + // z: x.type & y.type + val a: x.type = z + val b: y.type = z + case _ => \ No newline at end of file