Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #52 - Support generic functions #71

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Examples/Sources/ViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ protocol ServiceProtocol {
func append(name: (any Codable) -> (any Codable)?)
func get() async throws -> any Codable
func read() -> String!
func wrapDataInArray<T>(_ data: T) -> Array<T>
}

final class ViewModel {
Expand All @@ -43,4 +44,8 @@ final class ViewModel {
_ = try await service.fetchConfig(arg: 2)
config.removeAll()
}

func wrapData<T>(_ data: T) -> Array<T> {
service.wrapDataInArray(data)
}
}
10 changes: 10 additions & 0 deletions Examples/Tests/ViewModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ final class ViewModelTests: XCTestCase {
XCTFail("Unexpected error catched")
}
}

func testWrapData() {
// Important: When using generics, mocked return value types must match the types that are being returned in the use of the spy.
serviceSpy.wrapDataInArrayReturnValue = [123]
XCTAssertEqual(sut.wrapData(1), [123])
XCTAssertEqual(serviceSpy.wrapDataInArrayReceivedData as? Int, 1)

// ⚠️ The following would cause a fatal error, because an Array<String> will be returned by wrapData(), but we provided an Array<Int> to wrapDataInArrayReturnValue. ⚠️
// XCTAssertEqual(sut.wrapData("hi"), ["hello"])
}
}

extension ViewModelTests {
Expand Down
72 changes: 72 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,78 @@ func testFetchConfig() async throws {
}
```

### Generic Functions
Generic functions are supported, but require some care to use, as they get treated a little differently from other functionality.

Given a function:

```swift
func foo<T, U>(_ bar: T) -> U
```

The following will be created in a spy:

```swift
class MyProtocolSpy: MyProtocol {
var fooCallsCount = 0
var fooCalled: Bool {
return fooCallsCount > 0
}
var fooReceivedBar: Any?
var fooReceivedInvocations: [Any] = []
var fooReturnValue: Any!
var fooClosure: ((Any) -> Any)?
func foo<T, U>(_ bar: T) -> U {
fooCallsCount += 1
fooReceivedBar = (bar)
fooReceivedInvocations.append((bar))
if fooClosure != nil {
return fooClosure!(bar) as! U
} else {
return fooReturnValue as! U
}
}
}
```
Uses of `T` and `U` get substituted with `Any` because generics specified only by a function can't be stored as a property in the function's class. Using `Any` lets us store injected closures, invocations, etc.

Force casts get used to turn an injected closure or returnValue property from `Any` into an expected type. This means that *it's essential that expected types match up with values given to these injected properties*.

##### Example:
Given the following code:

```swift
@Spyable
protocol ServiceProtocol {
func wrapDataInArray<T>(_ data: T) -> Array<T>
}

struct ViewModel {
let service: ServiceProtocol

func wrapData<T>(_ data: T) -> Array<T> {
service.wrapDataInArray(data)
}
}
```

A test for ViewModel's `wrapData()` function could look like this:

```swift
func testWrapData() {
// Important: When using generics, mocked return value types must match the types that are being returned in the use of the spy.
serviceSpy.wrapDataInArrayReturnValue = [123]
XCTAssertEqual(sut.wrapData(1), [123])
XCTAssertEqual(serviceSpy.wrapDataInArrayReceivedData as? Int, 1)

// ⚠️ The following would be incorrect, and cause a fatal error, because an Array<String> will be returned by wrapData(), but here we'd be providing an Array<Int> to wrapDataInArrayReturnValue. ⚠️
// XCTAssertEqual(sut.wrapData("hi"), ["hello"])
}
```

> [!TIP]
> If you see a crash at force casting within a spy's generic function implementation, it most likely means that types are mismatched.

## Advanced Usage

### Restricting Spy Availability
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import SwiftSyntax

extension FunctionDeclSyntax {
/// The name of each generic type used. Ex: the set `[T, U]` in `func foo<T, U>()`.
var genericTypes: Set<String> {
Set(genericParameterClause?.parameters.map { $0.name.text } ?? [])
}

/// If the function declaration requires being cast to a type, this will specify that type.
/// Namely, this will apply to situations where generics are used in the function, and properties are consequently stored with generic types replaced with `Any`.
///
/// Ex: `func foo() -> T` will create `var fooReturnValue: Any!`, which will be used in the spy method implementation as `fooReturnValue as! T`
var forceCastType: TypeSyntax? {
guard !genericTypes.isEmpty,
let returnType = signature.returnClause?.type,
returnType.containsGenericType(from: genericTypes) == true else {
return nil
}
return returnType.trimmed
}
}
136 changes: 136 additions & 0 deletions Sources/SpyableMacro/Extensions/TypeSyntax+Extensions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import SwiftSyntax

extension TypeSyntax {

/// Returns `self`, cast to the first supported `TypeSyntaxSupportingGenerics` type that `self` can be cast to, or `nil` if `self` matches none.
private var asTypeSyntaxSupportingGenerics: TypeSyntaxSupportingGenerics? {
for typeSyntax in typeSyntaxesSupportingGenerics {
guard let cast = self.as(typeSyntax.self) else { continue }
return cast
}
return nil
}

/// An array of all of the `TypeSyntax`s that are used to compose this object.
///
/// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`, `(A, B)`, this will return the two type syntaxes, `A` & `B`.
private var nestedTypeSyntaxes: [Self] {
// TODO: An improvement upon this could be to throw an error here, instead of falling back to an empty array. This could be ultimately used to emit a diagnostic about the unsupported TypeSyntax for a better user experience.
asTypeSyntaxSupportingGenerics?.nestedTypeSyntaxes ?? []
}

/// Type erases generic types by substituting their names with `Any`.
///
/// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`,`(A, B)`, it will be turned into `(Any, B)` if `genericTypes` contains `"A"`.
/// - Parameter genericTypes: A list of generic type names to check against.
/// - Returns: This object, but with generic types names replaced with `Any`.
func erasingGenericTypes(_ genericTypes: Set<String>) -> Self {
guard !genericTypes.isEmpty else { return self }

// TODO: An improvement upon this could be to throw an error here, instead of falling back to `self`. This could be ultimately used to emit a diagnostic about the unsupported TypeSyntax for a better user experience.
return TypeSyntax(fromProtocol: asTypeSyntaxSupportingGenerics?.erasingGenericTypes(genericTypes)) ?? self
}

/// Recurses through type syntaxes to find all `IdentifierTypeSyntax` leaves, and checks each of them to see if its name exists in `genericTypes`.
///
/// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`,`(A, B)`, it will return `true` if `genericTypes` contains `"A"`.
/// - Parameter genericTypes: A list of generic type names to check against.
/// - Returns: Whether or not this `TypeSyntax` contains a type matching a name in `genericTypes`.
func containsGenericType(from genericTypes: Set<String>) -> Bool {
guard !genericTypes.isEmpty else { return false }

return if let type = self.as(IdentifierTypeSyntax.self),
genericTypes.contains(type.name.text) {
true
} else {
nestedTypeSyntaxes.contains { $0.containsGenericType(from: genericTypes) }
}
}
}

// MARK: - TypeSyntaxSupportingGenerics

/// Conform type syntaxes to this protocol and add them to `typeSyntaxesSupportingGenerics` to support having their generics scanned or type-erased.
///
/// - Warning: We are warned in the documentation of `TypeSyntaxProtocol`, "Do not conform to this protocol yourself". However, we don't use this protocol for anything other than defining additional behavior on particular conformers to `TypeSyntaxProtocol`; we're not using this to define a new type syntax.
private protocol TypeSyntaxSupportingGenerics: TypeSyntaxProtocol {
/// Type syntaxes that can be found nested within this type.
///
/// Ex: A `TupleTypeSyntax` representing `(A, (B, C))` would have the two nested type syntaxes: `IdentityTypeSyntax`, which would represent `A`, and `TupleTypeSyntax` would represent `(B, C)`, which would in turn have its own `nestedTypeSyntaxes`.
var nestedTypeSyntaxes: [TypeSyntax] { get }

/// Returns `self` with generics replaced with `Any`, when the generic identifiers exist in `genericTypes`.
func erasingGenericTypes(_ genericTypes: Set<String>) -> Self
}

private let typeSyntaxesSupportingGenerics: [TypeSyntaxSupportingGenerics.Type] = [
IdentifierTypeSyntax.self, // Start with IdentifierTypeSyntax for the sake of efficiency when looping through this array, as it's the most common TypeSyntax.
ArrayTypeSyntax.self,
GenericArgumentClauseSyntax.self,
TupleTypeSyntax.self,
]

extension IdentifierTypeSyntax: TypeSyntaxSupportingGenerics {
fileprivate var nestedTypeSyntaxes: [TypeSyntax] {
genericArgumentClause?.nestedTypeSyntaxes ?? []
}
fileprivate func erasingGenericTypes(_ genericTypes: Set<String>) -> Self {
var copy = self
if genericTypes.contains(name.text) {
copy = copy.with(\.name.tokenKind, .identifier("Any"))
}
if let genericArgumentClause {
copy = copy.with(
\.genericArgumentClause,
genericArgumentClause.erasingGenericTypes(genericTypes)
)
}
return copy
}
}

extension ArrayTypeSyntax: TypeSyntaxSupportingGenerics {
fileprivate var nestedTypeSyntaxes: [TypeSyntax] {
[element]
}
fileprivate func erasingGenericTypes(_ genericTypes: Set<String>) -> Self {
with(\.element, element.erasingGenericTypes(genericTypes))
}
}

extension GenericArgumentClauseSyntax: TypeSyntaxSupportingGenerics {
fileprivate var nestedTypeSyntaxes: [TypeSyntax] {
arguments.map { $0.argument }
}
fileprivate func erasingGenericTypes(_ genericTypes: Set<String>) -> Self {
with(
\.arguments,
GenericArgumentListSyntax {
for argumentElement in arguments {
argumentElement.with(
\.argument,
argumentElement.argument.erasingGenericTypes(genericTypes)
)
}
}
)
}
}

extension TupleTypeSyntax: TypeSyntaxSupportingGenerics {
fileprivate var nestedTypeSyntaxes: [TypeSyntax] {
elements.map { $0.type }
}
fileprivate func erasingGenericTypes(_ genericTypes: Set<String>) -> Self {
with(
\.elements,
TupleTypeElementListSyntax {
for element in elements {
element.with(
\.type,
element.type.erasingGenericTypes(genericTypes))
}
}
)
}
}
90 changes: 54 additions & 36 deletions Sources/SpyableMacro/Factories/ClosureFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,20 @@ import SwiftSyntaxBuilder
struct ClosureFactory {
func variableDeclaration(
variablePrefix: String,
functionSignature: FunctionSignatureSyntax
protocolFunctionDeclaration: FunctionDeclSyntax
) throws -> VariableDeclSyntax {
let returnClause: ReturnClauseSyntax
if let functionReturnClause = functionSignature.returnClause {
/*
func f() -> String!
*/
if let implicitlyUnwrappedType = functionReturnClause.type.as(
ImplicitlyUnwrappedOptionalTypeSyntax.self)
{
var functionReturnClause = functionReturnClause
/*
`() -> String!` is not a valid code
so we have to convert it to `() -> String?
*/
functionReturnClause.type = TypeSyntax(
OptionalTypeSyntax(wrappedType: implicitlyUnwrappedType.wrappedType))
returnClause = functionReturnClause
/*
func f() -> Any
func f() -> Any?
*/
} else {
returnClause = functionReturnClause
}
/*
func f()
*/
} else {
returnClause = ReturnClauseSyntax(
type: IdentifierTypeSyntax(
name: .identifier("Void")
)
)
}
let functionSignature = protocolFunctionDeclaration.signature
let genericTypes = protocolFunctionDeclaration.genericTypes
let returnClause = returnClause(protocolFunctionDeclaration: protocolFunctionDeclaration)

let elements = TupleTypeElementListSyntax {
TupleTypeElementSyntax(
type: FunctionTypeSyntax(
parameters: TupleTypeElementListSyntax {
for parameter in functionSignature.parameterClause.parameters {
TupleTypeElementSyntax(type: parameter.type)
TupleTypeElementSyntax(
type: parameter.type.erasingGenericTypes(genericTypes)
)
}
},
effectSpecifiers: TypeEffectSpecifiersSyntax(
Expand All @@ -90,10 +62,48 @@ struct ClosureFactory {
)
}

private func returnClause(
protocolFunctionDeclaration: FunctionDeclSyntax
) -> ReturnClauseSyntax {
let functionSignature = protocolFunctionDeclaration.signature
let genericTypes = protocolFunctionDeclaration.genericTypes

if let functionReturnClause = functionSignature.returnClause {
/*
func f() -> String!
*/
if let implicitlyUnwrappedType = functionReturnClause.type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
var functionReturnClause = functionReturnClause
/*
`() -> String!` is not a valid code
so we have to convert it to `() -> String?
*/
functionReturnClause.type = TypeSyntax(OptionalTypeSyntax(wrappedType: implicitlyUnwrappedType.wrappedType))
return functionReturnClause
/*
func f() -> Any
func f() -> Any?
*/
} else {
return functionReturnClause.with(\.type, functionReturnClause.type.erasingGenericTypes(genericTypes))
}
/*
func f()
*/
} else {
return ReturnClauseSyntax(
type: IdentifierTypeSyntax(
name: .identifier("Void")
)
)
}
}

func callExpression(
variablePrefix: String,
functionSignature: FunctionSignatureSyntax
protocolFunctionDeclaration: FunctionDeclSyntax
) -> ExprSyntaxProtocol {
let functionSignature = protocolFunctionDeclaration.signature
let calledExpression: ExprSyntaxProtocol

if functionSignature.returnClause == nil {
Expand Down Expand Up @@ -144,6 +154,14 @@ struct ClosureFactory {
expression = TryExprSyntax(expression: expression)
}

if let forceCastType = protocolFunctionDeclaration.forceCastType {
expression = AsExprSyntax(
expression: expression,
questionOrExclamationMark: .exclamationMarkToken(trailingTrivia: .space),
type: forceCastType
)
}

return expression
}

Expand Down
Loading