@@ -3,22 +3,22 @@ package org.scalamock.clazz
3
3
import scala .quoted .*
4
4
import org .scalamock .context .MockContext
5
5
6
- import scala .annotation .tailrec
6
+ import scala .annotation .{ experimental , tailrec }
7
7
private [clazz] class Utils (using val quotes : Quotes ):
8
8
import quotes .reflect .*
9
9
10
10
extension (tpe : TypeRepr )
11
- def collectPathDependent (ownerSymbol : Symbol ): List [TypeRepr ] =
11
+ def collectInnerTypes (ownerSymbol : Symbol ): List [TypeRepr ] =
12
12
def loop (currentTpe : TypeRepr , names : List [String ]): List [TypeRepr ] =
13
13
currentTpe match
14
- case AppliedType (inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent (ownerSymbol))
14
+ case AppliedType (inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes (ownerSymbol))
15
15
case TypeRef (inner, name) if name == ownerSymbol.name && names.nonEmpty => List (tpe)
16
16
case TypeRef (inner, name) => loop(inner, name :: names)
17
17
case _ => Nil
18
18
19
19
loop(tpe, Nil )
20
20
21
- def pathDependentOverride (ownerSymbol : Symbol , newOwnerSymbol : Symbol , applyTypes : Boolean ): TypeRepr =
21
+ def innerTypeOverride (ownerSymbol : Symbol , newOwnerSymbol : Symbol , applyTypes : Boolean ): TypeRepr =
22
22
@ tailrec
23
23
def loop (currentTpe : TypeRepr , names : List [(String , List [TypeRepr ])], appliedTypes : List [TypeRepr ]): TypeRepr =
24
24
currentTpe match
@@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes):
53
53
case _ =>
54
54
tpe
55
55
56
+ @ experimental
56
57
def resolveParamRefs (resType : TypeRepr , methodArgs : List [List [Tree ]]) =
57
- def loop (baseBindings : TypeRepr , typeRepr : TypeRepr ): TypeRepr =
58
- typeRepr match
59
- case pr@ ParamRef (bindings, idx) if bindings == baseBindings =>
60
- methodArgs.head(idx).asInstanceOf [TypeTree ].tpe
58
+ tpe match
59
+ case baseBindings : PolyType =>
60
+ def loop (typeRepr : TypeRepr ): TypeRepr =
61
+ typeRepr match
62
+ case pr@ ParamRef (bindings, idx) if bindings == baseBindings =>
63
+ methodArgs.head(idx).asInstanceOf [TypeTree ].tpe
61
64
62
- case AppliedType (tycon, args) =>
63
- AppliedType (tycon, args.map(arg => loop(baseBindings, arg)))
65
+ case AppliedType (tycon, args) =>
66
+ AppliedType (loop( tycon) , args.map(arg => loop(arg)))
64
67
65
- case other => other
68
+ case ff @ TypeRef (ref @ ParamRef (bindings, idx), name) =>
69
+ def getIndex (bindings : TypeRepr ): Int =
70
+ @ tailrec
71
+ def loop (bindings : TypeRepr , idx : Int ): Int =
72
+ bindings match
73
+ case MethodType (_, _, method : MethodType ) => loop(method, idx + 1 )
74
+ case _ => idx
66
75
67
- tpe match
68
- case pt : PolyType => loop(pt, resType)
69
- case _ => resType
76
+ loop(bindings, 1 )
77
+
78
+ val maxIndex = methodArgs.length
79
+ val parameterListIdx = maxIndex - getIndex(bindings)
80
+
81
+ TypeSelect (methodArgs(parameterListIdx)(idx).asInstanceOf [Term ], name).tpe
82
+
83
+ case other => other
84
+
85
+ loop(resType)
86
+ case _ =>
87
+ resType
70
88
71
89
72
- def collectTypes : List [TypeRepr ] =
73
- def loop (currentTpe : TypeRepr , params : List [TypeRepr ]): List [TypeRepr ] =
90
+ def collectTypes : (List [TypeRepr ], TypeRepr ) =
91
+ @ tailrec
92
+ def loop (currentTpe : TypeRepr , argTypesAcc : List [List [TypeRepr ]], resType : TypeRepr ): (List [TypeRepr ], TypeRepr ) =
74
93
currentTpe match
75
- case PolyType (_, _, res) => loop(res, Nil )
76
- case MethodType (_, argTypes, res) => argTypes ++ loop(res, params )
77
- case other => List ( other)
78
- loop(tpe, Nil )
94
+ case PolyType (_, _, res) => loop(res, List .empty[ TypeRepr ] :: argTypesAcc, resType )
95
+ case MethodType (_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType )
96
+ case other => (argTypesAcc.reverse.flatten, other)
97
+ loop(tpe, Nil , TypeRepr .of[ Nothing ] )
79
98
80
99
case class MockableDefinition (idx : Int , symbol : Symbol , ownerTpe : TypeRepr ):
81
100
val mockValName = s " mock $$ ${symbol.name}$$ $idx"
82
101
val tpe = ownerTpe.memberType(symbol)
83
- private val rawTypes = tpe.widen.collectTypes
102
+ private val ( rawTypes, rawResType) = tpe.widen.collectTypes
84
103
val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init
85
104
86
- def resTypeWithPathDependentOverrideFor (classSymbol : Symbol ): TypeRepr =
87
- val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol)
88
- val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false ))
89
- rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated)
105
+ def resTypeWithInnerTypesOverrideFor (classSymbol : Symbol ): TypeRepr =
106
+ updatePathDependent(rawResType, List (rawResType), classSymbol)
107
+
108
+ def tpeWithSubstitutedInnerTypesFor (classSymbol : Symbol ): TypeRepr =
109
+ updatePathDependent(tpe, rawResType :: rawTypes, classSymbol)
90
110
91
- def tpeWithSubstitutedPathDependentFor ( classSymbol : Symbol ): TypeRepr =
92
- val pathDependentTypes = rawTypes .flatMap(_.collectPathDependent (ownerTpe.typeSymbol))
93
- val pdUpdated = pathDependentTypes.map(_.pathDependentOverride (ownerTpe.typeSymbol, classSymbol, applyTypes = false ))
94
- tpe .substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
111
+ private def updatePathDependent ( where : TypeRepr , types : List [ TypeRepr ], classSymbol : Symbol ): TypeRepr =
112
+ val pathDependentTypes = types .flatMap(_.collectInnerTypes (ownerTpe.typeSymbol))
113
+ val pdUpdated = pathDependentTypes.map(_.innerTypeOverride (ownerTpe.typeSymbol, classSymbol, applyTypes = false ))
114
+ where .substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
95
115
96
- def prepareTypesFor (classSymbol : Symbol ) = rawTypes
97
- .map(_.pathDependentOverride (ownerTpe.typeSymbol, classSymbol, applyTypes = true ))
116
+ def prepareTypesFor (classSymbol : Symbol ) = ( rawTypes :+ rawResType)
117
+ .map(_.innerTypeOverride (ownerTpe.typeSymbol, classSymbol, applyTypes = true ))
98
118
.map { typeRepr =>
99
119
val adjusted =
100
120
typeRepr.widen.mapParamRefWithWildcard match
101
121
case TypeBounds (lower, upper) => upper
102
122
case AppliedType (TypeRef (_, " <repeated>" ), elemTyps) =>
103
123
TypeRepr .typeConstructorOf(classOf [Seq [_]]).appliedTo(elemTyps)
104
- case other => other
124
+ case TypeRef (_ : ParamRef , _) =>
125
+ TypeRepr .of[Any ]
126
+ case AppliedType (TypeRef (_ : ParamRef , _), _) =>
127
+ TypeRepr .of[Any ]
128
+ case other =>
129
+ other
105
130
adjusted.asType match
106
131
case ' [t] => TypeTree .of[t]
107
132
}
@@ -128,10 +153,11 @@ private[clazz] class Utils(using val quotes: Quotes):
128
153
129
154
def apply (tpe : TypeRepr ): List [MockableDefinition ] =
130
155
val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr .of[Object ].typeSymbol.methodMembers).toList
131
- .filter(sym => ! sym.flags.is(Flags .Private ) && ! sym.flags.is(Flags .Final ) && ! sym.flags.is(Flags .Mutable ))
132
- .filterNot(sym => tpe.memberType(sym) match
133
- case defaultParam @ ByNameType (AnnotatedType (_, Apply (Select (New (Inferred ()), " <init>" ), Nil ))) => true
134
- case _ => false
156
+ .filter(sym =>
157
+ ! sym.flags.is(Flags .Private ) &&
158
+ ! sym.flags.is(Flags .Final ) &&
159
+ ! sym.flags.is(Flags .Mutable ) &&
160
+ ! sym.name.contains(" $default$" )
135
161
)
136
162
.zipWithIndex
137
163
.map((sym, idx) => MockableDefinition (idx, sym, tpe))
0 commit comments