Skip to content

Commit

Permalink
Support ILanguageExceptionErrorInfo2 for extra context capture (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanlabelle authored Sep 4, 2024
1 parent 9cd34cb commit b871cf3
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 48 deletions.
4 changes: 1 addition & 3 deletions Support/Sources/COM/COMEmbedding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ public enum IUnknownVirtualTable {
guard let this, let iid, let ppvObject else { return COMError.toABI(hresult: HResult.invalidArg) }
ppvObject.pointee = nil

// Avoid setting the error info upon failure since QueryInterface is called
// by RoOriginateError, which is trying to set the error info itself.
return COMError.toABI(setErrorInfo: false) {
return COMError.toABI {
let id = GUIDProjection.toSwift(iid.pointee)
let this = IUnknownPointer(OpaquePointer(this))
let reference = id == uuidof(SWRT_SwiftCOMObject.self)
Expand Down
3 changes: 1 addition & 2 deletions Support/Sources/COM/COMInterop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ public struct COMInterop<ABIStruct> {
public func queryInterface(_ id: COMInterfaceID) throws -> IUnknownReference {
var iid = GUIDProjection.toABI(id)
var rawPointer: UnsafeMutableRawPointer? = nil
// Avoid calling GetErrorInfo since RoOriginateError causes QueryInterface calls
try COMError.fromABI(captureErrorInfo: false, unknown.pointee.VirtualTable.pointee.QueryInterface(unknown, &iid, &rawPointer))
try COMError.fromABI(unknown.pointee.VirtualTable.pointee.QueryInterface(unknown, &iid, &rawPointer))
guard let rawPointer else {
assertionFailure("QueryInterface succeeded but returned a null pointer")
throw COMError.noInterface
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import COM

/// Enables retrieving the IUnknown pointer stored in the error info with the call to RoOriginateLanguageException.
public typealias ILanguageExceptionErrorInfo = any ILanguageExceptionErrorInfoProtocol
public protocol ILanguageExceptionErrorInfoProtocol: IUnknownProtocol {
var languageException: IUnknown { get throws }
var languageException: IUnknown? { get throws }
}

import WindowsRuntime_ABI

public enum ILanguageExceptionErrorInfoProjection: COMTwoWayProjection {
public enum ILanguageExceptionErrorInfoProjection: COMProjection {
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo
public typealias SwiftObject = ILanguageExceptionErrorInfo

public static var interfaceID: COMInterfaceID { uuidof(ABIStruct.self) }
public static var virtualTablePointer: UnsafeRawPointer { .init(withUnsafePointer(to: &virtualTable) { $0 }) }

public static func _wrap(_ reference: consuming ABIReference) -> SwiftObject {
Import(_wrapping: reference)
Expand All @@ -23,30 +23,21 @@ public enum ILanguageExceptionErrorInfoProjection: COMTwoWayProjection {
}

private final class Import: COMImport<ILanguageExceptionErrorInfoProjection>, ILanguageExceptionErrorInfoProtocol {
var languageException: IUnknown {
get throws { try NullResult.unwrap(_interop.getLanguageException()) }
var languageException: IUnknown? {
get throws { try _interop.getLanguageException() }
}
}

private static var virtualTable: WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo_VirtualTable = .init(
QueryInterface: { IUnknownVirtualTable.QueryInterface($0, $1, $2) },
AddRef: { IUnknownVirtualTable.AddRef($0) },
Release: { IUnknownVirtualTable.Release($0) },
GetLanguageException: { this, languageException in _implement(this) { this in
guard let languageException else { throw COMError.fail }
languageException.pointee = try IUnknownProjection.toABI(this.languageException)
} })
}

public func uuidof(_: WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo.Type) -> COMInterfaceID {
.init(0x04a2dbf3, 0xdf83, 0x116c, 0x0946, 0x0812abf6e07d)
}

extension COMInterop where ABIStruct == WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo {
public func getLanguageException() throws -> IUnknown {
public func getLanguageException() throws -> IUnknown? {
var result: IUnknownPointer? = nil // IUnknownProjection.abiDefaultValue (compiler bug?)
defer { IUnknownProjection.release(&result) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.GetLanguageException(this, &result))
return try NullResult.unwrap(IUnknownProjection.toSwift(consuming: &result))
return IUnknownProjection.toSwift(consuming: &result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import COM

/// Enables language projections to provide and retrieve error information as with ILanguageExceptionErrorInfo,
/// with the additional benefit of working across language boundaries.
public typealias ILanguageExceptionErrorInfo2 = any ILanguageExceptionErrorInfo2Protocol
public protocol ILanguageExceptionErrorInfo2Protocol: ILanguageExceptionErrorInfoProtocol {
var previousLanguageExceptionErrorInfo: ILanguageExceptionErrorInfo2? { get throws }
func capturePropagationContext(_ languageException: IUnknown?) throws
var propagationContextHead: ILanguageExceptionErrorInfo2? { get throws }
}

import WindowsRuntime_ABI

public enum ILanguageExceptionErrorInfo2Projection: COMProjection {
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo2
public typealias SwiftObject = ILanguageExceptionErrorInfo2

public static var interfaceID: COMInterfaceID { uuidof(ABIStruct.self) }

public static func _wrap(_ reference: consuming ABIReference) -> SwiftObject {
Import(_wrapping: reference)
}

public static func toCOM(_ object: SwiftObject) throws -> ABIReference {
try Import.toCOM(object)
}

private final class Import: COMImport<ILanguageExceptionErrorInfo2Projection>, ILanguageExceptionErrorInfo2Protocol {
var languageException: IUnknown? {
get throws { try _interop.getLanguageException() }
}

var previousLanguageExceptionErrorInfo: ILanguageExceptionErrorInfo2? {
get throws { try _interop.getPreviousLanguageExceptionErrorInfo() }
}

func capturePropagationContext(_ languageException: IUnknown?) throws {
try _interop.capturePropagationContext(languageException)
}

var propagationContextHead: ILanguageExceptionErrorInfo2? {
get throws { try _interop.getPropagationContextHead() }
}
}
}

public func uuidof(_: WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo2.Type) -> COMInterfaceID {
.init(0x5746E5C4, 0x5B97, 0x424C, 0xB620, 0x2822915734DD)
}

extension COMInterop where ABIStruct == WindowsRuntime_ABI.SWRT_ILanguageExceptionErrorInfo2 {
public func getLanguageException() throws -> IUnknown? {
var result: IUnknownPointer? = nil // IUnknownProjection.abiDefaultValue (compiler bug?)
defer { IUnknownProjection.release(&result) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.GetLanguageException(this, &result))
return IUnknownProjection.toSwift(consuming: &result)
}

public func getPreviousLanguageExceptionErrorInfo() throws -> ILanguageExceptionErrorInfo2? {
var result = ILanguageExceptionErrorInfo2Projection.abiDefaultValue
defer { ILanguageExceptionErrorInfo2Projection.release(&result) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.GetPreviousLanguageExceptionErrorInfo(this, &result))
return ILanguageExceptionErrorInfo2Projection.toSwift(consuming: &result)
}

public func capturePropagationContext(_ languageException: IUnknown?) throws {
var languageException = try IUnknownProjection.toABI(languageException)
defer { IUnknownProjection.release(&languageException) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.CapturePropagationContext(this, languageException))
}

public func getPropagationContextHead() throws -> ILanguageExceptionErrorInfo2? {
var result = ILanguageExceptionErrorInfo2Projection.abiDefaultValue
defer { ILanguageExceptionErrorInfo2Projection.release(&result) }
try COMError.fromABI(this.pointee.VirtualTable.pointee.GetPropagationContextHead(this, &result))
return ILanguageExceptionErrorInfo2Projection.toSwift(consuming: &result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ public protocol IRestrictedErrorInfoProtocol: IUnknownProtocol {

import WindowsRuntime_ABI

public enum IRestrictedErrorInfoProjection: COMTwoWayProjection {
public enum IRestrictedErrorInfoProjection: COMProjection {
public typealias SwiftObject = IRestrictedErrorInfo
public typealias ABIStruct = WindowsRuntime_ABI.SWRT_IRestrictedErrorInfo

public static var interfaceID: COMInterfaceID { uuidof(ABIStruct.self) }
public static var virtualTablePointer: UnsafeRawPointer { .init(withUnsafePointer(to: &virtualTable) { $0 }) }

public static func _wrap(_ reference: consuming ABIReference) -> SwiftObject {
Import(_wrapping: reference)
Expand All @@ -38,28 +37,6 @@ public enum IRestrictedErrorInfoProjection: COMTwoWayProjection {

public var reference: String? { get throws { try _interop.getReference() } }
}

private static var virtualTable: WindowsRuntime_ABI.SWRT_IRestrictedErrorInfo_VirtualTable = .init(
QueryInterface: { IUnknownVirtualTable.QueryInterface($0, $1, $2) },
AddRef: { IUnknownVirtualTable.AddRef($0) },
Release: { IUnknownVirtualTable.Release($0) },
GetErrorDetails: { this, description, error, restrictedDescription, capabilitySid in _implement(this) {
var description_: String? = nil
var error_: HResult = .ok
var restrictedDescription_: String? = nil
var capabilitySid_: String? = nil
try $0.getErrorDetails(description: &description_, error: &error_, restrictedDescription: &restrictedDescription_, capabilitySid: &capabilitySid_)
var _success = false
if let description { description.pointee = try BStrProjection.toABI(description_) }
defer { if !_success, let description { BStrProjection.release(&description.pointee) } }
if let error { error.pointee = HResultProjection.toABI(error_) }
if let restrictedDescription { restrictedDescription.pointee = try BStrProjection.toABI(restrictedDescription_) }
defer { if !_success, let restrictedDescription { BStrProjection.release(&restrictedDescription.pointee) } }
if let capabilitySid { capabilitySid.pointee = try BStrProjection.toABI(capabilitySid_) }
defer { if !_success, let capabilitySid { BStrProjection.release(&capabilitySid.pointee) } }
_success = true
} },
GetReference: { this, reference in _implement(this) { try _set(reference, BStrProjection.toABI($0.reference)) } })
}

public func uuidof(_: WindowsRuntime_ABI.SWRT_IRestrictedErrorInfo.Type) -> COMInterfaceID {
Expand Down
27 changes: 24 additions & 3 deletions Support/Sources/WindowsRuntime/WinRTError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,34 @@ public struct WinRTError: COMErrorProtocol, CustomStringConvertible {
public static func fromABI(captureErrorInfo: Bool = true, _ hresult: WindowsRuntime_ABI.SWRT_HResult) throws -> HResult {
let hresult = HResult(hresult)
guard hresult.isFailure else { return hresult }

// Check for an associated IRestrictedErrorInfo
guard captureErrorInfo, let restrictedErrorInfo = try? Self.getRestrictedErrorInfo(matching: hresult) else {
throw WinRTError(hresult: hresult)
}

if let languageExceptionErrorInfo = try? restrictedErrorInfo.queryInterface(ILanguageExceptionErrorInfoProjection.self),
let languageException = try? languageExceptionErrorInfo.languageException as? LanguageException {
throw languageException.error
// Ensure we didn't get a stale IRestrictedErrorInfo
var description: String? = nil
var error: HResult = .ok
var restrictedDescription: String? = nil
var capabilitySid: String? = nil
try? restrictedErrorInfo.getErrorDetails(
description: &description, error: &error,
restrictedDescription: &restrictedDescription, capabilitySid: &capabilitySid)
guard error == hresult else { throw WinRTError(hresult: hresult) }

// Append to the propagation context, if available.
// See https://learn.microsoft.com/en-us/windows/win32/api/restrictederrorinfo/nf-restrictederrorinfo-ilanguageexceptionerrorinfo2-capturepropagationcontext
if let languageExceptionErrorInfo = try? restrictedErrorInfo.queryInterface(ILanguageExceptionErrorInfoProjection.self) {
let languageException = try? languageExceptionErrorInfo.languageException as? LanguageException

if let languageExceptionErrorInfo2 = try? languageExceptionErrorInfo.queryInterface(ILanguageExceptionErrorInfo2Projection.self) {
try languageExceptionErrorInfo2.capturePropagationContext(nil) // No new language exception to provide
}

if let languageException {
throw languageException.error
}
}

throw WinRTError(hresult: hresult, errorInfo: restrictedErrorInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ struct SWRT_ILanguageExceptionErrorInfo_VirtualTable {
SWRT_HResult (__stdcall *GetLanguageException)(SWRT_ILanguageExceptionErrorInfo* _this, SWRT_IUnknown** languageException);
};

typedef struct SWRT_ILanguageExceptionErrorInfo2 {
struct SWRT_ILanguageExceptionErrorInfo2_VirtualTable* VirtualTable;
} SWRT_ILanguageExceptionErrorInfo2;

struct SWRT_ILanguageExceptionErrorInfo2_VirtualTable {
SWRT_HResult (__stdcall *QueryInterface)(SWRT_ILanguageExceptionErrorInfo2* _this, SWRT_Guid* riid, void** ppvObject);
uint32_t (__stdcall *AddRef)(SWRT_ILanguageExceptionErrorInfo2* _this);
uint32_t (__stdcall *Release)(SWRT_ILanguageExceptionErrorInfo2* _this);
SWRT_HResult (__stdcall *GetLanguageException)(SWRT_ILanguageExceptionErrorInfo2* _this, SWRT_IUnknown** languageException);
SWRT_HResult (__stdcall *GetPreviousLanguageExceptionErrorInfo)(SWRT_ILanguageExceptionErrorInfo2* _this, SWRT_ILanguageExceptionErrorInfo2** previousLanguageExceptionErrorInfo);
SWRT_HResult (__stdcall *CapturePropagationContext)(SWRT_ILanguageExceptionErrorInfo2* _this, SWRT_IUnknown* languageException);
SWRT_HResult (__stdcall *GetPropagationContextHead)(SWRT_ILanguageExceptionErrorInfo2* _this, SWRT_ILanguageExceptionErrorInfo2** propagatedLanguageExceptionErrorInfoHead);
};

typedef struct SWRT_IRestrictedErrorInfo {
struct SWRT_IRestrictedErrorInfo_VirtualTable* VirtualTable;
} SWRT_IRestrictedErrorInfo;
Expand Down

0 comments on commit b871cf3

Please sign in to comment.