Skip to content

Commit 02c41bd

Browse files
authored
Scala 3 fixes and improvements (#509)
* fix default param with type param * path dependent types support added to scala 3 * fix context bounded classes
1 parent a43f4de commit 02c41bd

File tree

6 files changed

+231
-46
lines changed

6 files changed

+231
-46
lines changed

shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
package org.scalamock.clazz
2222

2323
import org.scalamock.context.MockContext
24-
2524
import scala.quoted.*
2625
import scala.reflect.Selectable
2726

@@ -42,16 +41,23 @@ private[clazz] object MockMaker:
4241
def asParent(tree: TypeTree): TypeTree | Term =
4342
val constructorFieldsFilledWithNulls: List[List[Term]] =
4443
tree.tpe.dealias.typeSymbol.primaryConstructor.paramSymss
45-
.filter(_.exists(!_.isType))
46-
.map(_.map(_.typeRef.asType match { case '[t] => '{ null.asInstanceOf[t] }.asTerm }))
44+
.filterNot(_.exists(_.isType))
45+
.map(_.map(_.info.widen match {
46+
case t@AppliedType(inner, applied) =>
47+
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(inner.appliedTo(tpe.typeArgs)))
48+
case other =>
49+
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(other))
50+
}))
4751

4852
if constructorFieldsFilledWithNulls.forall(_.isEmpty) then
4953
tree
5054
else
5155
Select(
5256
New(TypeIdent(tree.tpe.typeSymbol)),
5357
tree.tpe.typeSymbol.primaryConstructor
54-
).appliedToArgss(constructorFieldsFilledWithNulls)
58+
).appliedToTypes(tree.tpe.typeArgs)
59+
.appliedToArgss(constructorFieldsFilledWithNulls)
60+
5561

5662

5763
val parents =
@@ -91,15 +97,15 @@ private[clazz] object MockMaker:
9197
Symbol.newVal(
9298
parent = classSymbol,
9399
name = definition.symbol.name,
94-
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
100+
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
95101
flags = Flags.Override,
96102
privateWithin = Symbol.noSymbol
97103
)
98104
else
99105
Symbol.newMethod(
100106
parent = classSymbol,
101107
name = definition.symbol.name,
102-
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
108+
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
103109
flags = Flags.Override,
104110
privateWithin = Symbol.noSymbol
105111
)
@@ -177,7 +183,7 @@ private[clazz] object MockMaker:
177183
"asInstanceOf"
178184
),
179185
definition.tpe
180-
.resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args)
186+
.resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args)
181187
.asType match { case '[t] => List(TypeTree.of[t]) }
182188
)
183189
)

shared/src/main/scala-3/org/scalamock/clazz/Utils.scala

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@ package org.scalamock.clazz
33
import scala.quoted.*
44
import org.scalamock.context.MockContext
55

6-
import scala.annotation.tailrec
6+
import scala.annotation.{experimental, tailrec}
77
private[clazz] class Utils(using val quotes: Quotes):
88
import quotes.reflect.*
99

1010
extension (tpe: TypeRepr)
11-
def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] =
11+
def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] =
1212
def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] =
1313
currentTpe match
14-
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol))
14+
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol))
1515
case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe)
1616
case TypeRef(inner, name) => loop(inner, name :: names)
1717
case _ => Nil
1818

1919
loop(tpe, Nil)
2020

21-
def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
21+
def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
2222
@tailrec
2323
def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr =
2424
currentTpe match
@@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes):
5353
case _ =>
5454
tpe
5555

56+
@experimental
5657
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
57-
def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr =
58-
typeRepr match
59-
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
60-
methodArgs.head(idx).asInstanceOf[TypeTree].tpe
58+
tpe match
59+
case baseBindings: PolyType =>
60+
def loop(typeRepr: TypeRepr): TypeRepr =
61+
typeRepr match
62+
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
63+
methodArgs.head(idx).asInstanceOf[TypeTree].tpe
6164

62-
case AppliedType(tycon, args) =>
63-
AppliedType(tycon, args.map(arg => loop(baseBindings, arg)))
65+
case AppliedType(tycon, args) =>
66+
AppliedType(loop(tycon), args.map(arg => loop(arg)))
6467

65-
case other => other
68+
case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) =>
69+
def getIndex(bindings: TypeRepr): Int =
70+
@tailrec
71+
def loop(bindings: TypeRepr, idx: Int): Int =
72+
bindings match
73+
case MethodType(_, _, method: MethodType) => loop(method, idx + 1)
74+
case _ => idx
6675

67-
tpe match
68-
case pt: PolyType => loop(pt, resType)
69-
case _ => resType
76+
loop(bindings, 1)
77+
78+
val maxIndex = methodArgs.length
79+
val parameterListIdx = maxIndex - getIndex(bindings)
80+
81+
TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe
82+
83+
case other => other
84+
85+
loop(resType)
86+
case _ =>
87+
resType
7088

7189

72-
def collectTypes: List[TypeRepr] =
73-
def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] =
90+
def collectTypes: (List[TypeRepr], TypeRepr) =
91+
@tailrec
92+
def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) =
7493
currentTpe match
75-
case PolyType(_, _, res) => loop(res, Nil)
76-
case MethodType(_, argTypes, res) => argTypes ++ loop(res, params)
77-
case other => List(other)
78-
loop(tpe, Nil)
94+
case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType)
95+
case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType)
96+
case other => (argTypesAcc.reverse.flatten, other)
97+
loop(tpe, Nil, TypeRepr.of[Nothing])
7998

8099
case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr):
81100
val mockValName = s"mock$$${symbol.name}$$$idx"
82101
val tpe = ownerTpe.memberType(symbol)
83-
private val rawTypes = tpe.widen.collectTypes
102+
private val (rawTypes, rawResType) = tpe.widen.collectTypes
84103
val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init
85104

86-
def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr =
87-
val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol)
88-
val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
89-
rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated)
105+
def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr =
106+
updatePathDependent(rawResType, List(rawResType), classSymbol)
107+
108+
def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr =
109+
updatePathDependent(tpe, rawResType :: rawTypes, classSymbol)
90110

91-
def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr =
92-
val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol))
93-
val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
94-
tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
111+
private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr =
112+
val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol))
113+
val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
114+
where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
95115

96-
def prepareTypesFor(classSymbol: Symbol) = rawTypes
97-
.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
116+
def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType)
117+
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
98118
.map { typeRepr =>
99119
val adjusted =
100120
typeRepr.widen.mapParamRefWithWildcard match
101121
case TypeBounds(lower, upper) => upper
102122
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
103123
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
104-
case other => other
124+
case TypeRef(_: ParamRef, _) =>
125+
TypeRepr.of[Any]
126+
case AppliedType(TypeRef(_: ParamRef, _), _) =>
127+
TypeRepr.of[Any]
128+
case other =>
129+
other
105130
adjusted.asType match
106131
case '[t] => TypeTree.of[t]
107132
}
@@ -128,10 +153,11 @@ private[clazz] class Utils(using val quotes: Quotes):
128153

129154
def apply(tpe: TypeRepr): List[MockableDefinition] =
130155
val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList
131-
.filter(sym => !sym.flags.is(Flags.Private) && !sym.flags.is(Flags.Final) && !sym.flags.is(Flags.Mutable))
132-
.filterNot(sym => tpe.memberType(sym) match
133-
case defaultParam @ ByNameType(AnnotatedType(_, Apply(Select(New(Inferred()), "<init>"), Nil))) => true
134-
case _ => false
156+
.filter(sym =>
157+
!sym.flags.is(Flags.Private) &&
158+
!sym.flags.is(Flags.Final) &&
159+
!sym.flags.is(Flags.Mutable) &&
160+
!sym.name.contains("$default$")
135161
)
136162
.zipWithIndex
137163
.map((sym, idx) => MockableDefinition(idx, sym, tpe))
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.paulbutcher.test
2+
3+
import org.scalamock.scalatest.MockFactory
4+
import org.scalatest.funspec.AnyFunSpec
5+
6+
import scala.reflect.ClassTag
7+
8+
class ClassWithContextBoundSpec extends AnyFunSpec with MockFactory {
9+
10+
it("compile without args") {
11+
class ContextBounded[T: ClassTag] {
12+
def method(x: Int): Unit = ()
13+
}
14+
15+
val m = mock[ContextBounded[String]]
16+
17+
}
18+
19+
it("compile with args") {
20+
class ContextBounded[T: ClassTag](x: Int) {
21+
def method(x: Int): Unit = ()
22+
}
23+
24+
val m = mock[ContextBounded[String]]
25+
26+
}
27+
28+
it("compile with provided explicitly type class") {
29+
class ContextBounded[T](x: ClassTag[T]) {
30+
def method(x: Int): Unit = ()
31+
}
32+
33+
val m = mock[ContextBounded[String]]
34+
35+
}
36+
37+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package com.paulbutcher.test
2+
3+
import org.scalamock.matchers.Matchers
4+
import org.scalamock.scalatest.MockFactory
5+
import org.scalatest.funspec.AnyFunSpec
6+
7+
class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory {
8+
9+
trait Command {
10+
type Answer
11+
type AnswerConstructor[A]
12+
}
13+
14+
case class IntCommand() extends Command {
15+
override type Answer = Int
16+
override type AnswerConstructor[A] = Option[A]
17+
}
18+
19+
val cmd = IntCommand()
20+
21+
trait PathDependent {
22+
23+
def call0[T <: Command](cmd: T): cmd.Answer
24+
25+
def call1[T <: Command](x: Int)(cmd: T): cmd.Answer
26+
27+
def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer
28+
29+
def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer
30+
31+
def call4[T <: Command](cmd: T): Option[cmd.Answer]
32+
33+
def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit
34+
35+
def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int]
36+
37+
def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit
38+
}
39+
40+
41+
it("path dependent in return type") {
42+
val pathDependent = mock[PathDependent]
43+
44+
(pathDependent.call0[IntCommand] _).expects(cmd).returns(5)
45+
46+
assert(pathDependent.call0(cmd) == 5)
47+
}
48+
49+
it("path dependent in return type and parameter in last parameter list") {
50+
val pathDependent = mock[PathDependent]
51+
52+
(pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5)
53+
54+
assert(pathDependent.call1(5)(cmd) == 5)
55+
}
56+
57+
it("path dependent in return type and parameter in middle parameter list ") {
58+
val pathDependent = mock[PathDependent]
59+
60+
(pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5)
61+
62+
assert(pathDependent.call2("5")(cmd)(5) == 5)
63+
}
64+
65+
it("path dependent in return type and parameter in first parameter list ") {
66+
val pathDependent = mock[PathDependent]
67+
68+
(pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5)
69+
70+
assert(pathDependent.call3(cmd)("5")(5) == 5)
71+
}
72+
73+
it("path dependent in tycon return type") {
74+
val pathDependent = mock[PathDependent]
75+
76+
(pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5))
77+
78+
assert(pathDependent.call4(cmd) == Some(5))
79+
}
80+
81+
it("path dependent in parameter list") {
82+
val pathDependent = mock[PathDependent]
83+
84+
(pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(())
85+
86+
assert(pathDependent.call5(cmd)(5) == ())
87+
}
88+
89+
it("path dependent tycon in return type") {
90+
val pathDependent = mock[PathDependent]
91+
92+
(pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5))
93+
94+
assert(pathDependent.call6(cmd) == Some(5))
95+
}
96+
97+
it("path dependent tycon in parameter list") {
98+
val pathDependent = mock[PathDependent]
99+
100+
(pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int))
101+
.expects(cmd, Some("5"), 6)
102+
.returns(())
103+
104+
assert(pathDependent.call7(cmd)(Some("5"))(6) == ())
105+
}
106+
107+
}

shared/src/test/scala/com/paulbutcher/test/mock/MethodsWithDefaultParamsTest.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {
3434

3535
trait TraitHavingMethodsWithDefaultParams {
3636
def withAllDefaultParams(a: String = "default", b: CaseClass = CaseClass(42)): String
37+
38+
def withDefaultParamAndTypeParam[T](a: String = "default", b: Int = 5): T
3739
}
3840

3941
behavior of "Mocks"
@@ -84,5 +86,13 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {
8486
m.withAllDefaultParams("other", CaseClass(99))
8587
}
8688

89+
they should "mock trait methods with type param and default parameters" in {
90+
val m = mock[TraitHavingMethodsWithDefaultParams]
91+
92+
(m.withDefaultParamAndTypeParam[Int] _).expects("default", 5).returns(5)
93+
94+
m.withDefaultParamAndTypeParam[Int]("default", 5) shouldBe 5
95+
}
96+
8797
override def newInstance = new MethodsWithDefaultParamsTest
8898
}

shared/src/test/scala/org/scalamock/test/scalatest/AsyncSyncMixinTest.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020

2121
package org.scalamock.test.scalatest
2222

23-
import org.scalatest.flatspec.AnyFlatSpec
24-
import org.scalatest._
23+
import org.scalatest.flatspec.{AnyFlatSpec, AsyncFlatSpec}
24+
import org.scalamock.scalatest.{MockFactory, AsyncMockFactory}
2525

2626
/**
2727
* Tests for issue #371
2828
*/
29-
@Ignore
3029
class AsyncSyncMixinTest extends AnyFlatSpec {
3130

3231
"MockFactory" should "be mixed only with Any*Spec and not Async*Spec traits" in {

0 commit comments

Comments
 (0)