diff --git a/Examples/Sources/ViewModel.swift b/Examples/Sources/ViewModel.swift index 0797ebf..53b04bb 100644 --- a/Examples/Sources/ViewModel.swift +++ b/Examples/Sources/ViewModel.swift @@ -17,6 +17,7 @@ protocol ServiceProtocol { func append(name: (any Codable) -> (any Codable)?) func get() async throws -> any Codable func read() -> String! + func wrapDataInArray(_ data: T) -> Array } final class ViewModel { @@ -43,4 +44,8 @@ final class ViewModel { _ = try await service.fetchConfig(arg: 2) config.removeAll() } + + func wrapData(_ data: T) -> Array { + service.wrapDataInArray(data) + } } diff --git a/Examples/Tests/ViewModelTests.swift b/Examples/Tests/ViewModelTests.swift index 6034a6e..b63a385 100644 --- a/Examples/Tests/ViewModelTests.swift +++ b/Examples/Tests/ViewModelTests.swift @@ -55,6 +55,16 @@ final class ViewModelTests: XCTestCase { XCTFail("Unexpected error catched") } } + + func testWrapData() { + // Important: When using generics, mocked return value types must match the types that are being returned in the use of the spy. + serviceSpy.wrapDataInArrayReturnValue = [123] + XCTAssertEqual(sut.wrapData(1), [123]) + XCTAssertEqual(serviceSpy.wrapDataInArrayReceivedData as? Int, 1) + + // ⚠️ The following would cause a fatal error, because an Array will be returned by wrapData(), but we provided an Array to wrapDataInArrayReturnValue. ⚠️ + // XCTAssertEqual(sut.wrapData("hi"), ["hello"]) + } } extension ViewModelTests { diff --git a/README.md b/README.md index 2068f3f..62076c6 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,78 @@ func testFetchConfig() async throws { } ``` +### Generic Functions +Generic functions are supported, but require some care to use, as they get treated a little differently from other functionality. + +Given a function: + +```swift +func foo(_ bar: T) -> U +``` + +The following will be created in a spy: + +```swift +class MyProtocolSpy: MyProtocol { + var fooCallsCount = 0 + var fooCalled: Bool { + return fooCallsCount > 0 + } + var fooReceivedBar: Any? + var fooReceivedInvocations: [Any] = [] + var fooReturnValue: Any! + var fooClosure: ((Any) -> Any)? + func foo(_ bar: T) -> U { + fooCallsCount += 1 + fooReceivedBar = (bar) + fooReceivedInvocations.append((bar)) + if fooClosure != nil { + return fooClosure!(bar) as! U + } else { + return fooReturnValue as! U + } + } +} +``` +Uses of `T` and `U` get substituted with `Any` because generics specified only by a function can't be stored as a property in the function's class. Using `Any` lets us store injected closures, invocations, etc. + +Force casts get used to turn an injected closure or returnValue property from `Any` into an expected type. This means that *it's essential that expected types match up with values given to these injected properties*. + +##### Example: +Given the following code: + +```swift +@Spyable +protocol ServiceProtocol { + func wrapDataInArray(_ data: T) -> Array +} + +struct ViewModel { + let service: ServiceProtocol + + func wrapData(_ data: T) -> Array { + service.wrapDataInArray(data) + } +} +``` + +A test for ViewModel's `wrapData()` function could look like this: + +```swift +func testWrapData() { + // Important: When using generics, mocked return value types must match the types that are being returned in the use of the spy. + serviceSpy.wrapDataInArrayReturnValue = [123] + XCTAssertEqual(sut.wrapData(1), [123]) + XCTAssertEqual(serviceSpy.wrapDataInArrayReceivedData as? Int, 1) + + // ⚠️ The following would be incorrect, and cause a fatal error, because an Array will be returned by wrapData(), but here we'd be providing an Array to wrapDataInArrayReturnValue. ⚠️ + // XCTAssertEqual(sut.wrapData("hi"), ["hello"]) +} +``` + +> [!TIP] +> If you see a crash at force casting within a spy's generic function implementation, it most likely means that types are mismatched. + ## Advanced Usage ### Restricting Spy Availability diff --git a/Sources/SpyableMacro/Extensions/FunctionDeclSyntax+Extensions.swift b/Sources/SpyableMacro/Extensions/FunctionDeclSyntax+Extensions.swift new file mode 100644 index 0000000..c6c968b --- /dev/null +++ b/Sources/SpyableMacro/Extensions/FunctionDeclSyntax+Extensions.swift @@ -0,0 +1,21 @@ +import SwiftSyntax + +extension FunctionDeclSyntax { + /// The name of each generic type used. Ex: the set `[T, U]` in `func foo()`. + var genericTypes: Set { + Set(genericParameterClause?.parameters.map { $0.name.text } ?? []) + } + + /// If the function declaration requires being cast to a type, this will specify that type. + /// Namely, this will apply to situations where generics are used in the function, and properties are consequently stored with generic types replaced with `Any`. + /// + /// Ex: `func foo() -> T` will create `var fooReturnValue: Any!`, which will be used in the spy method implementation as `fooReturnValue as! T` + var forceCastType: TypeSyntax? { + guard !genericTypes.isEmpty, + let returnType = signature.returnClause?.type, + returnType.containsGenericType(from: genericTypes) == true else { + return nil + } + return returnType.trimmed + } +} diff --git a/Sources/SpyableMacro/Extensions/TypeSyntax+Extensions.swift b/Sources/SpyableMacro/Extensions/TypeSyntax+Extensions.swift new file mode 100644 index 0000000..637a209 --- /dev/null +++ b/Sources/SpyableMacro/Extensions/TypeSyntax+Extensions.swift @@ -0,0 +1,136 @@ +import SwiftSyntax + +extension TypeSyntax { + + /// Returns `self`, cast to the first supported `TypeSyntaxSupportingGenerics` type that `self` can be cast to, or `nil` if `self` matches none. + private var asTypeSyntaxSupportingGenerics: TypeSyntaxSupportingGenerics? { + for typeSyntax in typeSyntaxesSupportingGenerics { + guard let cast = self.as(typeSyntax.self) else { continue } + return cast + } + return nil + } + + /// An array of all of the `TypeSyntax`s that are used to compose this object. + /// + /// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`, `(A, B)`, this will return the two type syntaxes, `A` & `B`. + private var nestedTypeSyntaxes: [Self] { + // TODO: An improvement upon this could be to throw an error here, instead of falling back to an empty array. This could be ultimately used to emit a diagnostic about the unsupported TypeSyntax for a better user experience. + asTypeSyntaxSupportingGenerics?.nestedTypeSyntaxes ?? [] + } + + /// Type erases generic types by substituting their names with `Any`. + /// + /// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`,`(A, B)`, it will be turned into `(Any, B)` if `genericTypes` contains `"A"`. + /// - Parameter genericTypes: A list of generic type names to check against. + /// - Returns: This object, but with generic types names replaced with `Any`. + func erasingGenericTypes(_ genericTypes: Set) -> Self { + guard !genericTypes.isEmpty else { return self } + + // TODO: An improvement upon this could be to throw an error here, instead of falling back to `self`. This could be ultimately used to emit a diagnostic about the unsupported TypeSyntax for a better user experience. + return TypeSyntax(fromProtocol: asTypeSyntaxSupportingGenerics?.erasingGenericTypes(genericTypes)) ?? self + } + + /// Recurses through type syntaxes to find all `IdentifierTypeSyntax` leaves, and checks each of them to see if its name exists in `genericTypes`. + /// + /// Ex: If this `TypeSyntax` represents a `TupleTypeSyntax`,`(A, B)`, it will return `true` if `genericTypes` contains `"A"`. + /// - Parameter genericTypes: A list of generic type names to check against. + /// - Returns: Whether or not this `TypeSyntax` contains a type matching a name in `genericTypes`. + func containsGenericType(from genericTypes: Set) -> Bool { + guard !genericTypes.isEmpty else { return false } + + return if let type = self.as(IdentifierTypeSyntax.self), + genericTypes.contains(type.name.text) { + true + } else { + nestedTypeSyntaxes.contains { $0.containsGenericType(from: genericTypes) } + } + } +} + +// MARK: - TypeSyntaxSupportingGenerics + +/// Conform type syntaxes to this protocol and add them to `typeSyntaxesSupportingGenerics` to support having their generics scanned or type-erased. +/// +/// - Warning: We are warned in the documentation of `TypeSyntaxProtocol`, "Do not conform to this protocol yourself". However, we don't use this protocol for anything other than defining additional behavior on particular conformers to `TypeSyntaxProtocol`; we're not using this to define a new type syntax. +private protocol TypeSyntaxSupportingGenerics: TypeSyntaxProtocol { + /// Type syntaxes that can be found nested within this type. + /// + /// Ex: A `TupleTypeSyntax` representing `(A, (B, C))` would have the two nested type syntaxes: `IdentityTypeSyntax`, which would represent `A`, and `TupleTypeSyntax` would represent `(B, C)`, which would in turn have its own `nestedTypeSyntaxes`. + var nestedTypeSyntaxes: [TypeSyntax] { get } + + /// Returns `self` with generics replaced with `Any`, when the generic identifiers exist in `genericTypes`. + func erasingGenericTypes(_ genericTypes: Set) -> Self +} + +private let typeSyntaxesSupportingGenerics: [TypeSyntaxSupportingGenerics.Type] = [ + IdentifierTypeSyntax.self, // Start with IdentifierTypeSyntax for the sake of efficiency when looping through this array, as it's the most common TypeSyntax. + ArrayTypeSyntax.self, + GenericArgumentClauseSyntax.self, + TupleTypeSyntax.self, +] + +extension IdentifierTypeSyntax: TypeSyntaxSupportingGenerics { + fileprivate var nestedTypeSyntaxes: [TypeSyntax] { + genericArgumentClause?.nestedTypeSyntaxes ?? [] + } + fileprivate func erasingGenericTypes(_ genericTypes: Set) -> Self { + var copy = self + if genericTypes.contains(name.text) { + copy = copy.with(\.name.tokenKind, .identifier("Any")) + } + if let genericArgumentClause { + copy = copy.with( + \.genericArgumentClause, + genericArgumentClause.erasingGenericTypes(genericTypes) + ) + } + return copy + } +} + +extension ArrayTypeSyntax: TypeSyntaxSupportingGenerics { + fileprivate var nestedTypeSyntaxes: [TypeSyntax] { + [element] + } + fileprivate func erasingGenericTypes(_ genericTypes: Set) -> Self { + with(\.element, element.erasingGenericTypes(genericTypes)) + } +} + +extension GenericArgumentClauseSyntax: TypeSyntaxSupportingGenerics { + fileprivate var nestedTypeSyntaxes: [TypeSyntax] { + arguments.map { $0.argument } + } + fileprivate func erasingGenericTypes(_ genericTypes: Set) -> Self { + with( + \.arguments, + GenericArgumentListSyntax { + for argumentElement in arguments { + argumentElement.with( + \.argument, + argumentElement.argument.erasingGenericTypes(genericTypes) + ) + } + } + ) + } +} + +extension TupleTypeSyntax: TypeSyntaxSupportingGenerics { + fileprivate var nestedTypeSyntaxes: [TypeSyntax] { + elements.map { $0.type } + } + fileprivate func erasingGenericTypes(_ genericTypes: Set) -> Self { + with( + \.elements, + TupleTypeElementListSyntax { + for element in elements { + element.with( + \.type, + element.type.erasingGenericTypes(genericTypes)) + } + } + ) + } +} diff --git a/Sources/SpyableMacro/Factories/ClosureFactory.swift b/Sources/SpyableMacro/Factories/ClosureFactory.swift index 491f178..4fdfb09 100644 --- a/Sources/SpyableMacro/Factories/ClosureFactory.swift +++ b/Sources/SpyableMacro/Factories/ClosureFactory.swift @@ -30,48 +30,20 @@ import SwiftSyntaxBuilder struct ClosureFactory { func variableDeclaration( variablePrefix: String, - functionSignature: FunctionSignatureSyntax + protocolFunctionDeclaration: FunctionDeclSyntax ) throws -> VariableDeclSyntax { - let returnClause: ReturnClauseSyntax - if let functionReturnClause = functionSignature.returnClause { - /* - func f() -> String! - */ - if let implicitlyUnwrappedType = functionReturnClause.type.as( - ImplicitlyUnwrappedOptionalTypeSyntax.self) - { - var functionReturnClause = functionReturnClause - /* - `() -> String!` is not a valid code - so we have to convert it to `() -> String? - */ - functionReturnClause.type = TypeSyntax( - OptionalTypeSyntax(wrappedType: implicitlyUnwrappedType.wrappedType)) - returnClause = functionReturnClause - /* - func f() -> Any - func f() -> Any? - */ - } else { - returnClause = functionReturnClause - } - /* - func f() - */ - } else { - returnClause = ReturnClauseSyntax( - type: IdentifierTypeSyntax( - name: .identifier("Void") - ) - ) - } + let functionSignature = protocolFunctionDeclaration.signature + let genericTypes = protocolFunctionDeclaration.genericTypes + let returnClause = returnClause(protocolFunctionDeclaration: protocolFunctionDeclaration) let elements = TupleTypeElementListSyntax { TupleTypeElementSyntax( type: FunctionTypeSyntax( parameters: TupleTypeElementListSyntax { for parameter in functionSignature.parameterClause.parameters { - TupleTypeElementSyntax(type: parameter.type) + TupleTypeElementSyntax( + type: parameter.type.erasingGenericTypes(genericTypes) + ) } }, effectSpecifiers: TypeEffectSpecifiersSyntax( @@ -90,10 +62,48 @@ struct ClosureFactory { ) } + private func returnClause( + protocolFunctionDeclaration: FunctionDeclSyntax + ) -> ReturnClauseSyntax { + let functionSignature = protocolFunctionDeclaration.signature + let genericTypes = protocolFunctionDeclaration.genericTypes + + if let functionReturnClause = functionSignature.returnClause { + /* + func f() -> String! + */ + if let implicitlyUnwrappedType = functionReturnClause.type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) { + var functionReturnClause = functionReturnClause + /* + `() -> String!` is not a valid code + so we have to convert it to `() -> String? + */ + functionReturnClause.type = TypeSyntax(OptionalTypeSyntax(wrappedType: implicitlyUnwrappedType.wrappedType)) + return functionReturnClause + /* + func f() -> Any + func f() -> Any? + */ + } else { + return functionReturnClause.with(\.type, functionReturnClause.type.erasingGenericTypes(genericTypes)) + } + /* + func f() + */ + } else { + return ReturnClauseSyntax( + type: IdentifierTypeSyntax( + name: .identifier("Void") + ) + ) + } + } + func callExpression( variablePrefix: String, - functionSignature: FunctionSignatureSyntax + protocolFunctionDeclaration: FunctionDeclSyntax ) -> ExprSyntaxProtocol { + let functionSignature = protocolFunctionDeclaration.signature let calledExpression: ExprSyntaxProtocol if functionSignature.returnClause == nil { @@ -144,6 +154,14 @@ struct ClosureFactory { expression = TryExprSyntax(expression: expression) } + if let forceCastType = protocolFunctionDeclaration.forceCastType { + expression = AsExprSyntax( + expression: expression, + questionOrExclamationMark: .exclamationMarkToken(trailingTrivia: .space), + type: forceCastType + ) + } + return expression } diff --git a/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift b/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift index 006748c..dbd7ff4 100644 --- a/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift +++ b/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift @@ -94,7 +94,7 @@ struct FunctionImplementationFactory { if protocolFunctionDeclaration.signature.returnClause == nil { closureFactory.callExpression( variablePrefix: variablePrefix, - functionSignature: protocolFunctionDeclaration.signature + protocolFunctionDeclaration: protocolFunctionDeclaration ) } else { returnExpression( @@ -130,14 +130,17 @@ struct FunctionImplementationFactory { elseKeyword: .keyword(.else), elseBody: .codeBlock( CodeBlockSyntax { - returnValueFactory.returnStatement(variablePrefix: variablePrefix) + returnValueFactory.returnStatement( + variablePrefix: variablePrefix, + forceCastType: protocolFunctionDeclaration.forceCastType + ) } ), bodyBuilder: { ReturnStmtSyntax( expression: closureFactory.callExpression( variablePrefix: variablePrefix, - functionSignature: protocolFunctionDeclaration.signature + protocolFunctionDeclaration: protocolFunctionDeclaration ) ) } diff --git a/Sources/SpyableMacro/Factories/ReturnValueFactory.swift b/Sources/SpyableMacro/Factories/ReturnValueFactory.swift index 4f7222d..0b2159e 100644 --- a/Sources/SpyableMacro/Factories/ReturnValueFactory.swift +++ b/Sources/SpyableMacro/Factories/ReturnValueFactory.swift @@ -82,12 +82,21 @@ struct ReturnValueFactory { ) } - func returnStatement(variablePrefix: String) -> StmtSyntax { - StmtSyntax( - """ - return \(variableIdentifier(variablePrefix: variablePrefix)) - """ + func returnStatement( + variablePrefix: String, + forceCastType: TypeSyntax? = nil + ) -> StmtSyntaxProtocol { + var expression: ExprSyntaxProtocol = DeclReferenceExprSyntax( + baseName: variableIdentifier(variablePrefix: variablePrefix) ) + if let forceCastType { + expression = AsExprSyntax( + expression: expression, + questionOrExclamationMark: .exclamationMarkToken(trailingTrivia: .space), + type: forceCastType + ) + } + return ReturnStmtSyntax(expression: expression) } private func variableIdentifier(variablePrefix: String) -> TokenSyntax { diff --git a/Sources/SpyableMacro/Factories/SpyFactory.swift b/Sources/SpyableMacro/Factories/SpyFactory.swift index b73421b..a1bbb4b 100644 --- a/Sources/SpyableMacro/Factories/SpyFactory.swift +++ b/Sources/SpyableMacro/Factories/SpyFactory.swift @@ -124,7 +124,8 @@ struct SpyFactory { for functionDeclaration in functionDeclarations { let variablePrefix = variablePrefixFactory.text(for: functionDeclaration) - let parameterList = functionDeclaration.signature.parameterClause.parameters + let genericTypes = functionDeclaration.genericTypes + let parameterList = parameterList(protocolFunctionDeclaration: functionDeclaration, genericTypes: genericTypes) try callsCountFactory.variableDeclaration(variablePrefix: variablePrefix) try calledFactory.variableDeclaration(variablePrefix: variablePrefix) @@ -145,15 +146,16 @@ struct SpyFactory { } if let returnType = functionDeclaration.signature.returnClause?.type { + let genericTypeErasedReturnType = returnType.erasingGenericTypes(genericTypes) try returnValueFactory.variableDeclaration( variablePrefix: variablePrefix, - functionReturnType: returnType + functionReturnType: genericTypeErasedReturnType ) } try closureFactory.variableDeclaration( variablePrefix: variablePrefix, - functionSignature: functionDeclaration.signature + protocolFunctionDeclaration: functionDeclaration ) functionImplementationFactory.declaration( @@ -166,6 +168,22 @@ struct SpyFactory { } } +private func parameterList( + protocolFunctionDeclaration: FunctionDeclSyntax, + genericTypes: Set +) -> FunctionParameterListSyntax { + let functionSignatureParameters = protocolFunctionDeclaration.signature.parameterClause.parameters + return if genericTypes.isEmpty { + functionSignatureParameters + } else { + FunctionParameterListSyntax { + for parameter in functionSignatureParameters { + parameter.with(\.type, parameter.type.erasingGenericTypes(genericTypes)) + } + } + } +} + extension SyntaxProtocol { /// - Returns: `self` with leading space `Trivia` removed. fileprivate var removingLeadingSpaces: Self { diff --git a/Tests/SpyableMacroTests/Extensions/UT_FunctionDeclSyntax+Extensions.swift b/Tests/SpyableMacroTests/Extensions/UT_FunctionDeclSyntax+Extensions.swift new file mode 100644 index 0000000..422b155 --- /dev/null +++ b/Tests/SpyableMacroTests/Extensions/UT_FunctionDeclSyntax+Extensions.swift @@ -0,0 +1,54 @@ +import SwiftSyntax +import XCTest + +@testable import SpyableMacro + +final class UT_FunctionDeclSyntaxExtensions: XCTestCase { + + // MARK: - genericTypes + + func testGenericTypes_WithGenerics() throws { + let protocolFunctionDeclaration = try FunctionDeclSyntax( + """ + func foo() -> T + """ + ) {} + + XCTAssertEqual(protocolFunctionDeclaration.genericTypes, ["T", "U"]) + } + + func testGenericTypes_WithoutGenerics() throws { + let protocolFunctionDeclaration = try FunctionDeclSyntax( + """ + func foo() -> T + """ + ) {} + + XCTAssertTrue(protocolFunctionDeclaration.genericTypes.isEmpty) + } + + // MARK: - forceCastType + + func testForceCastType_WithGeneric() throws { + let protocolFunctionDeclaration = try FunctionDeclSyntax( + """ + func foo() -> T + """ + ) {} + + XCTAssertEqual( + try XCTUnwrap(protocolFunctionDeclaration.forceCastType).description, + TypeSyntax(stringLiteral: "T").description + ) + } + + func testForceCastType_WithoutGeneric() throws { + let protocolFunctionDeclaration = try FunctionDeclSyntax( + """ + func foo() -> T + """ + ) {} + + XCTAssertNil(protocolFunctionDeclaration.forceCastType) + } +} diff --git a/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ContainsGenericType.swift b/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ContainsGenericType.swift new file mode 100644 index 0000000..909406c --- /dev/null +++ b/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ContainsGenericType.swift @@ -0,0 +1,107 @@ +import SwiftSyntax +import XCTest + +@testable import SpyableMacro + +final class UT_TypeSyntax_ContainsGenericType: XCTestCase { + func testContainsGenericType_WithTypeSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax(stringLiteral: identifier) + .containsGenericType(from: genericTypes) + } + + XCTAssertTrue(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } + + func testContainsGenericType_WithIdentifierTypeSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax( + IdentifierTypeSyntax( + name: .identifier(identifier) + ) + ) + .containsGenericType(from: genericTypes) + } + + XCTAssertTrue(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } + + func testContainsGenericType_WithArrayTypeSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax( + ArrayTypeSyntax( + element: TypeSyntax(stringLiteral: identifier) + ) + ) + .containsGenericType(from: genericTypes) + } + + XCTAssertTrue(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } + + func testContainsGenericType_WithGenericArgumentClauseSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax( + IdentifierTypeSyntax( + name: .identifier("Array"), + genericArgumentClause: GenericArgumentClauseSyntax { + GenericArgumentSyntax(argument: TypeSyntax(stringLiteral: identifier)) + } + ) + ) + .containsGenericType(from: genericTypes) + } + + XCTAssertTrue(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } + + func testContainsGenericType_WithTupleTypeSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax( + TupleTypeSyntax(elements: TupleTypeElementListSyntax { + TupleTypeElementSyntax(type: IdentifierTypeSyntax( + name: .identifier(identifier) + )) + }) + ) + .containsGenericType(from: genericTypes) + } + + XCTAssertTrue(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } + + func testContainsGenericType_WithUnsupportedTypeSyntax() { + func typeSyntax( + with identifier: String, + containsGenericType genericTypes: Set + ) -> Bool { + TypeSyntax( + MissingTypeSyntax(placeholder: .identifier(identifier)) + ) + .containsGenericType(from: genericTypes) + } + + XCTAssertFalse(typeSyntax(with: "T", containsGenericType: ["T"])) + XCTAssertFalse(typeSyntax(with: "String", containsGenericType: ["T"])) + } +} diff --git a/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ErasingGenericType.swift b/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ErasingGenericType.swift new file mode 100644 index 0000000..5890882 --- /dev/null +++ b/Tests/SpyableMacroTests/Extensions/UT_TypeSyntax+ErasingGenericType.swift @@ -0,0 +1,109 @@ +import SwiftSyntax +import XCTest + +@testable import SpyableMacro + +final class UT_TypeSyntax_ErasingGenericTypes: XCTestCase { + func testErasingGenericTypes_WithTypeSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax(stringLiteral: identifier) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: " T "), " Any ") + XCTAssertEqual(typeSyntaxDescription(with: " String "), " String ") + } + + func testErasingGenericTypes_WithIdentifierTypeSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax( + IdentifierTypeSyntax( + leadingTrivia: .space, + name: .identifier(identifier), + trailingTrivia: .space + ) + ) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: "T"), " Any ") + XCTAssertEqual(typeSyntaxDescription(with: "String"), " String ") + } + + func testErasingGenericTypes_WithArrayTypeSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax( + ArrayTypeSyntax( + leadingTrivia: .space, + element: TypeSyntax(stringLiteral: identifier), + trailingTrivia: .space + ) + ) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: "T"), " [Any] ") + XCTAssertEqual(typeSyntaxDescription(with: "String"), " [String] ") + } + + func testErasingGenericTypes_WithGenericArgumentClauseSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax( + IdentifierTypeSyntax( + leadingTrivia: .space, + name: .identifier("Array"), + genericArgumentClause: GenericArgumentClauseSyntax { + GenericArgumentSyntax(argument: TypeSyntax(stringLiteral: identifier)) + }, + trailingTrivia: .space + ) + ) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: "T"), " Array ") + XCTAssertEqual(typeSyntaxDescription(with: "String"), " Array ") + } + + func testErasingGenericTypes_WithTupleTypeSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax( + TupleTypeSyntax( + leadingTrivia: .space, + elements: TupleTypeElementListSyntax { + TupleTypeElementSyntax(type: IdentifierTypeSyntax( + name: .identifier(identifier) + )) + TupleTypeElementSyntax(type: IdentifierTypeSyntax( + leadingTrivia: .space, + name: .identifier("Unerased") + )) + }, + trailingTrivia: .space + ) + ) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: "T"), " (Any, Unerased) ") + XCTAssertEqual(typeSyntaxDescription(with: "String"), " (String, Unerased) ") + } + + func testErasingGenericTypes_WithUnsupportedTypeSyntax() { + func typeSyntaxDescription(with identifier: String) -> String { + TypeSyntax( + MissingTypeSyntax(placeholder: .identifier(identifier)) + ) + .erasingGenericTypes(["T"]) + .description + } + + XCTAssertEqual(typeSyntaxDescription(with: "T"), "T") + XCTAssertEqual(typeSyntaxDescription(with: "String"), "String") + } +} diff --git a/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift index 29633c6..561f304 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift @@ -55,6 +55,14 @@ final class UT_ClosureFactory: XCTestCase { ) } + func testVariableDeclarationWithGenericParameter() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func _ignore_(value: T)", + prefixForVariable: "_prefix_", + expectingVariableDeclaration: "var _prefix_Closure: ((Any) -> Void)?" + ) + } + func testVariableDeclarationOptionalTypeReturnValue() throws { try assertProtocolFunction( withFunctionDeclaration: "func _ignore_() -> Data?", @@ -74,16 +82,11 @@ final class UT_ClosureFactory: XCTestCase { func testVariableDeclarationEverything() throws { try assertProtocolFunction( withFunctionDeclaration: """ - func _ignore_( - text: inout String, - product: (UInt?, name: String), - added: (() -> Void)?, - removed: @autoclosure @escaping () -> Bool - ) async throws -> (text: String, output: (() -> Void)?) + func _ignore_(text: inout String, value: T, product: (UInt?, name: String), added: (() -> Void)?, removed: @autoclosure @escaping () -> Bool) async throws -> String? """, prefixForVariable: "_prefix_", expectingVariableDeclaration: """ - var _prefix_Closure: ((inout String, (UInt?, name: String), (() -> Void)?, @autoclosure @escaping () -> Bool) async throws -> (text: String, output: (() -> Void)?) )? + var _prefix_Closure: ((inout String, Any, (UInt?, name: String), (() -> Void)?, @autoclosure @escaping () -> Bool) async throws -> String? )? """ ) } @@ -130,13 +133,21 @@ final class UT_ClosureFactory: XCTestCase { ) } + func testCallExpressionWithGenericParameter() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func _ignore_(value: T)", + prefixForVariable: "_prefix_", + expectingCallExpression: "_prefix_Closure?(value)" + ) + } + func testCallExpressionEverything() throws { try assertProtocolFunction( withFunctionDeclaration: """ - func _ignore_(text: inout String, product: (UInt?, name: String), added: (() -> Void)?, removed: @autoclosure @escaping () -> Bool) async throws -> String? + func _ignore_(value: inout T, product: (UInt?, name: String), added: (() -> Void)?, removed: @autoclosure @escaping () -> Bool) async throws -> String? """, prefixForVariable: "_prefix_", - expectingCallExpression: "try await _prefix_Closure!(&text, product, added, removed())" + expectingCallExpression: "try await _prefix_Closure!(&value, product, added, removed())" ) } @@ -153,7 +164,7 @@ final class UT_ClosureFactory: XCTestCase { let result = try ClosureFactory().variableDeclaration( variablePrefix: variablePrefix, - functionSignature: protocolFunctionDeclaration.signature + protocolFunctionDeclaration: protocolFunctionDeclaration ) assertBuildResult(result, expectedDeclaration, file: file, line: line) @@ -170,7 +181,7 @@ final class UT_ClosureFactory: XCTestCase { let result = ClosureFactory().callExpression( variablePrefix: variablePrefix, - functionSignature: protocolFunctionDeclaration.signature + protocolFunctionDeclaration: protocolFunctionDeclaration ) assertBuildResult(result, expectedExpression, file: file, line: line) diff --git a/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift b/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift index eafdefb..a7f92fd 100644 --- a/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift @@ -52,6 +52,25 @@ final class UT_FunctionImplementationFactory: XCTestCase { ) } + func testDeclarationGenerics() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func foo(value: T) -> U", + prefixForVariable: "_prefix_", + expectingFunctionDeclaration: """ + func foo(value: T) -> U { + _prefix_CallsCount += 1 + _prefix_ReceivedValue = (value) + _prefix_ReceivedInvocations.append((value)) + if _prefix_Closure != nil { + return _prefix_Closure!(value) as! U + } else { + return _prefix_ReturnValue as! U + } + } + """ + ) + } + func testDeclarationReturnValueAsyncThrows() throws { try assertProtocolFunction( withFunctionDeclaration: """ diff --git a/Tests/SpyableMacroTests/Factories/UT_ReceivedArgumentsFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ReceivedArgumentsFactory.swift index 949e530..2537366 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ReceivedArgumentsFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ReceivedArgumentsFactory.swift @@ -31,6 +31,14 @@ final class UT_ReceivedArgumentsFactory: XCTestCase { ) } + func testVariableDeclarationSingleGenericArgument() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func foo(bar: T)", + prefixForVariable: "_prefix_", + expectingVariableDeclaration: "var _prefix_ReceivedBar: T?" + ) + } + func testVariableDeclarationSingleArgumentDoubleParameterName() throws { try assertProtocolFunction( withFunctionDeclaration: "func foo(firstName secondName: (String, Int))", diff --git a/Tests/SpyableMacroTests/Factories/UT_ReceivedInvocationsFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ReceivedInvocationsFactory.swift index a4bcb71..ca57026 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ReceivedInvocationsFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ReceivedInvocationsFactory.swift @@ -33,6 +33,14 @@ final class UT_ReceivedInvocationsFactory: XCTestCase { ) } + func testVariableDeclarationSingleGenericArgument() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func foo(bar: T)", + prefixForVariable: "_prefix_", + expectingVariableDeclaration: "var _prefix_ReceivedInvocations: [T] = []" + ) + } + func testVariableDeclarationSingleClosureArgument() throws { try assertProtocolFunction( withFunctionDeclaration: "func foo(completion: () -> Void)", @@ -71,6 +79,17 @@ final class UT_ReceivedInvocationsFactory: XCTestCase { ) } + func testVariableDeclarationMultiArgumentsWithSomeGenericArgument() throws { + try assertProtocolFunction( + withFunctionDeclaration: + "func foo(text: String, value: T, _ count: (x: Int, UInt?)?, final price: Decimal?)", + prefixForVariable: "_prefix_", + expectingVariableDeclaration: """ + var _prefix_ReceivedInvocations: [(text: String, value: T, count: (x: Int, UInt?)?, price: Decimal?)] = [] + """ + ) + } + func testVariableDeclarationMultiArgumentsWithSomeClosureArgument() throws { try assertProtocolFunction( withFunctionDeclaration: @@ -111,6 +130,14 @@ final class UT_ReceivedInvocationsFactory: XCTestCase { ) } + func testAppendValueToVariableExpressionSingleArgumentGenericType() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func foo(bar: T)", + prefixForVariable: "_prefix_", + expectingExpression: "_prefix_ReceivedInvocations.append((bar))" + ) + } + func testAppendValueToVariableExpressionMultiArguments() throws { try assertProtocolFunction( withFunctionDeclaration: diff --git a/Tests/SpyableMacroTests/Factories/UT_ReturnValueFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ReturnValueFactory.swift index cfdb5f4..3ba87c9 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ReturnValueFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ReturnValueFactory.swift @@ -54,6 +54,22 @@ final class UT_ReturnValueFactory: XCTestCase { ) } + func testReturnStatementWithForceCastType() { + let variablePrefix = "function_name" + + let result = ReturnValueFactory().returnStatement( + variablePrefix: variablePrefix, + forceCastType: "MyType" + ) + + assertBuildResult( + result, + """ + return function_nameReturnValue as! MyType + """ + ) + } + // MARK: - Helper Methods for Assertions private func assert( diff --git a/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift b/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift index a63d730..5194adf 100644 --- a/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift @@ -173,6 +173,45 @@ final class UT_SpyFactory: XCTestCase { ) } + func testDeclarationGenericArgument() throws { + let declaration = DeclSyntax( + """ + protocol ViewModelProtocol { + func foo(text: String, value: T) -> U + } + """ + ) + let protocolDeclaration = try XCTUnwrap(ProtocolDeclSyntax(declaration)) + + let result = try SpyFactory().classDeclaration(for: protocolDeclaration) + + assertBuildResult( + result, + """ + class ViewModelProtocolSpy: ViewModelProtocol { + var fooTextValueCallsCount = 0 + var fooTextValueCalled: Bool { + return fooTextValueCallsCount > 0 + } + var fooTextValueReceivedArguments: (text: String, value: Any)? + var fooTextValueReceivedInvocations: [(text: String, value: Any)] = [] + var fooTextValueReturnValue: Any! + var fooTextValueClosure: ((String, Any) -> Any)? + func foo(text: String, value: T) -> U { + fooTextValueCallsCount += 1 + fooTextValueReceivedArguments = (text, value) + fooTextValueReceivedInvocations.append((text, value)) + if fooTextValueClosure != nil { + return fooTextValueClosure!(text, value) as! U + } else { + return fooTextValueReturnValue as! U + } + } + } + """ + ) + } + func testDeclarationEscapingAutoClosureArgument() throws { try assertProtocol( withDeclaration: """ diff --git a/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift index 9fbabea..7e2a377 100644 --- a/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift +++ b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift @@ -37,6 +37,7 @@ final class UT_SpyableMacro: XCTestCase { func onTapBack(context: String, action: () -> Void) func onTapNext(context: String, action: @Sendable () -> Void) func assert(_ message: @autoclosure () -> String) + func useGenerics(values1: [T], values2: Array, values3: (T, U, Int)) } """ @@ -177,6 +178,19 @@ final class UT_SpyableMacro: XCTestCase { assertCallsCount += 1 assertClosure?(message()) } + var useGenericsValues1Values2Values3CallsCount = 0 + var useGenericsValues1Values2Values3Called: Bool { + return useGenericsValues1Values2Values3CallsCount > 0 + } + var useGenericsValues1Values2Values3ReceivedArguments: (values1: [Any], values2: Array, values3: (Any, Any, Int))? + var useGenericsValues1Values2Values3ReceivedInvocations: [(values1: [Any], values2: Array, values3: (Any, Any, Int))] = [] + var useGenericsValues1Values2Values3Closure: (([Any], Array, (Any, Any, Int)) -> Void)? + func useGenerics(values1: [T], values2: Array, values3: (T, U, Int)) { + useGenericsValues1Values2Values3CallsCount += 1 + useGenericsValues1Values2Values3ReceivedArguments = (values1, values2, values3) + useGenericsValues1Values2Values3ReceivedInvocations.append((values1, values2, values3)) + useGenericsValues1Values2Values3Closure?(values1, values2, values3) + } } """, macros: sut