Skip to content

Commit

Permalink
Preserve Swift errors through WinRT boundaries via ILanguageException…
Browse files Browse the repository at this point in the history
…ErrorInfo (#261)
  • Loading branch information
tristanlabelle authored Sep 4, 2024
1 parent 8211356 commit 8b3bc8b
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 116 deletions.
2 changes: 0 additions & 2 deletions InteropTests/Tests/ErrorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class ErrorTests: WinRTTestCase {
}

func testSwiftErrorPreserved() throws {
try XCTSkipIf(true, "TODO(#248): Fix preserving Swift error objects across the WinRT boundary.")

struct SwiftError: Error {}
do {
try Errors.call { throw SwiftError() }
Expand Down
25 changes: 25 additions & 0 deletions Support/Sources/COM/COMError+SwiftErrorInfo.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
extension COMError {
/// Wraps a Swift Error object into an `IErrorInfo` to preserve it across COM boundaries.
internal final class SwiftErrorInfo: COMPrimaryExport<IErrorInfoProjection>, IErrorInfoProtocol {
public let error: Error

public init(error: Error) {
self.error = error
}

public func toABI(setErrorInfo: Bool = true) -> HResult.Value {
if setErrorInfo { try? COMError.setErrorInfo(self) }
return hresult.value
}

public var hresult: HResult { (self.error as? ErrorWithHResult)?.hresult ?? HResult.fail }
public var message: String { String(describing: error) }

// IErrorInfo
public var guid: GUID { get throws { throw COMError.fail } }
public var source: String? { get throws { throw COMError.fail } }
public var description: String? { self.message }
public var helpFile: String? { get throws { throw COMError.fail } }
public var helpContext: UInt32 { get throws { throw COMError.fail } }
}
}
23 changes: 0 additions & 23 deletions Support/Sources/COM/SwiftErrorInfo.swift

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import COM
import WindowsRuntime_ABI

public typealias IActivationFactory = any IActivationFactoryProtocol
public protocol IActivationFactoryProtocol: IInspectableProtocol {
func activateInstance() throws -> IInspectable
}

import WindowsRuntime_ABI

public enum IActivationFactoryProjection: InterfaceProjection {
public typealias SwiftObject = IActivationFactory
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IActivationFactory
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import COM
import WindowsRuntime_ABI

public typealias IInspectable = any IInspectableProtocol
public protocol IInspectableProtocol: IUnknownProtocol {
Expand All @@ -8,6 +7,8 @@ public protocol IInspectableProtocol: IUnknownProtocol {
func getTrustLevel() throws -> TrustLevel
}

import WindowsRuntime_ABI

public enum IInspectableProjection: InterfaceProjection {
public typealias SwiftObject = IInspectable
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IInspectable
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import COM

public typealias ILanguageExceptionErrorInfo = any ILanguageExceptionErrorInfoProtocol
public protocol ILanguageExceptionErrorInfoProtocol: IUnknownProtocol {
var languageException: IUnknown { get throws }
}

import WindowsRuntime_ABI

public enum ILanguageExceptionErrorInfoProjection: COMTwoWayProjection {
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo
public typealias SwiftObject = ILanguageExceptionErrorInfo

public static var interfaceID: COMInterfaceID { uuidof(ABIStruct.self) }
public static var virtualTablePointer: UnsafeRawPointer { .init(withUnsafePointer(to: &virtualTable) { $0 }) }

public static func _wrap(_ reference: consuming ABIReference) -> SwiftObject {
Import(_wrapping: reference)
}

public static func toCOM(_ object: SwiftObject) throws -> ABIReference {
try Import.toCOM(object)
}

private final class Import: COMImport<ILanguageExceptionErrorInfoProjection>, ILanguageExceptionErrorInfoProtocol {
var languageException: IUnknown {
get throws { try NullResult.unwrap(_interop.getLanguageException()) }
}
}

private static var virtualTable: WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo_VirtualTable = .init(
QueryInterface: { IUnknownVirtualTable.QueryInterface($0, $1, $2) },
AddRef: { IUnknownVirtualTable.AddRef($0) },
Release: { IUnknownVirtualTable.Release($0) },
GetLanguageException: { this, languageException in _implement(this) { this in
guard let languageException else { throw COMError.fail }
languageException.pointee = try IUnknownProjection.toABI(this.languageException)
} })
}

public func uuidof(_: WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo.Type) -> COMInterfaceID {
.init(0x04a2dbf3, 0xdf83, 0x116c, 0x0946, 0x0812abf6e07d)
}

extension COMInterop where ABIStruct == WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo {
public func getLanguageException() throws -> IUnknown {
var result: IUnknownPointer? = nil // IUnknownProjection.abiDefaultValue (compiler bug?)
defer { IUnknownProjection.release(&result) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.GetLanguageException(this, &result))
return try NullResult.unwrap(IUnknownProjection.toSwift(consuming: &result))
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import WindowsRuntime_ABI
import COM

public typealias IRestrictedErrorInfo = any IRestrictedErrorInfoProtocol
public protocol IRestrictedErrorInfoProtocol: IUnknownProtocol {
Expand All @@ -10,6 +10,8 @@ public protocol IRestrictedErrorInfoProtocol: IUnknownProtocol {
var reference: String? { get throws }
}

import WindowsRuntime_ABI

public enum IRestrictedErrorInfoProjection: COMTwoWayProjection {
public typealias SwiftObject = IRestrictedErrorInfo
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IRestrictedErrorInfo
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import WindowsRuntime_ABI
import COM

public typealias IWeakReference = any IWeakReferenceProtocol
public protocol IWeakReferenceProtocol: IUnknownProtocol {
func resolve() throws -> IInspectable?
}

import WindowsRuntime_ABI

public enum IWeakReferenceProjection: COMTwoWayProjection {
public typealias SwiftObject = IWeakReference
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IWeakReference
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import WindowsRuntime_ABI
import COM

public typealias IWeakReferenceSource = any IWeakReferenceSourceProtocol
public protocol IWeakReferenceSourceProtocol: IUnknownProtocol {
func getWeakReference() throws -> IWeakReference
}

import WindowsRuntime_ABI

public enum IWeakReferenceSourceProjection: COMTwoWayProjection {
public typealias SwiftObject = IWeakReferenceSource
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IWeakReferenceSource
Expand Down
36 changes: 0 additions & 36 deletions Support/Sources/WindowsRuntime/SwiftRestrictedErrorInfo.swift

This file was deleted.

10 changes: 10 additions & 0 deletions Support/Sources/WindowsRuntime/WinRTError+LanguageException.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
extension WinRTError {
/// Wraps a Swift Error object so it can be associated with an `IRestrictedErrorInfo`.
internal final class LanguageException: COMPrimaryExport<IUnknownProjection> {
public let error: Error

public init(error: Error) {
self.error = error
}
}
}
81 changes: 31 additions & 50 deletions Support/Sources/WindowsRuntime/WinRTError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
}

public init(hresult: HResult, message: String?) {
self.init(hresult: hresult, errorInfo: message.map { MessageRestrictedErrorInfo(hresult: hresult, message: $0) })
let errorInfo = message.flatMap { try? Self.createRestrictedErrorInfo(hresult: hresult, message: $0) }
self.init(hresult: hresult, errorInfo: errorInfo)
}

public var errorInfo: IErrorInfo? {
Expand All @@ -38,12 +39,13 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
public static func fromABI(captureErrorInfo: Bool = true, _ hresult: WindowsRuntime_ABI.SWRT_HResult) throws -> HResult {
let hresult = HResult(hresult)
guard hresult.isFailure else { return hresult }
guard captureErrorInfo else { throw WinRTError(hresult: hresult) }
guard captureErrorInfo, let restrictedErrorInfo = try? Self.getRestrictedErrorInfo(matching: hresult) else {
throw WinRTError(hresult: hresult)
}

let restrictedErrorInfo = try? Self.getRestrictedErrorInfo(matching: hresult)
if let swiftErrorInfo = restrictedErrorInfo as? SwiftRestrictedErrorInfo, swiftErrorInfo.hresult == hresult {
// This was originally a Swift error, throw it as such.
throw swiftErrorInfo.error
if let languageExceptionErrorInfo = try? restrictedErrorInfo.queryInterface(ILanguageExceptionErrorInfoProjection.self),
let languageException = try? languageExceptionErrorInfo.languageException as? LanguageException {
throw languageException.error
}

throw WinRTError(hresult: hresult, errorInfo: restrictedErrorInfo)
Expand All @@ -61,15 +63,19 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
if let comError = error as? any COMErrorProtocol { return comError.toABI() }

// Otherwise, originate a new error
let restrictedErrorInfo = SwiftRestrictedErrorInfo(error: error)
if originate { restrictedErrorInfo.originate(captureContext: captureContext) }
return restrictedErrorInfo.hresult.value
let hresult = (error as? ErrorWithHResult)?.hresult ?? HResult.fail
if originate && Self.originate(hresult: hresult, message: String(describing: error), languageException: LanguageException(error: error)) {
if captureContext { try? Self.captureContext(hresult: hresult) }
}

return hresult.value
}

public static func toABI(hresult: HResult, message: String? = nil, captureContext: Bool = true) -> HResult.Value {
guard hresult.isFailure else { return hresult.value }
Self.originate(hresult: hresult, message: message)
if captureContext { try? Self.captureContext(hresult: hresult) }
if Self.originate(hresult: hresult, message: message) {
if captureContext { try? Self.captureContext(hresult: hresult) }
}
return hresult.value
}

Expand All @@ -81,12 +87,12 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
}

@discardableResult
public static func originate(hresult: HResult, message: String?, restrictedErrorInfo: IRestrictedErrorInfo?) -> Bool {
guard let restrictedErrorInfo else { return originate(hresult: hresult, message: message) }
public static func originate(hresult: HResult, message: String?, languageException: IUnknown?) -> Bool {
guard let languageException else { return originate(hresult: hresult, message: message) }

var message = message == nil ? nil : try? StringProjection.toABI(message!)
defer { StringProjection.release(&message) }
var iunknown = try? IUnknownProjection.toABI(restrictedErrorInfo)
var iunknown = try? IUnknownProjection.toABI(languageException)
defer { IUnknownProjection.release(&iunknown) }
return WindowsRuntime_ABI.SWRT_RoOriginateLanguageException(hresult.value, message, iunknown)
}
Expand All @@ -103,6 +109,17 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
SWRT_RoFailFastWithErrorContext(hresult.value)
}

public static func createRestrictedErrorInfo(hresult: HResult, message: String?, languageException: IUnknown? = nil) throws -> IRestrictedErrorInfo {
// From the SetRestrictedErrorInfo docs at https://learn.microsoft.com/en-us/windows/win32/api/roerrorapi/nf-roerrorapi-setrestrictederrorinfo:
// > The call fails if IRestrictedErrorInfo isn't the system implementation.
// > To create an IRestrictedErrorInfo object, call the OriginateError, TransformError, or RoCaptureErrorContext functions.
// But RoOriginateError overwrites the current thread error info object,
// so we need to manually save and restore it around the call to RoOriginateError.
let previousErrorInfo = try? COMError.getErrorInfo()
defer { try? COMError.setErrorInfo(previousErrorInfo) }
return try NullResult.unwrap(Self.originate(hresult: hresult, message: message) ? try? Self.getRestrictedErrorInfo() : nil)
}

public static func getRestrictedErrorInfo() throws -> IRestrictedErrorInfo? {
var restrictedErrorInfo: UnsafeMutablePointer<SWRT_IRestrictedErrorInfo>?
defer { IRestrictedErrorInfoProjection.release(&restrictedErrorInfo) }
Expand All @@ -122,40 +139,4 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
defer { IRestrictedErrorInfoProjection.release(&abiValue) }
try fromABI(captureErrorInfo: false, WindowsRuntime_ABI.SWRT_SetRestrictedErrorInfo(abiValue))
}

private final class MessageRestrictedErrorInfo: COMPrimaryExport<IRestrictedErrorInfoProjection>,
IRestrictedErrorInfoProtocol, IErrorInfoProtocol {
public override class var implements: [COMImplements] { [
.init(IErrorInfoProjection.self)
] }

public let hresult: HResult
public let message: String

public init(hresult: HResult, message: String) {
self.hresult = hresult
self.message = message
}

// IErrorInfo
public var guid: GUID { get throws { throw COMError.fail } }
public var source: String? { get throws { throw COMError.fail } }
public var description: String? { self.message }
public var helpFile: String? { get throws { throw COMError.fail } }
public var helpContext: UInt32 { get throws { throw COMError.fail } }

// IRestrictedErrorInfo
public func getErrorDetails(
description: inout String?,
error: inout HResult,
restrictedDescription: inout String?,
capabilitySid: inout String?) throws {
description = self.message
error = self.hresult
restrictedDescription = self.message
capabilitySid = nil
}

public var reference: String? { get throws { throw COMError.fail } }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
#include "SWRT/windows/oleauto.h"
#include "SWRT/windows/unknwn.h"

typedef struct SWRT_ILanguageExceptionErrorInfo {
struct SWRT_ILanguageExceptionErrorInfo_VirtualTable* VirtualTable;
} SWRT_ILanguageExceptionErrorInfo;

struct SWRT_ILanguageExceptionErrorInfo_VirtualTable {
SWRT_HResult (__stdcall *QueryInterface)(SWRT_ILanguageExceptionErrorInfo* _this, SWRT_Guid* riid, void** ppvObject);
uint32_t (__stdcall *AddRef)(SWRT_ILanguageExceptionErrorInfo* _this);
uint32_t (__stdcall *Release)(SWRT_ILanguageExceptionErrorInfo* _this);
SWRT_HResult (__stdcall *GetLanguageException)(SWRT_ILanguageExceptionErrorInfo* _this, SWRT_IUnknown** languageException);
};

typedef struct SWRT_IRestrictedErrorInfo {
struct SWRT_IRestrictedErrorInfo_VirtualTable* VirtualTable;
} SWRT_IRestrictedErrorInfo;
Expand Down

0 comments on commit 8b3bc8b

Please sign in to comment.