Skip to content

Commit

Permalink
Fix error pattern to capture and set the COM error info when crossing…
Browse files Browse the repository at this point in the history
… ABI boundaries (#250)

Rename `COMError` to `ErrorWithHResult` and `HResult.Error` to
`COMError`.
Rename `throwIfFailed` and `catch` to `fromABI` and `toABI`, the former
capturing the error info and the latter setting it back.
Added `COMError` and `WinRTError` initializers taking a message and
constructing an error info internally.
  • Loading branch information
tristanlabelle committed Sep 3, 2024
1 parent db7cd44 commit 571f96f
Show file tree
Hide file tree
Showing 54 changed files with 601 additions and 244 deletions.
2 changes: 1 addition & 1 deletion Docs/How it works.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ extension COMInterop when T == SWRT_IFoo {
func getName() throws -> String {
var name: BSTR? = nil
defer { BStrProjection.release(&name) }
try HResult.throwIfFailed(pointer.pointee.vtable.pointee.GetName(pointer, &name))
try COMError.fromABI(pointer.pointee.vtable.pointee.GetName(pointer, &name))
return BStrProjection.toSwift(name)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ internal class SequenceIterator<I: IteratorProtocol>: WinRTPrimaryExport<IInspec
func _hasCurrent() throws -> Bool { current != nil }

func _current() throws -> T {
guard let current else { throw HResult.Error.illegalMethodCall }
guard let current else { throw COMError.illegalMethodCall }
return current
}

Expand All @@ -47,6 +47,6 @@ internal class SequenceIterator<I: IteratorProtocol>: WinRTPrimaryExport<IInspec
}

func getMany(_ items: [I.Element]) throws -> UInt32 {
throw HResult.Error.notImpl // TODO(#31): Implement out arrays
throw COMError.notImpl // TODO(#31): Implement out arrays
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ fileprivate class CollectionVectorView<C: Collection>: WinRTPrimaryExport<IInspe
}

public func getMany(_ startIndex: UInt32, _ items: [C.Element]) throws -> UInt32 {
throw HResult.Error.notImpl // TODO(#31): Implement out arrays
throw COMError.notImpl // TODO(#31): Implement out arrays
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class ArrayVector<T>: WinRTPrimaryExport<IInspectableProjection>,
}

public func getMany(_ startIndex: UInt32, _ items: [T]) throws -> UInt32 {
throw HResult.Error.notImpl // TODO(#31): Implement out arrays
throw COMError.notImpl // TODO(#31): Implement out arrays
}

public func replaceAll(_ items: [T]) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extension WindowsFoundation_IAsyncActionWithProgressProtocol {
public func get() async throws {
if try _status() == .started {
// We can't await if the completed handler is already set
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.HResult.Error.illegalMethodCall }
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.COMError.illegalMethodCall }
let awaiter = WindowsRuntime.AsyncAwaiter()
try _completed({ _, _ in _Concurrency.Task { await awaiter.signal() } })
await awaiter.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extension WindowsFoundation_IAsyncActionProtocol {
public func get() async throws {
if try _status() == .started {
// We can't await if the completed handler is already set
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.HResult.Error.illegalMethodCall }
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.COMError.illegalMethodCall }
let awaiter = WindowsRuntime.AsyncAwaiter()
try _completed({ _, _ in _Concurrency.Task { await awaiter.signal() } })
await awaiter.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extension WindowsFoundation_IAsyncOperationWithProgressProtocol {
public func get() async throws -> TResult {
if try _status() == .started {
// We can't await if the completed handler is already set
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.HResult.Error.illegalMethodCall }
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.COMError.illegalMethodCall }
let awaiter = WindowsRuntime.AsyncAwaiter()
try _completed({ _, _ in _Concurrency.Task { await awaiter.signal() } })
await awaiter.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extension WindowsFoundation_IAsyncOperationProtocol {
public func get() async throws -> TResult {
if try _status() == .started {
// We can't await if the completed handler is already set
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.HResult.Error.illegalMethodCall }
guard try COM.NullResult.catch(_completed()) == nil else { throw COM.COMError.illegalMethodCall }
let awaiter = WindowsRuntime.AsyncAwaiter()
try _completed({ _, _ in _Concurrency.Task { await awaiter.signal() } })
await awaiter.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import WindowsRuntime
extension Array where Element == UInt8 {
public init(_ buffer: WindowsFoundation_IMemoryBuffer) throws {
let reference = try buffer.createReference()
guard let bufferPointer = try reference.bytes else { throw HResult.Error.fail }
guard let bufferPointer = try reference.bytes else { throw COMError.fail }
self.init(bufferPointer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extension WindowsFoundation_MemoryBuffer {
public convenience init(_ bytes: [UInt8]) throws {
try self.init(UInt32(bytes.count))
let reference = try self.createReference()
guard let bufferPointer = try reference.bytes else { throw HResult.Error.fail }
guard let bufferPointer = try reference.bytes else { throw COMError.fail }
_ = bufferPointer.update(fromContentsOf: bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fileprivate func writeSwiftToABICall(
}

func writeCall() throws {
writer.writeStatement("try WinRTError.throwIfFailed("
writer.writeStatement("try WinRTError.fromABI("
+ "this.pointee.VirtualTable.pointee.\(abiMethodName)("
+ "\(abiArgs.joined(separator: ", "))))")
}
Expand Down Expand Up @@ -181,7 +181,7 @@ fileprivate func writeSwiftToABICall(
let returnValue: String
switch returnTypeProjection.kind {
case .identity where returnCOMReference:
writer.writeStatement("guard let \(returnParam.name) else { throw HResult.Error.pointer }")
writer.writeStatement("guard let \(returnParam.name) else { throw COMError.pointer }")
returnValue = "\(SupportModules.COM.comReference)(transferringRef: \(returnParam.name))"
case .identity where !returnCOMReference:
returnValue = returnParam.name
Expand Down
4 changes: 2 additions & 2 deletions Generator/Sources/SwiftWinRT/Writing/VirtualTable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ fileprivate func writeVirtualTableFunc(
// Ensure non-optional by reference params are non-null pointers
for param in params {
guard case .reference(in: _, out: _, optional: false) = param.passBy else { continue }
output.writeFullLine("guard let \(param.name) else { throw COM.HResult.Error.pointer }")
output.writeFullLine("guard let \(param.name) else { throw COM.COMError.pointer }")
}
if let returnParam {
output.writeFullLine("guard let \(returnParam.name) else { throw COM.HResult.Error.pointer }")
output.writeFullLine("guard let \(returnParam.name) else { throw COM.COMError.pointer }")
}

// Declare the Swift representation of params
Expand Down
6 changes: 3 additions & 3 deletions InteropTests/Tests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class AsyncTests : XCTestCase {
let _ = try await asyncOperation.get()
XCTFail("Expected an exception to be thrown")
}
catch let error as COMError {
catch let error as COMErrorProtocol {
XCTAssertEqual(error.hresult, HResult.outOfMemory)
}
}
Expand All @@ -43,7 +43,7 @@ class AsyncTests : XCTestCase {
let _ = try await asyncOperation.get()
XCTFail("Expected an exception to be thrown")
}
catch let error as COMError {
catch let error as COMErrorProtocol {
XCTAssertEqual(try asyncOperation._status(), .error)
XCTAssertEqual(error.hresult, HResult.outOfMemory)
}
Expand All @@ -56,7 +56,7 @@ class AsyncTests : XCTestCase {
let _ = try await asyncOperation.get()
XCTFail("Expected an exception to be thrown")
}
catch let error as COMError {
catch let error as COMErrorProtocol {
XCTAssertEqual(try asyncOperation._status(), .started)
XCTAssertEqual(error.hresult, HResult.illegalMethodCall)
}
Expand Down
29 changes: 22 additions & 7 deletions InteropTests/Tests/ErrorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ErrorTests: WinRTTestCase {
do {
try Errors.failWith(hresult, "")
XCTFail("Expected an error")
} catch let error as COMError {
} catch let error as COMErrorProtocol {
XCTAssertEqual(error.hresult, hresult)
}
}
Expand All @@ -31,14 +31,29 @@ class ErrorTests: WinRTTestCase {
}

func testThrowWithHResult() throws {
let hresult = HResult(unsigned: 0xCAFEBABE)
let error = try XCTUnwrap(HResult.Error(hresult: hresult))
XCTAssertEqual(
try Errors.catchHResult { throw error },
hresult)
struct TestError: ErrorWithHResult {
public var hresult: HResult { .init(unsigned: 0xCAFEBABE) }
}
let error = TestError()
XCTAssertEqual(try Errors.catchHResult { throw error }, error.hresult)
}

func testThrowWithMessage() throws {
throw XCTSkip("Not implemented: RoOriginateError")
struct TestError: Error, CustomStringConvertible {
public var description: String { "test" }
}
let error = TestError()
XCTAssertEqual(try Errors.catchMessage { throw error }, error.description)
}

func testSwiftErrorPreserved() throws {
try XCTSkipIf(true, "TODO(#248): Fix preserving Swift error objects across the WinRT boundary.")

struct SwiftError: Error {}
do {
try Errors.call { throw SwiftError() }
XCTFail("Expected an error")
}
catch _ as SwiftError {} // Success
}
}
4 changes: 4 additions & 0 deletions InteropTests/WinRTComponent/Errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace winrt::WinRTComponent::implementation
{
throw winrt::hresult_not_implemented();
}
void Errors::Call(winrt::WinRTComponent::MinimalDelegate const& callee)
{
callee();
}
winrt::hresult Errors::CatchHResult(winrt::WinRTComponent::MinimalDelegate const& callee)
{
try { callee(); }
Expand Down
1 change: 1 addition & 0 deletions InteropTests/WinRTComponent/Errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace winrt::WinRTComponent::implementation
static void FailWith(winrt::hresult const& hr, winrt::hstring const& message);
static hstring NotImplementedProperty();
static void NotImplementedProperty(hstring const& value);
static void Call(winrt::WinRTComponent::MinimalDelegate const& callee);
static winrt::hresult CatchHResult(winrt::WinRTComponent::MinimalDelegate const& callee);
static winrt::hstring CatchMessage(winrt::WinRTComponent::MinimalDelegate const& callee);
};
Expand Down
1 change: 1 addition & 0 deletions InteropTests/WinRTComponent/Errors.idl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace WinRTComponent
{
static void FailWith(Windows.Foundation.HResult hr, String message);
static String NotImplementedProperty;
static void Call(MinimalDelegate callee);
static Windows.Foundation.HResult CatchHResult(MinimalDelegate callee);
static String CatchMessage(MinimalDelegate callee);
};
Expand Down
6 changes: 4 additions & 2 deletions Support/Sources/COM/COMEmbedding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ public enum IUnknownVirtualTable {
_ this: UnsafeMutablePointer<ABIStruct>?,
_ iid: UnsafePointer<COM_ABI.SWRT_Guid>?,
_ ppvObject: UnsafeMutablePointer<UnsafeMutableRawPointer?>?) -> COM_ABI.SWRT_HResult {
guard let this, let iid, let ppvObject else { return HResult.invalidArg.value }
guard let this, let iid, let ppvObject else { return COMError.toABI(hresult: HResult.invalidArg) }
ppvObject.pointee = nil

return HResult.catchValue {
// Avoid setting the error info upon failure since QueryInterface is called
// by RoOriginateError, which is trying to set the error info itself.
return COMError.toABI(setErrorInfo: false) {
let id = GUIDProjection.toSwift(iid.pointee)
let this = IUnknownPointer(OpaquePointer(this))
let reference = id == uuidof(SWRT_SwiftCOMObject.self)
Expand Down
115 changes: 113 additions & 2 deletions Support/Sources/COM/COMError.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,115 @@
/// Protocol for errors which result from a COM error HRESULT.
public protocol COMError: Error {
import COM_ABI

/// Protocol for errors with an associated HResult value.
public protocol ErrorWithHResult: Error {
var hresult: HResult { get }
}

public protocol COMErrorProtocol: ErrorWithHResult {
/// Gets the error info for
var errorInfo: IErrorInfo? { get }

/// Converts this COM error to its ABI representation, including thread-local COM error information.
func toABI(setErrorInfo: Bool) -> HResult.Value
}

extension COMErrorProtocol {
/// Converts this COM error to its ABI representation, including thread-local COM error information.
public func toABI() -> HResult.Value { toABI(setErrorInfo: true) }
}

/// Captures a failure from a COM API invocation (HRESULT + optional IErrorInfo).
public struct COMError: COMErrorProtocol, CustomStringConvertible {
public static let fail = Self(hresult: HResult.fail)
public static let illegalMethodCall = Self(hresult: HResult.illegalMethodCall)
public static let invalidArg = Self(hresult: HResult.invalidArg)
public static let notImpl = Self(hresult: HResult.notImpl)
public static let noInterface = Self(hresult: HResult.noInterface)
public static let pointer = Self(hresult: HResult.pointer)
public static let outOfMemory = Self(hresult: HResult.outOfMemory)

public let hresult: HResult // Invariant: isFailure
public let errorInfo: IErrorInfo?

public init(hresult: HResult, errorInfo: IErrorInfo? = nil) {
assert(hresult.isFailure)
self.hresult = hresult
self.errorInfo = errorInfo
}

public init(hresult: HResult, description: String?) {
self.init(hresult: hresult, errorInfo: description.map { DescriptiveErrorInfo(description: $0) })
}

public var description: String {
if let errorInfo, let description = try? errorInfo.description { return description }
return hresult.description
}

public func toABI(setErrorInfo: Bool = true) -> HResult.Value {
if setErrorInfo { try? Self.setErrorInfo(errorInfo) }
return hresult.value
}

/// Throws any failure HRESULTs as COMErrors, optionally capturing the COM thread error info.
@discardableResult
public static func fromABI(captureErrorInfo: Bool = true, _ hresult: HResult.Value) throws -> HResult {
let hresult = HResult(hresult)
guard hresult.isFailure else { return hresult }
guard captureErrorInfo else { throw COMError(hresult: hresult) }

let errorInfo = try? Self.getErrorInfo()
if let swiftErrorInfo = errorInfo as? SwiftErrorInfo, swiftErrorInfo.hresult == hresult {
// This was originally a Swift error, throw it as such.
throw swiftErrorInfo.error
}

throw COMError(hresult: hresult, errorInfo: errorInfo)
}

/// Catches any thrown errors from a provided closure, converting it to an HRESULT and optionally setting the COM thread error info state.
public static func toABI(setErrorInfo: Bool = true, _ body: () throws -> Void) -> HResult.Value {
do { try body() }
catch { return toABI(error: error, setErrorInfo: setErrorInfo) }
return HResult.ok.value
}

public static func toABI(error: Error, setErrorInfo: Bool = true) -> HResult.Value {
// If the error already came from COM/WinRT, propagate it
if let comError = error as? any COMErrorProtocol { return comError.toABI(setErrorInfo: setErrorInfo) }

// Otherwise, create a new error info and set it
return SwiftErrorInfo(error: error).toABI(setErrorInfo: setErrorInfo)
}

public static func toABI(hresult: HResult, description: String? = nil) -> HResult.Value {
guard hresult.isFailure else { return hresult.value }
try? Self.setErrorInfo(description.map { DescriptiveErrorInfo(description: $0) })
return hresult.value
}

public static func getErrorInfo() throws -> IErrorInfo? {
var errorInfo: UnsafeMutablePointer<SWRT_IErrorInfo>?
defer { IErrorInfoProjection.release(&errorInfo) }
try fromABI(captureErrorInfo: false, COM_ABI.SWRT_GetErrorInfo(/* dwReserved: */ 0, &errorInfo))
return IErrorInfoProjection.toSwift(consuming: &errorInfo)
}

public static func setErrorInfo(_ errorInfo: IErrorInfo?) throws {
var errorInfo = try IErrorInfoProjection.toABI(errorInfo)
defer { IErrorInfoProjection.release(&errorInfo) }
try fromABI(captureErrorInfo: false, COM_ABI.SWRT_SetErrorInfo(/* dwReserved: */ 0, errorInfo))
}

private final class DescriptiveErrorInfo: COMPrimaryExport<IErrorInfoProjection>, IErrorInfoProtocol {
private let _description: String
public init(description: String) { self._description = description }

// IErrorInfo
public var guid: GUID { get throws { throw COMError.fail } }
public var source: String? { get throws { throw COMError.fail } }
public var description: String? { self._description }
public var helpFile: String? { get throws { throw COMError.fail } }
public var helpContext: UInt32 { get throws { throw COMError.fail } }
}
}
5 changes: 3 additions & 2 deletions Support/Sources/COM/COMInterop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ public struct COMInterop<ABIStruct> {
public func queryInterface(_ id: COMInterfaceID) throws -> IUnknownReference {
var iid = GUIDProjection.toABI(id)
var rawPointer: UnsafeMutableRawPointer? = nil
try HResult.throwIfFailed(unknown.pointee.VirtualTable.pointee.QueryInterface(unknown, &iid, &rawPointer))
// Avoid calling GetErrorInfo since RoOriginateError causes QueryInterface calls
try COMError.fromABI(captureErrorInfo: false, unknown.pointee.VirtualTable.pointee.QueryInterface(unknown, &iid, &rawPointer))
guard let rawPointer else {
assertionFailure("QueryInterface succeeded but returned a null pointer")
throw HResult.Error.noInterface
throw COMError.noInterface
}

let pointer = rawPointer.bindMemory(to: COM_ABI.SWRT_IUnknown.self, capacity: 1)
Expand Down
2 changes: 1 addition & 1 deletion Support/Sources/COM/COMPrimaryExport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ open class COMPrimaryExport<Projection: COMTwoWayProjection>: COMExportBase<Proj
if let interface = Self.implements.first(where: { $0.id == id }) {
return interface.createCOM(identity: self)
}
throw HResult.Error.noInterface
throw COMError.noInterface
}
}
}
Expand Down
Loading

0 comments on commit 571f96f

Please sign in to comment.