Skip to content

Commit

Permalink
Implemented Swift object unwrapping.
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanlabelle committed Dec 21, 2023
1 parent 85d3ab8 commit 8454295
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 117 deletions.
1 change: 0 additions & 1 deletion InteropTests/Tests/ObjectExportingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class ObjectExportingTests: WinRTTestCase {
}

func testUnwrapping() throws {
try XCTSkipIf(true, "TODO: Implement unwrapping of exported Swift objects")
let obj: IInspectable = ExportedClass()
let returnArgument = try XCTUnwrap(ReturnArgument.create())
let roundtripped = try XCTUnwrap(returnArgument.object(obj))
Expand Down
162 changes: 105 additions & 57 deletions Support/Sources/COM/COMExportedObject.swift
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import CWinRTCore

public protocol COMExportedObjectProtocol: IUnknownProtocol {
var unknown: IUnknownPointer { get }
var anyImplementation: Any { get }
var identity: any COMExportedObjectProtocol { get }
var queriableInterfaces: [COMExportInterface] { get }
}

public struct COMExportInterface {
public let id: COMInterfaceID
public let queryPointer: (_ identity: any COMExportedObjectProtocol) throws -> IUnknownPointer
public let queryPointer: (_ identity: COMExportedObjectCore) throws -> IUnknownPointer

public init<TargetProjection: COMTwoWayProjection>(_: TargetProjection.Type) {
self.id = TargetProjection.id
Expand All @@ -22,48 +15,31 @@ public struct COMExportInterface {
}
}

open class COMExportedObject<Projection: COMTwoWayProjection>: COMExportedObjectProtocol, IUnknownProtocol {
private struct COMInterface {
/// Virtual function table called by COM
public let virtualTablePointer: Projection.COMVirtualTablePointer = Projection.virtualTablePointer
public var object: Unmanaged<COMExportedObject<Projection>>! = nil
/// Provides an object layout that can be passed as a pointer to COM consumers, with a leading virtual table pointer.
open class COMExportedObjectCore: IUnknownProtocol {
/// Identifies that a COM object is an instance of this class.
fileprivate static let markerInterfaceId = COMInterfaceID(0x33934271, 0x7009, 0x4EF3, 0x90F1, 0x02090D7EBD64)

fileprivate struct COMInterface {
public let virtualTable: UnsafeRawPointer
public var this: Unmanaged<COMExportedObjectCore>!
}

private enum IdentityData {
fileprivate enum IdentityData {
case own(queriableInterfaces: [COMExportInterface])
case foreign(any COMExportedObjectProtocol)
case foreign(COMExportedObjectCore)
}

private var comInterface: COMInterface
private let identityData: IdentityData
public let implementation: Projection.SwiftObject
public var anyImplementation: Any { implementation }

public init(implementation: Projection.SwiftObject, queriableInterfaces: [COMExportInterface]) {
self.comInterface = COMInterface()
self.identityData = .own(queriableInterfaces: queriableInterfaces)
self.implementation = implementation
self.comInterface.object = Unmanaged.passUnretained(self)
fileprivate init(virtualTable: UnsafeRawPointer, identityData: IdentityData) {
self.comInterface = COMInterface(virtualTable: virtualTable, this: nil)
self.identityData = identityData
self.comInterface.this = Unmanaged.passUnretained(self)
}

fileprivate init(implementation: Projection.SwiftObject, identity: any COMExportedObjectProtocol) {
self.comInterface = COMInterface()
self.identityData = .foreign(identity)
self.implementation = implementation
self.comInterface.object = Unmanaged.passUnretained(self)
}

public var pointer: Projection.COMPointer {
withUnsafeMutablePointer(to: &comInterface) {
$0.withMemoryRebound(to: Projection.COMInterface.self, capacity: 1) { $0 }
}
}

public var unknown: IUnknownPointer {
IUnknownPointer.cast(pointer)
}

public var identity: any COMExportedObjectProtocol {
public var identity: COMExportedObjectCore {
switch identityData {
case .own: self
case .foreign(let other): other
Expand All @@ -77,12 +53,19 @@ open class COMExportedObject<Projection: COMTwoWayProjection>: COMExportedObject
}
}

public var unknown: IUnknownPointer {
withUnsafeMutablePointer(to: &comInterface) {
IUnknownPointer.cast($0)
}
}

// Overriden in derived class
public var anyImplementation: Any { fatalError() }

open func _queryInterfacePointer(_ id: COMInterfaceID) throws -> IUnknownPointer {
if id == Projection.id { return unknown.addingRef() }

switch identityData {
case .own(let queriableInterfaces):
if id == IUnknownProjection.id { return unknown.addingRef() }
if id == IUnknownProjection.id || id == Self.markerInterfaceId { return unknown.addingRef() }
guard let interface = queriableInterfaces.first(where: { $0.id == id }) else {
throw HResult.Error.noInterface
}
Expand All @@ -93,32 +76,97 @@ open class COMExportedObject<Projection: COMTwoWayProjection>: COMExportedObject
}
}

private static func cast(_ this: Projection.COMPointer) -> UnsafeMutablePointer<COMInterface> {
this.withMemoryRebound(to: COMInterface.self, capacity: 1) { $0 }
private static func toUnmanagedUnsafe(_ this: IUnknownPointer) -> Unmanaged<COMExportedObjectCore> {
this.withMemoryRebound(to: COMInterface.self, capacity: 1) { $0.pointee.this }
}

internal static func castUnsafe(_ this: IUnknownPointer) -> COMExportedObjectCore {
toUnmanagedUnsafe(this).takeUnretainedValue()
}

internal static func from(_ this: Projection.COMPointer) -> COMExportedObject<Projection> {
cast(this).pointee.object.takeUnretainedValue()
internal static func unwrapUnsafe(_ this: IUnknownPointer) -> Any {
castUnsafe(this).anyImplementation
}

@discardableResult
internal static func addRef(_ this: Projection.COMPointer) -> UInt32 {
let this = cast(this)
_ = this.pointee.object.retain()
internal static func addRefUnsafe(_ this: IUnknownPointer) -> UInt32 {
let unmanaged = toUnmanagedUnsafe(this)
_ = unmanaged.retain()
// Best effort refcount
return UInt32(_getRetainCount(this.pointee.object.takeUnretainedValue()))
return UInt32(_getRetainCount(unmanaged.takeUnretainedValue()))
}

@discardableResult
internal static func release(_ this: Projection.COMPointer) -> UInt32 {
let this = cast(this)
let oldRetainCount = _getRetainCount(this.pointee.object.takeUnretainedValue())
this.pointee.object.release()
internal static func releaseUnsafe(_ this: IUnknownPointer) -> UInt32 {
let unmanaged = toUnmanagedUnsafe(this)
let oldRetainCount = _getRetainCount(unmanaged.takeUnretainedValue())
unmanaged.release()
// Best effort refcount
return UInt32(oldRetainCount - 1)
}

internal static func queryInterface(_ this: Projection.COMPointer, _ id: COMInterfaceID) throws -> IUnknownPointer {
try cast(this).pointee.object.takeUnretainedValue()._queryInterfacePointer(id)
internal static func queryInterfaceUnsafe(_ this: IUnknownPointer, _ id: COMInterfaceID) throws -> IUnknownPointer {
try castUnsafe(this)._queryInterfacePointer(id)
}

public static func unwrap(_ this: IUnknownPointer) -> Any? {
// Use the marker interface to test if this is a COMExportedObject
guard let result = try? this.queryInterface(Self.markerInterfaceId) else { return nil }
result.release()
return unwrapUnsafe(this)
}
}

open class COMExportedObject<Projection: COMTwoWayProjection>: COMExportedObjectCore {
public let implementation: Projection.SwiftObject
public override var anyImplementation: Any { implementation }

public init(implementation: Projection.SwiftObject, queriableInterfaces: [COMExportInterface]) {
self.implementation = implementation
super.init(
virtualTable: Projection.virtualTablePointer,
identityData: .own(queriableInterfaces: queriableInterfaces))
}

fileprivate init(implementation: Projection.SwiftObject, identity: COMExportedObjectCore) {
self.implementation = implementation
super.init(
virtualTable: Projection.virtualTablePointer,
identityData: .foreign(identity))
}

public var pointer: Projection.COMPointer {
unknown.withMemoryRebound(to: Projection.COMInterface.self, capacity: 1) { $0 }
}

open override func _queryInterfacePointer(_ id: COMInterfaceID) throws -> IUnknownPointer {
if id == Projection.id { return unknown.addingRef() }
return try super._queryInterfacePointer(id)
}

internal static func castUnsafe(_ this: Projection.COMPointer) -> Self {
COMExportedObjectCore.castUnsafe(IUnknownPointer.cast(this)) as! Self
}

internal static func unwrapUnsafe(_ this: Projection.COMPointer) -> Projection.SwiftObject {
castUnsafe(this).implementation
}

@discardableResult
internal static func addRefUnsafe(_ this: Projection.COMPointer) -> UInt32 {
COMExportedObjectCore.addRefUnsafe(IUnknownPointer.cast(this))
}

@discardableResult
internal static func releaseUnsafe(_ this: Projection.COMPointer) -> UInt32 {
COMExportedObjectCore.releaseUnsafe(IUnknownPointer.cast(this))
}

internal static func queryInterfaceUnsafe(_ this: Projection.COMPointer, _ id: COMInterfaceID) throws -> IUnknownPointer {
try COMExportedObjectCore.queryInterfaceUnsafe(IUnknownPointer.cast(this), id)
}

public static func unwrap(_ this: Projection.COMPointer) -> Projection.SwiftObject? {
COMExportedObjectCore.unwrap(IUnknownPointer.cast(this)) as? Projection.SwiftObject
}
}
8 changes: 4 additions & 4 deletions Support/Sources/COM/COMImport+TwoWay.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import CWinRTCore

extension COMImport where Projection: COMTwoWayProjection {
public static func _getImplementation(_ pointer: Projection.COMPointer) -> Projection.SwiftObject {
COMExportedObject<Projection>.from(pointer).implementation
COMExportedObject<Projection>.unwrapUnsafe(pointer)
}

public static func _getImplementation(_ pointer: Projection.COMPointer?) -> Projection.SwiftObject? {
Expand Down Expand Up @@ -39,7 +39,7 @@ extension COMImport where Projection: COMTwoWayProjection {

return HResult.catchValue {
let id = GUIDProjection.toSwift(iid.pointee)
let unknownWithRef = try COMExportedObject<Projection>.queryInterface(this, id)
let unknownWithRef = try COMExportedObject<Projection>.queryInterfaceUnsafe(this, id)
ppvObject.pointee = UnsafeMutableRawPointer(unknownWithRef)
}
}
Expand All @@ -49,14 +49,14 @@ extension COMImport where Projection: COMTwoWayProjection {
assertionFailure("COM this pointer was null")
return 1
}
return COMExportedObject<Projection>.addRef(this)
return COMExportedObject<Projection>.addRefUnsafe(this)
}

public static func _release(_ this: Projection.COMPointer?) -> UInt32 {
guard let this else {
assertionFailure("COM this pointer was null")
return 0
}
return COMExportedObject<Projection>.release(this)
return COMExportedObject<Projection>.releaseUnsafe(this)
}
}
7 changes: 6 additions & 1 deletion Support/Sources/COM/COMProjection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ extension COMProjection {
public static func toSwift<Implementation: COMImport<Self>>(
transferringRef comPointer: COMPointer,
implementation: Implementation.Type) -> SwiftObject {
Implementation(transferringRef: comPointer).swiftObject
if let unwrappedAny = COMExportedObjectCore.unwrap(IUnknownPointer.cast(comPointer)),
let unwrapped = unwrappedAny as? SwiftObject {
IUnknownPointer.release(comPointer)
return unwrapped
}
return Implementation(transferringRef: comPointer).swiftObject
}

public static func toCOM<Implementation: COMImport<Self>>(
Expand Down
1 change: 0 additions & 1 deletion Support/Sources/CWinRTCore/include/CWinRTCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

#include "WinRT.h"
#include "IBufferByteAccess.h"
#include "ISwiftObject.h"

#include "Functions.h"
16 changes: 0 additions & 16 deletions Support/Sources/CWinRTCore/include/ISwiftObject.h

This file was deleted.

8 changes: 4 additions & 4 deletions Support/Sources/CWinRTCore/include/WinRT.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct SWRT_IPropertyValueVTable {
SWRT_HResult (__stdcall *GetUInt64)(SWRT_IPropertyValue* _this, uint64_t* value);
SWRT_HResult (__stdcall *GetSingle)(SWRT_IPropertyValue* _this, float* value);
SWRT_HResult (__stdcall *GetDouble)(SWRT_IPropertyValue* _this, double* value);
SWRT_HResult (__stdcall *GetChar16)(SWRT_IPropertyValue* _this, uint16_t* value);
SWRT_HResult (__stdcall *GetChar16)(SWRT_IPropertyValue* _this, char16_t* value);
SWRT_HResult (__stdcall *GetBoolean)(SWRT_IPropertyValue* _this, bool* value);
SWRT_HResult (__stdcall *GetString)(SWRT_IPropertyValue* _this, SWRT_HString* value);
SWRT_HResult (__stdcall *GetGuid)(SWRT_IPropertyValue* _this, SWRT_Guid* value);
Expand All @@ -116,7 +116,7 @@ struct SWRT_IPropertyValueVTable {
SWRT_HResult (__stdcall *GetUInt64Array)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, uint64_t** value);
SWRT_HResult (__stdcall *GetSingleArray)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, float** value);
SWRT_HResult (__stdcall *GetDoubleArray)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, double** value);
SWRT_HResult (__stdcall *GetChar16Array)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, uint16_t** value);
SWRT_HResult (__stdcall *GetChar16Array)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, char16_t** value);
SWRT_HResult (__stdcall *GetBooleanArray)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, bool** value);
SWRT_HResult (__stdcall *GetStringArray)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, SWRT_HString** value);
SWRT_HResult (__stdcall *GetInspectableArray)(SWRT_IPropertyValue* _this, uint32_t* __valueSize, SWRT_IInspectable*** value);
Expand Down Expand Up @@ -150,7 +150,7 @@ struct SWRT_IPropertyValueStaticsVTable {
SWRT_HResult (__stdcall *CreateUInt64)(SWRT_IPropertyValueStatics* _this, uint64_t value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateSingle)(SWRT_IPropertyValueStatics* _this, float value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateDouble)(SWRT_IPropertyValueStatics* _this, double value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateChar16)(SWRT_IPropertyValueStatics* _this, uint16_t value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateChar16)(SWRT_IPropertyValueStatics* _this, char16_t value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateBoolean)(SWRT_IPropertyValueStatics* _this, bool value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateString)(SWRT_IPropertyValueStatics* _this, SWRT_HString value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateInspectable)(SWRT_IPropertyValueStatics* _this, SWRT_IInspectable* value, SWRT_IInspectable** propertyValue);
Expand All @@ -169,7 +169,7 @@ struct SWRT_IPropertyValueStaticsVTable {
SWRT_HResult (__stdcall *CreateUInt64Array)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, uint64_t* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateSingleArray)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, float* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateDoubleArray)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, double* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateChar16Array)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, uint16_t* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateChar16Array)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, char16_t* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateBooleanArray)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, bool* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateStringArray)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, SWRT_HString* value, SWRT_IInspectable** propertyValue);
SWRT_HResult (__stdcall *CreateInspectableArray)(SWRT_IPropertyValueStatics* _this, uint32_t __valueSize, SWRT_IInspectable** value, SWRT_IInspectable** propertyValue);
Expand Down
26 changes: 14 additions & 12 deletions Support/Sources/WindowsRuntime/WinRTExport.swift
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import COM

/// Base for classes exported to WinRT.
open class WinRTExportedObject<Projection: WinRTTwoWayProjection>
: COMExportedObject<Projection>, IInspectableProtocol {
public override func _queryInterfacePointer(_ id: COMInterfaceID) throws -> IUnknownPointer {
return id == IInspectableProjection.id
? identity.unknown.addingRef()
: try super._queryInterfacePointer(id)
}
open class WinRTExport<Projection: WinRTTwoWayProjection>
: COMExport<Projection>, IInspectableProtocol {
open class var _runtimeClassName: String { String(describing: Self.self) }
open class var _trustLevel: TrustLevel { .base }

public final func getIids() throws -> [COMInterfaceID] { queriableInterfaces.map { $0.id } }
open func getRuntimeClassName() throws -> String { try (implementation as! IInspectable).getRuntimeClassName() }
open func getTrustLevel() throws -> TrustLevel { try (implementation as! IInspectable).getTrustLevel() }
}
public override func _createCOMObject() -> COMExportedObject<Projection> {
WinRTExportedObject<Projection>(
implementation: self as! Projection.SwiftObject,
queriableInterfaces: Self.queriableInterfaces)
}

public final func getIids() throws -> [COMInterfaceID] { Self.queriableInterfaces.map { $0.id } }
public final func getRuntimeClassName() throws -> String { Self._runtimeClassName }
public final func getTrustLevel() throws -> TrustLevel { Self._trustLevel }
}
17 changes: 0 additions & 17 deletions Support/Sources/WindowsRuntime/WinRTExportBase.swift

This file was deleted.

Loading

0 comments on commit 8454295

Please sign in to comment.