Skip to content

Commit

Permalink
Fix nested generic source generation (#1595)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomhurst authored Jan 17, 2025
1 parent 05080a3 commit a970eda
Show file tree
Hide file tree
Showing 6 changed files with 445 additions and 5 deletions.
17 changes: 17 additions & 0 deletions TUnit.Core.SourceGenerator.Tests/Bugs/1594/Hooks1594.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using TUnit.Core.SourceGenerator.CodeGenerators;

namespace TUnit.Core.SourceGenerator.Tests.Bugs._1594;

internal class Hooks1594 : TestsBase<TestHooksGenerator>
{
[Test]
public Task Test() => RunTest(Path.Combine(Git.RootDirectory.FullName,
"TUnit.TestProject",
"Bugs",
"1594",
"MyTests.cs"),
async generatedFiles =>
{
await Assert.That(generatedFiles.Length).IsEqualTo(3);
});
}
17 changes: 17 additions & 0 deletions TUnit.Core.SourceGenerator.Tests/Bugs/1594/Tests1594.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using TUnit.Core.SourceGenerator.CodeGenerators;

namespace TUnit.Core.SourceGenerator.Tests.Bugs._1594;

internal class Tests1594 : TestsBase<TestsGenerator>
{
[Test]
public Task Test() => RunTest(Path.Combine(Git.RootDirectory.FullName,
"TUnit.TestProject",
"Bugs",
"1594",
"MyTests.cs"),
async generatedFiles =>
{
await Assert.That(generatedFiles.Length).IsEqualTo(1);
});
}
187 changes: 187 additions & 0 deletions TUnit.Core.SourceGenerator.Tests/Hooks1594.Test.verified.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
[
// <auto-generated/>
#pragma warning disable
using global::System.Linq;
using global::System.Reflection;
using global::System.Runtime.CompilerServices;
using global::TUnit.Core;
using global::TUnit.Core.Hooks;
using global::TUnit.Core.Interfaces;

namespace TUnit.SourceGenerated;

[global::System.Diagnostics.StackTraceHidden]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
file partial class Hooks_MyTests : global::TUnit.Core.Interfaces.SourceGenerator.ITestHookSource
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
public static void Initialise()
{
var instance = new Hooks_MyTests();
SourceRegistrar.RegisterTestHookSource(instance);
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectBeforeEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectAfterEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectBeforeTestHooks(string sessionId)
{
return
[
new global::TUnit.Core.Hooks.InstanceHookMethod<global::TUnit.TestProject.Bugs._1594.MyTests>
{
MethodInfo = global::TUnit.Core.Helpers.MethodInfoRetriever.GetMethodInfo(typeof(global::TUnit.TestProject.Bugs._1594.MyTests), "SetupMyTests", 0, []),
Body = (classInstance, context, cancellationToken) => classInstance.SetupMyTests(),
HookExecutor = DefaultExecutor.Instance,
Order = 0,
MethodAttributes = [ new global::TUnit.Core.BeforeAttribute(global::TUnit.Core.HookType.Test)
{

} ],
ClassAttributes = [ new global::TUnit.Core.ClassDataSourceAttribute<global::TUnit.TestProject.Bugs._1594.MyFixture>()
{
Shared = global::TUnit.Core.SharedType.None,
} ],
AssemblyAttributes = [ ],
},
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectAfterTestHooks(string sessionId)
{
return
[
];
}
}


// <auto-generated/>
#pragma warning disable
using global::System.Linq;
using global::System.Reflection;
using global::System.Runtime.CompilerServices;
using global::TUnit.Core;
using global::TUnit.Core.Hooks;
using global::TUnit.Core.Interfaces;

namespace TUnit.SourceGenerated;

[global::System.Diagnostics.StackTraceHidden]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
file partial class Hooks_ParentTests : global::TUnit.Core.Interfaces.SourceGenerator.ITestHookSource
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
public static void Initialise()
{
var instance = new Hooks_ParentTests();
SourceRegistrar.RegisterTestHookSource(instance);
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectBeforeEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectAfterEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectBeforeTestHooks(string sessionId)
{
return
[
new global::TUnit.Core.Hooks.InstanceHookMethod<global::TUnit.TestProject.Bugs._1594.ParentTests<global::TUnit.TestProject.Bugs._1594.MyFixture>>
{
MethodInfo = global::TUnit.Core.Helpers.MethodInfoRetriever.GetMethodInfo(typeof(global::TUnit.TestProject.Bugs._1594.ParentTests<global::TUnit.TestProject.Bugs._1594.MyFixture>), "SetupParentTests", 0, []),
Body = (classInstance, context, cancellationToken) => classInstance.SetupParentTests(),
HookExecutor = DefaultExecutor.Instance,
Order = 0,
MethodAttributes = [ new global::TUnit.Core.BeforeAttribute(global::TUnit.Core.HookType.Test)
{

} ],
ClassAttributes = [ ],
AssemblyAttributes = [ ],
},
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectAfterTestHooks(string sessionId)
{
return
[
];
}
}


// <auto-generated/>
#pragma warning disable
using global::System.Linq;
using global::System.Reflection;
using global::System.Runtime.CompilerServices;
using global::TUnit.Core;
using global::TUnit.Core.Hooks;
using global::TUnit.Core.Interfaces;

namespace TUnit.SourceGenerated;

[global::System.Diagnostics.StackTraceHidden]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
file partial class Hooks_GrandParentTests : global::TUnit.Core.Interfaces.SourceGenerator.ITestHookSource
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
public static void Initialise()
{
var instance = new Hooks_GrandParentTests();
SourceRegistrar.RegisterTestHookSource(instance);
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectBeforeEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.StaticHookMethod<global::TUnit.Core.TestContext>> CollectAfterEveryTestHooks(string sessionId)
{
return
[
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectBeforeTestHooks(string sessionId)
{
return
[
new global::TUnit.Core.Hooks.InstanceHookMethod<global::TUnit.TestProject.Bugs._1594.GrandParentTests<global::TUnit.TestProject.Bugs._1594.MyFixture>>
{
MethodInfo = global::TUnit.Core.Helpers.MethodInfoRetriever.GetMethodInfo(typeof(global::TUnit.TestProject.Bugs._1594.GrandParentTests<global::TUnit.TestProject.Bugs._1594.MyFixture>), "SetupBase", 0, []),
Body = (classInstance, context, cancellationToken) => classInstance.SetupBase(),
HookExecutor = DefaultExecutor.Instance,
Order = 0,
MethodAttributes = [ new global::TUnit.Core.BeforeAttribute(global::TUnit.Core.HookType.Test)
{

} ],
ClassAttributes = [ ],
AssemblyAttributes = [ ],
},
];
}
public global::System.Collections.Generic.IReadOnlyList<global::TUnit.Core.Hooks.InstanceHookMethod> CollectAfterTestHooks(string sessionId)
{
return
[
];
}
}

]
111 changes: 111 additions & 0 deletions TUnit.Core.SourceGenerator.Tests/Tests1594.Test.verified.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
[
// <auto-generated/>
#pragma warning disable
using global::System.Linq;
using global::System.Reflection;
using global::TUnit.Core;
using global::TUnit.Core.Extensions;

namespace TUnit.SourceGenerated;

[global::System.Diagnostics.StackTraceHidden]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
file partial class MyTests : global::TUnit.Core.Interfaces.SourceGenerator.ITestSource
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
public static void Initialise()
{
global::TUnit.Core.SourceRegistrar.Register(new MyTests());
}
public global::System.Collections.Generic.IReadOnlyList<SourceGeneratedTestNode> CollectTests(string sessionId)
{
return Tests0(sessionId);
}
private global::System.Collections.Generic.List<SourceGeneratedTestNode> Tests0(string sessionId)
{
global::System.Collections.Generic.List<SourceGeneratedTestNode> nodes = [];
var classDataIndex = 0;
var testMethodDataIndex = 0;
try
{
var testClassType = typeof(global::TUnit.TestProject.Bugs._1594.MyTests);
var methodInfo = global::TUnit.Core.Helpers.MethodInfoRetriever.GetMethodInfo(typeof(global::TUnit.TestProject.Bugs._1594.MyTests), "Test1", 0, []);

var testBuilderContext = new global::TUnit.Core.TestBuilderContext();
var testBuilderContextAccessor = new global::TUnit.Core.TestBuilderContextAccessor(testBuilderContext);
var classArgDataGeneratorMetadata = new DataGeneratorMetadata
{
Type = TUnit.Core.Enums.DataGeneratorType.Parameters,
TestClassType = testClassType,
ParameterInfos = typeof(global::TUnit.TestProject.Bugs._1594.MyTests).GetConstructors().First().GetParameters(),
PropertyInfo = null,
TestBuilderContext = testBuilderContextAccessor,
TestSessionId = sessionId,
};
var classDataAttribute = new global::TUnit.Core.ClassDataSourceAttribute<global::TUnit.TestProject.Bugs._1594.MyFixture>()
{
Shared = global::TUnit.Core.SharedType.None,
};

var classArgGeneratedDataArray = classDataAttribute.GenerateDataSources(classArgDataGeneratorMetadata);

foreach (var classArgGeneratedDataAccessor in classArgGeneratedDataArray)
{
classDataIndex++;
var classArgGeneratedData = classArgGeneratedDataAccessor();

var resettableClassFactoryDelegate = () => new ResettableLazy<global::TUnit.TestProject.Bugs._1594.MyTests>(() =>
new global::TUnit.TestProject.Bugs._1594.MyTests(classArgGeneratedData)
, sessionId, testBuilderContext);

var resettableClassFactory = resettableClassFactoryDelegate();

nodes.Add(new TestMetadata<global::TUnit.TestProject.Bugs._1594.MyTests>
{
TestId = $"global::TUnit.Core.ClassDataSourceAttribute<global::TUnit.TestProject.Bugs._1594.MyFixture>:{classDataIndex}:CL-GAC0:TUnit.TestProject.Bugs._1594.MyTests(TUnit.TestProject.Bugs._1594.MyFixture).Test1:0",
TestClassArguments = [classArgGeneratedData],
TestMethodArguments = [],
TestClassProperties = [],
CurrentRepeatAttempt = 0,
RepeatLimit = 0,
MethodInfo = methodInfo,
ResettableClassFactory = resettableClassFactory,
TestMethodFactory = (classInstance, cancellationToken) => AsyncConvert.Convert(() => classInstance.Test1()),
TestFilePath = @"",
TestLineNumber = 15,
TestAttributes = [ new global::TUnit.Core.TestAttribute()
{

} ],
ClassAttributes = [ new global::TUnit.Core.ClassDataSourceAttribute<global::TUnit.TestProject.Bugs._1594.MyFixture>()
{
Shared = global::TUnit.Core.SharedType.None,
} ],
AssemblyAttributes = [ ],
DataAttributes = [ classDataAttribute ],
TestBuilderContext = testBuilderContext,
});
resettableClassFactory = resettableClassFactoryDelegate();
testBuilderContext = new();
testBuilderContextAccessor.Current = testBuilderContext;
}
}
catch (global::System.Exception exception)
{
nodes.Add(new global::TUnit.Core.FailedInitializationTest
{
TestId = $"global::TUnit.Core.ClassDataSourceAttribute<global::TUnit.TestProject.Bugs._1594.MyFixture>:{classDataIndex}:CL-GAC0:TUnit.TestProject.Bugs._1594.MyTests(TUnit.TestProject.Bugs._1594.MyFixture).Test1:0",
TestClass = typeof(global::TUnit.TestProject.Bugs._1594.MyTests),
ReturnType = global::TUnit.Core.Helpers.MethodInfoRetriever.GetMethodInfo(typeof(global::TUnit.TestProject.Bugs._1594.MyTests), "Test1", 0, []).ReturnType,
ParameterTypeFullNames = [],
TestName = "Test1",
TestFilePath = @"",
TestLineNumber = 15,
Exception = exception,
});
}
return nodes;
}
}

]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.CodeAnalysis;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.Extensions;

Expand All @@ -22,14 +23,43 @@ public static IEnumerable<INamedTypeSymbol> GetConstructedTypes(

foreach (var typeNode in typeNodes)
{
if (semanticModel.GetTypeInfo(typeNode).Type
is INamedTypeSymbol { IsGenericType: true } typeSymbol
&& !typeSymbol.IsGenericDefinition()
&& SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, originalGenericDefinition))
if (semanticModel.GetTypeInfo(typeNode).Type is not INamedTypeSymbol { IsGenericType: true } typeSymbol
|| typeSymbol.IsGenericDefinition())
{
continue;
}

if (SymbolEqualityComparer.Default.Equals(typeSymbol.OriginalDefinition, originalGenericDefinition))
{
yield return typeSymbol;
continue;
}

if (IsAncestor(originalGenericDefinition, typeSymbol, out var matchingTypeParameter))
{
yield return matchingTypeParameter;
}
}
}
}

private static bool IsAncestor(INamedTypeSymbol genericTypeDefinition, INamedTypeSymbol typeSymbol, [NotNullWhen(true)] out INamedTypeSymbol? foundMatch)
{
if (typeSymbol.GetBaseTypes()
.FirstOrDefault(x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, genericTypeDefinition))
is not {} matchingType)
{
foundMatch = null;
return false;
}

if (matchingType is not INamedTypeSymbol namedMatchingType)
{
foundMatch = null;
return false;
}

foundMatch = namedMatchingType;
return true;
}
}
Loading

0 comments on commit a970eda

Please sign in to comment.