From a80f341912c49f74b1f28627a1b4a841b4da3fb2 Mon Sep 17 00:00:00 2001 From: wolfcomp <4028289+wolfcomp@users.noreply.github.com> Date: Thu, 4 Sep 2025 01:19:31 +0200 Subject: [PATCH] add generic virtual functions --- .../GenerateInteropAttributeTests.cs | 37 ++++++++++ .../Generator/InheritsAttributeTests.cs | 68 +++++++++++++++++++ .../Generator/InteropGenerator.Parsing.cs | 3 +- .../InteropGenerator.Rendering.Inheritance.cs | 55 +++++++++------ .../Generator/InteropGenerator.Rendering.cs | 33 ++++++--- InteropGenerator/Models/StructInfo.cs | 3 +- 6 files changed, 168 insertions(+), 31 deletions(-) diff --git a/InteropGenerator.Tests/Generator/GenerateInteropAttributeTests.cs b/InteropGenerator.Tests/Generator/GenerateInteropAttributeTests.cs index d0545f718f..0e65dc2cbf 100644 --- a/InteropGenerator.Tests/Generator/GenerateInteropAttributeTests.cs +++ b/InteropGenerator.Tests/Generator/GenerateInteropAttributeTests.cs @@ -149,4 +149,41 @@ await VerifyIG.VerifyGeneratorAsync( code, ("TestStruct.InnerStruct.InteropGenerator.g.cs", result)); } + + [Fact] + public async Task StructWithGeneric() { + const string code = """ + [global::System.Runtime.InteropServices.StructLayout(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size=4)] + [GenerateInterop] + public partial struct TestStruct where T : unmanaged + { + [VirtualFunction(0)] + public partial void TestFunction(); + } + """; + + const string result = """ + // + unsafe partial struct TestStruct + { + public const int StructSize = 4; + [global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)] + public unsafe partial struct TestStructVirtualTable + { + [global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public delegate* unmanaged *, void> TestFunction; + } + [global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public TestStructVirtualTable* VirtualTable; + public static partial class Delegates + { + public delegate void TestFunction(TestStruct* thisPtr); + } + [global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)] + public partial void TestFunction() => VirtualTable->TestFunction((TestStruct*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this)); + } + """; + + await VerifyIG.VerifyGeneratorAsync( + code, + ("TestStruct.InteropGenerator.g.cs", result)); + } } diff --git a/InteropGenerator.Tests/Generator/InheritsAttributeTests.cs b/InteropGenerator.Tests/Generator/InheritsAttributeTests.cs index 5c704cf4d3..f962fb76be 100644 --- a/InteropGenerator.Tests/Generator/InheritsAttributeTests.cs +++ b/InteropGenerator.Tests/Generator/InheritsAttributeTests.cs @@ -2968,4 +2968,72 @@ await VerifyIG.VerifyGeneratorAsync( code, ("ChildStruct.Inheritance.InteropGenerator.g.cs", childStructInheritanceCode)); } + + [Fact] + public async Task VirtualFunctionGenericInheritance() { + const string code = """ + [global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size = 4)] + [GenerateInterop(true)] + public unsafe partial struct BaseStruct + { + [VirtualFunction(5)] + public partial int TestFunction(int argOne, void * argTwo); + } + + [global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size = 8)] + [GenerateInterop] + [Inherits] + public unsafe partial struct ChildStruct where T : unmanaged + { + } + """; + + const string baseStructGeneratedCode = """ + // + unsafe partial struct BaseStruct + { + public const int StructSize = 4; + [global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)] + public unsafe partial struct BaseStructVirtualTable + { + [global::System.Runtime.InteropServices.FieldOffsetAttribute(40)] public delegate* unmanaged TestFunction; + } + [global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public BaseStructVirtualTable* VirtualTable; + public static partial class Delegates + { + public delegate int TestFunction(BaseStruct* thisPtr, int argOne, void* argTwo); + } + [global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)] + public partial int TestFunction(int argOne, void* argTwo) => VirtualTable->TestFunction((BaseStruct*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), argOne, argTwo); + } + """; + + const string childStructInheritanceGeneratedCode = """ + // + unsafe partial struct ChildStruct + { + /// Inherited parent class accessor for BaseStruct + [global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public BaseStruct BaseStruct; + [global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)] + public unsafe partial struct ChildStructVirtualTable + { + [global::System.Runtime.InteropServices.FieldOffsetAttribute(40)] public delegate* unmanaged *, int, void*, int> TestFunction; + } + [global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public ChildStructVirtualTable* VirtualTable; + public static partial class Delegates + { + public delegate int TestFunction(ChildStruct* thisPtr, int argOne, void* argTwo); + } + /// + /// Method inherited from parent class BaseStruct. + [global::System.Runtime.CompilerServices.MethodImplAttribute(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)] + public int TestFunction(int argOne, void* argTwo) => VirtualTable->TestFunction((ChildStruct*)global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref this), argOne, argTwo); + } + """; + + await VerifyIG.VerifyGeneratorAsync( + code, + ("BaseStruct.InteropGenerator.g.cs", baseStructGeneratedCode), + ("ChildStruct.Inheritance.InteropGenerator.g.cs", childStructInheritanceGeneratedCode)); + } } diff --git a/InteropGenerator/Generator/InteropGenerator.Parsing.cs b/InteropGenerator/Generator/InteropGenerator.Parsing.cs index 2ceaa44c79..7b8a688c81 100644 --- a/InteropGenerator/Generator/InteropGenerator.Parsing.cs +++ b/InteropGenerator/Generator/InteropGenerator.Parsing.cs @@ -103,7 +103,8 @@ private static StructInfo ParseStructInfo(INamedTypeSymbol structSymbol, Attribu fixedSizeArrays, inheritanceInfoBuilder.ToImmutable(), structSize, - extraInheritedStructInfo); + extraInheritedStructInfo, + structSymbol.IsGenericType); } private static void ParseMethods(INamedTypeSymbol structSymbol, CancellationToken token, bool isInherited, diff --git a/InteropGenerator/Generator/InteropGenerator.Rendering.Inheritance.cs b/InteropGenerator/Generator/InteropGenerator.Rendering.Inheritance.cs index 3fb9fccb3a..26843e8c16 100644 --- a/InteropGenerator/Generator/InteropGenerator.Rendering.Inheritance.cs +++ b/InteropGenerator/Generator/InteropGenerator.Rendering.Inheritance.cs @@ -186,27 +186,44 @@ private static void RenderInheritedMemberFunctions(StructInfo inheritedStruct, s } private static void RenderInheritedVirtualTable(StructInfo structInfo, ImmutableArray<(StructInfo inheritedStruct, string path, int offset)> resolvedInheritanceOrder, IndentedTextWriter writer) { - // write virtual function pointers from inherited structs using the child struct type as the "this" pointer - // StructLayout can't be duplicated so only write it if it hasnt been written before - if (!structInfo.HasVirtualTable()) - writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]"); - writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable"); - using (writer.WriteBlock()) { - foreach ((StructInfo inheritedStruct, _, int offset) in resolvedInheritanceOrder) { - // only inherited structs at offset 0 are the primary inheritance chain that make up the main virtual table - if (offset != 0) - continue; - foreach (VirtualFunctionInfo virtualFunctionInfo in inheritedStruct.VirtualFunctions) { - var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {virtualFunctionInfo.MethodInfo.GetParameterTypeStringWithTrailingType()}{virtualFunctionInfo.MethodInfo.ReturnType}>"; - foreach (string inheritedAttribute in virtualFunctionInfo.MethodInfo.InheritableAttributes) - writer.WriteLine(inheritedAttribute); - writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({virtualFunctionInfo.Index * 8})] public {functionPointerType} {virtualFunctionInfo.MethodInfo.Name};"); - } + if (structInfo.IsGeneric) { + // write virtual function pointers from inherited structs using the child struct type as the "this" pointer + // StructLayout can't be duplicated so only write it if it hasnt been written before + if (!structInfo.HasVirtualTable()) + writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]"); + writer.WriteLine($"public unsafe partial struct {structInfo.Name[..^3]}VirtualTable"); + WriteVirtualFunctions(writer); + // if the only virtual functions were inherited we need to write the vtable accessor + if (!structInfo.HasVirtualTable()) { + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name[..^3]}VirtualTable* VirtualTable;"); + } + } else { + // write virtual function pointers from inherited structs using the child struct type as the "this" pointer + // StructLayout can't be duplicated so only write it if it hasnt been written before + if (!structInfo.HasVirtualTable()) + writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]"); + writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable"); + WriteVirtualFunctions(writer); + // if the only virtual functions were inherited we need to write the vtable accessor + if (!structInfo.HasVirtualTable()) { + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;"); } } - // if the only virtual functions were inherited we need to write the vtable accessor - if (!structInfo.HasVirtualTable()) { - writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;"); + + void WriteVirtualFunctions(IndentedTextWriter writer) { + using (writer.WriteBlock()) { + foreach ((StructInfo inheritedStruct, _, int offset) in resolvedInheritanceOrder) { + // only inherited structs at offset 0 are the primary inheritance chain that make up the main virtual table + if (offset != 0) + continue; + foreach (VirtualFunctionInfo virtualFunctionInfo in inheritedStruct.VirtualFunctions) { + var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {virtualFunctionInfo.MethodInfo.GetParameterTypeStringWithTrailingType()}{virtualFunctionInfo.MethodInfo.ReturnType}>"; + foreach (string inheritedAttribute in virtualFunctionInfo.MethodInfo.InheritableAttributes) + writer.WriteLine(inheritedAttribute); + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({virtualFunctionInfo.Index * 8})] public {functionPointerType} {virtualFunctionInfo.MethodInfo.Name};"); + } + } + } } } diff --git a/InteropGenerator/Generator/InteropGenerator.Rendering.cs b/InteropGenerator/Generator/InteropGenerator.Rendering.cs index 02377cc9c5..1d7a25e163 100644 --- a/InteropGenerator/Generator/InteropGenerator.Rendering.cs +++ b/InteropGenerator/Generator/InteropGenerator.Rendering.cs @@ -132,18 +132,31 @@ private static void RenderVirtualTable(StructInfo structInfo, IndentedTextWriter } else { writer.WriteLine("[global::System.Runtime.InteropServices.StructLayoutAttribute(global::System.Runtime.InteropServices.LayoutKind.Explicit)]"); } - writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable"); - using (writer.WriteBlock()) { - foreach (VirtualFunctionInfo vfi in structInfo.VirtualFunctions) { - var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {vfi.MethodInfo.GetParameterTypeStringWithTrailingType()}{vfi.MethodInfo.ReturnType}>"; - foreach (string attr in vfi.MethodInfo.InheritableAttributes) - writer.WriteLine(attr); - writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({vfi.Index * 8})] public {functionPointerType} {vfi.MethodInfo.Name};"); + if (structInfo.IsGeneric) { + writer.WriteLine($"public unsafe partial struct {structInfo.Name[..^3]}VirtualTable"); + WriteVirtualFunctions(writer); + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name[..^3]}VirtualTable* VirtualTable;"); + if (structInfo.StaticVirtualTableSignature is not null) { + writer.WriteLine($"public static {structInfo.Name[..^3]}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name[..^3]}VirtualTable*)Addresses.StaticVirtualTable.Value;"); + } + } else { + writer.WriteLine($"public unsafe partial struct {structInfo.Name}VirtualTable"); + WriteVirtualFunctions(writer); + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;"); + if (structInfo.StaticVirtualTableSignature is not null) { + writer.WriteLine($"public static {structInfo.Name}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name}VirtualTable*)Addresses.StaticVirtualTable.Value;"); } } - writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute(0)] public {structInfo.Name}VirtualTable* VirtualTable;"); - if (structInfo.StaticVirtualTableSignature is not null) { - writer.WriteLine($"public static {structInfo.Name}VirtualTable* StaticVirtualTablePointer => ({structInfo.Name}VirtualTable*)Addresses.StaticVirtualTable.Value;"); + + void WriteVirtualFunctions(IndentedTextWriter writer) { + using (writer.WriteBlock()) { + foreach (VirtualFunctionInfo vfi in structInfo.VirtualFunctions) { + var functionPointerType = $"delegate* unmanaged <{structInfo.Name}*, {vfi.MethodInfo.GetParameterTypeStringWithTrailingType()}{vfi.MethodInfo.ReturnType}>"; + foreach (string attr in vfi.MethodInfo.InheritableAttributes) + writer.WriteLine(attr); + writer.WriteLine($"[global::System.Runtime.InteropServices.FieldOffsetAttribute({vfi.Index * 8})] public {functionPointerType} {vfi.MethodInfo.Name};"); + } + } } } diff --git a/InteropGenerator/Models/StructInfo.cs b/InteropGenerator/Models/StructInfo.cs index a69a51d54a..face3e178c 100644 --- a/InteropGenerator/Models/StructInfo.cs +++ b/InteropGenerator/Models/StructInfo.cs @@ -15,7 +15,8 @@ internal sealed record StructInfo( EquatableArray FixedSizeArrays, EquatableArray InheritedStructs, int? Size, - ExtraInheritedStructInfo? ExtraInheritedStructInfo) { + ExtraInheritedStructInfo? ExtraInheritedStructInfo, + bool IsGeneric) { public string Name => Hierarchy[0]; public bool HasSignatures() => !MemberFunctions.IsEmpty || !StaticAddresses.IsEmpty || StaticVirtualTableSignature is not null; public bool HasVirtualTable() => !VirtualFunctions.IsEmpty || StaticVirtualTableSignature is not null;