Skip to content

Commit

Permalink
Use static-abstracts + DIMs to remove the extra shape validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky committed Aug 25, 2022
1 parent 7182f4b commit 9ca69dd
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 62 deletions.
37 changes: 24 additions & 13 deletions docs/design/libraries/ComInterfaceGenerator/VTableStubs.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class VirtualMethodIndexAttribute : Attribute

```

A new interface will be defined and used by the source generator to fetch the native `this` pointer and the vtable that the function pointer is stored in. This interface is designed to provide an API that various native platforms, like COM, WinRT, or Swift, could use to provide support for multiple managed interface wrappers from a single native object. In particular, this interface was designed to ensure it is possible support a managed gesture to do an unmanaged "type cast" (i.e., `QueryInterface` in the COM and WinRT worlds).
New interfaces will be defined and used by the source generator to fetch the native `this` pointer and the vtable that the function pointer is stored in. These interfaces are designed to provide an API that various native platforms, like COM, WinRT, or Swift, could use to provide support for multiple managed interface wrappers from a single native object. In particular, these interfaces are designed to ensure it is possible support a managed gesture to do an unmanaged "type cast" (i.e., `QueryInterface` in the COM and WinRT worlds).

```csharp
namespace System.Runtime.InteropServices;
Expand All @@ -82,13 +82,24 @@ public readonly ref struct VirtualMethodTableInfo

public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
{
VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);

public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
{
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
}
}

public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
{
public abstract static T TypeKey { get; }
}
```

## Required API Shapes

In addition to the provided APIs above, users will be required to add a `readonly static` field or `get`-able property to their user-defined interface type named `TypeKey`. The type of this member will be used as the `T` in `IUnmanagedVirtualMethodTableProvider<T>` and the value will be passed to `GetVirtualMethodTableInfoForKey`. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.
The user will be required to implement `IUnmanagedVirtualMethodTableProvider<T>` on the type that provides the method tables, and `IUnmanagedInterfaceType<T>` on the type that defines the unmanaged interface. The `T` types must match between the two interfaces. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.

## Example Usage

Expand Down Expand Up @@ -149,11 +160,11 @@ using System.Runtime.InteropServices;
[assembly:DisableRuntimeMarshalling]
// Define the interface of the native API
partial interface INativeAPI
partial interface INativeAPI : IUnmanagedInterfaceType<NoCasting>
{
// There is no concept of casting for this API, but providing a type key is still required by the generator.
// Use an empty readonly record struct to provide a type that implements IEquatable<T> but contains no data.
readonly static NoCasting TypeKey = default;
static NoCasting IUnmanagedInterfaceType.TypeKey => default;
[VirtualMethodIndex(0, ImplicitThisParameter = false, Direction = CustomTypeMarshallerDirection.In)]
int GetVersion();
Expand Down Expand Up @@ -218,7 +229,7 @@ partial interface INativeAPI
{
int INativeAPI.GetVersion()
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int>)vtable[0])();
return retVal;
Expand All @@ -231,7 +242,7 @@ partial interface INativeAPI
{
int INativeAPI.Add(int x, int y)
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int, int, int>)vtable[1])(x, y);
return retVal;
Expand All @@ -244,7 +255,7 @@ partial interface INativeAPI
{
int INativeAPI.Multiply(int x, int y)
{
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey(INativeAPI.TypeKey);
var (_, vtable) = ((IUnmanagedVirtualMethodTableProvider<NoCasting>)this).GetVirtualMethodTableInfoForKey<INativeAPI>();
int retVal;
retVal = ((delegate* unmanaged<int, int, int>)vtable[2])(x, y);
return retVal;
Expand Down Expand Up @@ -279,9 +290,9 @@ struct IUnknown
using System;
using System.Runtime.InteropServices;

interface IUnknown
interface IUnknown: IUnmanagedInterfaceType<Guid>
{
public static readonly Guid TypeKey = Guid.Parse("00000000-0000-0000-C000-000000000046");
static Guid IUnmanagedTypeInterfaceType<Guid>.TypeKey => Guid.Parse("00000000-0000-0000-C000-000000000046");

[UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall), typeof(CallConvMemberFunction) })]
[VirtualMethodIndex(0)]
Expand Down Expand Up @@ -347,7 +358,7 @@ partial interface IUnknown
{
int IUnknown.QueryInterface(in Guid riid, out IntPtr ppvObject)
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
int retVal;
fixed (Guid* riid__gen_native = &riid)
fixed (IntPtr* ppvObject__gen_native = &ppvObject)
Expand All @@ -364,7 +375,7 @@ partial interface IUnknown
{
uint IUnknown.AddRef()
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
uint retVal;
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[1])(thisPtr);
return retVal;
Expand All @@ -377,7 +388,7 @@ partial interface IUnknown
{
uint IUnknown.Release()
{
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey(IUnknown.TypeKey);
var (thisPtr, vtable) = ((IUnmanagedVirtualMethodTableProvider<Guid>)this).GetVirtualMethodTableInfoForKey<IUnknown>();
uint retVal;
retVal = ((delegate* unmanaged[Stdcall, MemberFunction]<IntPtr, uint>)vtable[2])(thisPtr);
return retVal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
{
var setupStatements = new List<StatementSyntax>
{
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey(<containingTypeName>.TypeKey)
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
Expand All @@ -141,15 +141,12 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
TypeArgumentList(
SingletonSeparatedList(typeKeyType.Syntax))),
ThisExpression())),
IdentifierName("GetVirtualMethodTableInfoForKey")))
GenericName(
Identifier("GetVirtualMethodTableInfoForKey"),
TypeArgumentList(
SingletonSeparatedList(containingTypeName)))))
.WithArgumentList(
ArgumentList(
SingletonSeparatedList(
Argument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
containingTypeName,
IdentifierName("TypeKey"))))))))
ArgumentList())))
};

GeneratedStatements statements = GeneratedStatements.Create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
INamedTypeSymbol iUnmanagedInterfaceTypeType = environment.Compilation.GetTypeByMetadataName(TypeNames.IUnmanagedInterfaceType_Metadata)!;
// Get any attributes of interest on the method
AttributeData? virtualMethodIndexAttr = null;
AttributeData? lcidConversionAttr = null;
Expand Down Expand Up @@ -310,14 +311,14 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
var typeKeyOwner = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
ManagedTypeInfo typeKeyType = SpecialTypeInfo.Byte;

IFieldSymbol? typeKeyField = symbol.ContainingType.GetMembers("TypeKey").OfType<IFieldSymbol>().FirstOrDefault(f => f.IsStatic);
if (typeKeyField is null)
INamedTypeSymbol? iUnmanagedInterfaceTypeInstantiation = symbol.ContainingType.AllInterfaces.FirstOrDefault(iface => SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, iUnmanagedInterfaceTypeType));
if (iUnmanagedInterfaceTypeInstantiation is null)
{
// Report invalid configuration
}
else
{
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(typeKeyField.Type);
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iUnmanagedInterfaceTypeInstantiation.TypeArguments[0]);
}

return new IncrementalStubGenerationContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public static class TypeNames

public const string IUnmanagedVirtualMethodTableProvider = "System.Runtime.InteropServices.IUnmanagedVirtualMethodTableProvider";

public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`1";

public const string System_Span_Metadata = "System.Span`1";
public const string System_Span = "System.Span";
public const string System_ReadOnlySpan_Metadata = "System.ReadOnlySpan`1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ public void Deconstruct(out IntPtr thisPointer, out ReadOnlySpan<IntPtr> virtual

public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
{
VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);

public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
{
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
}
}


public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
{
public abstract static T TypeKey { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ internal partial class ImplicitThis
{
public readonly record struct NoCasting;

internal partial interface INativeObject
internal partial interface INativeObject : IUnmanagedInterfaceType<NoCasting>
{
public static readonly NoCasting TypeKey = default;
static NoCasting IUnmanagedInterfaceType<NoCasting>.TypeKey => default;

[VirtualMethodIndex(0, ImplicitThisParameter = true)]
int GetData();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ internal partial class NoImplicitThis
{
public readonly record struct NoCasting;

internal partial interface IStaticMethodTable
internal partial interface IStaticMethodTable : IUnmanagedInterfaceType<NoCasting>
{
public static readonly NoCasting TypeKey = default;
static NoCasting IUnmanagedInterfaceType<NoCasting>.TypeKey => default;

[VirtualMethodIndex(0, ImplicitThisParameter = false)]
int Add(int x, int y);
Expand Down
Loading

0 comments on commit 9ca69dd

Please sign in to comment.