Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions InteropGenerator.Tests/Generator/GenerateInteropAttributeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,41 @@ await VerifyIG.VerifyGeneratorAsync(
code,
("TestStruct.InnerStruct.InteropGenerator.g.cs", result));
}

[Fact]
public async Task StructWithGeneric() {
const string code = """
[global::System.Runtime.InteropServices.StructLayout(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size=4)]
[GenerateInterop]
public partial struct TestStruct<T> where T : unmanaged
{
[VirtualFunction(0)]
public partial void TestFunction();
}
""";

const string result = """
// <auto-generated/>
unsafe partial struct TestStruct<T>
{
public const int StructSize = 4;
[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]
public unsafe partial struct TestStructVirtualTable
{
[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public delegate* unmanaged <TestStruct<T>*, void> TestFunction;
}
[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public TestStructVirtualTable* VirtualTable;
public static partial class Delegates
{
public delegate void TestFunction(TestStruct<T>* thisPtr);
}
[global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
public partial void TestFunction() => VirtualTable->TestFunction((TestStruct<T>*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this));
}
""";

await VerifyIG.VerifyGeneratorAsync(
code,
("TestStruct.InteropGenerator.g.cs", result));
}
}
68 changes: 68 additions & 0 deletions InteropGenerator.Tests/Generator/InheritsAttributeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2968,4 +2968,72 @@ await VerifyIG.VerifyGeneratorAsync(
code,
("ChildStruct.Inheritance.InteropGenerator.g.cs", childStructInheritanceCode));
}

[Fact]
public async Task VirtualFunctionGenericInheritance() {
const string code = """
[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size = 4)]
[GenerateInterop(true)]
public unsafe partial struct BaseStruct
{
[VirtualFunction(5)]
public partial int TestFunction(int argOne, void * argTwo);
}

[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size = 8)]
[GenerateInterop]
[Inherits<BaseStruct>]
public unsafe partial struct ChildStruct<T> where T : unmanaged
{
}
""";

const string baseStructGeneratedCode = """
// <auto-generated/>
unsafe partial struct BaseStruct
{
public const int StructSize = 4;
[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]
public unsafe partial struct BaseStructVirtualTable
{
[global::System.Runtime.InteropServices.FieldOffsetAttribute(40)] public delegate* unmanaged <BaseStruct*, int, void*, int> TestFunction;
}
[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public BaseStructVirtualTable* VirtualTable;
public static partial class Delegates
{
public delegate int TestFunction(BaseStruct* thisPtr, int argOne, void* argTwo);
}
[global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
public partial int TestFunction(int argOne, void* argTwo) => VirtualTable->TestFunction((BaseStruct*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), argOne, argTwo);
}
""";

const string childStructInheritanceGeneratedCode = """
// <auto-generated/>
unsafe partial struct ChildStruct<T>
{
/// <summary>Inherited parent class accessor for <see cref="BaseStruct">BaseStruct</see></summary>
[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public BaseStruct BaseStruct;
[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]
public unsafe partial struct ChildStructVirtualTable
{
[global::System.Runtime.InteropServices.FieldOffsetAttribute(40)] public delegate* unmanaged <ChildStruct<T>*, int, void*, int> TestFunction;
}
[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public ChildStructVirtualTable* VirtualTable;
public static partial class Delegates
{
public delegate int TestFunction(ChildStruct<T>* thisPtr, int argOne, void* argTwo);
}
/// <inheritdoc cref="BaseStruct.TestFunction(int, void*)" />
/// <remarks>Method inherited from parent class <see cref="BaseStruct">BaseStruct</see>.</remarks>
[global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
public int TestFunction(int argOne, void* argTwo) => VirtualTable->TestFunction((ChildStruct<T>*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), argOne, argTwo);
}
""";

await VerifyIG.VerifyGeneratorAsync(
code,
("BaseStruct.InteropGenerator.g.cs", baseStructGeneratedCode),
("ChildStruct.Inheritance.InteropGenerator.g.cs", childStructInheritanceGeneratedCode));
}
}
3 changes: 2 additions & 1 deletion InteropGenerator/Generator/InteropGenerator.Parsing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ private static StructInfo ParseStructInfo(INamedTypeSymbol structSymbol, Attribu
fixedSizeArrays,
inheritanceInfoBuilder.ToImmutable(),
structSize,
extraInheritedStructInfo);
extraInheritedStructInfo,
structSymbol.IsGenericType);
}

private static void ParseMethods(INamedTypeSymbol structSymbol, CancellationToken token, bool isInherited,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,44 @@ private static void RenderInheritedMemberFunctions(StructInfo inheritedStruct, s
}

private static void RenderInheritedVirtualTable(StructInfo structInfo, ImmutableArray<(StructInfo inheritedStruct, string path, int offset)> resolvedInheritanceOrder, IndentedTextWriter writer) {
// write virtual function pointers from inherited structs using the child struct type as the "this" pointer
// StructLayout can't be duplicated so only write it if it hasnt been written before
if (!structInfo.HasVirtualTable())
writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]");
writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable");
using (writer.WriteBlock()) {
foreach ((StructInfo inheritedStruct, _, int offset) in resolvedInheritanceOrder) {
// only inherited structs at offset 0 are the primary inheritance chain that make up the main virtual table
if (offset != 0)
continue;
foreach (VirtualFunctionInfo virtualFunctionInfo in inheritedStruct.VirtualFunctions) {
var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {virtualFunctionInfo.MethodInfo.GetParameterTypeStringWithTrailingType()}{virtualFunctionInfo.MethodInfo.ReturnType}>";
foreach (string inheritedAttribute in virtualFunctionInfo.MethodInfo.InheritableAttributes)
writer.WriteLine(inheritedAttribute);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({virtualFunctionInfo.Index * 8})] public {functionPointerType} {virtualFunctionInfo.MethodInfo.Name};");
}
if (structInfo.IsGeneric) {
// write virtual function pointers from inherited structs using the child struct type as the "this" pointer
// StructLayout can't be duplicated so only write it if it hasnt been written before
if (!structInfo.HasVirtualTable())
writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]");
writer.WriteLine($"public unsafe partial struct {structInfo.Name[..^3]}VirtualTable");
WriteVirtualFunctions(writer);
// if the only virtual functions were inherited we need to write the vtable accessor
if (!structInfo.HasVirtualTable()) {
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name[..^3]}VirtualTable* VirtualTable;");
}
} else {
// write virtual function pointers from inherited structs using the child struct type as the "this" pointer
// StructLayout can't be duplicated so only write it if it hasnt been written before
if (!structInfo.HasVirtualTable())
writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]");
writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable");
WriteVirtualFunctions(writer);
// if the only virtual functions were inherited we need to write the vtable accessor
if (!structInfo.HasVirtualTable()) {
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;");
}
}
// if the only virtual functions were inherited we need to write the vtable accessor
if (!structInfo.HasVirtualTable()) {
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;");

void WriteVirtualFunctions(IndentedTextWriter writer) {
using (writer.WriteBlock()) {
foreach ((StructInfo inheritedStruct, _, int offset) in resolvedInheritanceOrder) {
// only inherited structs at offset 0 are the primary inheritance chain that make up the main virtual table
if (offset != 0)
continue;
foreach (VirtualFunctionInfo virtualFunctionInfo in inheritedStruct.VirtualFunctions) {
var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {virtualFunctionInfo.MethodInfo.GetParameterTypeStringWithTrailingType()}{virtualFunctionInfo.MethodInfo.ReturnType}>";
foreach (string inheritedAttribute in virtualFunctionInfo.MethodInfo.InheritableAttributes)
writer.WriteLine(inheritedAttribute);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({virtualFunctionInfo.Index * 8})] public {functionPointerType} {virtualFunctionInfo.MethodInfo.Name};");
}
}
}
}
}

Expand Down
33 changes: 23 additions & 10 deletions InteropGenerator/Generator/InteropGenerator.Rendering.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,31 @@ private static void RenderVirtualTable(StructInfo structInfo, IndentedTextWriter
} else {
writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]");
}
writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable");
using (writer.WriteBlock()) {
foreach (VirtualFunctionInfo vfi in structInfo.VirtualFunctions) {
var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {vfi.MethodInfo.GetParameterTypeStringWithTrailingType()}{vfi.MethodInfo.ReturnType}>";
foreach (string attr in vfi.MethodInfo.InheritableAttributes)
writer.WriteLine(attr);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({vfi.Index * 8})] public {functionPointerType} {vfi.MethodInfo.Name};");
if (structInfo.IsGeneric) {
writer.WriteLine($"public unsafe partial struct {structInfo.Name[..^3]}VirtualTable");
WriteVirtualFunctions(writer);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name[..^3]}VirtualTable* VirtualTable;");
if (structInfo.StaticVirtualTableSignature is not null) {
writer.WriteLine($"public static {structInfo.Name[..^3]}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name[..^3]}VirtualTable*)Addresses.StaticVirtualTable.Value;");
}
} else {
writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable");
WriteVirtualFunctions(writer);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;");
if (structInfo.StaticVirtualTableSignature is not null) {
writer.WriteLine($"public static {structInfo.Name}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name}VirtualTable*)Addresses.StaticVirtualTable.Value;");
}
}
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;");
if (structInfo.StaticVirtualTableSignature is not null) {
writer.WriteLine($"public static {structInfo.Name}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name}VirtualTable*)Addresses.StaticVirtualTable.Value;");

void WriteVirtualFunctions(IndentedTextWriter writer) {
using (writer.WriteBlock()) {
foreach (VirtualFunctionInfo vfi in structInfo.VirtualFunctions) {
var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {vfi.MethodInfo.GetParameterTypeStringWithTrailingType()}{vfi.MethodInfo.ReturnType}>";
foreach (string attr in vfi.MethodInfo.InheritableAttributes)
writer.WriteLine(attr);
writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({vfi.Index * 8})] public {functionPointerType} {vfi.MethodInfo.Name};");
}
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion InteropGenerator/Models/StructInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ internal sealed record StructInfo(
EquatableArray<FixedSizeArrayInfo> FixedSizeArrays,
EquatableArray<InheritanceInfo> InheritedStructs,
int? Size,
ExtraInheritedStructInfo? ExtraInheritedStructInfo) {
ExtraInheritedStructInfo? ExtraInheritedStructInfo,
bool IsGeneric) {
public string Name => Hierarchy[0];
public bool HasSignatures() => !MemberFunctions.IsEmpty || !StaticAddresses.IsEmpty || StaticVirtualTableSignature is not null;
public bool HasVirtualTable() => !VirtualFunctions.IsEmpty || StaticVirtualTableSignature is not null;
Expand Down