Skip to content

Commit

Permalink
Fixed support for derived classes defined in WinRT. (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanlabelle authored Apr 2, 2024
1 parent e1f4746 commit 298c5dc
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 146 deletions.
6 changes: 6 additions & 0 deletions Generator/Sources/ProjectionModel/SupportModules.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ extension SupportModules.WinRT {
public static var winRTDelegateProjection: SwiftType { .chain(moduleName, "WinRTDelegateProjection") }
public static var winRTClassProjection: SwiftType { .chain(moduleName, "WinRTClassProjection") }

public static var winRTComposableClass: SwiftType { .chain(moduleName, "WinRTComposableClass") }

public static func winRTImport(of type: SwiftType) -> SwiftType {
.chain([ .init(moduleName), .init("WinRTImport", genericArgs: [type]) ])
}

public static func winRTArrayProjection(of type: SwiftType) -> SwiftType {
.chain([ .init(moduleName), .init("WinRTArrayProjection", genericArgs: [type]) ])
}
Expand Down
20 changes: 12 additions & 8 deletions Generator/Sources/ProjectionModel/SwiftProjection+types.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@ extension SwiftProjection {
public func toType(_ type: TypeNode) throws -> SwiftType {
switch type {
case let .bound(type):
if let specialTypeProjection = try getSpecialTypeProjection(type) {
return specialTypeProjection.swiftType
}

let swiftObjectType = SwiftType.identifier(
name: try toTypeName(type.definition),
genericArgs: try type.genericArgs.map { try toType($0) })
return type.definition.isValueType ? swiftObjectType : .optional(wrapped: swiftObjectType)
return try toType(type)
case let .genericParam(param):
return .identifier(param.name)
case let .array(of: element):
Expand All @@ -23,6 +16,17 @@ extension SwiftProjection {
}
}

public func toType(_ boundType: BoundType, nullable: Bool = true) throws -> SwiftType {
if let specialTypeProjection = try getSpecialTypeProjection(boundType) {
return specialTypeProjection.swiftType
}

let swiftObjectType = SwiftType.identifier(
name: try toTypeName(boundType.definition),
genericArgs: try boundType.genericArgs.map { try toType($0) })
return boundType.definition.isReferenceType && nullable ? .optional(wrapped: swiftObjectType) : swiftObjectType
}

public func isProjectionInert(_ typeDefinition: TypeDefinition) throws -> Bool {
switch typeDefinition {
case is InterfaceDefinition, is DelegateDefinition, is ClassDefinition: return false
Expand Down
125 changes: 82 additions & 43 deletions Generator/Sources/SwiftWinRT/Writing/ClassDefinition.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@ import WindowsMetadata
import struct Foundation.UUID

internal func writeClassDefinition(_ classDefinition: ClassDefinition, projection: SwiftProjection, to writer: SwiftSourceFileWriter) throws {
let composable = try classDefinition.hasAttribute(ComposableAttribute.self)
let interfaces = try ClassInterfaces(of: classDefinition, composable: composable)
let classKind = try ClassKind(classDefinition)
let interfaces = try ClassInterfaces(of: classDefinition, kind: classKind)
let typeName = try projection.toTypeName(classDefinition)

if interfaces.default != nil {
if classKind != .static {
let projectionTypeName = try projection.toProjectionTypeName(classDefinition)
assert(classDefinition.isSealed || composable)
assert(!classDefinition.isAbstract || composable)

// Write the Swift class definition
let base: SwiftType = .chain(.init("WindowsRuntime"), composable
? .init("WinRTComposableClass")
: .init("WinRTImport", genericArgs: [ .identifier(projectionTypeName) ]))
assert(classDefinition.isSealed || classKind.isComposable)
assert(!classDefinition.isAbstract || classKind.isComposable)

let base: SwiftType
switch classKind {
case .composable(base: .some(let baseClassDefinition)):
base = try projection.toType(baseClassDefinition.bindType(), nullable: false)
case .composable(base: nil):
base = SupportModules.WinRT.winRTComposableClass
default:
base = SupportModules.WinRT.winRTImport(of: .identifier(projectionTypeName))
}

var protocolConformances: [SwiftType] = []
for baseInterface in classDefinition.baseInterfaces {
Expand All @@ -32,7 +37,7 @@ internal func writeClassDefinition(_ classDefinition: ClassDefinition, projectio
visibility: SwiftProjection.toVisibility(classDefinition.visibility, inheritableClass: !classDefinition.isSealed),
final: classDefinition.isSealed, name: typeName, base: base, protocolConformances: protocolConformances) { writer in
try writeClassMembers(
classDefinition, interfaces: interfaces, composable: composable,
classDefinition, interfaces: interfaces, kind: classKind,
projection: projection, to: writer)
}
}
Expand All @@ -45,12 +50,40 @@ internal func writeClassDefinition(_ classDefinition: ClassDefinition, projectio
visibility: SwiftProjection.toVisibility(classDefinition.visibility),
name: typeName) { writer in
try writeClassMembers(
classDefinition, interfaces: interfaces, composable: false,
classDefinition, interfaces: interfaces, kind: .static,
projection: projection, to: writer)
}
}
}

fileprivate enum ClassKind: Equatable {
case activatable
case composable(base: ClassDefinition?)
case `static`

init(_ classDefinition: ClassDefinition) throws {
if classDefinition.isStatic {
self = .static
} else if try classDefinition.hasAttribute(ComposableAttribute.self) {
if let baseClassDefinition = try classDefinition.base?.definition as? ClassDefinition,
try baseClassDefinition != classDefinition.context.coreLibrary.systemObject {
self = .composable(base: baseClassDefinition)
} else {
self = .composable(base: nil)
}
} else {
self = .activatable
}
}

public var isComposable: Bool {
switch self {
case .composable: return true
default: return false
}
}
}

fileprivate struct ClassInterfaces {
var hasDefaultFactory = false
var factories: [InterfaceDefinition] = []
Expand All @@ -64,25 +97,30 @@ fileprivate struct ClassInterfaces {
var protected: Bool
}

public init(of classDefinition: ClassDefinition, composable: Bool) throws {
if composable {
factories = try classDefinition.getAttributes(ComposableAttribute.self).map { $0.factory }
hasDefaultFactory = false
}
else {
let activatableAttributes = try classDefinition.getAttributes(ActivatableAttribute.self)
factories = activatableAttributes.compactMap { $0.factory }
hasDefaultFactory = activatableAttributes.count > factories.count
public init(of classDefinition: ClassDefinition, kind: ClassKind) throws {
switch kind {
case .activatable:
let activatableAttributes = try classDefinition.getAttributes(ActivatableAttribute.self)
factories = activatableAttributes.compactMap { $0.factory }
hasDefaultFactory = activatableAttributes.count > factories.count
case .composable:
factories = try classDefinition.getAttributes(ComposableAttribute.self).map { $0.factory }
hasDefaultFactory = false
case .static:
factories = []
hasDefaultFactory = false
}

for baseInterface in classDefinition.baseInterfaces {
if try baseInterface.hasAttribute(DefaultAttribute.self) {
`default` = try baseInterface.interface
}
else {
let overridable = try baseInterface.hasAttribute(OverridableAttribute.self)
let protected = try baseInterface.hasAttribute(ProtectedAttribute.self)
secondary.append(Secondary(interface: try baseInterface.interface, overridable: overridable, protected: protected))
if kind != .static {
for baseInterface in classDefinition.baseInterfaces {
if try baseInterface.hasAttribute(DefaultAttribute.self) {
`default` = try baseInterface.interface
}
else {
let overridable = try baseInterface.hasAttribute(OverridableAttribute.self)
let protected = try baseInterface.hasAttribute(ProtectedAttribute.self)
secondary.append(Secondary(interface: try baseInterface.interface, overridable: overridable, protected: protected))
}
}
}

Expand All @@ -91,20 +129,20 @@ fileprivate struct ClassInterfaces {
}

fileprivate func writeClassMembers(
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, composable: Bool,
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, kind: ClassKind,
projection: SwiftProjection, to writer: SwiftTypeDefinitionWriter) throws {
try writeGenericTypeAliases(interfaces: classDefinition.baseInterfaces.map { try $0.interface }, projection: projection, to: writer)

try writeClassInterfaceImplementations(
classDefinition, interfaces: interfaces, composable: composable,
classDefinition, interfaces: interfaces, kind: kind,
projection: projection, to: writer)

writer.writeMarkComment("Implementation details")
try writeClassInterfaceProperties(
classDefinition, interfaces: interfaces, composable: composable,
classDefinition, interfaces: interfaces, kind: kind,
projection: projection, to: writer)

if composable {
if kind.isComposable {
let overridableInterfaces = interfaces.secondary.compactMap { $0.overridable ? $0.interface : nil }
if !overridableInterfaces.isEmpty {
writer.writeMarkComment("Override support")
Expand All @@ -114,7 +152,7 @@ fileprivate func writeClassMembers(
}

fileprivate func writeClassInterfaceImplementations(
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, composable: Bool,
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, kind: ClassKind,
projection: SwiftProjection, to writer: SwiftTypeDefinitionWriter) throws {
if interfaces.hasDefaultFactory {
writeMarkComment(forInterface: "IActivationFactory", to: writer)
Expand All @@ -124,8 +162,8 @@ fileprivate func writeClassInterfaceImplementations(
for factoryInterface in interfaces.factories {
if factoryInterface.methods.isEmpty { continue }
try writeMarkComment(forInterface: factoryInterface.bind(), to: writer)
if composable {
try writeComposableInitializers(classDefinition, factoryInterface: factoryInterface, projection: projection, to: writer)
if case .composable(base: let base) = kind {
try writeComposableInitializers(classDefinition, factoryInterface: factoryInterface, base: base, projection: projection, to: writer)
}
else {
try writeActivatableInitializers(classDefinition, activationFactory: factoryInterface, projection: projection, to: writer)
Expand All @@ -134,7 +172,7 @@ fileprivate func writeClassInterfaceImplementations(

if let defaultInterface = interfaces.default, !defaultInterface.definition.methods.isEmpty {
try writeMarkComment(forInterface: defaultInterface, to: writer)
let thisPointer: ThisPointer = composable
let thisPointer: ThisPointer = kind.isComposable
? .init(name: SecondaryInterfaces.getPropertyName(defaultInterface), lazy: true)
: .init(name: "_interop")
try writeInterfaceImplementation(
Expand Down Expand Up @@ -164,18 +202,18 @@ fileprivate func writeClassInterfaceImplementations(
}

fileprivate func writeClassInterfaceProperties(
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, composable: Bool,
_ classDefinition: ClassDefinition, interfaces: ClassInterfaces, kind: ClassKind,
projection: SwiftProjection, to writer: SwiftTypeDefinitionWriter) throws {
// Instance properties, initializers and deinit
if composable, let defaultInterface = interfaces.default {
try SecondaryInterfaces.writeDeclaration(defaultInterface, composable: composable, projection: projection, to: writer)
if kind.isComposable, let defaultInterface = interfaces.default {
try SecondaryInterfaces.writeDeclaration(defaultInterface, composable: kind.isComposable, projection: projection, to: writer)
}

for secondaryInterface in interfaces.secondary {
try SecondaryInterfaces.writeDeclaration(secondaryInterface.interface, composable: composable, projection: projection, to: writer)
try SecondaryInterfaces.writeDeclaration(secondaryInterface.interface, composable: kind.isComposable, projection: projection, to: writer)
}

if composable, let defaultInterface = interfaces.default {
if kind.isComposable, let defaultInterface = interfaces.default {
try writeSupportComposableInitializers(defaultInterface: defaultInterface, projection: projection, to: writer)
}

Expand Down Expand Up @@ -249,7 +287,7 @@ fileprivate func writeMarkComment(forInterface interfaceName: String, to writer:
}

fileprivate func writeComposableInitializers(
_ classDefinition: ClassDefinition, factoryInterface: InterfaceDefinition,
_ classDefinition: ClassDefinition, factoryInterface: InterfaceDefinition, base: ClassDefinition?,
projection: SwiftProjection, to writer: SwiftTypeDefinitionWriter) throws {
let propertyName = SecondaryInterfaces.getPropertyName(factoryInterface.bind())

Expand All @@ -259,6 +297,7 @@ fileprivate func writeComposableInitializers(
try writer.writeInit(
documentation: try projection.getDocumentationComment(abiMember: method, classDefinition: classDefinition),
visibility: .public,
override: params.count == 2 && base != nil, // Hack: assume all base classes have a default initializer we are overriding
params: params.dropLast(2).map { $0.toSwiftParam() },
throws: true) { writer in
let output = writer.output
Expand Down
50 changes: 50 additions & 0 deletions InteropTests/Tests/ClassInheritanceTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import COM
import XCTest
import WinRTComponent

class ClassInheritanceTests : XCTestCase {
public func testOverridableMemberOnBaseClass() throws {
// Created from Swift
XCTAssertEqual(try MinimalBaseClass()._typeName(), "MinimalBaseClass")
XCTAssertEqual(try MinimalBaseClass.getTypeName(MinimalBaseClass()), "MinimalBaseClass")

// Created from WinRT
XCTAssertEqual(try MinimalBaseClass.createBase()._typeName(), "MinimalBaseClass")
XCTAssertEqual(try MinimalBaseClass.getTypeName(MinimalBaseClass.createBase()), "MinimalBaseClass")
}

public func testOverridenMemberInWinRTDerivedClass() throws {
// Created from Swift
XCTAssertEqual(try MinimalDerivedClass()._typeName(), "MinimalDerivedClass")
XCTAssertEqual(try MinimalBaseClass.getTypeName(MinimalDerivedClass()), "MinimalDerivedClass")

// Created from WinRT
XCTAssertEqual(try MinimalDerivedClass.createDerived()._typeName(), "MinimalDerivedClass")
XCTAssertEqual(try MinimalBaseClass.getTypeName(MinimalDerivedClass.createDerived()), "MinimalDerivedClass")
}

public func testOverridenMemberInWinRTPrivateClass() throws {
XCTAssertEqual(try MinimalBaseClass.createPrivate()._typeName(), "PrivateClass")
XCTAssertEqual(try MinimalBaseClass.getTypeName(MinimalBaseClass.createPrivate()), "PrivateClass")
}

public func testOverridenMemberInSwiftClass() throws {
class SwiftDerived: MinimalBaseClass {
public override init() throws { try super.init() }
public override func _typeName() throws -> String { "SwiftDerived" }
}

XCTAssertEqual(try SwiftDerived()._typeName(), "SwiftDerived")
XCTAssertEqual(try MinimalBaseClass.getTypeName(SwiftDerived()), "SwiftDerived")
}

public func testOverridenMemberInSwiftClassDerivedFromWinRTDerivedClass() throws {
class SwiftDerived2: MinimalDerivedClass {
public override init() throws { try super.init() }
public override func _typeName() throws -> String { "SwiftDerived2" }
}

XCTAssertEqual(try SwiftDerived2()._typeName(), "SwiftDerived2")
XCTAssertEqual(try MinimalBaseClass.getTypeName(SwiftDerived2()), "SwiftDerived2")
}
}
22 changes: 0 additions & 22 deletions InteropTests/Tests/CompositionTests.swift

This file was deleted.

32 changes: 0 additions & 32 deletions InteropTests/WinRTComponent/MinimalUnsealedClass.cpp

This file was deleted.

Loading

0 comments on commit 298c5dc

Please sign in to comment.